"""
单文件 ONNX 导出脚本（简化版）

行为：
- 不使用命令行参数；按项目中 `app`/`model` 模式默认加载模型（默认模型 id 为
  "Qwen/Qwen3-ASR-0.6B"），并尝试在模型对象内查找第一个 `torch.nn.Module` 子模块
 （如果模型把推理逻辑封装在子模块里）。
- 导出 ONNX 到当前工作目录下的 `qwen_asr.onnx`。

注意：若模型内部没有直接暴露 `torch.nn.Module`，脚本会失败，此时需要指明要导出的
子模块名或手动调整脚本以匹配模型封装层。
"""

from __future__ import annotations

import logging
import os
from typing import Any, List, Tuple

import torch

from qwen_asr import Qwen3ASRModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def find_torch_modules(obj: Any) -> List[Tuple[str, torch.nn.Module]]:
    found: List[Tuple[str, torch.nn.Module]] = []
    # 先检查对象的直接属性（非递归）
    for name in dir(obj):
        if name.startswith("__"):
            continue
        try:
            attr = getattr(obj, name)
        except Exception:
            continue
        if isinstance(attr, torch.nn.Module):
            found.append((name, attr))

    # 如果 obj 本身是 nn.Module，遍历其子模块并只选取那些真正重写了 forward 的模块
    if isinstance(obj, torch.nn.Module):
        base_forward = getattr(torch.nn.Module, "forward", None)
        for name, sub in obj.named_modules():
            if not name:
                continue
            forward = getattr(sub, "forward", None)
            if forward is None:
                continue
            # bound method -> get underlying function
            func = getattr(forward, "__func__", forward)
            # 如果与基类的 forward 相同，说明没有被重写，跳过
            if func is base_forward:
                continue
            found.append((name, sub))

    # 去重并保持顺序
    seen = set()
    unique: List[Tuple[str, torch.nn.Module]] = []
    for n, m in found:
        if n in seen:
            continue
        seen.add(n)
        unique.append((n, m))
    return unique


def make_dummy(input_len: int, device: torch.device) -> torch.Tensor:
    # 默认构造为单通道波形形状 (1, input_len)
    return torch.randn(1, input_len, dtype=torch.float32, device=device)


def export_module_to_onnx(
    module: torch.nn.Module,
    dummy: torch.Tensor,
    output_path: str = "qwen_asr.onnx",
    opset: int = 16,
    input_name: str = "audio",
    output_name: str = "output",
):
    module.eval()
    module.cpu()
    args = (dummy,)
    dynamic_axes = {input_name: {1: "samples"}, output_name: {1: "seq"}}

    torch.onnx.export(
        module,
        args,
        output_path,
        export_params=True,
        opset_version=opset,
        do_constant_folding=True,
        input_names=[input_name],
        output_names=[output_name],
        dynamic_axes=dynamic_axes,
    )


