"""异步转写任务与可复用转写执行逻辑。"""

import asyncio
import logging
import os
import time
from typing import Any, Callable, Dict, Optional

from pydub import AudioSegment

from .services import (
    SessionChunkQueue,
    filter_time_stamps,
    merge_adjacent_same_speaker,
    preload_persisted_speakers,
    process_chunk_pipeline,
    split_file_to_segments,
)
from .utils.speaker_id import SpeakerRegistry


def run_transcribe_pipeline_for_file(
    *,
    tmp_path: str,
    manual_speaker_ids: Optional[str],
    chunk_seconds: int,
    speaker_device: str,
    build_speaker_registry: Callable[[], SpeakerRegistry],
    detailed_log_enabled: bool = False,
    logger: Optional[logging.Logger] = None,
    log_tag: str = "/transcribe",
) -> Dict[str, Any]:
    """Run the queue-based transcribe flow and return the response payload."""
    logger = logger or logging.getLogger(__name__)
    speaker_registry = build_speaker_registry()
    preload_persisted_speakers(speaker_registry, manual_speaker_ids)

    started_at = time.perf_counter()
    queue: SessionChunkQueue[AudioSegment] = SessionChunkQueue()

    try:
        segments = split_file_to_segments(tmp_path, chunk_seconds)
        for seg in segments:
            queue.put(seg)

        offset_sec = 0.0
        language = None
        text_parts = []
        enriched_ts = []
        chunk_count = len(segments)
        chunk_index = 0

        while not queue.empty():
            seg = queue.try_get()
            if seg is None:
                break

            estimated_duration_sec = float(len(seg)) / 1000.0
            if detailed_log_enabled:
                logger.info(
                    "[detailed][%s] start chunk index=%s range=[%.3f, %.3f]",
                    log_tag,
                    chunk_index,
                    offset_sec,
                    offset_sec + estimated_duration_sec,
                )

            chunk_result = process_chunk_pipeline(
                segment=seg,
                offset_sec=offset_sec,
                chunk_seconds=chunk_seconds,
                speaker_registry=speaker_registry,
                speaker_device=speaker_device,
            )
            if language is None:
                language = chunk_result.get("language")

            txt = (chunk_result.get("text") or "").strip()
            if txt:
                text_parts.append(txt)

            chunk_ts = chunk_result.get("time_stamps") or []
            try:
                chunk_ts = filter_time_stamps(chunk_ts)
            except Exception:
                pass
            if chunk_ts:
                enriched_ts.extend(chunk_ts)

            chunk_duration_sec = float(chunk_result.get("duration_sec", estimated_duration_sec) or estimated_duration_sec)
            if detailed_log_enabled:
                logger.info(
                    "[detailed][%s] result chunk index=%s range=[%.3f, %.3f] result=%s",
                    log_tag,
                    chunk_index,
                    offset_sec,
                    offset_sec + chunk_duration_sec,
                    chunk_result,
                )

            offset_sec += chunk_duration_sec
            chunk_index += 1

        enriched_ts = merge_adjacent_same_speaker(enriched_ts)
        return {
            "language": language,
            "text": " ".join(text_parts).strip(),
            "time_sec": round(time.perf_counter() - started_at, 2),
            "chunks": chunk_count,
            "time_stamps": enriched_ts,
        }
    finally:
        try:
            speaker_registry.clear()
        except Exception:
            pass


async def run_async_transcribe_task(
    *,
    task_id: str,
    tmp_path: str,
    manual_speaker_ids: Optional[str],
    mark_running: Callable[[str], None],
    mark_succeeded: Callable[[str, Dict[str, Any]], None],
    mark_failed: Callable[[str, str], None],
    chunk_seconds: int,
    speaker_device: str,
    build_speaker_registry: Callable[[], SpeakerRegistry],
    detailed_log_enabled: bool,
    logger: logging.Logger,
) -> None:
    """Background async runner for one uploaded file task."""
    try:
        mark_running(task_id)
        result = await asyncio.to_thread(
            run_transcribe_pipeline_for_file,
            tmp_path=tmp_path,
            manual_speaker_ids=manual_speaker_ids,
            chunk_seconds=chunk_seconds,
            speaker_device=speaker_device,
            build_speaker_registry=build_speaker_registry,
            detailed_log_enabled=detailed_log_enabled,
            logger=logger,
            log_tag="/transcribe/async",
        )
        mark_succeeded(task_id, result)
    except Exception as exc:
        logger.exception("Async transcribe failed for task=%s", task_id)
        try:
            mark_failed(task_id, str(exc))
        except Exception:
            logger.exception("Failed to persist failed status for task=%s", task_id)
    finally:
        try:
            os.remove(tmp_path)
        except Exception:
            pass
