"""文件转写核心逻辑。

职责概览：
- 对输入音频进行分块转写并拼接文本与时间戳。
- 当单块转写失败/空结果时自动降级为更小子块重试。
- 提供可配置的空结果容错模式（fail_on_empty）。
"""

import os
import time
import logging
import tempfile
from typing import List, Optional, Tuple

from .model import get_model
from .utils.audio_utils import split_audio_to_chunks, insert_punctuations_into_segments, merge_sentences_from_tokens
from pydub import AudioSegment
import json
from pathlib import Path

# 读取 JSON 配置（src/gzzm_config.json），不再使用环境变量覆盖
_config_path = Path(__file__).resolve().parents[1] / "gzzm_config.json"
try:
    with _config_path.open("r", encoding="utf-8") as _f:
        _GZZM_CONFIG = json.load(_f)
except Exception:
    _GZZM_CONFIG = {}

LOG_LEVEL = str(_GZZM_CONFIG.get("log_level", "INFO")).upper()
logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
logger = logging.getLogger(__name__)
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))


MIN_CHUNK_SEC_FOR_PARTIAL_FALLBACK = float(_GZZM_CONFIG.get("asr_min_chunk_sec_for_partial_fallback", 20))
PARTIAL_RESULT_COVERAGE_RATIO = float(_GZZM_CONFIG.get("asr_partial_result_coverage_ratio", 0.8))
PARTIAL_RESULT_TAIL_GAP_SEC = float(_GZZM_CONFIG.get("asr_partial_result_tail_gap_sec", 8))


def _extract_transcribe_payload(first) -> Tuple[Optional[str], str, List[dict]]:
    """Normalize transcribe result to (language, text, time_stamps)."""
    language = getattr(first, "language", None) or (first.get("language") if isinstance(first, dict) else None)
    text_val = getattr(first, "text", None) or (first.get("text") if isinstance(first, dict) else "")
    ts_list = getattr(first, "time_stamps", None) or (first.get("time_stamps") if isinstance(first, dict) else None) or []
    return language, text_val, ts_list


def _extract_end_seconds(ts_item) -> float:
    if isinstance(ts_item, dict):
        end_val = ts_item.get("end", ts_item.get("finish", ts_item.get("end_time", 0.0)))
    else:
        end_val = getattr(ts_item, "end", None)
        if end_val is None:
            end_val = getattr(ts_item, "finish", None)
        if end_val is None:
            end_val = getattr(ts_item, "end_time", 0.0)
    try:
        return float(end_val or 0.0)
    except Exception:
        return 0.0


def _should_fallback_for_partial_result(chunk_duration: float, ts_list: List[dict]) -> bool:
    if chunk_duration < MIN_CHUNK_SEC_FOR_PARTIAL_FALLBACK:
        return False
    if not ts_list:
        return False

    max_end_sec = max((_extract_end_seconds(ts) for ts in ts_list), default=0.0)
    if max_end_sec <= 0:
        return True

    coverage_ratio = max_end_sec / max(0.001, chunk_duration)
    tail_gap_sec = max(0.0, chunk_duration - max_end_sec)
    return coverage_ratio < PARTIAL_RESULT_COVERAGE_RATIO or tail_gap_sec > PARTIAL_RESULT_TAIL_GAP_SEC


def _run_transcribe_chunk_with_fallback(chunk_path: str, fallback_seconds: int = 15) -> Tuple[Optional[str], str, List[dict], float]:
    """Transcribe one chunk; if it fails, split it into smaller pieces and retry.

    Returns (language, text, time_stamps, chunk_duration_sec).
    """
    MODEL = get_model()
    if MODEL is None:
        raise RuntimeError("模型尚未加载完成")

    seg = AudioSegment.from_file(chunk_path)
    chunk_duration = len(seg) / 1000.0

    try:
        results = MODEL.transcribe(audio=chunk_path, language=None, return_time_stamps=True)
        if results:
            first = results[0]
            language, text_val, ts_list = _extract_transcribe_payload(first)
            if text_val or ts_list:
                if _should_fallback_for_partial_result(chunk_duration, ts_list):
                    max_end_sec = max((_extract_end_seconds(ts) for ts in ts_list), default=0.0)
                    logger.warning(
                        "检测到疑似截断结果，降级为小分片重试: chunk=%s duration=%.3fs max_end=%.3fs",
                        chunk_path,
                        chunk_duration,
                        max_end_sec,
                    )
                else:
                    return language, text_val, ts_list, chunk_duration
        logger.warning("切片转写结果为空，降级为小分片重试：%s", chunk_path)
    except Exception:
        logger.exception("模型对切片转写失败，降级为小分片重试：%s", chunk_path)

    fallback_ms = max(1, int(fallback_seconds)) * 1000
    merged_language: Optional[str] = None
    merged_texts: List[str] = []
    merged_ts: List[dict] = []

    for sub_start_ms in range(0, len(seg), fallback_ms):
        sub_seg = seg[sub_start_ms: sub_start_ms + fallback_ms]
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as sub_tmp:
            sub_seg.export(sub_tmp.name, format="wav")
            sub_path = sub_tmp.name
        try:
            sub_results = MODEL.transcribe(audio=sub_path, language=None, return_time_stamps=True)
            if not sub_results:
                continue
            sub_first = sub_results[0]
            sub_lang, sub_text, sub_ts_list = _extract_transcribe_payload(sub_first)
            merged_language = merged_language or sub_lang
            if sub_text:
                merged_texts.append(sub_text)

            sub_offset = sub_start_ms / 1000.0
            for ts in sub_ts_list:
                if isinstance(ts, dict):
                    start = float(ts.get("start", ts.get("begin", ts.get("start_time", 0))))
                    end = float(ts.get("end", ts.get("finish", ts.get("end_time", start))))
                    text_seg = ts.get("text", ts.get("word", ""))
                else:
                    start_attr = getattr(ts, "start", None)
                    end_attr = getattr(ts, "end", None)
                    if start_attr is None and hasattr(ts, "start_time"):
                        start_attr = getattr(ts, "start_time")
                    if end_attr is None and hasattr(ts, "end_time"):
                        end_attr = getattr(ts, "end_time")
                    start = float(start_attr or 0.0)
                    end = float(end_attr or start)
                    text_seg = getattr(ts, "text", getattr(ts, "word", ""))

                merged_ts.append({
                    "start": round(start + sub_offset, 3),
                    "end": round(end + sub_offset, 3),
                    "text": text_seg,
                })
        except Exception:
            logger.exception("降级子切片仍转写失败：%s", sub_path)
            continue
        finally:
            try:
                os.remove(sub_path)
            except Exception:
                pass

    return merged_language, "".join(merged_texts), merged_ts, chunk_duration


