"""gzzm 服务主入口。

职责概览：
- 暴露 HTTP `/transcribe` 与 WebSocket `/ws/transcribe` 两个转写接口。
- 在请求入口做参数校验、会话级资源初始化（SpeakerRegistry、队列）和结果汇总。
- 调用 services 中的统一流水线完成 ASR、二次切分、说话人匹配与时间戳合并。
"""

import os
import tempfile
import logging
from contextlib import asynccontextmanager
from typing import Optional
import json
from pathlib import Path
import asyncio
import io
import uuid

from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from starlette.websockets import WebSocketState
from pydub import AudioSegment

from .utils.device_utils import choose_device
from .model import get_model, init_model
from .utils.speaker_id import SpeakerRegistry
from .async_transcribe import run_async_transcribe_task, run_transcribe_pipeline_for_file
from .transcribe import transcribe_audio_file
from .services import (
    QueueChunk,
    SessionChunkQueue,
    merge_adjacent_same_speaker,
    preload_persisted_speakers,
    process_chunk_pipeline,
    filter_time_stamps,
)
from .services.async_task_store import (
    create_task_record,
    get_task_record,
    init_task_db,
    mark_task_failed,
    mark_task_running,
    mark_task_succeeded,
)

LOG_LEVEL = "INFO"
logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
logger = logging.getLogger(__name__)
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))

_config_path = Path(__file__).resolve().parents[1] / "gzzm_config.json"
try:
    with _config_path.open("r", encoding="utf-8") as _f:
        _GZZM_CONFIG = json.load(_f)
except Exception:
    _GZZM_CONFIG = {}

# HTTP 与 WS 的分块参数解耦：
# - chunk_seconds 影响 /transcribe 文件切片与处理
# - ws_chunk_seconds 影响 /ws/transcribe 每次模型处理时的内部分块
# - ws_emit_seconds 影响 WS 触发 chunk_result 的发送节奏（多久推一批）
TRANSCRIBE_CHUNK_SECONDS = int(_GZZM_CONFIG.get("chunk_seconds", 60))
WS_CHUNK_SECONDS = int(_GZZM_CONFIG.get("ws_chunk_seconds", TRANSCRIBE_CHUNK_SECONDS))
WS_EMIT_SECONDS = float(_GZZM_CONFIG.get("ws_emit_seconds", 2.0))
DETAILED_LOG_ENABLED = str(_GZZM_CONFIG.get("enable_detailed_log", False)).strip().lower() in ("1", "true", "yes", "on")
# 运行时从配置或环境变量获取声纹相关设置（按需每次读取，支持热变更）

# 模块级默认 SPEAKER_DEVICE（导入时读取一次）。
# 运行时可以通过 _get_runtime_speaker_settings() 在请求级别覆盖。

def _load_gzzm_config() -> dict:
    try:
        with _config_path.open("r", encoding="utf-8") as _f:
            return json.load(_f)
    except Exception:
        return {}


def _get_runtime_speaker_settings():
    """Return (similarity_threshold, soft_margin, disable_soft_reuse, speaker_device).

    Values are read from environment variables first, then from `gzzm_config.json`.
    This is intentionally lightweight so callers can call it per-request.
    """
    cfg = _load_gzzm_config()
    sim_thr = float(cfg.get("speaker_sim_threshold", 0.5))
    sim_soft = float(cfg.get("speaker_sim_soft_margin", 0.35))
    disable_soft = cfg.get("speaker_disable_soft_reuse", cfg.get("disable_soft_reuse_for_persisted", False))
    disable_soft = bool(disable_soft)

    # 仅从 JSON 配置读取说话人模型设备（`model_device_speaker`），不再读取环境变量。
    cfg_speaker_device = str(cfg.get("model_device_speaker") or "").strip()
    primary = choose_device(cfg_speaker_device)
    if primary == "npu":
        speaker_device = "npu:0"
    elif primary == "cuda":
        speaker_device = "cuda:0"
    else:
        speaker_device = "cpu"

    return sim_thr, sim_soft, disable_soft, speaker_device


@asynccontextmanager
async def lifespan(_: FastAPI):
    init_model()
    init_task_db()
    yield


