"""
设备检测工具（支持 CUDA 与 Ascend NPU）

提供统一的设备选择/映射接口，以便在代码中优雅地支持 `cuda` / `npu` / `cpu`。
"""
from typing import Optional
import torch


def has_npu() -> bool:
    """检测当前运行时是否存在华为 Ascend NPU（torch.npu）。"""
    try:
        npu_mod = getattr(torch, "npu", None)
        is_avail = getattr(npu_mod, "is_available", None)
        return bool(is_avail() if callable(is_avail) else False)
    except Exception:
        return False


def has_cuda() -> bool:
    """检测是否存在 CUDA（NVIDIA GPU）。"""
    try:
        return bool(torch.cuda.is_available())
    except Exception:
        return False


def choose_device(device_hint: Optional[str] = None) -> str:
    """
    基于 hint 与运行时检测选择设备，返回 'npu' | 'cuda' | 'cpu'.

    选择逻辑：优先尊重显式 hint（支持 'npu', 'cuda', 'cpu'，自动处理带索引的 'cuda:0'/'npu:0'），
    否则按顺序优先选择 NPU -> CUDA -> CPU。
    """
    hint = (device_hint or "").strip().lower()
    if hint:
        if hint.startswith("cpu"):
            return "cpu"
        if hint.startswith("cuda"):
            return "cuda" if has_cuda() else ("npu" if has_npu() else "cpu")
        if hint.startswith("npu"):
            return "npu" if has_npu() else ("cuda" if has_cuda() else "cpu")

    # 无显式 hint，则优先使用 NPU，其次 CUDA，最后 CPU
    if has_npu():
        return "npu"
    if has_cuda():
        return "cuda"
    return "cpu"


def device_map_for_model(device: Optional[str]) -> str:
    """根据选择的设备返回适合传给模型的 device_map 字符串。"""
    if not device:
        return "cpu"
    device = device.strip().lower()
    # 保留可能含索引的情况（cuda:0 / npu:0）
    return device


def torch_dtype_for_device(device: Optional[str]):
    """返回建议用于加载模型的 torch dtype。"""
    if not device or device == "cpu":
        return torch.float32
    # 对加速器（CUDA/NPU）默认使用 float16，必要时可在配置中覆盖
    return torch.float16


__all__ = ["has_npu", "has_cuda", "choose_device", "device_map_for_model", "torch_dtype_for_device"]
