# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Optional

import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs

from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
    initialize_model, process_weights_after_loading, set_default_torch_dtype)

logger = init_logger(__name__)


class TPUModelLoader(DefaultModelLoader):
    """
    A TPU model loader for model loading under SPMD mode.
    """

    def load_model(
        self,
        vllm_config: VllmConfig,
        model_config: ModelConfig,
        mesh: Optional[xs.Mesh] = None,
    ) -> nn.Module:
        # Initialize model and load weights on CPU. Then, during SPMD partition,
        # weights are sharded and transferred to TPUs.
        self.counter_before_loading_weights = time.perf_counter()
        model_config = vllm_config.model_config
        assert model_config.quantization is None, "Quantization not supported"
        target_device = torch.device('cpu')
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = initialize_model(vllm_config=vllm_config)

            load_format = vllm_config.load_config.load_format
            if load_format != "dummy":
                weights_to_load = {
                    name
                    for name, _ in model.named_parameters()
                }
                all_weights = self.get_all_weights(model_config, model)
                loaded_weights = model.load_weights(all_weights)
                self.counter_after_loading_weights = time.perf_counter()
                logger.info(
                    "Loading weights took %.2f seconds",
                    self.counter_after_loading_weights -
                    self.counter_before_loading_weights)
                # We only enable strict check for non-quantized models
                # that have loaded weights tracking currently.
                if model_config.quantization is None and \
                    loaded_weights is not None:
                    weights_not_loaded = weights_to_load - loaded_weights
                    if weights_not_loaded:
                        raise ValueError(
                            "Following weights were not initialized from "
                            f"checkpoint: {weights_not_loaded}")
            else:
                logger.info("Use dummy weight during weight loading.")

            process_weights_after_loading(model, model_config, target_device)

        counter_before_partition = time.perf_counter()
        model = model.eval()
        model = model.to('xla')
        shard_model(model, mesh)
        counter_after_partition = time.perf_counter()
        logger.info("Partition model took %.2f seconds",
                    counter_after_partition - counter_before_partition)

        # Ensure the model is properly loaded.
        self._check_model_is_loaded(mesh, model)

        # Need to torch compile after model sharding are done. Because the
        # compiler hints ('xs.mark_sharding') are torch ops.
        if not model_config.is_multimodal_model:
            model.model = torch.compile(model.model, backend="openxla")
        else:
            model.language_model.model = \
                torch.compile(model.language_model.model, backend="openxla")
        return model

    def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
                               model: nn.Module) -> None:
        """
        Ensure the model is properly loaded.
        1. All model parameters and buffers are on XLA device.
        2. Non-SPMD friendly layers are replaced as expected.
        """
        device = xm.xla_device()
        device_type = str(device.type)

        # Check parameters
        for name, param in model.named_parameters():
            assert param.device.type == device_type, f"Parameter {name} is on \
                {param.device.type} instead of {device_type}"

        # Check buffers
        for name, buffer in model.named_buffers():
            assert buffer.device.type == device_type, \
                f"Buffer {name} is on {buffer.device.type} instead of \
                    {device_type}"

        for module in model.modules():
            if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
                raise AssertionError("QKVParallelLinear should be replaced by \
                            XlaQKVParallelLinear under SPMD mode.")
