# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import math
import random
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union

import cv2
from functools import lru_cache
import numpy as np
from omegaconf import OmegaConf
from PIL import Image, ImageDraw, ImageFont

from .download_file import DownloadFile, DownloadFileInput
from .load_image import LoadImage
from .log import logger
from .typings import LangRec

root_dir = Path(__file__).resolve().parent.parent
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]

DEFAULT_FONT_DIR = root_dir / "models"
DEFAULT_FONT_PATH = DEFAULT_FONT_DIR / "FZYTK.TTF"
FONT_YAML_PATH = root_dir / "default_models.yaml"


@lru_cache()
def font_cfg():
    return OmegaConf.load(FONT_YAML_PATH).fonts


@lru_cache()
def get_font_path(
    font_path: Optional[Union[str, Path]] = None, lang_type: Optional[LangRec] = None
) -> str:
    default_info = font_cfg()["ch"]
    default_input_params = DownloadFileInput(
        file_url=default_info["path"],
        sha256=default_info["SHA256"],
        save_path=DEFAULT_FONT_PATH,
        logger=logger,
    )

    if lang_type is None:
        # 没有指定语种，用默认字体文件
        DownloadFile.run(default_input_params)
        return str(DEFAULT_FONT_PATH)

    lang_type = lang_type.value

    if font_path is None:
        # 指定了语种，但是没有指定字体文件，根据语种选择字体文件
        font_info = font_cfg().get(lang_type, None)
        if font_info is None:
            logger.warning(
                "Font file for %s is not found in the supported font list. Default font file will be used.",
                lang_type,
            )

            DownloadFile.run(default_input_params)
            return str(DEFAULT_FONT_PATH)
        
        font_url, font_sha256 = font_info["path"], font_info["SHA256"]
        save_font_path = DEFAULT_FONT_DIR / f"{Path(font_url).name}"
        input_param = DownloadFileInput(
            file_url=font_url,
            sha256=font_sha256,
            save_path=save_font_path,
            logger=logger,
            verbose=False,
        )
        DownloadFile.run(input_param)
        return str(save_font_path)

    return str(font_path)


