package com.gzzm.lobster.quota;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * CircuitBreakerRegistry —— 简单熔断器注册表 /
 * Minimal circuit-breaker registry.
 *
 * <p>按 key（model id、tool name、mcp server 等）维护状态机：
 * 连续失败达到阈值时进入 OPEN；冷却后半开；半开成功则 CLOSED。
 */
public class CircuitBreakerRegistry {

    public enum State { CLOSED, OPEN, HALF_OPEN }

    private final ConcurrentHashMap<String, Breaker> breakers = new ConcurrentHashMap<>();

    private final int failureThreshold;
    private final long openMillis;

    public CircuitBreakerRegistry() {
        this(5, 30_000L);
    }

    public CircuitBreakerRegistry(int failureThreshold, long openMillis) {
        this.failureThreshold = failureThreshold;
        this.openMillis = openMillis;
    }

    public boolean allow(String key) {
        Breaker b = breakers.computeIfAbsent(key, new java.util.function.Function<String, Breaker>() {
            @Override public Breaker apply(String s) { return new Breaker(); }
        });
        State cur = b.state();
        if (cur == State.OPEN) {
            long now = System.currentTimeMillis();
            if (now - b.openedAt.get() >= openMillis) {
                b.halfOpen();
                return true;
            }
            return false;
        }
        return true;
    }

    public void recordSuccess(String key) {
        Breaker b = breakers.get(key);
        if (b == null) return;
        b.failures.set(0);
        b.markClosed();
    }

    public void recordFailure(String key) {
        Breaker b = breakers.computeIfAbsent(key, new java.util.function.Function<String, Breaker>() {
            @Override public Breaker apply(String s) { return new Breaker(); }
        });
        int f = b.failures.incrementAndGet();
        if (f >= failureThreshold) {
            b.markOpen();
        }
    }

    public State stateOf(String key) {
        Breaker b = breakers.get(key);
        return b == null ? State.CLOSED : b.state();
    }

    private static final class Breaker {
        final AtomicInteger failures = new AtomicInteger(0);
        volatile State state = State.CLOSED;
        final AtomicLong openedAt = new AtomicLong(0);

        void markOpen() { state = State.OPEN; openedAt.set(System.currentTimeMillis()); }
        void halfOpen() { state = State.HALF_OPEN; }
        void markClosed() { state = State.CLOSED; }
        State state() { return state; }
    }
}
