package com.gzzm.lobster.llm;

import com.gzzm.lobster.common.LobsterException;
import com.gzzm.lobster.common.ModelServiceTier;
import com.gzzm.lobster.config.AgentProfile;
import com.gzzm.lobster.config.AgentProfileDao;
import com.gzzm.platform.commons.Tools;
import net.cyan.nest.annotation.Inject;

import java.util.ArrayList;
import java.util.List;

/**
 * ModelRouter —— 模型路由器 / Model routing service.
 *
 * <p>决策链（按优先级）：组织默认 → Agent 配置 → Skill 声明 → 任务画像 → 实时负载 → 降级。
 * Decision chain: org default → agent → skill → task signal → runtime load → fallback.
 */
public class ModelRouter {

    @Inject
    private ModelProfileDao modelProfileDao;

    @Inject
    private AgentProfileDao agentProfileDao;

    /** thunwind DAO 跨线程保护 —— 详见 feedback_thunwind_dao_thread_binding */
    private ModelProfileDao modelProfileDao() {
        try {
            ModelProfileDao d = Tools.getBean(ModelProfileDao.class);
            if (d != null) return d;
        } catch (Throwable ignore) { /* fallback */ }
        return modelProfileDao;
    }
    private AgentProfileDao agentProfileDao() {
        try {
            AgentProfileDao d = Tools.getBean(AgentProfileDao.class);
            if (d != null) return d;
        } catch (Throwable ignore) { /* fallback */ }
        return agentProfileDao;
    }

    public ModelRouteResult route(ModelSelectionContext ctx) {
        try {
            ModelProfile primary = selectPrimary(ctx);
            if (primary == null) {
                throw new LobsterException("llm.route", "No enabled model profile found");
            }
            // fallback 也走 ctx 能力筛选——多模态需求下绝不让纯文本模型进 chain，
            // 否则 LlmRuntime 会真正向那个 profile 发请求（含 image_url），
            // 要么被服务端 400、要么静默丢图，都是糟糕体验.
            List<ModelProfile> fallbacks = loadFallbacks(primary, ctx);
            return new ModelRouteResult(primary, fallbacks, explain(primary, ctx));
        } catch (LobsterException e) {
            throw e;
        } catch (Exception e) {
            throw new LobsterException("llm.route", "Route failed: " + e.getMessage(), e);
        }
    }

    private ModelProfile selectPrimary(ModelSelectionContext ctx) throws Exception {
        // 0. 摘要类任务（title 生成、transcript 压缩等）专走 summary tier.
        //    没配 summary 模型就 fall through 到下面的常规链路，保证旧部署不破.
        if ("summary".equals(ctx.getTaskType())) {
            List<ModelProfile> summary = modelProfileDao().listByTier(ModelServiceTier.summary);
            if (summary != null) {
                for (ModelProfile p : summary) {
                    if (passes(p, ctx)) return p;
                }
            }
            // 进入下面的常规决策链——caller 通常会同时设 requiresFastResponse=true，
            // 所以会优先命中 agent.fastModelId.
        }
        // 1. Agent 指定 fast/premium/default
        if (ctx.getAgentId() != null) {
            AgentProfile agent = agentProfileDao().getAgent(ctx.getAgentId());
            if (agent != null) {
                if (ctx.isRequiresFastResponse() && agent.getFastModelId() != null) {
                    ModelProfile p = modelProfileDao().getProfile(agent.getFastModelId());
                    if (passes(p, ctx)) return p;
                }
                if ((ctx.isRequiresLongContext() || isHeavyTask(ctx.getTaskType())) && agent.getPremiumModelId() != null) {
                    ModelProfile p = modelProfileDao().getProfile(agent.getPremiumModelId());
                    if (passes(p, ctx)) return p;
                }
                if (agent.getDefaultModelId() != null) {
                    ModelProfile p = modelProfileDao().getProfile(agent.getDefaultModelId());
                    if (passes(p, ctx)) return p;
                }
            }
        }
        // 2. 组织默认 agent
        AgentProfile orgAgent = agentProfileDao().getDefaultByOrg(ctx.getOrgId());
        if (orgAgent != null && orgAgent.getDefaultModelId() != null) {
            ModelProfile p = modelProfileDao().getProfile(orgAgent.getDefaultModelId());
            if (passes(p, ctx)) return p;
        }
        // 3. 兜底：取第一个 enabled 标准模型（仍要过 ctx 能力筛选）
        List<ModelProfile> standard = modelProfileDao().listByTier(ModelServiceTier.standard);
        if (standard != null) {
            for (ModelProfile p : standard) {
                if (passes(p, ctx)) return p;
            }
        }
        List<ModelProfile> any = modelProfileDao().listEnabled();
        if (any != null) {
            for (ModelProfile p : any) {
                if (passes(p, ctx)) return p;
            }
        }
        // 仍找不到 + 是 multimodal 需求 → 给一个明确错误，让上层告诉用户「没有支持图片的模型」.
        if (ctx.isRequiresMultimodal()) {
            throw new LobsterException("llm.route.no_multimodal",
                    "No enabled multimodal model profile found; configure ModelProfile.multimodal=true on a model first.");
        }
        return null;
    }

    /** 候选 profile 是否通过 ctx 的能力筛选（enabled + multimodal 等）. */
    private boolean passes(ModelProfile p, ModelSelectionContext ctx) {
        if (p == null) return false;
        if (!Boolean.TRUE.equals(p.getEnabled())) return false;
        if (ctx.isRequiresMultimodal() && !Boolean.TRUE.equals(p.getMultimodal())) return false;
        return true;
    }

    private List<ModelProfile> loadFallbacks(ModelProfile primary, ModelSelectionContext ctx) throws Exception {
        List<ModelProfile> fallbacks = new ArrayList<>();
        // 先尝试同 provider 的其他 profile（仍要过 ctx 能力筛选，例如 requiresMultimodal）
        List<ModelProfile> sameProvider = modelProfileDao().listByProvider(primary.getProvider());
        if (sameProvider != null) {
            for (ModelProfile p : sameProvider) {
                if (p.getModelId().equals(primary.getModelId())) continue;
                if (!passes(p, ctx)) continue;
                fallbacks.add(p);
            }
        }
        // 再追加 fallback tier
        List<ModelProfile> fb = modelProfileDao().listByTier(ModelServiceTier.fallback);
        if (fb != null) {
            for (ModelProfile p : fb) {
                if (p.getModelId().equals(primary.getModelId())) continue;
                if (contains(fallbacks, p.getModelId())) continue;
                if (!passes(p, ctx)) continue;
                fallbacks.add(p);
            }
        }
        return fallbacks;
    }

    private boolean contains(List<ModelProfile> list, String id) {
        for (ModelProfile p : list) if (p.getModelId().equals(id)) return true;
        return false;
    }

    private boolean isHeavyTask(String taskType) {
        if (taskType == null) return false;
        String t = taskType.toLowerCase();
        return t.contains("report") || t.contains("multi") || t.contains("long") || t.contains("integration");
    }

    private String explain(ModelProfile p, ModelSelectionContext ctx) {
        try {
            Tools.log("[ModelRouter] selected model " + p.getModelId() + " for agent=" + ctx.getAgentId() + " task=" + ctx.getTaskType());
        } catch (Throwable ignore) { /* ignore */ }
        return "primary=" + p.getModelId() + "; tier=" + p.getServiceTier();
    }
}
