package com.gzzm.lobster.llm;

import com.gzzm.lobster.audit.ModelCallLog;
import com.gzzm.lobster.audit.ModelCallLogDao;
import com.gzzm.lobster.common.IdGenerator;
import com.gzzm.lobster.common.JsonUtil;
import com.gzzm.lobster.common.LobsterException;
import com.gzzm.lobster.config.LobsterConfig;
import com.gzzm.lobster.storage.ContentStore;
import com.gzzm.lobster.storage.FileSystemContentStore;
import com.gzzm.platform.commons.Tools;
import net.cyan.commons.util.Provider;
import net.cyan.nest.annotation.Inject;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * LlmRuntime —— 模型调用运行时 / LLM runtime responsible for:
 *
 * <ul>
 *   <li>按 route 选择主模型 / choose primary from route</li>
 *   <li>失败/超时按 fallback 链降级 / cascade through fallback chain</li>
 *   <li>记录审计日志 / record audit log</li>
 *   <li>维护简单熔断状态 / simple circuit breaker</li>
 * </ul>
 */
public class LlmRuntime {

    @Inject
    private LobsterAdapterFactory adapterFactory;

    @Inject
    private ModelCallLogDao modelCallLogDao;


    /**
     * 按接口注入 {@link ContentStore} 在当前 nest 配置下没生效（即使 nest.xml 有
     * {@code <bean class=接口 imp=实现>} 绑定），所以直接注入具体类——nest 对具体类的
     * 零参构造自动实例化是可靠路径。如果以后要换 MinIO 之类实现：
     *   1) 写个 {@code MinioContentStore implements ContentStore}
     *   2) 改成 {@code @Inject MinioContentStore}（或统一用 {@link ContentStore} 变量持有）
     */
    @Inject
    private FileSystemContentStore contentStore;

    @Inject
    private Provider<ContentStore> contentStoreProvider;

    /**
     * 调用 LLM，内建失败降级与审计。
     * Call the LLM with built-in fallback and auditing.
     */
    public LlmResponse chat(LlmCallRequest request, ModelRouteResult route,
                            List<LobsterMessage> messages, List<ToolSpec> tools) {
        List<ModelProfile> chain = buildChain(route);
        Throwable last = null;
        for (ModelProfile profile : chain) {
            long start = System.currentTimeMillis();
            boolean downgraded = !profile.getModelId().equals(route.getPrimary().getModelId());
            String downgradeReason = downgraded ? "fallback from " + route.getPrimary().getModelId() : null;
            try {
                LobsterLlmAdapter adapter = adapterFactory.get(profile);
                LlmResponse response = adapter.chat(messages, tools);
                saveAudit(request, profile, downgraded, downgradeReason, response, start, null,
                        messages, tools, null, false);
                return response;
            } catch (Throwable t) {
                last = t;
                saveAudit(request, profile, downgraded, downgradeReason, null, start, t,
                        messages, tools, null, false);
                try {
                    Tools.log("[LlmRuntime] model " + profile.getModelId() + " failed, trying fallback: " + t.getMessage());
                } catch (Throwable ignore) { /* ignore */ }
            }
        }
        throw new LobsterException("llm.exhausted", "All model fallbacks exhausted: "
                + (last == null ? "unknown" : last.getMessage()), last);
    }

    /** 流式调用 / Streaming variant with fallback on initial-connect failure.
     *
     * <p>V2 取消语义（2026-04）：一旦检测到 {@link RunCancelledException#isHardCancel(Throwable) hard cancel}
     *（USER / TIMEOUT / BUDGET），立即向上抛，<b>不</b>再走 fallback 链 —— 模型降级只用于
     * 应对上游故障，不应当在用户意图明确时继续重试。
     */
    public void chatStream(LlmCallRequest request, ModelRouteResult route,
                           List<LobsterMessage> messages, List<ToolSpec> tools,
                           StreamingResponseHandler handler) {
        List<ModelProfile> chain = buildChain(route);
        Throwable last = null;
        for (ModelProfile profile : chain) {
            long start = System.currentTimeMillis();
            boolean downgraded = !profile.getModelId().equals(route.getPrimary().getModelId());
            String downgradeReason = downgraded ? "fallback from " + route.getPrimary().getModelId() : null;
            try {
                LobsterLlmAdapter adapter = adapterFactory.get(profile);
                AuditingHandler wrap = new AuditingHandler(handler, request, profile, downgraded, downgradeReason, start, this, messages, tools);
                adapter.chatStream(messages, tools, wrap);
                if (!wrap.failed) return;
                last = wrap.error;
                // hard cancel 不走 fallback
                if (RunCancelledException.isHardCancel(last)) {
                    handler.onError(last);
                    return;
                }
            } catch (Throwable t) {
                last = t;
                saveAudit(request, profile, downgraded, downgradeReason, null, start, t,
                        messages, tools, null, true);
                if (RunCancelledException.isHardCancel(t)) {
                    handler.onError(t);
                    return;
                }
            }
        }
        handler.onError(last == null ? new LobsterException("llm.exhausted", "stream exhausted") : last);
    }

