# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import os
import struct
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Union

import torch
import zmq
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save

from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
from vllm.utils import join_host_port, make_zmq_path, split_host_port

logger = init_logger(__name__)
NONE_INT = -150886311


@dataclass
class MooncakeTransferEngineConfig:
    prefill_url: str
    decode_url: str
    metadata_backend: Union[str, None]
    metadata_server: str
    protocol: str
    device_name: str

    @staticmethod
    def from_file(file_path: str) -> 'MooncakeTransferEngineConfig':
        """Load the config from a JSON file."""
        with open(file_path) as fin:
            config = json.load(fin)
        return MooncakeTransferEngineConfig(
            prefill_url=config.get("prefill_url"),
            decode_url=config.get("decode_url"),
            metadata_backend=config.get("metadata_backend", None),
            metadata_server=config.get("metadata_server"),
            protocol=config.get("protocol", "tcp"),
            device_name=config.get("device_name", ""),
        )

    @staticmethod
    def load_from_env() -> 'MooncakeTransferEngineConfig':
        """Load config from a file specified in the environment variable."""
        config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
        if config_file_path is None:
            raise ValueError(
                "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
        return MooncakeTransferEngineConfig.from_file(config_file_path)


class MooncakeTransferEngine:
    """Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""

    def __init__(self, kv_rank: int, local_rank: int):
        try:
            from mooncake.engine import TransferEngine
        except ImportError as e:
            raise ImportError(
                "Please install mooncake by following the instructions at "
                "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "  # noqa: E501
                "to run vLLM with MooncakeConnector.") from e

        self.engine = TransferEngine()
        self.local_rank = local_rank

        try:
            self.config = MooncakeTransferEngineConfig.load_from_env()
            logger.info("Mooncake Configuration loaded successfully.")
        except ValueError as e:
            logger.error(e)
            raise
        except Exception as exc:
            logger.error(
                "An error occurred while loading the configuration: %s", exc)
            raise
        prefill_host, base_prefill_port = split_host_port(
            self.config.prefill_url)
        decode_host, base_decode_port = split_host_port(self.config.decode_url)

        # Avoid ports conflict when running prefill and decode on the same node
        if prefill_host == decode_host and \
                base_prefill_port == base_decode_port:
            base_decode_port = base_decode_port + 100

        prefill_port = base_prefill_port + self.local_rank
        decode_port = base_decode_port + self.local_rank
        self.prefill_url = join_host_port(prefill_host, prefill_port)
        self.decode_url = join_host_port(decode_host, decode_port)

        self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
                        self.config.metadata_server, self.config.protocol,
                        self.config.device_name, self.config.metadata_backend)

        self.remote_url = (self.decode_url
                           if kv_rank == 0 else self.prefill_url)

        # Initialize ZeroMQ context and sockets
        self.context = zmq.Context()  # type: ignore[attr-defined]
        self.sender_socket = self.context.socket(zmq.constants.PUSH)
        self.receiver_socket = self.context.socket(zmq.constants.PULL)
        self.sender_ack = self.context.socket(zmq.constants.PULL)
        self.receiver_ack = self.context.socket(zmq.constants.PUSH)

        self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
        self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
                                     decode_host, base_decode_port)

    def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int,
                                d_host: str, d_port: int) -> None:
        """Set up ZeroMQ sockets for sending and receiving data."""
        # Offsets < 8 are left for initialization in case tp and pp are enabled
        p_rank_offset = p_port + 8 + self.local_rank * 2
        d_rank_offset = d_port + 8 + self.local_rank * 2
        if kv_rank == 0:
            self.sender_socket.bind(
                make_zmq_path("tcp", p_host, p_rank_offset + 1))
            self.receiver_socket.connect(
                make_zmq_path("tcp", d_host, d_rank_offset + 1))
            self.sender_ack.connect(
                make_zmq_path("tcp", d_host, d_rank_offset + 2))
            self.receiver_ack.bind(
                make_zmq_path("tcp", p_host, p_rank_offset + 2))
        else:
            self.receiver_socket.connect(
                make_zmq_path("tcp", p_host, p_rank_offset + 1))
            self.sender_socket.bind(
                make_zmq_path("tcp", d_host, d_rank_offset + 1))
            self.receiver_ack.bind(
                make_zmq_path("tcp", d_host, d_rank_offset + 2))
            self.sender_ack.connect(
                make_zmq_path("tcp", p_host, p_rank_offset + 2))

    def initialize(self, local_hostname: str, metadata_server: str,
                   protocol: str, device_name: str,
                   metadata_backend: Union[str, None]) -> None:
        """Initialize the mooncake instance."""
        if metadata_backend is None:
            self.engine.initialize(local_hostname, metadata_server, protocol,
                                   device_name)
        else:
            supported_backend = ["etcd", "redis"]
            metadata_backend = metadata_backend.lower()
            if metadata_backend not in supported_backend:
                raise ValueError(
                    "Mooncake Configuration error. `metadata_backend`"
                    f" should be one of {supported_backend}.")

            self.engine.initialize_ext(local_hostname, metadata_server,
                                       protocol, device_name, metadata_backend)

    def allocate_managed_buffer(self, length: int) -> int:
        """Allocate a managed buffer of the specified length."""
        ret = self.engine.allocate_managed_buffer(length)
        if ret <= 0:
            logger.error("Allocation Return Error")
            raise Exception("Allocation Return Error")
        return ret

    def free_managed_buffer(self, buffer: int, length: int) -> int:
        """Free a previously allocated managed buffer."""
        return self.engine.free_managed_buffer(buffer, length)

    def transfer_sync(self, buffer: int, peer_buffer_address: int,
                      length: int) -> int:
        """Synchronously transfer data to the specified address."""
        ret = self.engine.transfer_sync_read(self.remote_url, buffer,
                                             peer_buffer_address, length)
        if ret < 0:
            logger.error("Transfer Return Error")
            raise Exception("Transfer Return Error")
        return ret

    def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
                              length: int) -> int:
        """Write bytes to the allocated buffer."""
        return self.engine.write_bytes_to_buffer(buffer, user_data, length)

    def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
        """Read bytes from the allocated buffer."""
        return self.engine.read_bytes_from_buffer(buffer, length)

    def wait_for_ack(self, src_ptr: int, length: int) -> None:
        """Asynchronously wait for ACK from the receiver."""
        ack = self.sender_ack.recv()
        if ack != b'ACK':
            logger.error("Failed to receive ACK from the receiver")

        self.free_managed_buffer(src_ptr, length)

    def send_bytes(self, user_data: bytes) -> None:
        """Send bytes to the remote process."""
        length = len(user_data)
        src_ptr = self.allocate_managed_buffer(length)
        self.write_bytes_to_buffer(src_ptr, user_data, length)
        self.sender_socket.send_multipart(
            [struct.pack("!Q", src_ptr),
             struct.pack("!Q", length)])
        self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)

    def recv_bytes(self) -> bytes:
        """Receive bytes from the remote process."""
        data = self.receiver_socket.recv_multipart()
        src_ptr = struct.unpack("!Q", data[0])[0]
        length = struct.unpack("!Q", data[1])[0]
        dst_ptr = self.allocate_managed_buffer(length)
        self.transfer_sync(dst_ptr, src_ptr, length)
        ret = self.read_bytes_from_buffer(dst_ptr, length)

        # Buffer cleanup
        self.receiver_ack.send(b'ACK')
        self.free_managed_buffer(dst_ptr, length)

        return ret


class MooncakePipe(KVPipeBase):
    """MooncakeTransferEngine based Pipe implementation."""

    def __init__(self,
                 local_rank: int,
                 config: KVTransferConfig,
                 device: Optional[str] = None):
        """Initialize the mooncake pipe and set related parameters."""
        self.config = config
        self.local_rank = local_rank
        self.kv_rank = self.config.kv_rank
        if device is None:
            self.device = self._select_device(self.config.kv_buffer_device)
        else:
            self.device = self._select_device(device)

        self.transfer_engine = MooncakeTransferEngine(self.kv_rank,
                                                      self.local_rank)
        self.transport_thread: Optional[ThreadPoolExecutor] = None
        self.none_tensor = torch.tensor([NONE_INT], device=self.device)

    def _select_device(self, device: str) -> torch.device:
        """Select available device (CUDA or CPU)."""
        logger.info("Selecting device: %s", device)
        if device == "cuda":
            return torch.device(f"cuda:{self.local_rank}")
        else:
            return torch.device("cpu")

    def tensor_hash(self, tensor: torch.Tensor) -> int:
        """Calculate the hash value of the tensor."""
        return hash(tensor.data_ptr())

    def _send_impl(self, tensor: torch.Tensor) -> None:
        """Implement the tensor sending logic using safetensors."""
        self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))

    def _recv_impl(self) -> torch.Tensor:
        """Implement the tensor receiving logic using safetensors."""
        data = self.transfer_engine.recv_bytes()
        return safetensors_load(data)["tensor"].to(self.device)

    def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
        """Send tensor to the target process."""
        if self.transport_thread is None:
            self.transport_thread = ThreadPoolExecutor(max_workers=1)
        tensor = tensor if tensor is not None else self.none_tensor
        assert (len(tensor.shape) > 0)
        self.transport_thread.submit(self._send_impl, tensor)

    def recv_tensor(self) -> Optional[torch.Tensor]:
        """Receive tensor from other processes."""
        if self.transport_thread is None:
            self.transport_thread = ThreadPoolExecutor(max_workers=1)
        tensor = self.transport_thread.submit(self._recv_impl).result()
        if tensor.numel() == 1 and tensor.item() == NONE_INT:
            return None
        else:
            return tensor

    def close(self) -> None:
        """Cleanup logic when closing the pipe."""
        self.transfer_engine.sender_socket.close()
        self.transfer_engine.receiver_socket.close()
        self.transfer_engine.sender_ack.close()
        self.transfer_engine.receiver_ack.close()
        self.transfer_engine.context.term()  # Terminate the ZMQ context
        logger.info("Closed the transfer engine and cleaned up resources.")
