package com.gzzm.lobster.api.admin;

import com.gzzm.lobster.audit.ModelCallLog;
import com.gzzm.lobster.audit.ModelCallLogDao;
import com.gzzm.lobster.common.JsonUtil;
import com.gzzm.lobster.common.LobsterException;
import com.gzzm.lobster.config.LobsterConfig;
import com.gzzm.lobster.identity.AdminGuard;
import com.gzzm.lobster.storage.ContentStore;
import com.gzzm.lobster.thread.ThreadDao;
import com.gzzm.lobster.thread.ThreadRoom;
import com.gzzm.lobster.tool.mcp.McpCallLog;
import com.gzzm.lobster.tool.mcp.McpCallLogDao;
import net.cyan.arachne.HttpMethod;
import net.cyan.arachne.annotation.Service;
import net.cyan.nest.annotation.Inject;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * AdminCallLogApi —— 模型调用审计 + 完整 trace 后台查询 /
 * Admin queries for model call logs + full LLM I/O trace.
 *
 * <p>和 {@link com.gzzm.lobster.llm.LlmRuntime} 对应：每次 LLM 调用产生一行
 * {@code AI_MODEL_CALL_LOG}，如果 {@code LobsterConfig.llmTraceEnabled}，
 * 则 {@code traceRef} 指向 ContentStore 里的完整 JSON（messages + tools + response）。
 */
@Service
public class AdminCallLogApi {

    @Inject private ModelCallLogDao callLogDao;
    @Inject private McpCallLogDao mcpCallLogDao;
    @Inject private ThreadDao threadDao;
    @Inject private ContentStore contentStore;

    @Service(url = "/ai/api/admin/calls/list", method = HttpMethod.all)
    public Map<String, Object> list(String runId, String threadId,
                                    Integer offset, Integer limit) throws Exception {
        AdminGuard.requireAdmin();
        int o = offset == null ? 0 : Math.max(0, offset);
        int l = limit == null
                ? LobsterConfig.getListDefaultPageSize()
                : Math.min(LobsterConfig.getListMaxPageSize(), Math.max(1, limit));

        List<ModelCallLog> rows;
        long total;
        if (runId != null && !runId.isEmpty()) {
            rows = callLogDao.listByRun(runId, o, l);
            total = callLogDao.countByRun(runId);
        } else if (threadId != null && !threadId.isEmpty()) {
            rows = callLogDao.listByThread(threadId, o, l);
            total = callLogDao.countByThread(threadId);
        } else {
            rows = callLogDao.listAll(o, l);
            total = callLogDao.countAll();
        }

        List<Map<String, Object>> items = new ArrayList<>();
        if (rows != null) for (ModelCallLog c : rows) items.add(toSummary(c));
        Map<String, Object> out = new LinkedHashMap<>();
        out.put("items", items);
        out.put("total", total);
        out.put("offset", o);
        out.put("limit", l);
        return out;
    }

    @Service(url = "/ai/api/admin/calls/threads", method = HttpMethod.all)
    public Map<String, Object> threads(String threadId, Integer offset, Integer limit) throws Exception {
        AdminGuard.requireAdmin();
        int o = offset == null ? 0 : Math.max(0, offset);
        int l = limit == null
                ? LobsterConfig.getListDefaultPageSize()
                : Math.min(LobsterConfig.getListMaxPageSize(), Math.max(1, limit));

        List<Map<String, Object>> items = new ArrayList<>();
        long total;
        if (threadId != null && !threadId.isEmpty()) {
            ThreadRoom thread = threadDao.getThread(threadId);
            total = thread == null ? 0 : 1;
            if (thread != null && o == 0 && l > 0) items.add(toThreadSummary(thread));
        } else {
            List<ThreadRoom> rows = threadDao.listAll(o, l);
            total = safeLong(threadDao.countAll());
            if (rows != null) for (ThreadRoom t : rows) items.add(toThreadSummary(t));
        }

        Map<String, Object> out = new LinkedHashMap<>();
        out.put("items", items);
        out.put("total", total);
        out.put("offset", o);
        out.put("limit", l);
        return out;
    }

    @Service(url = "/ai/api/admin/calls/{$0}", method = HttpMethod.all)
    public Map<String, Object> get(String callId) throws Exception {
        AdminGuard.requireAdmin();
        ModelCallLog c = callLogDao.getCall(callId);
        if (c == null) throw new LobsterException("admin.call.not_found", "Call not found: " + callId);
        return toFull(c);
    }

