package com.gzzm.lobster.context;

import com.gzzm.lobster.common.MessageRole;
import com.gzzm.lobster.llm.LobsterMessage;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.*;

public class ContextCompactionPolicyTest {

    @Test
    public void toolResultsAreNotTruncatedWhenUnderBudget() {
        ToolResultWidthPolicy width = new ToolResultWidthPolicy(100);
        ContextCompactionPolicy cp = new ContextCompactionPolicy(width, 3);
        String huge = repeat("A", 5000);
        List<LobsterMessage> transcript = new ArrayList<>();
        transcript.add(LobsterMessage.user("question"));
        transcript.add(LobsterMessage.tool("id1", "read_file", huge));

        List<LobsterMessage> projected = cp.projectTranscript(transcript, 1_000_000);

        LobsterMessage tool = findTool(projected, "read_file");
        assertNotNull(tool);
        assertEquals(huge, tool.getContent());
    }

    @Test
    public void largeToolResultsAreTruncatedOnlyWhenOverBudget() {
        ToolResultWidthPolicy width = new ToolResultWidthPolicy(100);
        ContextCompactionPolicy cp = new ContextCompactionPolicy(width, 3);
        String huge = repeat("A", 5000);
        List<LobsterMessage> transcript = new ArrayList<>();
        transcript.add(LobsterMessage.user("question"));
        transcript.add(LobsterMessage.toolWithRef("id1", "read_file", huge, "message/ref.txt"));

        List<LobsterMessage> projected = cp.projectTranscript(transcript, 100);

        LobsterMessage tool = findTool(projected, "read_file");
        assertNotNull(tool);
        assertTrue(tool.getContent().length() < huge.length());
        assertTrue(tool.getContent().contains("read_externalized_content"));
        assertEquals("message/ref.txt", tool.getFullContentRef());
    }

    @Test
    public void useSkillToolResultIsNotWidthTruncated() {
        ToolResultWidthPolicy width = new ToolResultWidthPolicy(100);
        ContextCompactionPolicy cp = new ContextCompactionPolicy(width, 3);
        String guidance = repeat("S", 5000);
        List<LobsterMessage> transcript = new ArrayList<>();
        transcript.add(LobsterMessage.user("need guidance"));
        transcript.add(LobsterMessage.tool("skill-call", "use_skill", guidance));

        List<LobsterMessage> projected = cp.projectTranscript(transcript, 100);

        LobsterMessage skillResult = findTool(projected, "use_skill");
        assertNotNull(skillResult);
        assertEquals(guidance, skillResult.getContent());
    }

    @Test
    public void oaSearchKnowledgeToolResultIsNotWidthTruncated() {
        ToolResultWidthPolicy width = new ToolResultWidthPolicy(100);
        ContextCompactionPolicy cp = new ContextCompactionPolicy(width, 3);
        String hits = repeat("K", 5000);
        List<LobsterMessage> transcript = new ArrayList<>();
        transcript.add(LobsterMessage.user("search kb"));
        transcript.add(LobsterMessage.tool("kb-call", "oa_search_knowledge", hits));

        List<LobsterMessage> projected = cp.projectTranscript(transcript, 100);

        LobsterMessage searchResult = findTool(projected, "oa_search_knowledge");
        assertNotNull(searchResult);
        assertEquals(hits, searchResult.getContent());
    }

    @Test
    public void keepsRecentTurnsOnOverBudget() {
        ToolResultWidthPolicy width = new ToolResultWidthPolicy(500);
        ContextCompactionPolicy cp = new ContextCompactionPolicy(width, 1);
        List<LobsterMessage> transcript = new ArrayList<>();
        for (int i = 0; i < 10; i++) {
            transcript.add(LobsterMessage.user("user" + i + repeat("u", 200)));
            transcript.add(LobsterMessage.assistant("reply" + i + repeat("r", 200)));
        }

        List<LobsterMessage> projected = cp.projectTranscript(transcript, 50);

        assertTrue("should compact: " + projected.size() + " vs " + transcript.size(),
                projected.size() < transcript.size());
    }

    private LobsterMessage findTool(List<LobsterMessage> messages, String toolName) {
        for (LobsterMessage m : messages) {
            if (m.getRole() == MessageRole.tool && toolName.equals(m.getToolName())) {
                return m;
            }
        }
        return null;
    }

    private String repeat(String s, int n) {
        StringBuilder sb = new StringBuilder(s.length() * n);
        for (int i = 0; i < n; i++) sb.append(s);
        return sb.toString();
    }
}
