package com.gzzm.lobster.runtime;

import com.gzzm.lobster.common.*;
import com.gzzm.lobster.config.LobsterConfig;
import com.gzzm.lobster.context.ContextAssembler;
import com.gzzm.lobster.context.ContextAssembly;
import com.gzzm.lobster.guardrails.ClaimConsistencyChecker;
import com.gzzm.lobster.guardrails.ContentFilter;
import com.gzzm.lobster.guardrails.InternalInfoSanitizer;
import com.gzzm.lobster.identity.UserContext;
import com.gzzm.lobster.llm.*;
import com.gzzm.lobster.quota.ConcurrencyGuard;
import com.gzzm.lobster.run.Run;
import com.gzzm.lobster.run.RunDao;
import com.gzzm.lobster.run.RunStreamEventService;
import com.gzzm.lobster.thread.ThreadMessage;
import com.gzzm.lobster.thread.ThreadRoom;
import com.gzzm.lobster.thread.ThreadService;
import com.gzzm.lobster.tool.*;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatchers;

import java.lang.reflect.Field;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;

/**
 * AgentRuntime 主循环端到端集成测试（Mockito）。
 *
 * <p>覆盖：
 * <ol>
 *   <li>LLM 无 tool_calls → exitReason=normal</li>
 *   <li>LLM 发出 tool_calls → 执行工具 → 再调 LLM → 无 tool_calls → 结束</li>
 *   <li>工具返回 pending → 当前 run 结束（exitReason=pending）</li>
 *   <li>连续 3 次相同 tool + args → loop_detected</li>
 *   <li>cancel 在 LLM 边界真实生效 → exitReason=cancelled</li>
 * </ol>
 *
 * <p>AgentRuntime 现在通过 {@code llmRuntime.chatStream(..., handler)} 取响应，
 * 测试里用 {@code doAnswer} 让 mock 同步回调 handler.onComplete 推入预设响应。
 */
public class AgentRuntimeFlowTest {

    private AgentRuntime runtime;
    private ThreadService threadService;
    private ContextAssembler contextAssembler;
    private ToolExecutorDispatcher dispatcher;
    private ToolRegistry toolRegistry;
    private LlmRuntime llmRuntime;
    private ModelRouter modelRouter;
    private RunDao runDao;
    private RunStreamEventService runStreamEventService;
    private com.gzzm.lobster.config.AgentProfileDao agentProfileDao;
    private com.gzzm.lobster.audit.AuditService auditService;
    private com.gzzm.lobster.plan.PlanService planService;
    private ContentFilter contentFilter;
    private InternalInfoSanitizer sanitizer;
    private ConcurrencyGuard concurrencyGuard;

    private ThreadRoom thread;
    private UserContext user;