# 在模块导入时读取一次默认 SPEAKER_DEVICE，用作不传 device 时的默认值。
# 请求级别仍可通过 _get_runtime_speaker_settings() 获取最新设置。
_, _, _, SPEAKER_DEVICE = _get_runtime_speaker_settings()


app = FastAPI(title="Qwen3 ASR API", version="1.0.0", lifespan=lifespan)


def _build_speaker_registry() -> SpeakerRegistry:
    sim_thr, sim_soft, disable_soft, speaker_device = _get_runtime_speaker_settings()
    return SpeakerRegistry(
        similarity_threshold=sim_thr,
        soft_margin=sim_soft,
        model_device=speaker_device,
        disable_soft_reuse_for_persisted=disable_soft,
    )


def _resolve_manual_speaker_ids(request: Request, speaker_ids_camel: Optional[str], speaker_ids_snake: Optional[str]) -> Optional[str]:
    # 兼容多种传参来源：form(speakerIds/speaker_ids) -> query(speakerIds/speaker_ids)
    manual_speaker_ids = speaker_ids_camel or speaker_ids_snake
    if manual_speaker_ids:
        return manual_speaker_ids

    qp = request.query_params
    return qp.get("speakerIds") or qp.get("speaker_ids")


async def transcribe_audio_file_handler(file: UploadFile = File(...)) -> JSONResponse:
    """Legacy helper: pure ASR without speaker identification."""
    model = get_model()
    if model is None:
        raise HTTPException(status_code=503, detail="Model is not initialized")

    filename = (file.filename or "").lower()
    if not (filename.endswith(".wav") or filename.endswith(".mp3")):
        raise HTTPException(status_code=400, detail="Only wav or mp3 files are supported")

    suffix = ".wav" if filename.endswith(".wav") else ".mp3"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        tmp.write(await file.read())
        tmp_path = tmp.name

    try:
        result = transcribe_audio_file(tmp_path, TRANSCRIBE_CHUNK_SECONDS)
        return JSONResponse(result)
    finally:
        try:
            os.remove(tmp_path)
        except Exception:
            pass


@app.post("/transcribe")
async def transcribe_audio(
    request: Request,
    file: UploadFile = File(...),
    speakerIds: Optional[str] = Form(None),
    speaker_ids: Optional[str] = Form(None),
) -> JSONResponse:
    """Queue-based /transcribe pipeline: preload speakers -> chunk queue -> ASR -> second split -> speaker id."""
    model = get_model()
    if model is None:
        raise HTTPException(status_code=503, detail="Model is not initialized")

    filename = (file.filename or "").lower()
    if not (filename.endswith(".wav") or filename.endswith(".mp3")):
        raise HTTPException(status_code=400, detail="Only wav or mp3 files are supported")

    suffix = ".wav" if filename.endswith(".wav") else ".mp3"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        tmp.write(await file.read())
        tmp_path = tmp.name

    manual_speaker_ids = _resolve_manual_speaker_ids(request, speakerIds, speaker_ids)

    logger.info(
        "/transcribe manual speakerIds resolved: %s (form speakerIds=%s, form speaker_ids=%s)",
        manual_speaker_ids,
        speakerIds,
        speaker_ids,
    )

    try:
        payload = run_transcribe_pipeline_for_file(
            tmp_path=tmp_path,
            manual_speaker_ids=manual_speaker_ids,
            chunk_seconds=TRANSCRIBE_CHUNK_SECONDS,
            speaker_device=SPEAKER_DEVICE,
            build_speaker_registry=_build_speaker_registry,
            detailed_log_enabled=DETAILED_LOG_ENABLED,
            logger=logger,
            log_tag="/transcribe",
        )
        return JSONResponse(payload)
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Transcribe failed")
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        try:
            os.remove(tmp_path)
        except Exception:
            pass