    private List<ModelProfile> buildChain(ModelRouteResult route) {
        List<ModelProfile> chain = new ArrayList<>();
        chain.add(route.getPrimary());
        if (route.getFallbacks() != null) chain.addAll(route.getFallbacks());
        return chain;
    }

    void saveAudit(LlmCallRequest request, ModelProfile profile, boolean downgraded,
                   String downgradeReason, LlmResponse response, long start, Throwable error,
                   List<LobsterMessage> messages, List<ToolSpec> tools,
                   String capturedStreamingText, boolean streaming) {
        ModelCallLog log = new ModelCallLog();
        log.setCallId(IdGenerator.callId());
        log.setRunId(request.getRunId());
        log.setThreadId(request.getThreadId());
        log.setUserId(request.getUserId());
        log.setOrgId(request.getOrgId());
        log.setAgentId(request.getAgentId());
        log.setModelId(profile.getModelId());
        log.setProtocol(profile.getProtocol() == null ? null : profile.getProtocol().name());
        log.setDowngraded(downgraded);
        log.setDowngradeReason(downgradeReason);
        // 输入 token：优先用 adapter 从 usage 拿到的值；没有（或为 0）就用 TokenEstimator 按 messages 估算
        int inTokens = response == null ? 0 : response.getInputTokens();
        if (inTokens <= 0 && messages != null) inTokens = estimateInputTokens(messages);
        log.setInputTokens(inTokens);
        log.setOutputTokens(response == null ? 0 : response.getOutputTokens());
        log.setDurationMs((int) (System.currentTimeMillis() - start));
        log.setStatus(error == null ? "ok" : "error");
        if (error != null) log.setErrorMessage(safeMsg(error));
        log.setCreateTime(new Date());

        // 先写 trace 拿到 ref，再落库；失败就 ref 留空，元数据正常落
        if (LobsterConfig.isLlmTraceEnabled()) {
            String ref = writeTrace(log, profile, messages, tools, response, capturedStreamingText, error, streaming);
            if (ref != null) log.setTraceRef(ref);
        }

        try {
            currentDao().save(log);
        } catch (Throwable t) {
            try { Tools.log("[LlmRuntime] audit save failed", t); } catch (Throwable ignore) { /* ignore */ }
        }
    }

