"""统一转写流水线服务。

职责概览：
- 对单个音频 chunk 执行 ASR、说话人匹配和时间偏移叠加。
- 提供 HTTP 文件路径下的分段读取能力，供 /transcribe 顺序消费。
"""

import tempfile
from typing import Dict, List

from pydub import AudioSegment

from ..utils.audio_utils import split_audio_to_chunks
from ..utils.speaker_id import SpeakerRegistry
from ..transcribe import transcribe_audio_file
from .speaker_matcher import audio_segment_to_pcm_bytes, identify_speaker_id_from_pcm

_SENT_END = set(list("。.!?！？"))


def transcribe_segment(segment: AudioSegment, chunk_seconds: int) -> Dict:
    # 统一走临时 wav 文件路径，复用 transcribe_audio_file 的容错和时间戳逻辑。
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
        segment.export(tmp.name, format="wav")
        path = tmp.name
    try:
        return transcribe_audio_file(path, chunk_seconds, fail_on_empty=False)
    finally:
        try:
            import os

            os.remove(path)
        except Exception:
            pass


def process_chunk_pipeline(
    segment: AudioSegment,
    offset_sec: float,
    chunk_seconds: int,
    speaker_registry: SpeakerRegistry,
    speaker_device: str,
) -> Dict:
    """Run ASR -> chunk-level speaker matching for one chunk."""
    # 单个 chunk 的标准流水线：
    # 1) ASR 拿到句级 time_stamps
    # 2) 整个 chunk 只做一次说话人匹配
    # 3) 将同一个 rl/sim 回填到该 chunk 的句级 time_stamps
    # 4) 输出带绝对时间（offset_sec）的结构化结果
    sample_rate = int(getattr(segment, "frame_rate", 16000) or 16000)
    channels = int(getattr(segment, "channels", 1) or 1)

    transcribed = transcribe_segment(segment, chunk_seconds)
    language = transcribed.get("language")
    text = (transcribed.get("text") or "").strip()
    chunk_ts = transcribed.get("time_stamps") or []

    # 按句内标点把句子切成更小片段，再对每个小片段做声纹匹配（回退到 chunk 级匹配当 token 不可用）。
    token_ts = transcribed.get("time_stamps_tokens") or []
    enriched: List[Dict] = []

    # 内部用于分割的标点（逗号/顿号等），句末标点由 _SENT_END 处理。
    INTERNAL_SPLIT_PUNCTS = set(list("，,、;；:：-—"))

    if not token_ts:
        # 回退到原来整 chunk 一次匹配
        chunk_pcm_bytes = audio_segment_to_pcm_bytes(segment)
        chunk_rl, chunk_sim = identify_speaker_id_from_pcm(
            pcm_bytes=chunk_pcm_bytes,
            sample_rate=sample_rate,
            channels=channels,
            registry=speaker_registry,
            device=speaker_device,
        )
        for sentence in chunk_ts:
            try:
                sent_start = float(sentence.get("start", 0.0) or 0.0)
                sent_end = float(sentence.get("end", sent_start) or sent_start)
            except Exception:
                sent_start = 0.0
                sent_end = 0.0
            sent_text = sentence.get("text", "") or ""
            sentence_end = sentence.get("sentence_end")
            if sentence_end is None:
                sentence_end = bool(sent_text and sent_text[-1] in _SENT_END)

            enriched.append(
                {
                    "start": round(offset_sec + sent_start, 3),
                    "end": round(offset_sec + sent_end, 3),
                    "text": sent_text,
                    "sentence_end": bool(sentence_end),
                    "rl": chunk_rl,
                    "sim": round(chunk_sim, 3) if chunk_sim is not None else None,
                }
            )
    else:
        # 按句（chunk_ts）为单位，从 token 列表中提取所在 token 片段，再按句内标点拆分并匹配
        for sentence in chunk_ts:
            try:
                s_start = float(sentence.get("start", 0.0) or 0.0)
                s_end = float(sentence.get("end", s_start) or s_start)
            except Exception:
                s_start = 0.0
                s_end = 0.0
            sent_text = sentence.get("text", "") or ""
            sentence_end_flag = sentence.get("sentence_end")
            if sentence_end_flag is None:
                sentence_end_flag = bool(sent_text and sent_text[-1] in _SENT_END)

            # 取属于该句的 token（考虑数值精度，允许小范围误差）
            tokens_in_sentence = [
                t
                for t in token_ts
                if float(t.get("start", 0.0) or 0.0) >= s_start - 1e-3
                and float(t.get("end", 0.0) or 0.0) <= s_end + 1e-3
            ]

            if not tokens_in_sentence:
                # 无 token 信息，退回到句级匹配
                start_ms = max(0, int(round(s_start * 1000)))
                end_ms = min(len(segment), int(round(s_end * 1000)))
                sub_seg = segment[start_ms:end_ms]
                sub_pcm = audio_segment_to_pcm_bytes(sub_seg)
                rl_sub, sim_sub = identify_speaker_id_from_pcm(
                    pcm_bytes=sub_pcm,
                    sample_rate=sample_rate,
                    channels=channels,
                    registry=speaker_registry,
                    device=speaker_device,
                )
                enriched.append(
                    {
                        "start": round(offset_sec + s_start, 3),
                        "end": round(offset_sec + s_end, 3),
                        "text": sent_text,
                        "sentence_end": bool(sentence_end_flag),
                        "rl": rl_sub,
                        "sim": round(sim_sub, 3) if sim_sub is not None else None,
                    }
                )
                continue

            # 根据内部标点切分 token 列表为若干小组
            groups: List[List[Dict]] = []
            cur_group: List[Dict] = []
            for t in tokens_in_sentence:
                txt = (t.get("text") or "") or ""
                cur_group.append(t)
                if len(txt) == 1 and (txt in INTERNAL_SPLIT_PUNCTS or txt in _SENT_END):
                    groups.append(cur_group)
                    cur_group = []
            if cur_group:
                groups.append(cur_group)

            for gi, grp in enumerate(groups):
                g_start = float(grp[0].get("start", 0.0) or 0.0)
                g_end = float(grp[-1].get("end", g_start) or g_start)
                g_text = "".join([x.get("text", "") or "" for x in grp]).strip()
                is_last = gi == (len(groups) - 1)
                grp_sentence_end = bool(is_last and sentence_end_flag)

                start_ms = max(0, int(round(g_start * 1000)))
                end_ms = min(len(segment), int(round(g_end * 1000)))
                if end_ms <= start_ms:
                    # 尝试扩展以获得非空切片
                    start_ms = max(0, start_ms - 20)
                    end_ms = min(len(segment), start_ms + 40)

                sub_seg = segment[start_ms:end_ms]
                sub_pcm = audio_segment_to_pcm_bytes(sub_seg)
                rl_sub, sim_sub = identify_speaker_id_from_pcm(
                    pcm_bytes=sub_pcm,
                    sample_rate=sample_rate,
                    channels=channels,
                    registry=speaker_registry,
                    device=speaker_device,
                )

                enriched.append(
                    {
                        "start": round(offset_sec + g_start, 3),
                        "end": round(offset_sec + g_end, 3),
                        "text": g_text,
                        "sentence_end": bool(grp_sentence_end),
                        "rl": rl_sub,
                        "sim": round(sim_sub, 3) if sim_sub is not None else None,
                    }
                )

    return {
        "language": language,
        "text": text,
        "time_stamps": enriched,
        "duration_sec": float(len(segment)) / 1000.0,
    }


def split_file_to_segments(file_path: str, chunk_seconds: int) -> List[AudioSegment]:
    # 该函数仅用于 HTTP 文件接口，先切 wav 再读回 AudioSegment，最后清理临时切片。
    chunks = split_audio_to_chunks(file_path, chunk_seconds)
    out: List[AudioSegment] = []
    try:
        for p in chunks:
            out.append(AudioSegment.from_file(p))
    finally:
        for p in chunks:
            try:
                import os

                os.remove(p)
            except Exception:
                pass
    return out
