# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
                                                   get_act_fn)
from vllm.model_executor.layers.fused_moe.fused_moe import (
    fused_topk, torch_vllm_outplace_fused_experts)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (SupportsQuant,
                                                   default_pooling_type)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors


class BertWithRopeEmbedding(nn.Module):

    def __init__(self, config: PretrainedConfig):

        super().__init__()
        if config.position_embedding_type not in ["rope", "rotary"]:
            raise ValueError("Only 'rotary'('rope') position_embedding_type" +
                             " is supported")

        self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                      config.hidden_size)
        if config.type_vocab_size > 0:
            self.token_type_embeddings = VocabParallelEmbedding(
                config.type_vocab_size, config.hidden_size)
        else:
            self.token_type_embeddings = None

        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        input_shape = input_ids.size()
        inputs_embeds = self.word_embeddings(input_ids)

        embeddings = inputs_embeds
        if self.token_type_embeddings is not None:
            if token_type_ids is None:
                token_type_ids = torch.zeros(input_shape,
                                             dtype=torch.long,
                                             device=inputs_embeds.device)

            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            embeddings += token_type_embeddings

        embeddings = self.LayerNorm(embeddings)
        return embeddings


class BertWithRopeAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = True,
        rotary_kwargs: Optional[dict] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_attention_heads
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = self.total_num_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj")

        self.rotary_emb = get_rope(**rotary_kwargs)

        self.attn = Attention(num_heads=self.num_heads,
                              head_size=self.head_dim,
                              scale=self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn",
                              attn_type=AttentionType.ENCODER_ONLY)

        self.out_proj = RowParallelLinear(input_size=hidden_size,
                                          output_size=hidden_size,
                                          bias=bias,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.dense")

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.out_proj(attn_output)
        return output


class BertWithRopeGatedMLP(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 hidden_act: str,
                 bias: bool = True,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.act_fn = get_act_and_mul_fn(hidden_act)
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
                                           bias=bias,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(hidden_states)
        hidden_states = self.act_fn(gate_up)
        hidden_states, _ = self.down_proj(hidden_states)
        return hidden_states


class BertWithRopeMLP(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 hidden_act: str,
                 bias: bool = True,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.act_fn = get_act_fn(hidden_act)
        self.up_proj = ColumnParallelLinear(input_size=hidden_size,
                                            output_size=intermediate_size,
                                            bias=bias,
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.up_proj")
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
                                           bias=bias,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.up_proj(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states, _ = self.down_proj(hidden_states)
        return hidden_states


class NomicMoE(nn.Module):

    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        params_dtype: Optional[torch.dtype] = None,
        tp_size: Optional[int] = None,
    ):
        super().__init__()

        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
        self.total_intermediate_size = intermediate_size
        self.intermediate_size = divide(intermediate_size, self.tp_size)
        self.hidden_act = hidden_act

        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        self.router = ReplicatedLinear(self.hidden_size,
                                       self.num_total_experts,
                                       bias=False)
        self.w1 = nn.Parameter(
            torch.empty(self.num_total_experts,
                        self.intermediate_size,
                        self.hidden_size,
                        device=current_platform.device_type,
                        dtype=self.params_dtype))
        self.w2 = nn.Parameter(
            torch.empty(self.num_total_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        device=current_platform.device_type,
                        dtype=self.params_dtype))
        self.bias = nn.Parameter(torch.zeros(self.hidden_size))
        set_weight_attrs(self.w1, {
            "weight_loader": self.weight_loader,
        })
        set_weight_attrs(self.w2, {
            "weight_loader": self.weight_loader,
        })

    def weight_loader(
        self,
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
        weight_name: str,
    ):
        # NOTE: Nomic-MoE has fused experts weights with shape
        # (num_experts * intermediate_size, hidden_size)
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1"):
            loaded_weight = loaded_weight.reshape(
                self.num_total_experts,
                self.total_intermediate_size,
                self.hidden_size,
            )[:, shard]
        if weight_name.endswith("w2"):
            loaded_weight = loaded_weight.reshape(
                self.num_total_experts,
                self.total_intermediate_size,
                self.hidden_size,
            )[:, shard].transpose(1, 2)
        param_data.copy_(loaded_weight)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.router(hidden_states)
        # FIXME(Isotr0py): This implementation is too tricky,
        # we should use FusedMoE instead in the future
        # after supporting ungated activation for it.
        topk_weights, topk_ids, _ = fused_topk(hidden_states,
                                               router_logits,
                                               self.top_k,
                                               renormalize=False)
        final_hidden_states = torch_vllm_outplace_fused_experts(
            hidden_states=hidden_states,
            w1=self.w1,
            w2=self.w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=self.hidden_act,
            is_act_and_mul=False,
        )

        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)

        return final_hidden_states.view(num_tokens, hidden_size) + self.bias


class BertWithRopeBlock(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 moe: bool = False,
                 bias: bool = True,
                 rotary_kwargs: Optional[dict] = None,
                 prefix: str = ""):
        super().__init__()
        self.attn = BertWithRopeAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            bias=bias,
            rotary_kwargs=rotary_kwargs,
            prefix=f"{prefix}.attention")

        if moe:
            self.mlp = NomicMoE(num_experts=config.num_experts,
                                top_k=config.moe_top_k,
                                hidden_size=config.hidden_size,
                                intermediate_size=config.intermediate_size,
                                hidden_act=config.hidden_act)
        else:
            if config.hidden_act in ["silu", "geglu"]:
                self.mlp = BertWithRopeGatedMLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=config.intermediate_size,
                    hidden_act=config.hidden_act,
                    bias=bias,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp")
            else:
                self.mlp = BertWithRopeMLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=config.intermediate_size,
                    hidden_act=config.hidden_act,
                    bias=bias,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp")

        self.attn_ln = nn.LayerNorm(config.hidden_size,
                                    eps=config.layer_norm_eps)
        self.mlp_ln = nn.LayerNorm(config.hidden_size,
                                   eps=config.layer_norm_eps)

    def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
        attn_output = self.attn(positions, hidden_states)
        hidden_states = self.attn_ln(hidden_states + attn_output)
        mlp_out = self.mlp(hidden_states)
        hidden_states = self.mlp_ln(hidden_states + mlp_out)
        return hidden_states


class BertWithRopeEncoder(nn.Module):

    def __init__(self,
                 vllm_config: VllmConfig,
                 bias: bool = True,
                 rotary_kwargs: Optional[dict] = None,
                 prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        every_n = getattr(config, "moe_every_n_layers", 0)
        self.layers = nn.ModuleList([
            BertWithRopeBlock(config=config,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              bias=bias,
                              moe=every_n > 0 and (layer_idx % every_n == 1),
                              rotary_kwargs=rotary_kwargs,
                              prefix=f"{prefix}.layer.{layer_idx}")
            for layer_idx in range(config.num_hidden_layers)
        ])

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        for layer in self.layers:
            hidden_states = layer(positions, hidden_states)
        return hidden_states


@support_torch_compile
@default_pooling_type("CLS")
class BertWithRope(nn.Module, SupportsQuant):
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config
        self.embeddings = BertWithRopeEmbedding(self.config)
        self.encoder = BertWithRopeEncoder(
            vllm_config=vllm_config,
            bias=getattr(self.config, "bias", True),
            rotary_kwargs=self.config.rotary_kwargs,
            prefix=f"{prefix}.encoder")

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.embeddings(input_ids=input_ids,
                                            token_type_ids=token_type_ids)
        return self.encoder(positions, hidden_states)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        weights = self.hf_to_vllm_mapper.apply(weights)

        if self.config.hidden_act in ["silu", "geglu"]:
            stacked_params_mapping = [
                # (param_name, shard_name, shard_id)
                ("gate_up_proj", "gate_proj", 0),
                ("gate_up_proj", "up_proj", 1),
            ]
        else:
            stacked_params_mapping = []

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "pooler" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if name.endswith((".w1", ".w2")):
                    # Nomic-MoE has fused experts weights
                    weight_loader(param, loaded_weight, name)
                else:
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class NomicBertModel(BertWithRope):
    # for https://huggingface.co/nomic-ai/nomic-bert-2048

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "emb_ln": "embeddings.LayerNorm",
            "attn.Wqkv": "attn.qkv_proj",
            "norm1": "attn_ln",
            "mlp.fc1.": "mlp.up_proj.",
            "mlp.fc11": "mlp.up_proj",
            "mlp.fc12": "mlp.gate_proj",
            "mlp.fc2": "mlp.down_proj",
            "norm2": "mlp_ln",
            # MoE mapping
            "experts.mlp.": "",
            "experts.": "",
            "router.layer": "router",
        })


class GteNewModel(BertWithRope):
    # for https://huggingface.co/Alibaba-NLP/new-impl

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "new.": "",
            "layer": "layers",
            "attention.qkv_proj": "attn.qkv_proj",
            "attention.o_proj": "attn.out_proj",
        })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        # GteNewModel only gate_up_proj does not have bias.
        # Hack method learned from vllm/model_executor/models/glm.py
        for layer in self.encoder.layers:
            layer.mlp.gate_up_proj.bias = None
            layer.mlp.gate_up_proj.skip_bias_add = True

    def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
        n = "mlp.up_gate_proj"
        for name, weight in weights:
            if n in name:
                up, gate = weight.chunk(2, dim=0)
                yield name.replace(n, "mlp.up_proj"), up
                yield name.replace(n, "mlp.gate_proj"), gate
            else:
                yield name, weight

    def ignore_unnecessary_layers(self,
                                  weights: Iterable[tuple[str, torch.Tensor]]):
        for name, weight in weights:
            if name.startswith("classifier"):
                continue
            yield name, weight

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        weights = self.ignore_unnecessary_layers(weights)
        weights = self.split_up_gate_proj(weights)
        return super().load_weights(weights)