    /**
     * 把本次 LLM 完整 I/O（系统 prompt + transcript + 工具列表 + 响应 + 错误）序列化为 JSON
     * 并通过 {@link ContentStore} 落盘，返回可回查的 content ref。
     *
     * <p>{@code ContentStore} 的路径规范是
     * {@code {category}/{yyyy}/{MM}/{dd}/{userId}/{uuid}.{ext}}，这里 {@code category="llm-trace"}。
     * 返回的 ref 写进 {@link ModelCallLog#setTraceRef(String)}，后台可按 callId 查到 log 再读 ref 拿回内容。
     *
     * @return content ref，如果写入失败返回 {@code null}
     */
    private String writeTrace(ModelCallLog log, ModelProfile profile,
                              List<LobsterMessage> messages, List<ToolSpec> tools,
                              LlmResponse response, String capturedStreamingText,
                              Throwable error, boolean streaming) {
        try {
            Map<String, Object> out = new LinkedHashMap<>();
            out.put("callId", log.getCallId());
            out.put("runId", log.getRunId());
            out.put("threadId", log.getThreadId());
            out.put("userId", log.getUserId());
            out.put("orgId", log.getOrgId());
            out.put("agentId", log.getAgentId());
            out.put("modelId", log.getModelId());
            out.put("provider", profile.getProvider() == null ? null : profile.getProvider().name());
            out.put("protocol", profile.getProtocol() == null ? null : profile.getProtocol().name());
            out.put("endpoint", profile.getEndpoint());
            out.put("streaming", streaming);
            out.put("downgraded", log.getDowngraded());
            out.put("downgradeReason", log.getDowngradeReason());
            out.put("durationMs", log.getDurationMs());
            out.put("status", log.getStatus());
            out.put("inputTokens", log.getInputTokens());
            out.put("outputTokens", log.getOutputTokens());
            out.put("promptCache", promptCacheToMap(response, log.getInputTokens()));
            out.put("createTime", log.getCreateTime());

            // request
            Map<String, Object> req = new LinkedHashMap<>();
            req.put("messages", messagesToList(messages));
            req.put("tools", toolsToList(tools));
            out.put("request", req);

            // response
            Map<String, Object> resp = new LinkedHashMap<>();
            if (response != null) {
                resp.put("finishReason", response.getFinishReason());
                resp.put("assistantText", truncate(response.getAssistantText()));
                if (response.getReasoningContent() != null && !response.getReasoningContent().isEmpty()) {
                    // thinking-mode 模型的思考原文一并落 trace，方便 admin 排查"模型为什么这么干"
                    resp.put("reasoningContent", truncate(response.getReasoningContent()));
                }
                resp.put("toolCalls", toolCallsToList(response.getToolCalls()));
                resp.put("rawText", truncate(response.getRawText()));
                resp.put("promptCache", promptCacheToMap(response, log.getInputTokens()));
            }
            if (capturedStreamingText != null && !capturedStreamingText.isEmpty()) {
                // 流式过程中 onDelta 累积的文本（断流 / cancel / 错误时特别有用）
                resp.put("streamedText", truncate(capturedStreamingText));
            }
            if (error != null) {
                resp.put("error", errorToMap(error));
            }
            out.put("response", resp);

            String json = JsonUtil.toJson(out);
            if (contentStore == null) {
                Tools.log("[LlmRuntime] trace skipped: FileSystemContentStore not injected");
                return null;
            }
            String uid = log.getUserId() == null ? "system" : log.getUserId();
            // 关键：trace 是 best-effort 落盘，不能受调用方线程的 interrupt 状态影响.
            // 走 cancel 路径时（streamFuture.cancel(true) → worker 线程被 interrupt → adapter
            // onError → AuditingHandler.onError → saveAudit → writeTrace），同一线程上后续
            // 任何 NIO Channel 调用都会抛 ClosedByInterruptException 把 trace 写丢.
            // 临时清掉 interrupt 跑完 I/O，再原样恢复，不破坏 cancel 信号语义.
            boolean wasInterrupted = Thread.interrupted();
            String ref;
            try {
                ref = contentStore.write("llm-trace", uid, json, "json");
            } finally {
                if (wasInterrupted) Thread.currentThread().interrupt();
            }
            if (ref == null || ref.isEmpty()) {
                Tools.log("[LlmRuntime] trace write returned empty ref for callId=" + log.getCallId());
                return null;
            }
            return ref;
        } catch (Throwable t) {
            // 把 cause chain 显式串起来再打——只打 t.getMessage() 容易丢底层 IOException 的关键字
            // （AccessDenied / NoSpace / ClosedByInterrupt 等），排查时凭 LobsterException 一行外壳信息没法定位.
            try { Tools.log("[LlmRuntime] trace write failed: " + describeCause(t), t); } catch (Throwable ignore) { /* ignore */ }
            return null;
        }
    }

    /** 串起 throwable 的 cause chain 成一行，便于在单行日志里看到根因. */
    private static String describeCause(Throwable t) {
        StringBuilder sb = new StringBuilder();
        Throwable cur = t;
        int depth = 0;
        while (cur != null && depth < 5) {
            if (sb.length() > 0) sb.append(" <- ");
            sb.append(cur.getClass().getSimpleName()).append(": ").append(cur.getMessage());
            cur = cur.getCause();
            depth++;
        }
        return sb.toString();
    }