@app.post("/transcribe/async")
async def transcribe_audio_async(
    request: Request,
    file: UploadFile = File(...),
    speakerIds: Optional[str] = Form(None),
    speaker_ids: Optional[str] = Form(None),
) -> JSONResponse:
    """Submit async transcribe task and return task UUID immediately."""
    model = get_model()
    if model is None:
        raise HTTPException(status_code=503, detail="Model is not initialized")

    filename = (file.filename or "").lower()
    if not (filename.endswith(".wav") or filename.endswith(".mp3")):
        raise HTTPException(status_code=400, detail="Only wav or mp3 files are supported")

    suffix = ".wav" if filename.endswith(".wav") else ".mp3"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        tmp.write(await file.read())
        tmp_path = tmp.name

    manual_speaker_ids = _resolve_manual_speaker_ids(request, speakerIds, speaker_ids)
    logger.info(
        "/transcribe/async manual speakerIds resolved: %s (form speakerIds=%s, form speaker_ids=%s)",
        manual_speaker_ids,
        speakerIds,
        speaker_ids,
    )

    task_id = str(uuid.uuid4())
    try:
        create_task_record(task_id, file.filename or "", manual_speaker_ids)
    except Exception as e:
        try:
            os.remove(tmp_path)
        except Exception:
            pass
        logger.exception("Failed to create async transcribe task record")
        raise HTTPException(status_code=500, detail=f"failed to create task: {e}")

    asyncio.create_task(
        run_async_transcribe_task(
            task_id=task_id,
            tmp_path=tmp_path,
            manual_speaker_ids=manual_speaker_ids,
            mark_running=mark_task_running,
            mark_succeeded=mark_task_succeeded,
            mark_failed=mark_task_failed,
            chunk_seconds=TRANSCRIBE_CHUNK_SECONDS,
            speaker_device=SPEAKER_DEVICE,
            build_speaker_registry=_build_speaker_registry,
            detailed_log_enabled=DETAILED_LOG_ENABLED,
            logger=logger,
        )
    )

    return JSONResponse(
        {
            "task_id": task_id,
            "status": "queued",
            "query_path": f"/transcribe/async/{task_id}",
        }
    )


@app.get("/transcribe/async/{task_id}")
async def get_transcribe_task(task_id: str) -> JSONResponse:
    """Query async transcribe task status and result by UUID."""
    task = get_task_record(task_id)
    if task is None:
        raise HTTPException(status_code=404, detail="task not found")

    return JSONResponse(task)