def transcribe_audio_file(tmp_path: str, chunk_seconds: int, fail_on_empty: bool = True) -> dict:
    """Transcribe a file at `tmp_path` using the global model.

    Returns a dict compatible with the previous JSONResponse payload.
    This function will remove any temporary chunk files it created, but
    does not remove `tmp_path` (caller is responsible).
    """
    MODEL = get_model()
    if MODEL is None:
        raise RuntimeError("模型尚未加载完成")

    start_infer = time.time()
    chunk_paths = split_audio_to_chunks(tmp_path, chunk_seconds)
    logger.debug("切分后得到 %d 个 chunk: %s", len(chunk_paths), chunk_paths)
    texts: List[str] = []
    language: Optional[str] = None

    start_offset = 0.0
    all_time_stamps: List[dict] = []

    try:
        for chunk_path in chunk_paths:
            chunk_lang, chunk_text, chunk_ts_list, chunk_duration = _run_transcribe_chunk_with_fallback(
                chunk_path,
                fallback_seconds=max(5, min(20, chunk_seconds // 4)) if chunk_seconds > 0 else 15,
            )

            language = language or chunk_lang
            if chunk_text:
                texts.append(chunk_text)

            for ts in chunk_ts_list:
                if isinstance(ts, dict):
                    start = float(ts.get("start", ts.get("begin", ts.get("start_time", 0))))
                    end = float(ts.get("end", ts.get("finish", ts.get("end_time", start))))
                    text_seg = ts.get("text", ts.get("word", ""))
                else:
                    start_attr = getattr(ts, "start", None)
                    end_attr = getattr(ts, "end", None)
                    if start_attr is None and hasattr(ts, "start_time"):
                        start_attr = getattr(ts, "start_time")
                    if end_attr is None and hasattr(ts, "end_time"):
                        end_attr = getattr(ts, "end_time")
                    start = float(start_attr or 0.0)
                    end = float(end_attr or start)
                    text_seg = getattr(ts, "text", getattr(ts, "word", ""))

                all_time_stamps.append({
                    "start": round(start + start_offset, 3),
                    "end": round(end + start_offset, 3),
                    "text": text_seg,
                })

            start_offset += chunk_duration

        infer_time = time.time() - start_infer

        if not texts:
            if fail_on_empty:
                # 提供更多上下文以便排查间歇性失败
                raise RuntimeError(f"识别失败: 未从任何切片获得文本 (chunks={len(chunk_paths)}, paths={chunk_paths})")
            logger.warning("本次转写未获得文本，返回空结果: chunks=%d", len(chunk_paths))
            return {
                "language": language,
                "text": "",
                "time_sec": round(infer_time, 2),
                "chunks": len(chunk_paths),
                "time_stamps": [],
                "time_stamps_tokens": [],
            }

        full_text = "".join(texts)
        time_stamps_tokens = insert_punctuations_into_segments(all_time_stamps, full_text)
        merged_time_stamps = merge_sentences_from_tokens(time_stamps_tokens)

        return {
            "language": language,
            "text": " ".join(texts).strip(),
            "time_sec": round(infer_time, 2),
            "chunks": len(chunk_paths),
            "time_stamps": merged_time_stamps,
            "time_stamps_tokens": time_stamps_tokens,
        }
    finally:
        # cleanup chunk files
        for p in chunk_paths:
            try:
                os.remove(p)
            except Exception:
                pass