    private static List<Map<String, Object>> messagesToList(List<LobsterMessage> messages) {
        List<Map<String, Object>> list = new ArrayList<>();
        if (messages == null) return list;
        for (LobsterMessage m : messages) {
            Map<String, Object> row = new LinkedHashMap<>();
            row.put("role", m.getRole() == null ? null : m.getRole().name());
            row.put("content", truncate(m.getContent()));
            if (m.getToolCallId() != null) row.put("toolCallId", m.getToolCallId());
            if (m.getToolName() != null) row.put("toolName", m.getToolName());
            if (m.getToolCalls() != null && !m.getToolCalls().isEmpty()) {
                row.put("toolCalls", toolCallsToList(m.getToolCalls()));
            }
            if (m.getImageUrls() != null && !m.getImageUrls().isEmpty()) {
                row.put("imageUrls", m.getImageUrls());
            }
            list.add(row);
        }
        return list;
    }

    private static List<Map<String, Object>> toolsToList(List<ToolSpec> tools) {
        List<Map<String, Object>> list = new ArrayList<>();
        if (tools == null) return list;
        for (ToolSpec t : tools) {
            Map<String, Object> row = new LinkedHashMap<>();
            row.put("name", t.getName());
            row.put("description", t.getDescription());
            row.put("parametersSchema", t.getParametersSchema());
            list.add(row);
        }
        return list;
    }

    private static List<Map<String, Object>> toolCallsToList(List<ToolCall> calls) {
        List<Map<String, Object>> list = new ArrayList<>();
        if (calls == null) return list;
        for (ToolCall c : calls) {
            Map<String, Object> row = new LinkedHashMap<>();
            row.put("id", c.getId());
            row.put("name", c.getName());
            row.put("arguments", c.getArgumentsJson());
            list.add(row);
        }
        return list;
    }

    private static Map<String, Object> errorToMap(Throwable t) {
        Map<String, Object> m = new LinkedHashMap<>();
        m.put("class", t.getClass().getName());
        m.put("message", t.getMessage());
        StringWriter sw = new StringWriter();
        t.printStackTrace(new PrintWriter(sw));
        m.put("stackTrace", sw.toString());
        return m;
    }

    private static Map<String, Object> promptCacheToMap(LlmResponse response, int inputTokens) {
        Map<String, Object> m = new LinkedHashMap<>();
        int hit = response == null ? 0 : response.getPromptCacheHitTokens();
        int miss = response == null ? 0 : response.getPromptCacheMissTokens();
        int total = hit + miss;
        boolean rateKnown = total > 0 && (miss > 0 || hit == 0);
        if (hit > 0 && miss <= 0) {
            if (inputTokens > hit) {
                miss = inputTokens - hit;
                total = inputTokens;
                rateKnown = true;
            } else if (inputTokens == hit) {
                total = inputTokens;
                rateKnown = true;
            } else {
                rateKnown = false;
            }
        }
        m.put("hitTokens", hit);
        m.put("missTokens", miss);
        m.put("totalTokens", total);
        if (rateKnown && total > 0) {
            m.put("hitRate", ((double) hit) / total);
        } else {
            m.put("hitRate", null);
        }
        return m;
    }

    private static String truncate(String s) {
        int cap = LobsterConfig.getLlmTraceMaxMessageChars();
        if (s == null || cap <= 0 || s.length() <= cap) return s;
        return s.substring(0, cap) + "...[truncated " + (s.length() - cap) + " chars]";
    }

    /**
     * 拿当前线程绑定的 DAO / Resolve a DAO bound to the current thread.
     *
     * <p>thunwind 的 DAO 代理是"线程创建者绑定"的：{@link #modelCallLogDao @Inject 字段}
     * 只能在完成注入的那个线程（一般是 HTTP 请求线程）上用。流式 LLM 调用的 audit
     * 回调跑在 {@code AgentRuntime} 的 executor 工作线程里，于是抛
     * {@code dao is created in a thread and used in another thread}。
     *
     * <p>解决方案：每次 save 都从容器现拿一份——{@code SystemConfig.get()} 返回的是
     * "当前线程当前范围内"的 DAO 实例；如果容器不可用（单元测试等）就退回到
     * {@code @Inject} 字段（测试里由反射直接注入 mock）。
     */
    private ModelCallLogDao currentDao() {
        try {
            ModelCallLogDao dao = Tools.getBean(ModelCallLogDao.class);
            if (dao != null) return dao;
        } catch (Throwable ignore) {
            /* 容器不可用，回退 */
        }
        return modelCallLogDao;
    }