class VisRes:
    load_img = LoadImage()

    def __init__(
        self,
        text_score: float = 0.5,
        lang_type: Optional[LangRec] = None,
        font_path: Optional[str] = None,
    ):
        self.text_score = text_score
        self._font_path = font_path
        self._lang_type = lang_type

    @property
    def font_path(self):
        return get_font_path(self._font_path, self._lang_type)

    def __call__(
        self,
        img_content: InputType,
        dt_boxes: np.ndarray,
        txts: Optional[Union[List[str], Tuple[str]]] = None,
        scores: Optional[List[float]] = None,
    ) -> np.ndarray:
        logger.info(f"Using {self.font_path} to visualize results.")

        if txts is None:
            return self.draw_dt_boxes(img_content, dt_boxes, scores)
        return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores)

    def draw_dt_boxes(
        self,
        img_content: InputType,
        dt_boxes: np.ndarray,
        scores: Optional[List[float]] = None,
    ) -> np.ndarray:
        img = self.load_img(img_content)
        if scores is None:
            scores = [1.0] * len(dt_boxes)

        for idx, (box, score) in enumerate(zip(dt_boxes, scores)):
            color = self.get_random_color()

            points = np.array(box)
            cv2.polylines(img, np.int32([points]), 1, color=color, thickness=1)

            start_point = round(points[0][0]), round(points[0][1])
            cv2.putText(
                img,
                f"{idx}:{score:.3f}",
                start_point,
                cv2.FONT_HERSHEY_SIMPLEX,
                1,
                color,
                3,
            )
        return img

    def draw_rec_res(
        self,
        imgs: Sequence[InputType],
        txts: Union[List[str], Tuple[str]],
        scores: List[float],
    ) -> np.ndarray:
        result_imgs = []
        for img, txt, score in zip(imgs, txts, scores):
            vis_img = self.draw_one_rec_res(img, txt, score)
            result_imgs.append(vis_img)
        return self.concat_imgs(result_imgs, direction="vertical")

    def draw_one_rec_res(
        self, img_content: InputType, txt: str, score: float
    ) -> np.ndarray:
        image = Image.fromarray(self.load_img(img_content))
        h, w = image.height, image.width
        if image.mode == "L":
            image = image.convert("RGB")

        img_left = image.copy()
        img_right = Image.new("RGB", (w, h), (255, 255, 255))
        draw_right = ImageDraw.Draw(img_right)
        box = [[0, 0], [w, 0], [w, h], [0, h]]

        box_height = self.get_box_height(box)
        box_width = self.get_box_width(box)
        if box_height > 2 * box_width:
            font_size = max(int(box_width * 0.9), 10)
            font = ImageFont.truetype(self.font_path, font_size, encoding="utf-8")
            cur_y = box[0][1]

            for c in txt:
                draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
                cur_y += self.get_char_size(font, c)
        else:
            font_size = max(int(box_height * 0.8), 10)
            font = ImageFont.truetype(self.font_path, font_size, encoding="utf-8")
            draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)

        img_left = Image.blend(image, img_left, 0.5)
        img_show = Image.new("RGB", (w * 2, h), 255)
        img_show.paste(img_left, (0, 0, w, h))
        img_show.paste(img_right, (w, 0, w * 2, h))
        return np.array(img_show)

    def draw_ocr_box_txt(
        self,
        img_content: InputType,
        dt_boxes: np.ndarray,
        txts: Union[List[str], Tuple[str]],
        scores: Optional[List[float]] = None,
    ) -> np.ndarray:
        image = Image.fromarray(self.load_img(img_content))
        h, w = image.height, image.width
        if image.mode == "L":
            image = image.convert("RGB")

        img_left = image.copy()
        img_right = Image.new("RGB", (w, h), (255, 255, 255))

        random.seed(0)
        draw_left = ImageDraw.Draw(img_left)
        draw_right = ImageDraw.Draw(img_right)
        for idx, (box, txt) in enumerate(zip(dt_boxes, txts)):
            if scores is not None and float(scores[idx]) < self.text_score:
                continue

            color = self.get_random_color()

            box_list = np.array(box).reshape(8).tolist()
            draw_left.polygon(box_list, fill=color)
            draw_right.polygon(box_list, outline=color)

            box_height = self.get_box_height(box)
            box_width = self.get_box_width(box)
            if box_height > 2 * box_width:
                font_size = max(int(box_width * 0.9), 10)
                font = ImageFont.truetype(self.font_path, font_size, encoding="utf-8")
                cur_y = box[0][1]

                for c in txt:
                    draw_right.text(
                        (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font
                    )
                    cur_y += self.get_char_size(font, c)
            else:
                font_size = max(int(box_height * 0.8), 10)
                font = ImageFont.truetype(self.font_path, font_size, encoding="utf-8")
                draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)

        img_left = Image.blend(image, img_left, 0.5)
        img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
        img_show.paste(img_left, (0, 0, w, h))
        img_show.paste(img_right, (w, 0, w * 2, h))
        return np.array(img_show)

    @staticmethod
    def get_random_color() -> Tuple[int, int, int]:
        return (
            random.randint(0, 255),
            random.randint(0, 255),
            random.randint(0, 255),
        )

    @staticmethod
    def get_box_height(box: List[List[Union[float, int]]]) -> float:
        return math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)

    @staticmethod
    def get_box_width(box: List[List[Union[float, int]]]) -> float:
        return math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)

    @staticmethod
    def get_char_size(font, char_str: str) -> float:
        # compatible with Pillow v9 and v10.
        if hasattr(font, "getsize"):
            get_size_func = getattr(font, "getsize")
            return get_size_func(char_str)[1]

        if hasattr(font, "getlength"):
            get_size_func = getattr(font, "getlength")
            return get_size_func(char_str)

        raise ValueError(
            "The Pillow ImageFont instance has not getsize or getlength func."
        )

    @staticmethod
    def concat_imgs(
        imgs: List[np.ndarray], direction: str = "horizontal"
    ) -> np.ndarray:
        img_list = [Image.fromarray(img_path) for img_path in imgs]

        img_sizes = np.array([v.size for v in img_list])

        if direction == "horizontal":
            width = np.sum(img_sizes[:, 0])
            height = np.max(img_sizes[:, 1])
        elif direction == "vertical":
            width = np.max(img_sizes[:, 0])
            height = np.sum(img_sizes[:, 1])
        else:
            raise ValueError(f"{direction} is not supported.")

        new_image = Image.new("RGB", (width, height), color="white")

        for i, img in enumerate(img_list):
            if i == 0:
                new_image.paste(img, (0, 0))
                continue

            if direction == "horizontal":
                x = np.sum(img_sizes[:, 0][:i])
                new_image.paste(img, (x, 0))
            elif direction == "vertical":
                y = np.sum(img_sizes[:, 1][:i])
                new_image.paste(img, (0, y))

        return np.array(new_image)
