# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import time
from typing import Any, Dict

import numpy as np

from rapidocr.inference_engine.base import get_engine

from .utils import DBPostProcess, DetPreProcess, TextDetOutput

_BOX_SORT_Y_THRESHOLD = 10


class TextDetector:
    def __init__(self, cfg: Dict[str, Any]):
        self.limit_side_len = cfg.get("limit_side_len")
        self.limit_type = cfg.get("limit_type")
        self.mean = cfg.get("mean")
        self.std = cfg.get("std")
        self.preprocess_op = None

        post_process = {
            "thresh": cfg.get("thresh", 0.3),
            "box_thresh": cfg.get("box_thresh", 0.5),
            "max_candidates": cfg.get("max_candidates", 1000),
            "unclip_ratio": cfg.get("unclip_ratio", 1.6),
            "use_dilation": cfg.get("use_dilation", True),
            "score_mode": cfg.get("score_mode", "fast"),
        }
        self.postprocess_op = DBPostProcess(**post_process)

        self.session = get_engine(cfg.engine_type)(cfg)

    def __call__(self, img: np.ndarray) -> TextDetOutput:
        start_time = time.perf_counter()

        if img is None:
            raise ValueError("img is None")

        ori_img_shape = img.shape[0], img.shape[1]
        self.preprocess_op = self.get_preprocess(max(img.shape[0], img.shape[1]))
        prepro_img = self.preprocess_op(img)
        if prepro_img is None:
            return TextDetOutput()

        preds = self.session(prepro_img)
        boxes, scores = self.postprocess_op(preds, ori_img_shape)
        if len(boxes) < 1:
            return TextDetOutput()

        boxes = self.sorted_boxes(boxes)
        elapse = time.perf_counter() - start_time
        return TextDetOutput(img, boxes, scores, elapse=elapse)

    def get_preprocess(self, max_wh: int) -> DetPreProcess:
        if self.limit_type == "min":
            limit_side_len = self.limit_side_len
        elif max_wh < 960:
            limit_side_len = 960
        elif max_wh < 1500:
            limit_side_len = 1500
        else:
            limit_side_len = 2000
        return DetPreProcess(limit_side_len, self.limit_type, self.mean, self.std)

    @staticmethod
    def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray:
        """
        Equivalent NumPy implementation of the original bubble-adjusted sort.
        """
        if len(dt_boxes) == 0:
            return dt_boxes

        # Step 1: Stable sort by y (top to bottom)
        y_coords = dt_boxes[:, 0, 1]
        y_order = np.argsort(y_coords, kind="stable")
        boxes_y_sorted = dt_boxes[y_order]
        y_sorted = y_coords[y_order]

        # Step 2: Assign line IDs based on adjacent y differences
        dy = np.diff(y_sorted)
        line_increments = (dy >= _BOX_SORT_Y_THRESHOLD).astype(np.int32)
        line_ids = np.concatenate([[0], np.cumsum(line_increments)])

        # Now, within each line_id group, sort by x (left to right)
        x_coords = boxes_y_sorted[:, 0, 0]
        final_order_in_y_sorted = np.lexsort((x_coords, line_ids))

        return boxes_y_sorted[final_order_in_y_sorted]