    private String safeMsg(Throwable t) {
        String m = t.getMessage();
        if (m == null) return t.getClass().getSimpleName();
        return m.length() > 400 ? m.substring(0, 400) : m;
    }

    /**
     * 估算输入 token 数 / Estimate input tokens.
     *
     * <p>当 provider 不回 {@code usage} 时的兜底。把所有 messages 的内容拼起来交给
     * {@link com.gzzm.lobster.common.TokenEstimator}，另外按"每条消息 4 token"
     * 的 OpenAI 经验补偿 role / 分隔符开销。不精确但足以让监控和后台有数字可看。
     */
    private static int estimateInputTokens(List<LobsterMessage> messages) {
        if (messages == null || messages.isEmpty()) return 0;
        int total = 0;
        for (LobsterMessage m : messages) {
            if (m.getContent() != null) {
                total += com.gzzm.lobster.common.TokenEstimator.estimate(m.getContent());
            }
            if (m.getToolCalls() != null) {
                for (ToolCall c : m.getToolCalls()) {
                    total += com.gzzm.lobster.common.TokenEstimator.estimate(c.getArgumentsJson());
                    total += com.gzzm.lobster.common.TokenEstimator.estimate(c.getName());
                }
            }
            total += 4; // role / 分隔符开销
        }
        return total;
    }

    /** 包裹流式 handler 以实现降级和审计 / Stream handler wrapper for audit + fallback. */
    static final class AuditingHandler implements StreamingResponseHandler {
        private final StreamingResponseHandler delegate;
        private final LlmCallRequest request;
        private final ModelProfile profile;
        private final boolean downgraded;
        private final String downgradeReason;
        private final long start;
        private final LlmRuntime outer;
        private final List<LobsterMessage> messages;
        private final List<ToolSpec> tools;
        private final StringBuilder captured = new StringBuilder();
        boolean failed;
        Throwable error;
        boolean sawDelta;

        AuditingHandler(StreamingResponseHandler delegate, LlmCallRequest request, ModelProfile profile,
                        boolean downgraded, String downgradeReason, long start, LlmRuntime outer,
                        List<LobsterMessage> messages, List<ToolSpec> tools) {
            this.delegate = delegate;
            this.request = request;
            this.profile = profile;
            this.downgraded = downgraded;
            this.downgradeReason = downgradeReason;
            this.start = start;
            this.outer = outer;
            this.messages = messages;
            this.tools = tools;
        }

        @Override public void onDelta(String delta) {
            sawDelta = true;
            if (delta != null) captured.append(delta);
            delegate.onDelta(delta);
        }
        @Override public void onReasoningDelta(String delta) {
            // 思考分片不计入 sawDelta —— 只有真正的 content 出来才算"已开始推送、不能再降级".
            // 思考阶段失败仍可走 fallback 链（DeepSeek thinking 抽风、走 Qwen 回退）.
            delegate.onReasoningDelta(delta);
        }
        @Override public void onToolCall(ToolCall toolCall) { delegate.onToolCall(toolCall); }
        @Override public void onWriteFileContentDelta(String toolCallId, int toolIndex, String contentDelta) {
            delegate.onWriteFileContentDelta(toolCallId, toolIndex, contentDelta);
        }
        @Override public void onComplete(LlmResponse response) {
            outer.saveAudit(request, profile, downgraded, downgradeReason, response, start, null,
                    messages, tools, captured.toString(), true);
            delegate.onComplete(response);
        }
        @Override public void onError(Throwable t) {
            outer.saveAudit(request, profile, downgraded, downgradeReason, null, start, t,
                    messages, tools, captured.toString(), true);
            if (sawDelta) {
                // 已经开始推送过分片 —— 不再尝试降级，直接上抛
                // Already streamed — do not fall back mid-stream.
                delegate.onError(t);
            } else {
                failed = true;
                error = t;
            }
        }
        /**
         * 必须转发 / MUST forward cancel signal to the adapter.
         *
         * <p>AgentRuntime 在 delegate 上 override 了 isCancelled()；如果 wrapper 不转发，
         * adapter 每次读到的永远是接口默认的 false，就永远不会主动 disconnect，
         * cancel 也就等于没实现。
         */
        @Override public boolean isCancelled() {
            return delegate.isCancelled();
        }
    }
}
