package com.gzzm.lobster.quota;

import com.gzzm.lobster.common.LobsterException;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * ConcurrencyGuard —— 并发限制守卫 / Concurrency guard.
 *
 * <p>每个用户、每个 thread 的并发 run 数控制。
 */
public class ConcurrencyGuard {

    private final ConcurrentHashMap<String, AtomicInteger> userActive = new ConcurrentHashMap<>();
    private final ConcurrentHashMap<String, AtomicInteger> threadActive = new ConcurrentHashMap<>();

    /** 开启一个 run 的并发许可；超限抛异常 / Acquire a run slot. */
    public Acquisition acquire(String userId, String threadId, int userMax, int threadMax) {
        AtomicInteger uc = userActive.computeIfAbsent(userId, new java.util.function.Function<String, AtomicInteger>() {
            @Override public AtomicInteger apply(String s) { return new AtomicInteger(); }
        });
        AtomicInteger tc = threadActive.computeIfAbsent(threadId, new java.util.function.Function<String, AtomicInteger>() {
            @Override public AtomicInteger apply(String s) { return new AtomicInteger(); }
        });
        if (uc.incrementAndGet() > userMax) {
            uc.decrementAndGet();
            throw new LobsterException("quota.user.concurrent",
                    "User " + userId + " exceeded concurrent-run limit (" + userMax + ")");
        }
        if (tc.incrementAndGet() > threadMax) {
            uc.decrementAndGet();
            tc.decrementAndGet();
            throw new LobsterException("quota.thread.concurrent",
                    "Thread " + threadId + " exceeded concurrent-run limit (" + threadMax + ")");
        }
        return new Acquisition(uc, tc);
    }

    public static final class Acquisition implements AutoCloseable {
        private final AtomicInteger uc;
        private final AtomicInteger tc;
        private boolean released = false;
        Acquisition(AtomicInteger uc, AtomicInteger tc) { this.uc = uc; this.tc = tc; }
        @Override public void close() {
            if (!released) {
                uc.decrementAndGet();
                tc.decrementAndGet();
                released = true;
            }
        }
    }
}
