import asyncio
import base64
import json
import os
import pathlib
from typing import Any, List, Optional

import aiofiles
import cv2
from ultralytics.engine.results import Results
from ultralytics.models import YOLO

from helps.aio_http_client import post_to_ollama
from helps.config import RootSetting
from helps.logger import AppLogger
from services.jobs import BaseJob

from .utils import register_job

logger = AppLogger.get_logger(__name__)


class ImageReadTask:

    def __init__(self, settings: RootSetting) -> None:
        self._settings = settings

    async def _read_cache_files(self):
        logger.info("read_cache_files from {self._settings.image_cache.source_path}")
        all_files = os.listdir(self._settings.image_cache.source_path)
        for ifile in all_files:
            if pathlib.Path(ifile).is_file():
                yield ifile

    async def read_cache_files(self):

        async for file_name in self._read_cache_files():
            file_path = pathlib.Path(file_name)
            logger.info(f"read file {file_path.name}")
            async with aiofiles.open(file_path, "rb") as fd:
                content = await fd.read()
                yield file_path, content


class YoloEDrawingTask:

    def __init__(self, settings: RootSetting) -> None:
        self._settings = settings
        self._model_name = "yoloe"
        self._model: Optional[YOLO] = None

    async def load_model(self) -> None:
        models = self._settings.models
        for model in models:
            if model.name == self._model_name:
                self._model = await asyncio.to_thread(YOLO, model.path)
                break

    async def inference(self, names: List[str], file_path: pathlib.Path) -> Any:
        if self._model is None:
            await self.load_model()

        if self._model is None:
            raise RuntimeError(f"model {self._model_name} not found")

        predict_model: YOLO = self._model

        predict_model.set_classes(names, predict_model.get_text_pe(names))
        return predict_model.predict(source=file_path)


class ImageTargetCacheTask:

    def __init__(self, settings: RootSetting) -> None:
        self._settings = settings
        self._target_path = self._settings.image_cache.target_path

    async def post(self, results: Results, file_path: pathlib.Path) -> None:
        annotated_img_np = results[0].plot()
        saved_image_path = (
            pathlib.Path(self._target_path) / f"{file_path.name}_annotated.png"
        )
        cv2.imwrite(str(saved_image_path), annotated_img_np)
        # save or upload the image


class OllamaClient:

    @classmethod
    async def post_to_ollama(cls, sub_path, payload: str, cb_fn) -> Any:

        host = cls._ollama_host or "127.0.0.1"
        port = cls._ollama_port or 13001
        rest_api = f"http://{host}:{port}/{sub_path}"
        return await post_to_ollama(payload, rest_api, cb_fn)

    @classmethod
    def setup_ollamea_client(cls, settings: RootSetting):
        cls._ollama_host = settings.ollama_client.host
        cls._ollama_port = settings.ollama_client.port


@register_job
class SpillDetectionJob(BaseJob):
    NAME = "spill_detection"

    def __init__(self, settings: RootSetting) -> None:
        super().__init__(settings)
        self._image_reader = ImageReadTask(settings)
        self._yoloe_task = YoloEDrawingTask(settings)
        self._image_reporter = ImageTargetCacheTask(settings)
        OllamaClient.setup_ollamea_client(settings)

    def pre_check_reco_image_cb(self, resp) -> bool:
        if resp.status != 200:
            logger.error(f"post_to_ollama failed {resp.status}")
            return False
        logger.info("post_to_ollama success")

        # Check the result of ollama return
        return True

    async def pre_check_reco_image(self, content: bytes) -> bool:
        logger.info("pre-check the recogonization image for the validation")
        prompt = self._settings.llm_prompt.validate_spill_prompt
        payload = {
            "prompt": prompt,
            "image": base64.b64decode(content).decode("utf-8"),
        }
        return await OllamaClient.post_to_ollama(
            "",
            json.dumps(payload),
            self.pre_check_reco_image_cb,
        )

    async def reco_spill_from_image_cb(self, resp) -> List[str]:
        if resp.status != 200:
            logger.error(f"post_to_ollama failed {resp.status}")
            return []
        logger.info("post_to_ollama success")

        # Check the result of ollama return
        return resp.json().get("data", [])

    async def reco_spill_from_image(self, content: bytes) -> List[str]:
        logger.info("recognize the spills from image")
        prompt = self._settings.llm_prompt.reco_spill_prompt
        payload = {
            "prompt": prompt,
            "image": base64.b64decode(content).decode("utf-8"),
        }
        return await OllamaClient.post_to_ollama(
            "",
            json.dumps(payload),
            self.reco_spill_from_image_cb,
        )

    async def start(self) -> None:
        while True:
            logger.info("SpillDetectionJob is running")
            async for (
                file_path,
                content,
            ) in self._image_reader.read_cache_files():
                logger.info(f"post_to_ollama {file_path.name}")
                if content:
                    try:
                        rs = await self.pre_check_reco_image(content)
                        if not rs:
                            logger.info(f"file {file_path.name} is not a spill image")
                            continue

                        rs = await self.reco_spill_from_image(content)
                        if rs and len(rs) == 0:
                            logger.info(f"file {file_path.name} is not a spill image")
                            continue
                        result = await self._yoloe_task.inference(rs, file_path)
                        if result:
                            image = result[0]
                            await self._image_reporter.post(image, file_path)
                    except Exception as e:
                        logger.error(e)
                    finally:
                        logger.info(f"remove file {file_path.name}")
                        pathlib.Path(file_path).unlink(missing_ok=True)
                else:
                    logger.error(f"file {file_path.name} is empty")

            await asyncio.sleep(1)

    async def stop(self) -> None:
        pass