    @Before
    public void setup() throws Exception {
        runtime = new AgentRuntime();
        threadService = mock(ThreadService.class);
        contextAssembler = mock(ContextAssembler.class);
        dispatcher = mock(ToolExecutorDispatcher.class);
        toolRegistry = mock(ToolRegistry.class);
        llmRuntime = mock(LlmRuntime.class);
        modelRouter = mock(ModelRouter.class);
        runDao = mock(RunDao.class);
        runStreamEventService = mock(RunStreamEventService.class);
        agentProfileDao = mock(com.gzzm.lobster.config.AgentProfileDao.class);
        auditService = mock(com.gzzm.lobster.audit.AuditService.class);
        planService = mock(com.gzzm.lobster.plan.PlanService.class);
        ClaimConsistencyChecker claimChecker = new ClaimConsistencyChecker();
        contentFilter = mock(ContentFilter.class);
        sanitizer = new InternalInfoSanitizer();
        concurrencyGuard = new ConcurrencyGuard();

        inject("threadService", threadService);
        inject("contextAssembler", contextAssembler);
        inject("toolExecutorDispatcher", dispatcher);
        inject("toolRegistry", toolRegistry);
        inject("llmRuntime", llmRuntime);
        inject("modelRouter", modelRouter);
        inject("runDao", runDao);
        inject("runStreamEventService", runStreamEventService);
        inject("agentProfileDao", agentProfileDao);
        inject("auditService", auditService);
        inject("planService", planService);
        inject("claimChecker", claimChecker);
        inject("contentFilter", contentFilter);
        inject("sanitizer", sanitizer);
        inject("concurrencyGuard", concurrencyGuard);

        thread = new ThreadRoom();
        thread.setThreadId("th_unit");
        thread.setUserId("u_unit");
        thread.setOrgId("org");
        user = new UserContext("u_unit", "ext", "dept", "org", "张三", Collections.<String>emptySet());

        when(threadService.appendMessage(any(), any(), any(), any(), any(), any(), any(), any()))
                .thenReturn(new ThreadMessage());
        // AgentRuntime 调的是 9 参数 assemble(thread, currentUserInput, budgetTokens, route,
        //                                       kbEnabled, kbMode, kbScopeIds, user, runSnapshot).
        // 用 anyXxx 全部宽松匹配；返回空 ContextAssembly 即可（测试不关心送给 LLM 的 send-view 内容）.
        when(contextAssembler.assemble(any(ThreadRoom.class), any(), anyInt(), any(),
                anyBoolean(), any(), any(), any(), any()))
                .thenReturn(new ContextAssembly(new ArrayList<LobsterMessage>(), null, null, 0));
        when(toolRegistry.toToolSpecs(any())).thenReturn(new ArrayList<ToolSpec>());
        ModelProfile mp = new ModelProfile();
        mp.setModelId("test-model");
        mp.setProvider(ModelProvider.openai_compatible);
        mp.setProtocol(ModelProtocol.chat_completions);
        when(modelRouter.route(any(ModelSelectionContext.class)))
                .thenReturn(new ModelRouteResult(mp, Collections.<ModelProfile>emptyList(), "test"));
        when(contentFilter.check(anyString(), anyString()))
                .thenReturn(ContentFilter.FilterResult.pass());
        when(runStreamEventService.maxSeq(anyString())).thenReturn(0L);
        when(runStreamEventService.record(anyString(), anyString(), anyString(),
                anyLong(), any(StreamEvent.class))).thenAnswer(inv -> inv.getArgument(4));
        final String workerId = runtimeWorkerId();
        when(runDao.getRun(anyString())).thenAnswer(inv -> {
            Run run = new Run();
            run.setRunId(inv.getArgument(0));
            run.setThreadId(thread.getThreadId());
            run.setUserId(user.getUserId());
            run.setStatus(RunStatus.running);
            run.setWorkerId(workerId);
            run.setHeartbeatAt(new Date());
            return run;
        });
        when(runDao.progressOwned(anyInt(), any(Date.class), anyString(), any(RunStatus.class), anyString()))
                .thenReturn(1);
        when(runDao.updateRequestPayloadOwned(anyString(), anyString(), any(RunStatus.class), anyString()))
                .thenReturn(1);
        when(runDao.updateModelOwned(anyString(), anyString(), any(RunStatus.class), anyString()))
                .thenReturn(1);
        when(runDao.finishOwned(any(RunStatus.class), any(Date.class), any(RunExitReason.class),
                anyInt(), any(), any(), anyString(), any(RunStatus.class), anyString()))
                .thenReturn(1);
        when(runDao.suspendOwned(any(RunStatus.class), any(RunExitReason.class), anyInt(),
                any(Date.class), anyString(), any(RunStatus.class), anyString()))
                .thenReturn(1);
        when(runDao.heartbeatOwned(any(Date.class), anyString(), any(RunStatus.class), anyString()))
                .thenReturn(1);
    }

    // ---- helpers ----

    /** 让 mock llmRuntime.chatStream 同步调 handler.onComplete(next) 推入响应。 */
    private void stubStream(LlmResponse... sequence) {
        final AtomicInteger idx = new AtomicInteger(0);
        doAnswer(inv -> {
            StreamingResponseHandler h = inv.getArgument(4);
            int i = Math.min(idx.getAndIncrement(), sequence.length - 1);
            LlmResponse r = sequence[i];
            if (r.getAssistantText() != null && !r.getAssistantText().isEmpty()) {
                h.onDelta(r.getAssistantText());
            }
            for (ToolCall tc : r.getToolCalls()) h.onToolCall(tc);
            h.onComplete(r);
            return null;
        }).when(llmRuntime).chatStream(any(), any(), any(), any(), any());
    }

