"""轻量级内存说话人注册表，用于处理 PCM 切片。

- 在内存中维护说话人 embedding 缓存。
- 给定 PCM 切片，如果相似度超过阈值则返回已有说话人 ID；
    否则分配新的 ID 并缓存。
- 设计为被 `gzzm.app` 导入以支持下游的说话人分离逻辑。

注意：当前使用一个简单的 embedding（对帧求均值和标准差）作为占位实现。
如有更好的说话人模型，请替换 `_extract_embedding` 以提取真实 embedding。
"""

import threading
from typing import List, Optional, Tuple, Union

import uuid
import os
from pathlib import Path
import json

import numpy as np
import torch
from .device_utils import choose_device

# 兼容性处理：在某些环境中 torchaudio wheel 可能缺少 list_audio_backends
try:  # pragma: no cover
    import torchaudio

    if not hasattr(torchaudio, "list_audio_backends"):
        def _dummy_backends():
            return []

        torchaudio.list_audio_backends = _dummy_backends  # type: ignore[attr-defined]
except Exception:
    torchaudio = None  # type: ignore

try:
    from speechbrain.inference import SpeakerRecognition
except Exception:
    # Backward compatibility for SpeechBrain < 1.0
    try:
        from speechbrain.pretrained import SpeakerRecognition
    except Exception as exc:  # pragma: no cover - optional dependency
        SpeakerRecognition = None
        _IMPORT_ERROR = exc
    else:
        _IMPORT_ERROR = None
else:
    _IMPORT_ERROR = None


# 简单的余弦相似度辅助函数

def _cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) + 1e-8
    return float(np.dot(a, b) / denom)


def _pcm_bytes_to_mono_float32(data: bytes, sample_rate: int, channels: int) -> Tuple[np.ndarray, int]:
    if len(data) % 2 != 0:
        raise ValueError("PCM bytes length must be even for int16")
    if sample_rate <= 0:
        raise ValueError("Invalid sample_rate")
    if channels <= 0:
        raise ValueError("Invalid channels")

    pcm = np.frombuffer(data, dtype=np.int16)
    if channels > 1:
        if pcm.size % channels != 0:
            raise ValueError("PCM size not divisible by channels")
        pcm = pcm.reshape(-1, channels).mean(axis=1).astype(np.int16)
    wav = pcm.astype(np.float32) / 32768.0
    return wav, sample_rate


def _resample_linear(wav: np.ndarray, src_sr: int, dst_sr: int = 16000) -> np.ndarray:
    if wav.size == 0:
        return np.zeros((0,), dtype=np.float32)
    if src_sr == dst_sr:
        return wav.astype(np.float32, copy=False)
    dur = wav.shape[0] / float(src_sr)
    out_len = max(1, int(round(dur * dst_sr)))
    x_old = np.linspace(0.0, dur, num=wav.shape[0], endpoint=False)
    x_new = np.linspace(0.0, dur, num=out_len, endpoint=False)
    return np.interp(x_new, x_old, wav).astype(np.float32)


_MODEL: Optional[SpeakerRecognition] = None
_MODEL_LOCK = threading.Lock()

# 读取 JSON 配置（src/gzzm_config.json），失败时使用默认值
_config_path = Path(__file__).resolve().parents[1] / "gzzm_config.json"
try:
    with _config_path.open("r", encoding="utf-8") as _f:
        _GZZM_CONFIG = json.load(_f)
except Exception:
    _GZZM_CONFIG = {}

# 默认为配置中指定的 device，优先使用 model_device_speaker，最后回退到自动检测（优先 NPU -> CUDA -> CPU）
_cfg_dev = str(_GZZM_CONFIG.get("model_device_speaker") or "").strip()
if _cfg_dev:
    _MODEL_DEVICE: str = _cfg_dev
else:
    _MODEL_DEVICE: str = choose_device("")
# 初始基准时长（秒）：首次生成 embedding 时优先采样的时长（默认 10 秒）
_INITIAL_BASELINE_SECONDS = float(_GZZM_CONFIG.get("speaker_initial_baseline_seconds", 10.0))


