package com.gzzm.lobster.llm;

import org.junit.Test;

import java.lang.reflect.Method;
import java.util.Collections;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

public class LlmRuntimePromptCacheTest {

    @SuppressWarnings("unchecked")
    private Map<String, Object> promptCacheToMap(LlmResponse response, int inputTokens) throws Exception {
        Method m = LlmRuntime.class.getDeclaredMethod("promptCacheToMap", LlmResponse.class, int.class);
        m.setAccessible(true);
        return (Map<String, Object>) m.invoke(null, response, inputTokens);
    }

    @Test
    public void derivesMissTokensFromInputTokensWhenProviderOnlyReportsHits() throws Exception {
        LlmResponse response = new LlmResponse("", Collections.<ToolCall>emptyList(),
                1000, 0, "stop", "test", "", null, 600, 0);

        Map<String, Object> cache = promptCacheToMap(response, 1000);

        assertEquals(600, cache.get("hitTokens"));
        assertEquals(400, cache.get("missTokens"));
        assertEquals(1000, cache.get("totalTokens"));
        assertEquals(0.6d, (Double) cache.get("hitRate"), 0.0001d);
    }

    @Test
    public void hitRateIsUnknownWhenOnlyHitTokensAreKnown() throws Exception {
        LlmResponse response = new LlmResponse("", Collections.<ToolCall>emptyList(),
                0, 0, "stop", "test", "", null, 600, 0);

        Map<String, Object> cache = promptCacheToMap(response, 0);

        assertEquals(600, cache.get("hitTokens"));
        assertEquals(0, cache.get("missTokens"));
        assertEquals(600, cache.get("totalTokens"));
        assertNull(cache.get("hitRate"));
    }
}
