# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional

from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.utils import is_list_of

if TYPE_CHECKING:
    from vllm.config import ModelConfig

# The idea of multimodal input caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- P0:
#  - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
#    each input multi-modal item (e.g. image),
#  - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
#    which are MultiModalKwargsItem instances that each correspond to an
#    input multi-modal item.
#  - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
#    `mm_hash` for each item. It stores the `mm_hash` as keys and the size
#    of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
#    up additional memory in P0.
#  - The `mm_hash` is always sent to P1.
#  - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
#    in MultiModalInputCacheServer.
#
# -- P1:
#  - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
#    MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
#  - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
#    MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
#  - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
#    the engine for model execution.
#
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.


class MultiModalInputCacheClient:
    """Used by P0 to check whether multi-modal kwargs are cached in P1."""

    def __init__(self, model_config: "ModelConfig",
                 mm_registry: MultiModalRegistry) -> None:
        super().__init__()

        self.enabled = mm_registry.enable_mm_input_cache(model_config)
        self.mm_cache = MultiModalCache.get_lru_cache(
            model_config.get_mm_input_cache_gb(),
            MultiModalCacheItemMetadata,
        )

    def get_and_update(
        self,
        mm_kwargs: Sequence[MultiModalKwargsItem],
        mm_hashes: list[str],
    ) -> list[Optional[MultiModalKwargsItem]]:
        if not self.enabled:
            return list(mm_kwargs)

        assert len(mm_kwargs) == len(mm_hashes)

        out_mm_items = list[Optional[MultiModalKwargsItem]]()
        for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
            if self.mm_cache.get(mm_hash) is not None:
                out_mm_items.append(None)
            else:
                self.mm_cache[mm_hash] = \
                    MultiModalCacheItemMetadata.wraps(mm_item)
                out_mm_items.append(mm_item)

        return out_mm_items

    def reset(self) -> None:
        self.mm_cache.clear()


class MultiModalInputCacheServer:
    """Used by P1 to avoid requiring past multi-modal kwargs from P0."""

    def __init__(self, model_config: "ModelConfig",
                 mm_registry: MultiModalRegistry) -> None:
        super().__init__()

        self.enabled = mm_registry.enable_mm_input_cache(model_config)
        self.mm_cache = MultiModalCache.get_lru_cache(
            model_config.get_mm_input_cache_gb(),
            MultiModalKwargsItem,
        )

    def get_and_update(
        self,
        mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
        mm_hashes: list[str],
    ) -> list[MultiModalKwargsItem]:
        if not self.enabled:
            mm_kwargs_lst = list(mm_kwargs)
            assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
            return mm_kwargs_lst

        assert len(mm_kwargs) == len(mm_hashes)

        out_mm_items = list[MultiModalKwargsItem]()
        for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
            if mm_item is None:
                out_mm_items.append(self.mm_cache[mm_hash])
            else:
                self.mm_cache[mm_hash] = mm_item
                out_mm_items.append(mm_item)

        return out_mm_items

    def reset(self) -> None:
        self.mm_cache.clear()