    /**
     * 读取完整 trace JSON（messages / tools / response / error 堆栈）。
     * 返回一个包装 map，内含原始 trace 对象（通过解析 JSON 还原），便于前端直接渲染。
     */
    @Service(url = "/ai/api/admin/calls/{$0}/trace", method = HttpMethod.all)
    public Map<String, Object> trace(String callId) throws Exception {
        AdminGuard.requireAdmin();
        ModelCallLog c = callLogDao.getCall(callId);
        if (c == null) throw new LobsterException("admin.call.not_found", "Call not found: " + callId);
        if (c.getTraceRef() == null || c.getTraceRef().isEmpty()) {
            // 可能原因排查：
            //  1) 这条记录是 trace 特性上线前产生的（检查 createTime）
            //  2) DB 表里没有 TRACEREF 列（thunwind 对已存在的表不一定自动 ALTER；
            //     执行 `DESC AI_MODEL_CALL_LOG` 看看有没有该列，没有就 ALTER 或 DROP 重建）
            //  3) ContentStore 未注入（翻 catalina.out 搜 "[LlmRuntime] trace skipped"）
            //  4) 写 trace 抛异常被吞（搜 "[LlmRuntime] trace write failed"）
            throw new LobsterException("admin.call.no_trace",
                    "callId " + callId + " 未记录 trace。可能原因："
                            + " ①该调用发生于 trace 特性上线前；"
                            + " ②AI_MODEL_CALL_LOG 表缺 TRACEREF 列（需 ALTER TABLE 或 DROP 重建）；"
                            + " ③ContentStore bean 未注入（检查 nest.xml + 重启 Tomcat）；"
                            + " ④写 trace 异常（查 catalina.out 搜 LlmRuntime）。"
                            + " createTime=" + c.getCreateTime());
        }
        String json = contentStore.read(c.getTraceRef());
        if (json == null) {
            throw new LobsterException("admin.call.trace_missing",
                    "Trace ref exists but content not readable: " + c.getTraceRef());
        }

        Map<String, Object> out = new LinkedHashMap<>();
        out.put("callId", c.getCallId());
        out.put("traceRef", c.getTraceRef());
        // 解析成对象返回；解析失败则回退到字符串
        try {
            Object parsed = JsonUtil.fromJson(json, Object.class);
            out.put("trace", parsed);
        } catch (Throwable t) {
            out.put("trace", json);
            out.put("parseFailed", true);
        }
        out.put("mcpCalls", mcpCallsForRun(c.getRunId()));
        return out;
    }

    private static Map<String, Object> toSummary(ModelCallLog c) {
        Map<String, Object> m = new LinkedHashMap<>();
        m.put("callId", c.getCallId());
        m.put("runId", c.getRunId());
        m.put("threadId", c.getThreadId());
        m.put("userId", c.getUserId());
        m.put("modelId", c.getModelId());
        m.put("protocol", c.getProtocol());
        m.put("status", c.getStatus());
        m.put("downgraded", c.getDowngraded());
        m.put("inputTokens", c.getInputTokens());
        m.put("outputTokens", c.getOutputTokens());
        m.put("durationMs", c.getDurationMs());
        m.put("hasTrace", c.getTraceRef() != null && !c.getTraceRef().isEmpty());
        m.put("createTime", c.getCreateTime());
        return m;
    }

    private static Map<String, Object> toFull(ModelCallLog c) {
        Map<String, Object> m = toSummary(c);
        m.put("orgId", c.getOrgId());
        m.put("agentId", c.getAgentId());
        m.put("downgradeReason", c.getDowngradeReason());
        m.put("errorMessage", c.getErrorMessage());
        m.put("traceRef", c.getTraceRef());
        return m;
    }

    private static Map<String, Object> toThreadSummary(ThreadRoom t) {
        Map<String, Object> m = new LinkedHashMap<>();
        m.put("threadId", t.getThreadId());
        m.put("title", t.getTitle());
        m.put("userId", t.getUserId());
        m.put("orgId", t.getOrgId());
        m.put("type", t.getType());
        m.put("status", t.getStatus());
        m.put("workspaceId", t.getWorkspaceId());
        m.put("lastActivityAt", t.getLastActivityAt());
        m.put("createTime", t.getCreateTime());
        m.put("updateTime", t.getUpdateTime());
        return m;
    }

    private static long safeLong(Long v) {
        return v == null ? 0L : v;
    }

    private List<Map<String, Object>> mcpCallsForRun(String runId) {
        List<Map<String, Object>> out = new ArrayList<>();
        if (runId == null || runId.isEmpty() || mcpCallLogDao == null) return out;
        try {
            List<McpCallLog> rows = mcpCallLogDao.listByRun(runId);
            if (rows != null) {
                for (McpCallLog row : rows) out.add(toMcpCallTraceSummary(row));
            }
        } catch (Throwable ignore) {
            Map<String, Object> err = new LinkedHashMap<>();
            err.put("error", "mcp_call_log_unavailable");
            out.add(err);
        }
        return out;
    }

    private static Map<String, Object> toMcpCallTraceSummary(McpCallLog row) {
        Map<String, Object> m = new LinkedHashMap<>();
        m.put("callId", row.getCallId());
        m.put("serverId", row.getServerId());
        m.put("localToolName", row.getLocalToolName());
        m.put("remoteToolName", row.getRemoteToolName());
        m.put("requestId", row.getRequestId());
        m.put("threadId", row.getThreadId());
        m.put("runId", row.getRunId());
        m.put("toolCallId", row.getToolCallId());
        m.put("status", row.getStatus());
        m.put("durationMs", row.getDurationMs());
        m.put("requestSummary", parseMaybeJson(row.getRequestSummary()));
        m.put("responseSummary", parseMaybeJson(row.getResponseSummary()));
        m.put("requestJsonRef", row.getRequestJsonRef());
        m.put("responseJsonRef", row.getResponseJsonRef());
        m.put("errorCode", row.getErrorCode());
        m.put("errorMessage", row.getErrorMessage());
        m.put("createTime", row.getCreateTime());
        return m;
    }

    private static Object parseMaybeJson(String raw) {
        if (raw == null || raw.isEmpty()) return null;
        try {
            return JsonUtil.fromJson(raw, Object.class);
        } catch (Throwable ignore) {
            return raw;
        }
    }
}