    // ---- tests ----

    @Test
    public void noToolCallsEndsNormally() throws Exception {
        stubStream(new LlmResponse("已完成总结。", new ArrayList<ToolCall>(),
                10, 5, "stop", "test-model", ""));

        RunResult r = runtime.run(new RunRequest(thread, user, "请总结", "user_input", null, null), null);

        assertEquals(RunExitReason.normal, r.getExitReason());
        assertEquals(1, r.getTurns());
        assertEquals("已完成总结。", r.getLastAssistantText());
    }

    @Test(timeout = 2000)
    public void activeStreamingDoesNotHitTurnIdleTimeout() throws Exception {
        LobsterConfig cfg = new LobsterConfig();
        long previousTimeoutMs = LobsterConfig.getLlmTurnTimeoutMs();
        cfg.setLlmTurnTimeoutMs(100);
        try {
            doAnswer(inv -> {
                StreamingResponseHandler h = inv.getArgument(4);
                StringBuilder text = new StringBuilder();
                for (int i = 0; i < 4; i++) {
                    Thread.sleep(70);
                    String delta = "chunk" + i;
                    text.append(delta);
                    h.onDelta(delta);
                }
                h.onComplete(new LlmResponse(text.toString(), new ArrayList<ToolCall>(),
                        10, 5, "stop", "test-model", ""));
                return null;
            }).when(llmRuntime).chatStream(any(), any(), any(), any(), any());

            RunResult r = runtime.run(new RunRequest(thread, user, "long streaming task",
                    "user_input", null, null), null);

            assertEquals(RunExitReason.normal, r.getExitReason());
            assertEquals("chunk0chunk1chunk2chunk3", r.getLastAssistantText());
        } finally {
            cfg.setLlmTurnTimeoutMs(previousTimeoutMs);
        }
    }

    @Test
    public void toolCallsLoopUntilNoMore() throws Exception {
        stubStream(
                new LlmResponse("", Arrays.asList(new ToolCall("id1", "read_file", "{}")),
                        10, 5, "tool_calls", "test-model", ""),
                new LlmResponse("已读取完成。", new ArrayList<ToolCall>(),
                        10, 5, "stop", "test-model", ""));
        when(dispatcher.dispatch(any(ToolContext.class), any(ToolCall.class)))
                .thenReturn(ToolResult.ok("file content", null));

        RunResult r = runtime.run(new RunRequest(thread, user, "读文件", "user_input", null, null), null);

        assertEquals(RunExitReason.normal, r.getExitReason());
        assertEquals(2, r.getTurns());
        verify(dispatcher, times(1)).dispatch(any(ToolContext.class), any(ToolCall.class));
    }

    @Test
    public void pendingToolEndsRunWithPendingReason() throws Exception {
        stubStream(new LlmResponse("", Arrays.asList(new ToolCall("id1", "confirm_action", "{}")),
                10, 5, "tool_calls", "test-model", ""));
        when(dispatcher.dispatch(any(ToolContext.class), any(ToolCall.class)))
                .thenReturn(ToolResult.pending("pr_xxx", "等待确认"));

        RunResult r = runtime.run(new RunRequest(thread, user, "覆盖原文件", "user_input", null, null), null);

        assertEquals(RunExitReason.pending, r.getExitReason());
        assertEquals("pr_xxx", r.getPendingRequestId());
    }

    @Test
    public void loopDetectionTripsAfterThreeIdenticalCalls() throws Exception {
        stubStream(new LlmResponse("", Arrays.asList(new ToolCall("id", "read_file", "{\"a\":1}")),
                10, 5, "tool_calls", "test-model", ""));
        when(dispatcher.dispatch(any(ToolContext.class), any(ToolCall.class)))
                .thenReturn(ToolResult.ok("data", null));

        RunResult r = runtime.run(new RunRequest(thread, user, "读", "user_input", null, null), null);
        assertEquals(RunExitReason.loop_detected, r.getExitReason());
    }

