package com.gzzm.lobster.context;

import com.gzzm.lobster.common.MessageRole;
import com.gzzm.lobster.common.TokenEstimator;
import com.gzzm.lobster.llm.LobsterMessage;
import com.gzzm.lobster.llm.ModelRouteResult;
import com.gzzm.lobster.llm.ToolCall;
import com.gzzm.lobster.thread.ThreadRoom;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

/**
 * Send-view context compaction policy.
 *
 * Tool results are preserved exactly while the transcript fits the model budget.
 * Width limiting and ref markers are only introduced when compaction is needed.
 */
public final class ContextCompactionPolicy {

    /**
     * Tool results that should not be width-limited even during compaction.
     */
    private static final Set<String> WIDTH_EXEMPT_TOOLS = Collections.unmodifiableSet(
            new LinkedHashSet<String>(Arrays.asList("use_skill", "oa_search_knowledge")));

    private static final int EXEMPT_KEEP_RECENT_DEFAULT = 2;

    private final ToolResultWidthPolicy widthPolicy;
    private final int keepRecentTurns;
    private final int exemptKeepRecent;
    private final SummarizerService summarizer;
    private final ThreadRoom thread;
    private final ModelRouteResult route;

    public ContextCompactionPolicy(ToolResultWidthPolicy widthPolicy, int keepRecentTurns) {
        this(widthPolicy, keepRecentTurns, EXEMPT_KEEP_RECENT_DEFAULT, null, null, null);
    }

    public ContextCompactionPolicy(ToolResultWidthPolicy widthPolicy, int keepRecentTurns,
                                   int exemptKeepRecent,
                                   SummarizerService summarizer, ThreadRoom thread, ModelRouteResult route) {
        this.widthPolicy = widthPolicy;
        this.keepRecentTurns = keepRecentTurns;
        this.exemptKeepRecent = Math.max(0, exemptKeepRecent);
        this.summarizer = summarizer;
        this.thread = thread;
        this.route = route;
    }

    public List<LobsterMessage> projectTranscript(List<LobsterMessage> transcript, int budgetTokens) {
        if (transcript == null || transcript.isEmpty()) return new ArrayList<>();

        if (estimateTotal(transcript) <= budgetTokens) return transcript;

        List<LobsterMessage> widthLimited = limitToolResultWidth(transcript);
        if (estimateTotal(widthLimited) <= budgetTokens) return widthLimited;

        return collapseOldMessages(widthLimited);
    }

    private int estimateTotal(List<LobsterMessage> messages) {
        int total = 0;
        for (LobsterMessage m : messages) total += TokenEstimator.estimate(m.getContent());
        return total;
    }

    private List<LobsterMessage> limitToolResultWidth(List<LobsterMessage> transcript) {
        List<LobsterMessage> widthLimited = new ArrayList<>(transcript.size());
        for (LobsterMessage m : transcript) {
            if (m.getRole() == MessageRole.tool && m.getContent() != null) {
                if (WIDTH_EXEMPT_TOOLS.contains(m.getToolName())) {
                    widthLimited.add(m);
                } else {
                    String limited = widthPolicy.limit(m.getContent(), m.getFullContentRef());
                    if (m.hasFullContentRef()) {
                        widthLimited.add(LobsterMessage.toolWithRef(m.getToolCallId(), m.getToolName(),
                                limited, m.getFullContentRef()));
                    } else {
                        widthLimited.add(LobsterMessage.tool(m.getToolCallId(), m.getToolName(), limited));
                    }
                }
            } else {
                widthLimited.add(m);
            }
        }
        return widthLimited;
    }

    private List<LobsterMessage> collapseOldMessages(List<LobsterMessage> msgs) {
        int cutIndex = findCutIndex(msgs);
        if (cutIndex <= 0) return msgs;

        LinkedHashSet<String> recentExemptCallIds = new LinkedHashSet<>();
        if (exemptKeepRecent > 0) {
            for (int i = cutIndex - 1; i >= 0 && recentExemptCallIds.size() < exemptKeepRecent; i--) {
                LobsterMessage m = msgs.get(i);
                if (m.getRole() == MessageRole.tool
                        && WIDTH_EXEMPT_TOOLS.contains(m.getToolName())
                        && m.getToolCallId() != null) {
                    recentExemptCallIds.add(m.getToolCallId());
                }
            }
        }

        List<LobsterMessage> exemptKept = new ArrayList<>();
        List<LobsterMessage> toBeSummarized = new ArrayList<>();
        for (int i = 0; i < cutIndex; i++) {
            LobsterMessage m = msgs.get(i);
            boolean exempt = false;
            if (!recentExemptCallIds.isEmpty()) {
                if (m.getRole() == MessageRole.tool
                        && recentExemptCallIds.contains(m.getToolCallId())) {
                    exempt = true;
                } else if (m.getRole() == MessageRole.assistant && m.hasToolCalls()) {
                    for (ToolCall tc : m.getToolCalls()) {
                        if (tc.getId() != null && recentExemptCallIds.contains(tc.getId())) {
                            exempt = true;
                            break;
                        }
                    }
                }
            }
            if (exempt) {
                exemptKept.add(m);
            } else {
                toBeSummarized.add(m);
            }
        }

        String summaryText = null;
        if (summarizer != null) {
            try {
                summaryText = summarizer.summarize(toBeSummarized, thread, route);
            } catch (Throwable ignore) { /* fallback below */ }
        }
        if (summaryText == null || summaryText.isEmpty()) {
            summaryText = new BulletSummarizer().summarize(toBeSummarized, thread, route);
        }

        List<LobsterMessage> out = new ArrayList<>();
        out.add(LobsterMessage.system(summaryText));
        out.addAll(exemptKept);
        out.addAll(msgs.subList(cutIndex, msgs.size()));
        return out;
    }

    private int findCutIndex(List<LobsterMessage> msgs) {
        int turns = 0;
        int idx = msgs.size() - 1;
        while (idx >= 0) {
            if (msgs.get(idx).getRole() == MessageRole.user) {
                turns++;
                if (turns >= keepRecentTurns) return idx;
            }
            idx--;
        }
        return 0;
    }
}
