"""模型加载与全局单例管理。

职责概览：
- 按配置和运行环境选择 ASR 设备（cpu/cuda/npu）。
- 初始化并缓存全局 Qwen3ASR 模型实例。
- 对外提供 `init_model` 与 `get_model` 供接口层复用。
"""

import os
import time
import logging
from typing import Optional

import transformers
import torch
from qwen_asr import Qwen3ASRModel
import json
from pathlib import Path
from .utils.device_utils import choose_device
from .utils.speaker_id import preload_speaker_model

_config_path_top = Path(__file__).resolve().parents[1] / "gzzm_config.json"
try:
    with _config_path_top.open("r", encoding="utf-8") as _f:
        _GZZM_CONFIG_TOP = json.load(_f)
except Exception:
    _GZZM_CONFIG_TOP = {}

LOG_LEVEL = str(_GZZM_CONFIG_TOP.get("log_level", "INFO")).upper()
logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
logger = logging.getLogger(__name__)
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))


MODEL: Optional[Qwen3ASRModel] = None
_SPEAKER_MODEL_PRELOADED = False
_STARTUP_BANNER_PRINTED = False


def _read_startup_banner() -> str:
    banner_path = Path(__file__).resolve().parent / "brand" / "gzzm_ascii.txt"
    try:
        return banner_path.read_text(encoding="utf-8").rstrip("\n")
    except Exception:
        return ""


def _render_startup_banner(total_seconds: float) -> None:
    banner = _read_startup_banner()
    if not banner:
        return
    color_enabled = str(_GZZM_CONFIG_TOP.get("gzzm_banner_color", "1")).strip().lower() in ("1", "true", "yes", "on")
    footer = f"startup_time: {total_seconds:.2f}s"
    if color_enabled:
        # 支持 24-bit ANSI truecolor，通过环境变量 GZZM_BANNER_HEX 指定十六进制颜色（如 #3a96dd）
        hex_code = str(_GZZM_CONFIG_TOP.get("gzzm_banner_hex", "#3a96dd") or "#3a96dd").strip()
        try:
            if hex_code.startswith("#"):
                hex_code = hex_code[1:]
            if len(hex_code) != 6:
                raise ValueError("invalid hex")
            r = int(hex_code[0:2], 16)
            g = int(hex_code[2:4], 16)
            b = int(hex_code[4:6], 16)
            color_escape = f"\033[38;2;{r};{g};{b}m"
        except Exception:
            # 回退到浅蓝色（bright cyan）以保证兼容性
            color_escape = "\033[96m"
        reset = "\033[0m"
        # 只为 ASCII art 上色，footer 保持默认颜色
        logger.info("%s", f"{color_escape}\n{banner}{reset}\n{footer}")
    else:
        logger.info("\n%s\n%s", banner, footer)


def _parse_bool(value: object, default: bool) -> bool:
    if value is None:
        return default
    return str(value).strip().lower() in ("1", "true", "yes", "on")