    /**
     * cancel 对未知 runId 不抛异常（API 幂等性保险）。
     */
    @Test
    public void cancelOnUnknownRunIdIsNoop() throws Exception {
        stubStream(new LlmResponse("完成", new ArrayList<ToolCall>(),
                10, 5, "stop", "test-model", ""));
        runtime.cancel("never-existed");  // 不应抛异常
        RunResult r = runtime.run(new RunRequest(thread, user, "x", "user_input", null, null), null);
        assertEquals(RunExitReason.normal, r.getExitReason());
    }

    /**
     * cancel 真实中断：chatStream 后台线程还在长时间阻塞时，主循环必须在 200ms 级别看到 cancel。
     *
     * <p>模拟慢 adapter：chatStream 在后台线程里睡 3 秒才 onComplete；在此期间另一个线程
     * 触发 cancel。runtime 应该在 <1s 内返回 RunResult.cancelled，而不是等 3 秒。
     */
    /**
     * cancel 真实中断：chatStream 在后台线程阻塞 3 秒才会"自然完成"；
     * 但在运行期间另一个线程触发 cancel，run 必须在 &lt;1s 内返回 cancelled。
     *
     * <p>模拟 adapter 行为：每 50ms 检查 handler.isCancelled()，收到信号就 onError。
     * 这正是我们对 OpenAiCompatibleAdapter / OllamaAdapter 的 cancel 语义约定。
     */
    @Test(timeout = 2500)
    public void cancelInterruptsBlockingStreamQuickly() throws Exception {
        // chatStream 的 mock：模拟"慢" adapter，每 50ms check isCancelled()
        doAnswer(inv -> {
            StreamingResponseHandler h = inv.getArgument(4);
            long start = System.currentTimeMillis();
            while (System.currentTimeMillis() - start < 3000) {
                if (h.isCancelled()) {
                    h.onError(new com.gzzm.lobster.common.LobsterException(
                            "llm.cancelled", "cancelled by upper layer"));
                    return null;
                }
                Thread.sleep(50);
            }
            // 3 秒自然完成（若测试断言 <1s 没成立，这里会兜底结束）
            h.onComplete(new LlmResponse("太慢了", new ArrayList<ToolCall>(),
                    0, 0, "stop", "test-model", ""));
            return null;
        }).when(llmRuntime).chatStream(any(), any(), any(), any(), any());

        // run 启动后 250ms 触发 cancel：通过反射读 cancelFlags 拿到当前 run 的 id
        new Thread(() -> {
            try { Thread.sleep(250); } catch (InterruptedException ignore) { /* ignore */ }
            try {
                java.lang.reflect.Field f = AgentRuntime.class.getDeclaredField("cancelFlags");
                f.setAccessible(true);
                @SuppressWarnings("unchecked")
                java.util.concurrent.ConcurrentHashMap<String, ?> map =
                        (java.util.concurrent.ConcurrentHashMap<String, ?>) f.get(runtime);
                for (String id : map.keySet()) runtime.cancel(id);
            } catch (Exception ignore) { /* best-effort */ }
        }).start();

        long t0 = System.currentTimeMillis();
        RunResult r = runtime.run(
                new RunRequest(thread, user, "slow task", "user_input", null, null), null);
        long elapsed = System.currentTimeMillis() - t0;

        assertEquals(RunExitReason.cancelled, r.getExitReason());
        assertTrue("cancel should propagate in < 1s, actual " + elapsed + "ms", elapsed < 1000);
    }

    private void inject(String fieldName, Object value) throws Exception {
        Field f = AgentRuntime.class.getDeclaredField(fieldName);
        f.setAccessible(true);
        f.set(runtime, value);
    }

    private String runtimeWorkerId() throws Exception {
        Field f = AgentRuntime.class.getDeclaredField("runtimeWorkerId");
        f.setAccessible(true);
        return (String) f.get(runtime);
    }
}