def _get_model(device: str = "cpu") -> SpeakerRecognition:
    if SpeakerRecognition is None:
        raise ImportError("speechbrain is required for speaker embeddings") from _IMPORT_ERROR

    global _MODEL
    global _MODEL_DEVICE
    with _MODEL_LOCK:
        # 如果尚未加载模型或请求的设备与已加载模型设备不同，则重新加载模型到目标设备。
        # 这样可以通过在构造函数中配置 model_device 来指定是否使用 GPU。
        if (_MODEL is None) or (device != _MODEL_DEVICE):
            _MODEL = SpeakerRecognition.from_hparams(
                source="speechbrain/spkrec-ecapa-voxceleb",
                run_opts={"device": device},
            )
            _MODEL_DEVICE = device
        return _MODEL


def _extract_embedding(wav: np.ndarray, sample_rate: int, device: str = "cpu") -> np.ndarray:
    """通过 SpeechBrain 提取 ECAPA embedding（spkrec-ecapa-voxceleb）。"""
    if wav.size == 0:
        return np.zeros(192, dtype=np.float32)

    wav16 = _resample_linear(wav, sample_rate, 16000)
    model = _get_model(device=device)
    tensor = torch.from_numpy(wav16).unsqueeze(0)  # 形状 (1, T)
    with torch.no_grad():
        emb = model.encode_batch(tensor.to(device))  # (1, 1, dim)
    emb_np = emb.squeeze().cpu().numpy().astype(np.float32)
    # L2 归一化
    norm = np.linalg.norm(emb_np) + 1e-8
    return emb_np / norm


def preload_speaker_model(device: Optional[str] = None) -> None:
    """Preload speaker recognition model at startup to avoid first-request latency."""
    _get_model(device=device or _MODEL_DEVICE)