def init_model() -> Qwen3ASRModel:
    """加载并返回 Qwen3 ASR 模型到模块全局变量 `MODEL`。"""
    global MODEL
    global _SPEAKER_MODEL_PRELOADED
    global _STARTUP_BANNER_PRINTED

    init_started_at = time.perf_counter()

    # 读取外部 JSON 配置（位于 src/gzzm_config.json）
    _config_path = Path(__file__).resolve().parents[1] / "gzzm_config.json"
    try:
        with _config_path.open("r", encoding="utf-8") as _f:
            _GZZM_CONFIG = json.load(_f)
    except Exception:
        _GZZM_CONFIG = {}

    preload_speaker_enabled = _parse_bool(_GZZM_CONFIG.get("preload_speaker_model", True), default=True)

    # 仅从 JSON 配置读取说话人模型设备（`model_device_speaker`），不再读取环境变量。
    cfg_speaker_device = str(_GZZM_CONFIG.get("model_device_speaker") or "").strip()
    speaker_primary = choose_device(cfg_speaker_device)
    if speaker_primary == "npu":
        speaker_device = "npu:0"
    elif speaker_primary == "cuda":
        speaker_device = "cuda:0"
    else:
        speaker_device = "cpu"

    if MODEL is not None:
        if preload_speaker_enabled and not _SPEAKER_MODEL_PRELOADED:
            try:
                preload_speaker_model(device=speaker_device)
                _SPEAKER_MODEL_PRELOADED = True
                logger.info("Speaker model preloaded on startup, device=%s", speaker_device)
            except ImportError:
                logger.warning("Speaker model preload skipped: speechbrain is not available")
            except Exception as e:
                logger.warning("Speaker model preload failed, fallback to lazy load: %s", e)
        if _SPEAKER_MODEL_PRELOADED and not _STARTUP_BANNER_PRINTED:
            _render_startup_banner(time.perf_counter() - init_started_at)
            _STARTUP_BANNER_PRINTED = True
        return MODEL

    model_name_asr = str(_GZZM_CONFIG.get("model_name_asr", "Qwen/Qwen3-ASR-0.6B")).strip()
    forced_aligner_model = str(_GZZM_CONFIG.get("forced_aligner_model", "Qwen/Qwen3-ForcedAligner-0.6B")).strip()

    logger.info("【开始加载 ASR 模型】model=%s aligner=%s", model_name_asr, forced_aligner_model)
    start_load = time.time()

    # 支持通过环境变量或 JSON 配置强制使用设备：优先使用 JSON 中的 model_device_asr
    config_device = str(_GZZM_CONFIG.get("model_device_asr") or "").lower()
    device_hint = config_device
    # 统一选择设备（cpu / cuda / npu）
    selected_device = choose_device(device_hint)


    # 从配置读取可调参数
    max_inference_batch_size = int(_GZZM_CONFIG.get("max_inference_batch_size", 16))
    max_new_tokens = int(_GZZM_CONFIG.get("max_new_tokens", 256))

    if selected_device == "cpu":
        logger.info("未检测到可用 CUDA/NPU 或已指定 CPU，使用 CPU 加载模型（可能较慢）")
        device_map = "cpu"
        dtype = torch.float32
        forced_aligner_kwargs = dict(
            dtype=torch.float32,
            device_map="cpu",
        )
    elif selected_device == "cuda":
        device_map = "cuda"
        dtype = torch.float16
        forced_aligner_kwargs = dict(
            dtype=torch.bfloat16,
            device_map="cuda:0",
        )
    else:  # npu
        device_map = "npu"
        dtype = torch.float16
        forced_aligner_kwargs = dict(
            dtype=torch.float32,
            device_map="npu:0",
        )

    MODEL = Qwen3ASRModel.from_pretrained(
        model_name_asr,
        cache_dir=_GZZM_CONFIG.get("hf_cache_dir", "/app/hf_cache"),
        torch_dtype=dtype,
        device_map=device_map,
        low_cpu_mem_usage=True,
        use_safetensors=True,
        max_inference_batch_size=max_inference_batch_size,
        max_new_tokens=max_new_tokens,
        forced_aligner=forced_aligner_model,
        forced_aligner_kwargs=forced_aligner_kwargs,
    )

    # 将 transformers 的日志级别设置为 ERROR，避免过多无关信息干扰日志输出
    transformers.logging.set_verbosity_error()

    load_time = time.time() - start_load
    logger.info(f"✅ 模型加载完成！耗时: {load_time:.2f} 秒")
    logger.info(f"forced_aligner: {getattr(MODEL, 'forced_aligner', None)}")

    if preload_speaker_enabled and not _SPEAKER_MODEL_PRELOADED:
        try:
            preload_speaker_model(device=speaker_device)
            _SPEAKER_MODEL_PRELOADED = True
            logger.info("Speaker model preloaded on startup, device=%s", speaker_device)
        except ImportError:
            logger.warning("Speaker model preload skipped: speechbrain is not available")
        except Exception as e:
            logger.warning("Speaker model preload failed, fallback to lazy load: %s", e)

    if _SPEAKER_MODEL_PRELOADED and not _STARTUP_BANNER_PRINTED:
        _render_startup_banner(time.perf_counter() - init_started_at)
        _STARTUP_BANNER_PRINTED = True

    return MODEL


def get_model() -> Optional[Qwen3ASRModel]:
    return MODEL
