package com.gzzm.lobster.runtime;

import java.util.Map;
import java.util.LinkedHashMap;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import com.gzzm.lobster.common.StreamEventType;
import com.gzzm.platform.commons.Tools;

/**
 * Forwards one run's stream events to all currently attached SSE clients.
 *
 * <p>Network writes must not run on the AgentRuntime worker thread. Each
 * attached client gets a tiny single-threaded queue so a slow browser can only
 * delay its own observation stream, never the backend run lifecycle.
 */
public final class MultiplexStreamEmitter implements StreamEmitter {

    private static final int CLIENT_QUEUE_CAPACITY = 8192;
    private static final int MAX_MERGED_TEXT_CHARS = 4096;
    private static final long TEXT_BATCH_WAIT_MS = 40L;
    private static final long POLL_WAIT_MS = 200L;
    private static final long CLOSE_WAIT_MS = 1000L;

    private final Map<StreamEmitter, ClientSink> sinks = new ConcurrentHashMap<>();

    public void add(StreamEmitter sink) {
        if (sink == null || sinks.containsKey(sink)) return;
        ClientSink created = new ClientSink(sink, this);
        ClientSink existing = sinks.putIfAbsent(sink, created);
        if (existing != null) created.close();
    }

    public void emitTo(StreamEmitter sink, StreamEvent event) {
        if (sink == null || event == null) return;
        ClientSink client = sinks.get(sink);
        if (client != null) client.enqueue(event);
    }

    public void remove(StreamEmitter sink) {
        if (sink == null) return;
        ClientSink client = sinks.remove(sink);
        if (client != null) client.close();
    }

    @Override
    public void emit(StreamEvent event) {
        for (ClientSink sink : sinks.values()) {
            sink.enqueue(event);
        }
    }

    @Override
    public boolean isBroken() {
        return false;
    }

    private void removeBroken(StreamEmitter sink, ClientSink client) {
        if (sink == null || client == null) return;
        if (sinks.remove(sink, client)) client.closeNow();
    }

    private static final class ClientSink {
        private final StreamEmitter sink;
        private final MultiplexStreamEmitter owner;
        private final ExecutorService writer;
        private final BlockingDeque<StreamEvent> queue = new LinkedBlockingDeque<>(CLIENT_QUEUE_CAPACITY);
        private final AtomicBoolean closed = new AtomicBoolean(false);
        private volatile Thread writerThread;
        private volatile boolean overflowLogged;

        ClientSink(StreamEmitter sink, MultiplexStreamEmitter owner) {
            this.sink = sink;
            this.owner = owner;
            this.writer = Executors.newSingleThreadExecutor(r -> {
                Thread t = new Thread(r, "lobster-sse-client");
                t.setDaemon(true);
                writerThread = t;
                return t;
            });
            try {
                writer.submit(this::pump);
            } catch (Throwable t) {
                closed.set(true);
                owner.removeBroken(sink, this);
            }
        }

        synchronized void enqueue(final StreamEvent event) {
            if (event == null || closed.get()) return;
            if (queue.offerLast(event)) return;
            handleOverflow(event);
        }

        private void handleOverflow(StreamEvent event) {
            if (!overflowLogged) {
                overflowLogged = true;
                try { Tools.log("[MultiplexStreamEmitter] SSE client queue overflow; closing observation stream for reconnect"); }
                catch (Throwable ignore) { /* ignore */ }
            }
            failObservation();
        }

        private void failObservation() {
            closed.set(true);
            queue.clear();
            owner.removeBroken(sink, this);
            try { sink.closeObservation(); } catch (Throwable ignore) { /* ignore */ }
        }

        private void pump() {
            StreamEvent pending = null;
            try {
                while (!closed.get() || pending != null || !queue.isEmpty()) {
                    StreamEvent event = pending;
                    pending = null;
                    if (event == null) {
                        event = queue.pollFirst(POLL_WAIT_MS, TimeUnit.MILLISECONDS);
                    }
                    if (event == null) continue;
                    StreamEvent toWrite = event;
                    if (isMergeableText(event)) {
                        MergeResult merged = mergeTextBatch(event);
                        toWrite = merged.event;
                        pending = merged.pending;
                    }
                    sink.emit(toWrite);
                    if (sink.isBroken()) {
                        owner.removeBroken(sink, this);
                        return;
                    }
                }
            } catch (InterruptedException ie) {
                Thread.currentThread().interrupt();
            } catch (Throwable t) {
                failObservation();
            }
        }