class SnowflakeGteNewModel(GteNewModel):
    # for Snowflake/snowflake-arctic-embed-m-v2.0

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "layer": "layers",
            "attention.qkv_proj": "attn.qkv_proj",
            "attention.o_proj": "attn.out_proj",
        })


class JinaRobertaModel(BertWithRope):
    # for https://huggingface.co/jinaai/jina-embeddings-v3

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "emb_ln": "embeddings.LayerNorm",
            "mixer.Wqkv": "attn.qkv_proj",
            "mixer.out_proj": "attn.out_proj",
            "norm1": "attn_ln",
            "mlp.fc1.": "mlp.up_proj.",
            "mlp.fc2": "mlp.down_proj",
            "norm2": "mlp_ln",
        })

    @torch.inference_mode()
    def jina_merge_lora_weights(self, weights: Iterable[tuple[str,
                                                              torch.Tensor]]):
        # use for jina-embeddings-v3
        # Merge Lora weights into a single weight tensor.
        # This is a temporary solution until we have a better way to handle

        scaling = self.config.lora_alpha / self.config.lora_rank
        device = self.vllm_config.device_config.device

        weights = {name: weight for name, weight in weights}

        o = ".original"
        a = ".0.lora_A"
        b = ".0.lora_B"

        # text-matching
        i = -1

        for name in list(weights.keys()):
            if o in name:
                dtype = weights[name].dtype
                shape = weights[name].shape
                weight_name = name[:-len(o)]

                if "embeddings" in weight_name:
                    B = weights[weight_name + a][i].to(device).float()
                    A = weights[weight_name + b][i].to(device).float()
                else:
                    B = weights[weight_name + b][i].to(device).float()
                    A = weights[weight_name + a][i].to(device).float()

                weight = (weights[weight_name + o].to(device) +
                          torch.matmul(B, A).view(shape) * scaling)
                weight = weight.cpu().to(dtype)

                weights[weight_name.replace(".parametrizations", "")] = weight

                del weights[weight_name + o], weights[weight_name +
                                                      a], weights[weight_name +
                                                                  b]

        return [(name, weight) for name, weight in weights.items()]

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        weights = self.jina_merge_lora_weights(weights)
        return super().load_weights(weights)