@app.websocket("/ws/transcribe")
async def websocket_transcribe(websocket: WebSocket):
    """Queue-based WS pipeline. Client controls session lifecycle; server keeps connection open unless error/disconnect."""
    await websocket.accept()

    model = get_model()
    if model is None:
        logger.warning("WS closed before processing: model is not initialized")
        await websocket.close(code=1013, reason="Model is not initialized")
        return

    try:
        init = await websocket.receive_json()
    except Exception:
        logger.warning("WS closed due to invalid init payload: JSON is required")
        await websocket.close(code=1003, reason="Invalid init payload, JSON is required")
        return

    fmt = (init.get("format", "pcm_s16le") or "pcm_s16le")
    sample_rate = int(init.get("sample_rate", 16000))
    channels = max(1, int(init.get("channels", 1)))
    sample_width = int(init.get("sample_width", 2))

    if fmt != "pcm_s16le":
        logger.warning("WS closed due to unsupported format: %s", fmt)
        await websocket.close(code=1003, reason="Only pcm_s16le is supported")
        return

    bytes_per_sec = max(1, sample_rate * channels * sample_width)
    # WS 每累计 emit_seconds 音频即触发一次处理和推送，保证“持续返回”体验。
    emit_seconds = max(0.2, float(WS_EMIT_SECONDS))
    chunk_bytes = max(1, int(round(emit_seconds * float(bytes_per_sec))))

    queue: SessionChunkQueue[QueueChunk] = SessionChunkQueue()
    pending = bytearray()
    speaker_registry = _build_speaker_registry()
    manual_ws_speaker_ids = init.get("speakerIds") if "speakerIds" in init else init.get("speaker_ids")
    logger.info("/ws/transcribe manual speakerIds: %s", manual_ws_speaker_ids)
    preload_persisted_speakers(
        speaker_registry,
        manual_ws_speaker_ids,
    )

    seq = 0
    index = 0
    offset_sec = 0.0
    all_time_stamps = []
    language = None

    def _format_exc(e: Exception) -> str:
        text = str(e).strip()
        if text:
            return text
        return repr(e)

    async def safe_close(code: int, reason: str) -> None:
        if getattr(websocket, "application_state", None) == WebSocketState.DISCONNECTED:
            logger.info("WS close skipped because application_state=DISCONNECTED: code=%s reason=%s", code, reason)
            return
        try:
            await websocket.close(code=code, reason=reason)
            logger.warning("WS closed by server: code=%s reason=%s", code, reason)
        except Exception as close_error:
            logger.info(
                "WS close skipped/failed: code=%s reason=%s error_type=%s error_repr=%r",
                code,
                reason,
                type(close_error).__name__,
                close_error,
            )

    async def safe_send_json(payload: dict) -> bool:
        try:
            await websocket.send_json(payload)
            return True
        except WebSocketDisconnect as e:
            logger.info(
                "WS send skipped because client disconnected: type=%s index=%s code=%s reason=%s client_state=%s app_state=%s",
                payload.get("type"),
                payload.get("index"),
                e.code,
                getattr(e, "reason", None),
                getattr(websocket, "client_state", None),
                getattr(websocket, "application_state", None),
            )
            return False
        except Exception as e:
            logger.warning(
                "WS send failed type=%s index=%s error_type=%s error=%s client_state=%s app_state=%s",
                payload.get("type"),
                payload.get("index"),
                type(e).__name__,
                _format_exc(e),
                getattr(websocket, "client_state", None),
                getattr(websocket, "application_state", None),
            )
            return False

    async def process_queue() -> tuple[bool, Optional[int], Optional[str]]:
        # 消费会话队列中的 PCM 片段：逐片转写、打说话人标签、再推送 chunk_result。
        nonlocal index, offset_sec, language, all_time_stamps
        while not queue.empty():
            item = queue.try_get()
            if item is None:
                break

            if DETAILED_LOG_ENABLED:
                logger.info(
                    "[detailed][/ws/transcribe] start chunk index=%s seq=%s range=[%.3f, %.3f]",
                    index,
                    item.seq,
                    offset_sec,
                    offset_sec + item.duration_sec,
                )

            segment = AudioSegment.from_raw(
                io.BytesIO(item.payload),
                sample_width=sample_width,
                frame_rate=sample_rate,
                channels=channels,
            )

            loop = asyncio.get_running_loop()
            try:
                chunk_result = await loop.run_in_executor(
                    None,
                    process_chunk_pipeline,
                    segment,
                    offset_sec,
                    WS_CHUNK_SECONDS,
                    speaker_registry,
                    SPEAKER_DEVICE,
                )
            except Exception as e:
                error_text = _format_exc(e)
                sent = await safe_send_json({"type": "chunk_error", "index": index, "error": error_text})
                if not sent:
                    return False, 1011, f"chunk_error send failed at index={index}"
                return False, 1011, f"chunk processing failed at index={index}: {error_text}"

            if language is None:
                language = chunk_result.get("language")

            chunk_text = (chunk_result.get("text") or "").strip()

            chunk_ts = chunk_result.get("time_stamps") or []
            # 丢弃纯标点且 rl 为 null 的短片段（按配置控制日志输出）
            try:
                chunk_ts = filter_time_stamps(chunk_ts)
            except Exception:
                pass
            # 先合并当前 chunk 内相邻同说话人，再与历史片段做跨 chunk 合并。
            chunk_ts = merge_adjacent_same_speaker(chunk_ts)
            if chunk_ts:
                all_time_stamps.extend(chunk_ts)
                all_time_stamps = merge_adjacent_same_speaker(all_time_stamps)

            duration_sec = float(chunk_result.get("duration_sec", item.duration_sec) or item.duration_sec)
            if DETAILED_LOG_ENABLED:
                logger.info(
                    "[detailed][/ws/transcribe] result chunk index=%s seq=%s range=[%.3f, %.3f] result=%s",
                    index,
                    item.seq,
                    offset_sec,
                    offset_sec + duration_sec,
                    chunk_result,
                )

            ok = await safe_send_json(
                {
                    "type": "chunk_result",
                    "index": index,
                    "text": chunk_text,
                    "time_stamps": chunk_ts,
                    "chunk_duration": round(duration_sec, 3),
                    "total_duration": round(offset_sec + duration_sec, 3),
                }
            )
            if not ok:
                # 客户端已断开时不再视为服务端内部错误；交由外层正常结束会话。
                if getattr(websocket, "application_state", None) == WebSocketState.DISCONNECTED:
                    return False, None, "client disconnected while sending chunk_result"
                return False, 1011, f"chunk_result send failed at index={index}"

            offset_sec += duration_sec
            index += 1

        return True, None, None

    try:
        while True:
            msg = await websocket.receive()
            if msg.get("type") == "websocket.disconnect":
                logger.info(
                    "WebSocket disconnected by client: code=%s reason=%s detail=%s",
                    msg.get("code"),
                    msg.get("reason"),
                    msg,
                )
                break

            if msg.get("type") == "websocket.receive" and msg.get("bytes") is not None:
                data = msg.get("bytes")
                pending.extend(data)

                # 达到发送门槛后，按顺序切块入队，确保 FIFO 处理。
                while len(pending) >= chunk_bytes:
                    payload = bytes(pending[:chunk_bytes])
                    del pending[:chunk_bytes]
                    queue.put(
                        QueueChunk(
                            seq=seq,
                            payload=payload,
                            duration_sec=float(len(payload)) / float(bytes_per_sec),
                        )
                    )
                    seq += 1

                ok, close_code, close_reason = await process_queue()
                if not ok:
                    if close_code is not None and close_reason:
                        await safe_close(close_code, close_reason)
                    break

            elif msg.get("type") == "websocket.receive" and msg.get("text"):
                # eof 仅表示“当前音频段结束”，服务端会冲刷尾包并保持会话不断开。
                if msg.get("text") == "eof":
                    if pending:
                        payload = bytes(pending)
                        pending.clear()
                        queue.put(
                            QueueChunk(
                                seq=seq,
                                payload=payload,
                                duration_sec=float(len(payload)) / float(bytes_per_sec),
                            )
                        )
                        seq += 1

                    ok, close_code, close_reason = await process_queue()
                    if not ok:
                        logger.warning("WS session stop after eof flush because send/process failed")
                        if close_code is not None and close_reason:
                            await safe_close(close_code, close_reason)
                        break

                    await safe_send_json(
                        {
                            "type": "final",
                            "index": index - 1 if index > 0 else 0,
                            "total_duration": round(offset_sec, 3),
                        }
                    )
    except WebSocketDisconnect as e:
        logger.info("WebSocket client disconnected (exception): code=%s reason=%s", e.code, getattr(e, "reason", None))
    except Exception:
        logger.exception("WS handling failed")
        try:
            logger.warning("WS closing with internal error code=1011")
            await websocket.close(code=1011, reason="Internal error")
        except Exception:
            pass
    finally:
        try:
            speaker_registry.clear()
        except Exception:
            pass


