package com.gzzm.lobster.llm.adapter;

import com.gzzm.lobster.common.JsonUtil;
import com.gzzm.lobster.common.LobsterException;
import com.gzzm.lobster.common.MessageRole;
import com.gzzm.lobster.common.TokenEstimator;
import com.gzzm.lobster.llm.*;

import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * OllamaAdapter —— Ollama /api/chat 协议适配 / Adapter for Ollama /api/chat.
 *
 * <p>兜底策略（设计文档 §7.2）：
 * <ul>
 * <li>检测到 tools 参数时，自动降级 {@code stream:false}（Ollama 流式工具调用不稳定）</li>
 * <li>无 tool_choice 时通过 system prompt 引导</li>
 * <li>NDJSON 解析统一转换为内部 delta 事件</li>
 * </ul>
 *
 * <p>V2 改造（2026-04）：原实现把整个 NDJSON 响应体全量读完再 {@code split("\n")}，
 * 导致流式/取消形同虚设 —— {@code handler.isCancelled()} 只能在所有帧都已到达
 * 之后才被检查，cancel 根本进不了 HTTP 循环。本版改为 {@link BufferedReader#readLine()}
 * 边读边解析，每拿到一行立刻 check cancel，cancel 命中立刻 {@code conn.disconnect()}。
 */
public class OllamaAdapter implements LobsterLlmAdapter {

    private final ModelProfile profile;

    public OllamaAdapter(ModelProfile profile) {
        this.profile = profile;
    }

    @Override
    public ModelProfile profile() {
        return profile;
    }

    @Override
    public LlmResponse chat(List<LobsterMessage> messages, List<ToolSpec> tools) {
        Map<String, Object> payload = buildPayload(messages, tools, false);
        String raw = doHttpBuffered(payload);
        return parseAggregated(raw);
    }

    @Override
    public void chatStream(List<LobsterMessage> messages, List<ToolSpec> tools,
                           StreamingResponseHandler handler) {
        // 有 tools 时 Ollama 流式工具调用不稳定 —— 降级为同步 + 分片
        if (tools != null && !tools.isEmpty()) {
            try {
                LlmResponse r = chat(messages, tools);
                String text = r.getAssistantText();
                int step = Math.max(1, text.length() / 20);
                for (int i = 0; i < text.length(); i += step) {
                    if (handler.isCancelled()) {
                        handler.onError(new LobsterException("llm.cancelled", "cancelled by upper layer"));
                        return;
                    }
                    handler.onDelta(text.substring(i, Math.min(text.length(), i + step)));
                }
                for (ToolCall tc : r.getToolCalls()) handler.onToolCall(tc);
                handler.onComplete(r);
            } catch (Throwable t) {
                handler.onError(t);
            }
            return;
        }

        // 无 tools：按 NDJSON 行式真流式
        streamNdjson(messages, tools, handler);
    }

    /**
     * 行式流式 / Line-by-line NDJSON streaming.
     *
     * <p>每读一行就 check cancel 并可能 {@code disconnect()}；与旧版的"读完再 split"相比，
     * cancel 可在 100ms 级别内真实中断，不再等 HTTP 自然结束。
     */
    private void streamNdjson(List<LobsterMessage> messages, List<ToolSpec> tools,
                              StreamingResponseHandler handler) {
        Map<String, Object> payload = buildPayload(messages, tools, true);
        HttpURLConnection conn = null;
        StringBuilder fullText = new StringBuilder();
        boolean cancelledByUpper = false;
        boolean completed = false;
        int inTokens = 0, outTokens = 0;
        try {
            conn = openConn();
            writeBody(conn, payload);

            int code = conn.getResponseCode();
            if (code < 200 || code >= 300) {
                String err = readAll(conn.getErrorStream());
                handler.onError(new LobsterException("llm.http", "HTTP " + code + " -> " + err));
                return;
            }

            try (BufferedReader br = new BufferedReader(new InputStreamReader(
                    conn.getInputStream(), StandardCharsets.UTF_8))) {
                String line;
                while ((line = br.readLine()) != null) {
                    // —— 每行取消检查 —— //
                    if (handler.isCancelled()) {
                        cancelledByUpper = true;
                        try { conn.disconnect(); } catch (Throwable ignore) { /* ignore */ }
                        break;
                    }
                    if (line.isEmpty()) continue;

                    Map<String, Object> frame;
                    try { frame = JsonUtil.fromJsonToMap(line); }
                    catch (Throwable t) { continue; /* tolerate malformed ndjson frame */ }

                    Object msg = frame.get("message");
                    if (msg instanceof Map) {
                        Object content = ((Map<?, ?>) msg).get("content");
                        if (content != null) {
                            String delta = String.valueOf(content);
                            if (!delta.isEmpty()) {
                                fullText.append(delta);
                                handler.onDelta(delta);
                            }
                        }
                    }

                    if (Boolean.TRUE.equals(frame.get("done"))) {
                        inTokens = asInt(frame.get("prompt_eval_count"));
                        outTokens = asInt(frame.get("eval_count"));
                        completed = true;
                        break;
                    }
                }
            }
        } catch (Throwable t) {
            if (!cancelledByUpper) handler.onError(t);
            return;
        } finally {
            if (conn != null) try { conn.disconnect(); } catch (Throwable ignore) { /* ignore */ }
        }

        if (cancelledByUpper) {
            handler.onError(new LobsterException("llm.cancelled", "cancelled by upper layer"));
            return;
        }
        String finalText = fullText.toString();
        if (inTokens <= 0) inTokens = TokenEstimator.estimate(finalText);
        if (outTokens <= 0) outTokens = TokenEstimator.estimate(finalText);
        LlmResponse r = new LlmResponse(finalText, Collections.<ToolCall>emptyList(),
                inTokens, outTokens, completed ? "stop" : "incomplete",
                profile.getModelId(), finalText);
        handler.onComplete(r);
    }

    // ---- private helpers ----

    private Map<String, Object> buildPayload(List<LobsterMessage> messages, List<ToolSpec> tools, boolean stream) {
        Map<String, Object> payload = new LinkedHashMap<>();
        payload.put("model", profile.getModelId());
        payload.put("messages", toOllamaMessages(messages));
        payload.put("stream", stream);
        if (Boolean.TRUE.equals(profile.getNativeToolCalling()) && tools != null && !tools.isEmpty()) {
            payload.put("tools", toOllamaTools(tools));
        }
        // 思考模式：Ollama /api/chat 0.5+ 支持顶层 think:boolean 字段（针对 deepseek-r1 / qwen-qwq 等）.
        // 三态映射：on→true，off→false，auto→不传.
        ModelThinkingMode mode = profile.resolveThinkingMode();
        if (mode == ModelThinkingMode.on) {
            payload.put("think", true);
        } else if (mode == ModelThinkingMode.off) {
            payload.put("think", false);
        }
        return payload;
    }

    private List<Map<String, Object>> toOllamaMessages(List<LobsterMessage> messages) {
        List<Map<String, Object>> list = new ArrayList<>(messages.size());
        for (LobsterMessage m : messages) {
            Map<String, Object> msg = new LinkedHashMap<>();
            // Ollama 的 role 与 OpenAI 基本一致
            msg.put("role", m.getRole().name());
            if (m.getRole() == MessageRole.tool) {
                msg.put("content", m.getContent() == null ? "" : m.getContent());
                msg.put("name", m.getToolName());
            } else {
                msg.put("content", m.getContent() == null ? "" : m.getContent());
                // 多模态：Ollama /api/chat 支持 user 消息附 images 字段（纯 base64 字符串数组，
                // 不带 data: URL 前缀）—— LLaVA / Qwen-VL 等本地模型靠这个收图.
                // 协议参考：https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
                if (m.getImageUrls() != null && !m.getImageUrls().isEmpty()) {
                    List<String> images = new ArrayList<>(m.getImageUrls().size());
                    for (String url : m.getImageUrls()) {
                        if (url == null || url.isEmpty()) continue;
                        // 兼容 data:image/...;base64,XXX 与裸 base64 两种入参，剥前缀只留载荷.
                        int comma = url.startsWith("data:") ? url.indexOf(',') : -1;
                        images.add(comma > 0 ? url.substring(comma + 1) : url);
                    }
                    if (!images.isEmpty()) msg.put("images", images);
                }
            }
            list.add(msg);
        }
        return list;
    }

    private List<Map<String, Object>> toOllamaTools(List<ToolSpec> tools) {
        List<Map<String, Object>> list = new ArrayList<>(tools.size());
        for (ToolSpec t : tools) {
            Map<String, Object> spec = new LinkedHashMap<>();
            spec.put("type", "function");
            Map<String, Object> fn = new LinkedHashMap<>();
            fn.put("name", t.getName());
            fn.put("description", t.getDescription());
            fn.put("parameters", t.getParametersSchema());
            spec.put("function", fn);
            list.add(spec);
        }
        return list;
    }

    private HttpURLConnection openConn() throws Exception {
        String url = profile.getEndpoint();
        if (!url.endsWith("/api/chat")) {
            url = url.endsWith("/") ? (url + "api/chat") : (url + "/api/chat");
        }
        HttpURLConnection conn = (HttpURLConnection) new URL(url).openConnection();
        conn.setConnectTimeout(profile.getFirstTokenTimeoutMs() == null ? 10000 : profile.getFirstTokenTimeoutMs());
        // readTimeout 适度收短 —— cancel 一旦置位，conn.disconnect() 在 Windows 下
        // 不保证立刻唤醒阻塞的 readLine()；短 timeout 能保证最坏情况下在这个 interval 内醒来。
        int readTo = profile.getTotalTimeoutMs() == null ? 120000 : profile.getTotalTimeoutMs();
        conn.setReadTimeout(Math.min(readTo, 60000));
        conn.setDoOutput(true);
        conn.setRequestMethod("POST");
        conn.setRequestProperty("Content-Type", "application/json; charset=utf-8");
        return conn;
    }

    private void writeBody(HttpURLConnection conn, Map<String, Object> payload) throws Exception {
        try (DataOutputStream os = new DataOutputStream(conn.getOutputStream())) {
            os.write(JsonUtil.toJson(payload).getBytes(StandardCharsets.UTF_8));
        }
    }

    /** 同步场景保留的全读 / Buffered read for non-streaming chat(). */
    private String doHttpBuffered(Map<String, Object> payload) {
        HttpURLConnection conn = null;
        try {
            conn = openConn();
            writeBody(conn, payload);
            int code = conn.getResponseCode();
            String body = readAll(code >= 200 && code < 300 ? conn.getInputStream() : conn.getErrorStream());
            if (code >= 200 && code < 300) return body;
            throw new LobsterException("llm.http", "HTTP " + code + " -> " + body);
        } catch (LobsterException e) {
            throw e;
        } catch (Exception e) {
            throw new LobsterException("llm.http", "Ollama call failed: " + e.getMessage(), e);
        } finally {
            if (conn != null) conn.disconnect();
        }
    }

    private static String readAll(java.io.InputStream in) {
        if (in == null) return "";
        StringBuilder sb = new StringBuilder();
        try (BufferedReader br = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8))) {
            String line;
            while ((line = br.readLine()) != null) sb.append(line).append('\n');
        } catch (Throwable ignore) { /* tolerate */ }
        return sb.toString();
    }

    @SuppressWarnings("unchecked")
    private LlmResponse parseAggregated(String raw) {
        // 非流式：Ollama 会返回单个 JSON 对象 {message: {...}, done: true, ...}
        // 若返回 NDJSON 多帧，也按末帧为准
        Map<String, Object> json;
        String lastLine = null;
        for (String line : raw.split("\n")) {
            if (!line.isEmpty()) lastLine = line;
        }
        if (lastLine == null) {
            return new LlmResponse("", Collections.<ToolCall>emptyList(), 0, 0, "empty", profile.getModelId(), raw);
        }
        json = JsonUtil.fromJsonToMap(lastLine);
        Object msgObj = json.get("message");
        String content = "";
        List<ToolCall> toolCalls = new ArrayList<>();
        if (msgObj instanceof Map) {
            Map<String, Object> msg = (Map<String, Object>) msgObj;
            content = msg.get("content") == null ? "" : String.valueOf(msg.get("content"));
            Object tc = msg.get("tool_calls");
            if (tc instanceof List) {
                int i = 0;
                for (Object o : (List<Object>) tc) {
                    Map<String, Object> call = (Map<String, Object>) o;
                    Map<String, Object> fn = (Map<String, Object>) call.get("function");
                    String name = String.valueOf(fn.get("name"));
                    Object args = fn.get("arguments");
                    String argStr = args == null ? "{}" : (args instanceof String ? (String) args : JsonUtil.toJson(args));
                    String id = call.get("id") == null ? ("ol_" + (i++)) : String.valueOf(call.get("id"));
                    toolCalls.add(new ToolCall(id, name, argStr));
                }
            }
        }
        int inputTokens = asInt(json.get("prompt_eval_count"));
        int outputTokens = asInt(json.get("eval_count"));
        if (inputTokens == 0) inputTokens = TokenEstimator.estimate(raw);
        if (outputTokens == 0) outputTokens = TokenEstimator.estimate(content);
        return new LlmResponse(content, toolCalls, inputTokens, outputTokens, "stop", profile.getModelId(), raw);
    }

    private int asInt(Object o) {
        if (o == null) return 0;
        if (o instanceof Number) return ((Number) o).intValue();
        try { return Integer.parseInt(String.valueOf(o)); } catch (Exception e) { return 0; }
    }
}
