# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import types
from enum import Enum, auto
from typing import Optional

import torch

from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
                                             LogitsProcessor,
                                             MoveDirectionality)

MODEL_NAME = "facebook/opt-125m"
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
DUMMY_LOGITPROC_ARG = "target_token"
TEMP_GREEDY = 0.0
MAX_TOKENS = 20
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
DUMMY_LOGITPROC_MODULE = "DummyModule"
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"


class CustomLogitprocSource(Enum):
    """How to source a logitproc for testing purposes"""
    LOGITPROC_SOURCE_NONE = auto()  # No custom logitproc
    LOGITPROC_SOURCE_ENTRYPOINT = auto()  # Via entrypoint
    LOGITPROC_SOURCE_FQCN = auto()  # Via fully-qualified class name (FQCN)
    LOGITPROC_SOURCE_CLASS = auto()  # Via provided class object


# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


class DummyLogitsProcessor(LogitsProcessor):
    """Fake logit processor to support unit testing and examples"""

    def __init__(self, vllm_config: "VllmConfig", device: torch.device,
                 is_pin_memory: bool):
        self.req_info: dict[int, SamplingParams] = {}

    def is_argmax_invariant(self) -> bool:
        """Never impacts greedy sampling"""
        return False

    def update_state(self, batch_update: Optional[BatchUpdate]):
        if not batch_update:
            return

        # Process added requests.
        for index, params, _, _ in batch_update.added:
            assert params is not None
            if params.extra_args and (target_token :=
                                      params.extra_args.get("target_token")):
                self.req_info[index] = target_token

        if self.req_info:
            # Process removed requests.
            for index in batch_update.removed:
                self.req_info.pop(index, None)

            # Process moved requests, unidirectional move (a->b) and swap
            # (a<->b)
            for adx, bdx, direct in batch_update.moved:
                a_val = self.req_info.pop(adx, None)
                b_val = self.req_info.pop(bdx, None)
                if a_val is not None:
                    self.req_info[bdx] = a_val
                if direct == MoveDirectionality.SWAP and b_val is not None:
                    self.req_info[adx] = b_val

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if not self.req_info:
            return logits

        # Save target values before modification
        rows_list = list(self.req_info.keys())
        cols = torch.tensor([self.req_info[i] for i in rows_list],
                            dtype=torch.long,
                            device=logits.device)
        rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
        values_to_keep = logits[rows, cols].clone()

        # Mask all but target tokens
        logits[rows] = float('-inf')
        logits[rows, cols] = values_to_keep

        return logits


"""Dummy module with dummy logitproc class"""
dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE)
dummy_module.DummyLogitsProcessor = DummyLogitsProcessor  # type: ignore


class EntryPoint:
    """Dummy entrypoint class for logitsprocs testing"""

    def __init__(self):
        self.name = DUMMY_LOGITPROC_ENTRYPOINT
        self.value = DUMMY_LOGITPROC_FQCN

    def load(self):
        return DummyLogitsProcessor


class EntryPoints(list):
    """Dummy EntryPoints class for logitsprocs testing"""

    def __init__(self, group: str):
        # Emulate list-like functionality
        eps = [EntryPoint()] if group == LOGITSPROCS_GROUP else []
        super().__init__(eps)
        # Extra attributes
        self.names = [ep.name for ep in eps]


"""Fake version of importlib.metadata.entry_points"""
entry_points = lambda group: EntryPoints(group)