class SpeakerRegistry:
    def __init__(
        self,
        similarity_threshold: float = 0.75,
        soft_margin: float = 0.3,
        *,
        model_device: str = _MODEL_DEVICE,
        # 如果为 True，则对从磁盘加载的持久化 embedding 禁用软复用（更严格）
        # 若为 None，则从 JSON 配置字段 "speaker_disable_soft_reuse" 读取
        # （兼容旧键 "disable_soft_reuse_for_persisted"）
        disable_soft_reuse_for_persisted: Optional[bool] = None,
        energy_threshold: float = 1e-3,
        min_duration_ms: int = 400,
        agg_alpha: float = 0.6,
        update_on_match: bool = True,
    ) -> None:
        """内存中说话人注册表构造函数。

        新增参数说明（仅列出新增参数的含义与推荐档位）：

        参数（新增）:
            energy_threshold: 当片段的 RMS 能量低于该阈值时视为静音并跳过。值越大越严格（更多片段被视为静音）。
            min_duration_ms: 只有时长 >= 该值（毫秒）的片段才会被用于提取 embedding。增大可减少短切片带来的噪声。
            agg_alpha: 命中后使用移动平均聚合缓存 embedding 时新 embedding 的权重，范围 0.0-1.0。
            update_on_match: 命中后是否更新缓存 embedding（启用可逐步聚合说话人特征）。

        推荐档位（示例，按实际数据/场景调整）:
            - 高速度（低延迟/更少模型调用）:
                energy_threshold=5e-3, min_duration_ms=800, agg_alpha=0.6, update_on_match=False
            - 均衡（默认，实时场景常用）:
                energy_threshold=1e-3, min_duration_ms=400, agg_alpha=0.6, update_on_match=True
            - 高精确度（更敏感，适合离线或批处理）:
                energy_threshold=5e-4, min_duration_ms=250, agg_alpha=0.3, update_on_match=True

        使用建议:
            - 实时低延迟服务优先选择“高速度”或“均衡”。
            - 做离线评估或要求较高识别质量时使用“高精确度”。
            - 在引入这些值后，请在目标环境下用小样本集合验证并微调阈值。
        """
        """内存中说话人注册表。

        similarity_threshold: 将两个 embedding 视为同一说话人的余弦相似度阈值。
        soft_margin: 在相似度略低于阈值时允许复用最佳匹配，减少在噪音或短切片场景下的 ID 分裂。
        """
        self._embs: List[np.ndarray] = []
        # _ids 可以存储 int（本地顺序 id）或 str（持久化的 uuid）
        self._ids: List[Union[int, str]] = []
        self._next_id: int = 1
        self._lock = threading.Lock()
        self._thr = similarity_threshold
        self._soft = max(0.0, soft_margin)
        # 对持久化加载的 embedding 禁用软复用可以避免单一已知 persisted id 吞并未知说话人
        # 使传入参数优先于 JSON 配置
        if disable_soft_reuse_for_persisted is None:
            try:
                cfg_path = Path(__file__).resolve().parents[1] / "gzzm_config.json"
                with cfg_path.open("r", encoding="utf-8") as _f:
                    _cfg = json.load(_f)
            except Exception:
                _cfg = {}
            disable_soft_reuse_for_persisted = bool(
                _cfg.get("speaker_disable_soft_reuse", _cfg.get("disable_soft_reuse_for_persisted", False))
            )
        self._disable_soft_for_persisted = bool(disable_soft_reuse_for_persisted)

        # 简易静音/无效片段过滤参数
        # 当帧均方根能量 (RMS) 小于 energy_threshold 时视为静音
        self._energy_thr = float(energy_threshold)
        # 仅对时长 >= min_duration_ms 的片段提取 embedding
        self._min_dur_ms = int(min_duration_ms)

        # 命中后是否使用移动平均更新缓存 embedding（减小短切片噪声影响）
        # agg_alpha 权重越大越偏向新 embedding（0.0-1.0）
        self._agg_alpha = float(agg_alpha)
        self._update_on_match = bool(update_on_match)

        # 用于控制说话人模型加载的 device（"cpu" 或 "cuda" 等）
        self._model_device = str(model_device)

        # 每个 id 对应的累计有效时长（毫秒），用于基于时长的聚合与增强策略
        self._durations: List[float] = []
        # 初始基准时长（毫秒），用于判断“第一次是否足够长”
        try:
            self._initial_baseline_ms = float(_GZZM_CONFIG.get("speaker_initial_baseline_seconds", 10.0)) * 1000.0
        except Exception:
            self._initial_baseline_ms = 10000.0

        # 持久化：用于存储 numpy embedding 的目录（仅从 JSON 配置读取）
        store = str(_GZZM_CONFIG.get("speaker_store_dir", _GZZM_CONFIG.get("speaker_store", "speaker_store")))
        self._store_dir = Path(store)
        try:
            self._store_dir.mkdir(parents=True, exist_ok=True)
        except Exception:
            # 仅尽力而为（失败不抛出）
            pass
        # 从持久化存储加载到内存的 id 集合（uuid 字符串）
        self._persisted_ids = set()

    def register_or_match(
        self,
        pcm_bytes: bytes,
        sample_rate: int,
        channels: int = 1,
        device: Optional[str] = None,
        return_sim: bool = False,
    ) -> int | tuple[int, float]:
        """如果相似度超过阈值则返回已有说话人 ID，否则新增并返回新 ID。

        参数：
            pcm_bytes: 原始 PCM int16 字节流。
            sample_rate: PCM 的采样率。
            channels: 通道数。
            device: 用于说话人模型的 torch 设备字符串（"cpu" | "cuda" | "cuda:0" ...）。
            return_sim: 如果为 True，同时返回最佳相似度分数。
        """
        wav, sr = _pcm_bytes_to_mono_float32(pcm_bytes, sample_rate, channels)

        # 使用实例级 model_device 作为默认设备，除非显式传入 device 参数
        device = device or self._model_device

        # 简易静音与最小时长过滤：先计算 RMS 与时长，短/静音片段不参与注册
        if wav.size == 0:
            rms = 0.0
            dur_ms = 0.0
        else:
            rms = float(np.sqrt(np.mean(np.square(wav), axis=0)))
            dur_ms = float(wav.shape[0]) / float(sr) * 1000.0

        if (rms < self._energy_thr) or (dur_ms < float(self._min_dur_ms)):
            # 返回 -1 表示被视作静音/无效片段（调用方可据此跳过）
            return (-1, 0.0) if return_sim else -1

        emb = _extract_embedding(wav, sr, device=device)
        # 如果模型返回接近零向量，也视作无效并跳过
        if np.linalg.norm(emb) < 1e-6:
            return (-1, 0.0) if return_sim else -1

        with self._lock:
            best_id: Optional[Union[int, str]] = None
            best_sim = -1.0
            best_idx: Optional[int] = None
            for idx, (cached_id, cached_emb) in enumerate(zip(self._ids, self._embs)):
                sim = _cosine_sim(emb, cached_emb)
                if sim > best_sim:
                    best_sim = sim
                    best_id = cached_id
                    best_idx = idx

            # 命中（硬阈值）
            if best_id is not None and best_sim >= self._thr:
                # 命中后使用基于时长的聚合权重更新缓存 embedding（长期片段更重要）
                if self._update_on_match and best_idx is not None:
                    stored_dur = self._durations[best_idx] if best_idx < len(self._durations) else 0.0
                    if stored_dur > 0 and dur_ms > 0:
                        a = float(dur_ms) / (dur_ms + stored_dur)
                        a = max(0.05, min(a, 0.95))
                    else:
                        a = max(0.0, min(1.0, self._agg_alpha))
                    new_emb = (a * emb) + ((1.0 - a) * self._embs[best_idx])
                    norm = np.linalg.norm(new_emb) + 1e-8
                    self._embs[best_idx] = (new_emb / norm).astype(np.float32)
                    # 累加记录时长（毫秒）
                    try:
                        self._durations[best_idx] = (self._durations[best_idx] if best_idx < len(self._durations) else 0.0) + dur_ms
                    except Exception:
                        # 容错，确保长度一致
                        if best_idx < len(self._durations):
                            self._durations[best_idx] = dur_ms
                        else:
                            self._durations.append(dur_ms)
                return (best_id, best_sim) if return_sim else best_id

            # 软复用（收紧策略）
            effective_soft = self._soft * 0.5
            can_soft_reuse = True
            if self._disable_soft_for_persisted and (best_id in self._persisted_ids):
                can_soft_reuse = False

            if best_id is not None and can_soft_reuse and best_sim >= (self._thr - effective_soft):
                if self._update_on_match and best_idx is not None:
                    stored_dur = self._durations[best_idx] if best_idx < len(self._durations) else 0.0
                    if stored_dur > 0 and dur_ms > 0:
                        a = float(dur_ms) / (dur_ms + stored_dur)
                        a = max(0.05, min(a, 0.95))
                    else:
                        a = max(0.0, min(1.0, self._agg_alpha))
                    new_emb = (a * emb) + ((1.0 - a) * self._embs[best_idx])
                    norm = np.linalg.norm(new_emb) + 1e-8
                    self._embs[best_idx] = (new_emb / norm).astype(np.float32)
                    try:
                        self._durations[best_idx] = (self._durations[best_idx] if best_idx < len(self._durations) else 0.0) + dur_ms
                    except Exception:
                        if best_idx < len(self._durations):
                            self._durations[best_idx] = dur_ms
                        else:
                            self._durations.append(dur_ms)
                return (best_id, best_sim) if return_sim else best_id

            # 如果当前最佳候选是先前由很短片段创建，而本次片段很长（>= 初始基准），
            # 则尝试用长片段增强（合并）已有 embedding 而不是直接创建新 id。
            if best_id is not None and best_idx is not None:
                stored_dur = self._durations[best_idx] if best_idx < len(self._durations) else 0.0
                min_enhance_sim = max(0.2, self._thr - (self._soft * 2))
                if dur_ms >= self._initial_baseline_ms and stored_dur < dur_ms and best_sim >= min_enhance_sim:
                    if self._update_on_match:
                        a = float(dur_ms) / (dur_ms + stored_dur) if stored_dur > 0 else 1.0
                        a = max(0.05, min(a, 0.95))
                        new_emb = (a * emb) + ((1.0 - a) * self._embs[best_idx])
                        norm = np.linalg.norm(new_emb) + 1e-8
                        self._embs[best_idx] = (new_emb / norm).astype(np.float32)
                        try:
                            self._durations[best_idx] = (self._durations[best_idx] if best_idx < len(self._durations) else 0.0) + dur_ms
                        except Exception:
                            if best_idx < len(self._durations):
                                self._durations[best_idx] = dur_ms
                            else:
                                self._durations.append(dur_ms)
                    return (best_id, best_sim) if return_sim else best_id

            # 创建本地顺序 id（整数）并记录时长
            new_id = self._next_id
            self._next_id += 1
            self._ids.append(new_id)
            self._embs.append(emb)
            try:
                self._durations.append(dur_ms)
            except Exception:
                self._durations = self._durations + [dur_ms]
            return (new_id, best_sim) if return_sim else new_id

    def persist_embedding(self, emb: np.ndarray) -> str:
        """将 embedding 持久化到磁盘并返回其 UUID 字符串。"""
        uid = str(uuid.uuid4())
        path = self._store_dir / f"{uid}.npy"
        try:
            np.save(path, emb)
        except Exception:
            # 如有需要，将错误传播给调用方
            raise
        return uid

    def load_persisted(self, uid: str) -> bool:
        """通过 UUID 将持久化的 embedding 加载到内存缓存；成功则返回 True。"""
        if not uid:
            return False
        path = self._store_dir / f"{uid}.npy"
        if not path.exists():
            return False
        try:
            emb = np.load(path)
        except Exception:
            return False
        with self._lock:
            # 避免重复加载
            if uid in self._ids:
                return True
            self._ids.append(uid)
            # 确保为 float32 并已归一化
            emb = emb.astype(np.float32)
            norm = np.linalg.norm(emb) + 1e-8
            if norm != 0:
                emb = emb / norm
            self._embs.append(emb)
            self._persisted_ids.add(uid)
            # 对持久化 embedding 赋予可信的初始时长（以便后续聚合权重判断）
            try:
                self._durations.append(self._initial_baseline_ms)
            except Exception:
                self._durations.append(10000.0)
        return True

    def load_persisted_bulk(self, uids: List[str]) -> List[str]:
        """尝试加载多个持久化 id；返回成功加载的 id 列表。"""
        loaded = []
        for uid in uids or []:
            try:
                ok = self.load_persisted(uid)
            except Exception:
                ok = False
            if ok:
                loaded.append(uid)
        return loaded

    def persist_from_pcm(
        self,
        pcm_bytes: bytes,
        sample_rate: int,
        channels: int = 1,
        device: Optional[str] = None,
    ) -> str:
        """从原始 PCM 字节提取 embedding，持久化到磁盘，加载到内存缓存，并返回 UUID。"""
        wav, sr = _pcm_bytes_to_mono_float32(pcm_bytes, sample_rate, channels)
        device = device or self._model_device
        emb = _extract_embedding(wav, sr, device=device)
        uid = self.persist_embedding(emb)
        # 加载到内存缓存以便立即使用
        try:
            self.load_persisted(uid)
        except Exception:
            pass
        return uid

    def clear(self) -> None:
        with self._lock:
            self._embs.clear()
            self._ids.clear()
            self._next_id = 1
            try:
                self._durations.clear()
            except Exception:
                pass

    def stats(self) -> dict:
        with self._lock:
            return {
                "count": len(self._ids),
                "ids": list(self._ids),
                "threshold": self._thr,
            }


__all__ = ["SpeakerRegistry", "preload_speaker_model"]