        void close() {
            closed.set(true);
            try { writer.shutdown(); } catch (Throwable ignore) { /* ignore */ }
            if (Thread.currentThread() == writerThread) return;
            try {
                if (!writer.awaitTermination(CLOSE_WAIT_MS, TimeUnit.MILLISECONDS)) {
                    writer.shutdownNow();
                }
            } catch (Throwable ignore) {
                try { writer.shutdownNow(); } catch (Throwable ignored) { /* ignore */ }
            }
        }

        void closeNow() {
            closed.set(true);
            try { writer.shutdownNow(); } catch (Throwable ignore) { /* ignore */ }
        }

        private MergeResult mergeTextBatch(StreamEvent first) throws InterruptedException {
            StringBuilder text = new StringBuilder(textDelta(first));
            StreamEvent pending = null;
            long eventSeq = eventSeq(first);
            long deadline = System.currentTimeMillis() + TEXT_BATCH_WAIT_MS;
            while (text.length() < MAX_MERGED_TEXT_CHARS) {
                long waitMs = deadline - System.currentTimeMillis();
                if (waitMs <= 0L) break;
                StreamEvent next = queue.pollFirst(waitMs, TimeUnit.MILLISECONDS);
                if (next == null) break;
                if (!canMerge(first, next)) {
                    pending = next;
                    break;
                }
                eventSeq = Math.max(eventSeq, eventSeq(next));
                text.append(textDelta(next));
            }
            return new MergeResult(withTextDelta(first, text.toString(), eventSeq), pending);
        }

        private boolean isMergeableText(StreamEvent event) {
            if (event == null) return false;
            StreamEventType type = event.getType();
            return type == StreamEventType.assistant_text
                    || type == StreamEventType.assistant_thinking
                    || type == StreamEventType.write_file_content_delta;
        }

        private boolean canMerge(StreamEvent first, StreamEvent next) {
            if (first == null || next == null) return false;
            if (first.getType() != next.getType()) return false;
            if (!samePayload(first, next, "threadId")) return false;
            if (!samePayload(first, next, "runId")) return false;
            if (first.getType() == StreamEventType.write_file_content_delta) {
                return samePayload(first, next, "toolCallId")
                        && samePayload(first, next, "toolIndex");
            }
            return true;
        }

        private boolean samePayload(StreamEvent left, StreamEvent right, String key) {
            Object a = left.getPayload().get(key);
            Object b = right.getPayload().get(key);
            return a == null ? b == null : a.equals(b);
        }

        private String textDelta(StreamEvent event) {
            String key = textKey(event);
            Object raw = key == null ? null : event.getPayload().get(key);
            return raw == null ? "" : String.valueOf(raw);
        }

        private StreamEvent withTextDelta(StreamEvent event, String text, long eventSeq) {
            String key = textKey(event);
            if (key == null) return event;
            Map<String, Object> payload = new LinkedHashMap<>(event.getPayload());
            payload.put(key, text);
            if (eventSeq > 0L) payload.put("eventSeq", eventSeq);
            return StreamEvent.of(event.getType(), payload);
        }

        private long eventSeq(StreamEvent event) {
            if (event == null || event.getPayload() == null) return 0L;
            Object raw = event.getPayload().get("eventSeq");
            if (raw instanceof Number) return ((Number) raw).longValue();
            if (raw == null) return 0L;
            try { return Long.parseLong(String.valueOf(raw)); }
            catch (Throwable ignore) { return 0L; }
        }

        private String textKey(StreamEvent event) {
            if (event == null) return null;
            if (event.getType() == StreamEventType.write_file_content_delta) return "contentDelta";
            return "delta";
        }
    }

    private static final class MergeResult {
        final StreamEvent event;
        final StreamEvent pending;

        MergeResult(StreamEvent event, StreamEvent pending) {
            this.event = event;
            this.pending = pending;
        }
    }
}