@app.post("/speaker/register")
async def speaker_register(
    file: UploadFile = File(...),
    sample_rate: int = 16000,
    channels: int = 1,
    sample_width: int = 2,
    format: str = "pcm_s16le",
) -> JSONResponse:
    """Persist one speaker embedding and return UUID."""
    data = await file.read()

    pcm_bytes = None
    fname = (file.filename or "").lower()
    if format in ("wav", "mp3") or fname.endswith((".wav", ".mp3")):
        try:
            seg = AudioSegment.from_file(io.BytesIO(data))
            buf = io.BytesIO()
            seg.export(buf, format="raw")
            pcm_bytes = buf.getvalue()
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"failed to decode audio: {e}")
    else:
        if format in ("pcm", "pcm_s16le") or fname.endswith(".pcm"):
            if sample_width != 2:
                raise HTTPException(status_code=400, detail="only 16-bit PCM (sample_width=2) is supported")
            pcm_bytes = data
        else:
            raise HTTPException(status_code=400, detail="unsupported audio format; provide wav/mp3 or pcm_s16le")

    reg = _build_speaker_registry()
    try:
        uid = reg.persist_from_pcm(pcm_bytes, sample_rate=sample_rate, channels=channels, device=SPEAKER_DEVICE)
    except ImportError:
        raise HTTPException(status_code=501, detail="speechbrain not available for embedding extraction")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

    return JSONResponse({"uuid": uid})