def main() -> None:
    model_id = "Qwen/Qwen3-ASR-0.6B"
    output_path = "qwen_asr.onnx"
    input_len = 16000

    logger.info("加载模型 %s（将以 CPU 方式加载以便导出）...", model_id)
    cache_dir = os.getenv("HF_CACHE_DIR", "/app/hf_cache")
    try:
        os.makedirs(cache_dir, exist_ok=True)
    except Exception:
        # 如果不能创建目录，继续让 from_pretrained 处理缓存路径
        pass

    model = Qwen3ASRModel.from_pretrained(
        model_id,
        cache_dir=cache_dir,
        torch_dtype=torch.float32,
        device_map="cpu",
        low_cpu_mem_usage=True,
        use_safetensors=True,
        max_inference_batch_size=1,
        max_new_tokens=32,
    )

    # 使用通用检测：先尝试查找 model 对象内的 torch.nn.Module 子模块
    found = find_torch_modules(model)

    if not found:
        # 如果没有直接找到，检查 model 的属性里是否包含模块或容器（例如 model.model、model.module）
        candidates_attr = []
        for name in dir(model):
            if name.startswith("__"):
                continue
            try:
                attr = getattr(model, name)
            except Exception:
                continue
            if isinstance(attr, torch.nn.Module):
                candidates_attr.append((name, type(attr).__name__))
            else:
                # 如果属性自身有 named_modules 方法，视为候选容器
                if hasattr(attr, "named_modules"):
                    try:
                        mods = list(attr.named_modules())
                        if mods:
                            candidates_attr.append((name, type(attr).__name__))
                    except Exception:
                        pass

        if not candidates_attr:
            logger.error("未在模型对象或其属性中发现可导出的 torch.nn.Module，请检查封装或实现 wrapper。")
            return

        logger.info("在模型属性中发现以下候选模块（属性名, type）：")
        for idx, (name, typ) in enumerate(candidates_attr, start=1):
            logger.info("%d. %s (%s)", idx, name, typ)

        logger.info("请通过设置环境变量 ONNX_EXPORT_MODULE=<属性名> 指定要导出的属性后重试（例如: export ONNX_EXPORT_MODULE=model）")
        return

    # 如果找到了子模块，列出并让用户通过环境变量选择
    logger.info("模型中发现以下可导出子模块（name, type, overrides_forward）:")
    any_overridden = False
    for idx, (name, sub) in enumerate(found, start=1):
        typ = type(sub).__name__
        forward = getattr(sub, "forward", None)
        base_forward = getattr(torch.nn.Module, "forward", None)
        func = getattr(forward, "__func__", forward)
        overridden = func is not base_forward
        if overridden:
            any_overridden = True
        logger.info("%d. %s (%s) overrides_forward=%s", idx, name, typ, overridden)

    # 支持通过环境变量选择要导出的模块名（优先）
    selected_name = os.getenv("ONNX_EXPORT_MODULE")
    module = None
    module_name = None

    if selected_name:
        for name, sub in found:
            if name == selected_name:
                module = sub
                module_name = name
                break
        if module is None:
            logger.error("通过 ONNX_EXPORT_MODULE=%s 指定的模块未在候选列表中找到。", selected_name)
            return
        logger.info("通过 ONNX_EXPORT_MODULE 选择模块：%s", module_name)
    else:
        # 无环境变量时，默认选择类型为 Qwen3ASRForConditionalGeneration 的模块（如果存在）
        for name, sub in found:
            if type(sub).__name__ == "Qwen3ASRForConditionalGeneration":
                module = sub
                module_name = name
                logger.info("未指定模块，默认选择类型 Qwen3ASRForConditionalGeneration 的子模块：%s", module_name)
                break

        if module is None:
            # 如果没有默认类型，且没有任何子模块重写 forward，则列出第一个候选的内部子模块供选择
            if not any_overridden:
                name0, sub0 = found[0]
                logger.info("注意：候选模块均未重写 forward；将尝试列出 '%s' 的内部子模块。", name0)
                try:
                    inner = list(sub0.named_modules())
                except Exception as e:
                    logger.error("无法列出子模块：%s", e)
                    logger.info("请在运行前通过设置 ONNX_EXPORT_MODULE 指定一个有效的子模块名。")
                    return

                inner_list = []
                for iname, isub in inner:
                    if not iname:
                        continue
                    forward = getattr(isub, "forward", None)
                    func = getattr(forward, "__func__", forward)
                    overridden = func is not getattr(torch.nn.Module, "forward", None)
                    full_name = f"{name0}.{iname}"
                    inner_list.append((full_name, type(isub).__name__, overridden))

                if not inner_list:
                    logger.error("未在内部找到可选子模块。请手动检查模型或实现 wrapper。")
                    return

                logger.info("内部子模块（可作为 ONNX_EXPORT_MODULE 的值）:")
                for idx, (n, t, ov) in enumerate(inner_list, start=1):
                    logger.info("%d. %s (%s) overrides_forward=%s", idx, n, t, ov)

                logger.info("设置示例: export ONNX_EXPORT_MODULE=%s 然后重试", inner_list[0][0])
                return

            logger.info("未指定要导出的子模块。若要导出，请设置环境变量 ONNX_EXPORT_MODULE=<模块名> 并重试。")
            return

    device = torch.device("cpu")
    dummy = make_dummy(input_len, device)

    try:
        export_module_to_onnx(module, dummy, output_path)
        logger.info("导出成功：%s", output_path)
    except Exception as e:
        logger.exception("导出失败：%s", e)


if __name__ == "__main__":
    main()
