package com.gzzm.lobster.llm.stream;

import com.gzzm.lobster.llm.LlmResponse;
import com.gzzm.lobster.llm.ToolCall;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * LlmStreamAggregator —— 把 {@link LlmStreamEvent} 序列聚合成 {@link LlmResponse} /
 * Aggregate stream events into a final response.
 *
 * <p>统一处理：
 * <ul>
 *   <li>文本拼接</li>
 *   <li>按 index 合并 tool_call 的 name / arguments delta</li>
 *   <li>usage 累加（流式末帧或单独帧都能拿到）</li>
 *   <li>finish reason</li>
 *   <li>error 记录 —— 供 adapter 决定是否 retryable</li>
 * </ul>
 *
 * <p>所有 adapter 不再直接 {@code new LlmResponse(...)}，统一把事件灌给这个类再 {@link #build()}。
 * 这样上层永远拿到语义一致的 LlmResponse，不管底下是 OpenAI / Ollama / MCP。
 *
 * <p>非线程安全 —— 每次 chatStream 都 new 一个。
 */
public final class LlmStreamAggregator {

    private final String modelId;

    private final StringBuilder text = new StringBuilder();
    /** key = index（没有 index 的 adapter 用 toolCallId 当 key，但要保证唯一）。 */
    private final Map<String, ToolCallBuilder> toolCalls = new LinkedHashMap<>();

    private String finishReason;
    private int promptTokens = 0;
    private int completionTokens = 0;
    private Throwable error;
    private boolean errorRetryable;

    public LlmStreamAggregator(String modelId) {
        this.modelId = modelId;
    }

    /**
     * 喂入一个事件 / Feed one event.
     *
     * <p>幂等性：同一个事件重复 accept 会重复累加（调用方自己保证不重复）。
     */
    public void accept(LlmStreamEvent event) {
        if (event == null) return;
        if (event instanceof LlmStreamEvent.TextDelta) {
            LlmStreamEvent.TextDelta e = (LlmStreamEvent.TextDelta) event;
            if (e.delta != null) text.append(e.delta);
        } else if (event instanceof LlmStreamEvent.ToolCallDelta) {
            LlmStreamEvent.ToolCallDelta e = (LlmStreamEvent.ToolCallDelta) event;
            String key = e.toolCallId != null ? e.toolCallId : ("#" + e.index);
            ToolCallBuilder b = toolCalls.get(key);
            if (b == null) {
                b = new ToolCallBuilder();
                b.id = e.toolCallId;
                toolCalls.put(key, b);
            }
            if (e.nameDelta != null && !e.nameDelta.isEmpty()) {
                b.name = (b.name == null ? "" : b.name) + e.nameDelta;
            }
            if (e.argsDelta != null && !e.argsDelta.isEmpty()) {
                b.args.append(e.argsDelta);
            }
        } else if (event instanceof LlmStreamEvent.ToolCallCompleted) {
            LlmStreamEvent.ToolCallCompleted e = (LlmStreamEvent.ToolCallCompleted) event;
            String key = e.toolCallId;
            ToolCallBuilder b = toolCalls.get(key);
            if (b == null) {
                b = new ToolCallBuilder();
                toolCalls.put(key, b);
            }
            b.id = e.toolCallId;
            b.name = e.name;
            b.args.setLength(0);
            b.args.append(e.argsJson == null ? "{}" : e.argsJson);
        } else if (event instanceof LlmStreamEvent.Usage) {
            LlmStreamEvent.Usage e = (LlmStreamEvent.Usage) event;
            if (e.promptTokens > 0) promptTokens = e.promptTokens;
            if (e.completionTokens > 0) completionTokens = e.completionTokens;
        } else if (event instanceof LlmStreamEvent.Finish) {
            LlmStreamEvent.Finish e = (LlmStreamEvent.Finish) event;
            if (e.finishReason != null) finishReason = e.finishReason;
        } else if (event instanceof LlmStreamEvent.Error) {
            LlmStreamEvent.Error e = (LlmStreamEvent.Error) event;
            error = e.cause;
            errorRetryable = e.retryable;
        }
    }

    /** 生成最终响应 / Build final response. */
    public LlmResponse build() {
        List<ToolCall> calls = new ArrayList<>(toolCalls.size());
        int idx = 0;
        for (Map.Entry<String, ToolCallBuilder> entry : toolCalls.entrySet()) {
            ToolCallBuilder b = entry.getValue();
            String id = b.id != null ? b.id : ("call_" + (idx++));
            String name = b.name == null ? "" : b.name;
            String args = b.args.length() == 0 ? "{}" : b.args.toString();
            calls.add(new ToolCall(id, name, args));
        }
        return new LlmResponse(
                text.toString(),
                calls,
                promptTokens,
                completionTokens,
                finishReason == null ? "stop" : finishReason,
                modelId,
                text.toString()
        );
    }

    public String bufferedText() { return text.toString(); }
    public Throwable getError() { return error; }
    public boolean isErrorRetryable() { return errorRetryable; }

    private static final class ToolCallBuilder {
        String id;
        String name;
        final StringBuilder args = new StringBuilder();
    }
}
