package com.gzzm.lobster.runtime;

import com.gzzm.lobster.common.StreamEventType;
import org.junit.Test;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class MultiplexStreamEmitterTest {

    @Test
    public void batchesConsecutiveAssistantTextWithoutDroppingTerminalEvent() throws Exception {
        MultiplexStreamEmitter mux = new MultiplexStreamEmitter();
        CapturingEmitter sink = new CapturingEmitter();
        mux.add(sink);

        mux.emit(text("a", 10L));
        mux.emit(text("b", 11L));
        mux.emit(text("c", 12L));
        mux.emit(runEnded());

        assertTrue(sink.awaitEvents(2, 1000L));
        mux.remove(sink);

        assertEquals(StreamEventType.assistant_text, sink.events.get(0).getType());
        assertEquals("abc", sink.events.get(0).getPayload().get("delta"));
        assertEquals(12L, ((Number) sink.events.get(0).getPayload().get("eventSeq")).longValue());
        assertEquals(StreamEventType.run_ended, sink.events.get(1).getType());
    }

    private static StreamEvent text(String delta, long eventSeq) {
        Map<String, Object> payload = basePayload();
        payload.put("delta", delta);
        payload.put("eventSeq", eventSeq);
        return StreamEvent.of(StreamEventType.assistant_text, payload);
    }

    private static StreamEvent runEnded() {
        Map<String, Object> payload = basePayload();
        payload.put("exitReason", "done");
        return StreamEvent.of(StreamEventType.run_ended, payload);
    }

    private static Map<String, Object> basePayload() {
        Map<String, Object> payload = new LinkedHashMap<>();
        payload.put("threadId", "th_1");
        payload.put("runId", "run_1");
        return payload;
    }

    private static final class CapturingEmitter implements StreamEmitter {
        final List<StreamEvent> events = new ArrayList<>();

        @Override
        public synchronized void emit(StreamEvent event) {
            events.add(event);
            notifyAll();
        }

        synchronized boolean awaitEvents(int count, long timeoutMs) throws InterruptedException {
            long deadline = System.currentTimeMillis() + timeoutMs;
            while (events.size() < count) {
                long waitMs = deadline - System.currentTimeMillis();
                if (waitMs <= 0) break;
                wait(waitMs);
            }
            return events.size() >= count;
        }
    }
}
