import os
import time
import tempfile
from contextlib import asynccontextmanager
from typing import List, Optional

import torch
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from pydub import AudioSegment
from qwen_asr import Qwen3ASRModel

MODEL: Optional[Qwen3ASRModel] = None


@asynccontextmanager
async def lifespan(_: FastAPI):
    global MODEL
    print("【1 开始加载 Qwen3-ASR-0.6B 模型...】")
    start_load = time.time()
    if not torch.cuda.is_available():
        raise RuntimeError("未检测到 GPU（CUDA），请确认容器已正确启用 GPU")
    MODEL = Qwen3ASRModel.from_pretrained(
        "Qwen/Qwen3-ASR-0.6B",
        cache_dir="/app/hf_cache",
        torch_dtype=torch.float16,
        device_map="cuda",
        low_cpu_mem_usage=True,
        use_safetensors=True,
        max_inference_batch_size=1,
        max_new_tokens=32,
    )
    load_time = time.time() - start_load
    print(f"✅ 模型加载完成！耗时: {load_time:.2f} 秒\n")
    yield


app = FastAPI(title="Qwen3 ASR API", version="1.0.0", lifespan=lifespan)

CHUNK_SECONDS = int(os.getenv("ASR_CHUNK_SECONDS", "15"))


def split_audio_to_chunks(file_path: str, chunk_seconds: int) -> List[str]:
    audio = AudioSegment.from_file(file_path)
    chunk_ms = max(1, chunk_seconds) * 1000
    chunks: List[str] = []
    for start in range(0, len(audio), chunk_ms):
        end = start + chunk_ms
        segment = audio[start:end]
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
            segment.export(tmp.name, format="wav")
            chunks.append(tmp.name)
    return chunks


@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)) -> JSONResponse:
    if MODEL is None:
        raise HTTPException(status_code=503, detail="模型尚未加载完成")

    filename = (file.filename or "").lower()
    if not (filename.endswith(".wav") or filename.endswith(".mp3")):
        raise HTTPException(status_code=400, detail="仅支持 wav 或 mp3 文件")

    suffix = ".wav" if filename.endswith(".wav") else ".mp3"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        tmp.write(await file.read())
        tmp_path = tmp.name

    try:
        start_infer = time.time()
        chunk_paths = split_audio_to_chunks(tmp_path, CHUNK_SECONDS)
        texts: List[str] = []
        language: Optional[str] = None

        for chunk_path in chunk_paths:
            results = MODEL.transcribe(audio=chunk_path, language=None)
            if results:
                language = language or results[0].language
                texts.append(results[0].text)

        infer_time = time.time() - start_infer

        if not texts:
            raise HTTPException(status_code=500, detail="识别失败")

        return JSONResponse(
            {
                "language": language,
                "text": " ".join(texts).strip(),
                "time_sec": round(infer_time, 2),
                "chunks": len(chunk_paths),
            }
        )
    finally:
        try:
            os.remove(tmp_path)
        except OSError:
            pass
        try:
            for chunk_path in locals().get("chunk_paths", []):
                try:
                    os.remove(chunk_path)
                except OSError:
                    pass
        except Exception:
            pass


if __name__ == "__main__":
    import uvicorn

    uvicorn.run("qwen_asr_demo:app", host="0.0.0.0", port=8000)