# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (BatchFeature, Phi4MultimodalAudioConfig,
                          Phi4MultimodalConfig, Phi4MultimodalFeatureExtractor,
                          Phi4MultimodalImageProcessorFast)
from transformers import Phi4MultimodalProcessor as Phi4MMProcessor
from transformers.models.phi4_multimodal.modeling_phi4_multimodal import (
    Phi4MultimodalAudioConvModule, Phi4MultimodalAudioNemoConvSubsampling,
    Phi4MultimodalAudioRelativeAttentionBias, adaptive_enc_mask, unfold_tensor)

from vllm.config import VllmConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems,
                                   ImageProcessorItems, ImageSize,
                                   MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)

# <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
# <|endoftext11|>
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011

_AUDIO_MAX_SOUNDFILE_SIZE = 241_000


def _get_padding_size(orig_width: int, orig_height: int, target_height: int,
                      target_width: int):
    ratio_width = target_width / orig_width
    ratio_height = target_height / orig_height

    if ratio_width < ratio_height:
        padding_width = 0
        padding_height = target_height - int(orig_height * ratio_width)
    else:
        padding_width = target_width - int(orig_width * ratio_height)
        padding_height = 0
    return padding_height, padding_width


class Phi4MMProjector(nn.Module):

    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.up = ColumnParallelLinear(input_size, hidden_size)
        self.down = RowParallelLinear(hidden_size, hidden_size)
        self.act = get_act_fn("gelu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up(x)
        x = self.act(x)
        x, _ = self.down(x)
        return x


class Phi4MMImageEmbedding(nn.Module):
    """Image embedding."""

    def __init__(self, config: Phi4MultimodalConfig):
        super().__init__()
        self.config = config
        self.layer_idx = config.vision_config.feature_layer
        self.crop_size = config.vision_config.crop_size
        self.image_dim_out = config.vision_config.hidden_size

        n_patches = (config.vision_config.image_size //
                     config.vision_config.patch_size)
        if n_patches % 2 != 0:
            self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1))
            n_patches += 1
        self.num_img_tokens = (n_patches // 2)**2

        num_hidden_layers = (config.vision_config.num_hidden_layers +
                             self.layer_idx +
                             1 if self.layer_idx < 0 else self.layer_idx + 1)
        self.img_processor = Idefics2VisionTransformer(
            config.vision_config,
            require_post_norm=False,
            num_hidden_layers_override=num_hidden_layers)
        self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2)
        self.img_projection = Phi4MMProjector(self.image_dim_out,
                                              config.hidden_size)
        self.global_img_feature_extensor = nn.Parameter(
            torch.zeros([1, 1, self.image_dim_out]))
        self.sub_img_feature_extensor = nn.Parameter(
            torch.zeros([1, 1, 1, self.image_dim_out]))

    def get_img_features(
        self,
        img_embeds: torch.FloatTensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        img_feature = self.img_processor(img_embeds,
                                         patch_attention_mask=attention_mask)

        patch_feature = img_feature
        # reshape to 2D tensor
        width = int(math.sqrt(patch_feature.size(1)))
        patch_feature = patch_feature.view(-1, width, width,
                                           patch_feature.size(-1))
        # convert to NCHW
        patch_feature = patch_feature.permute(0, 3, 1, 2)
        if getattr(self, "img_processor_padding", None) is not None:
            patch_feature = self.img_processor_padding(patch_feature)
        patch_feature = self.image_token_compression(patch_feature)
        # convert to NHWC
        patch_feature = patch_feature.permute(0, 2, 3, 1)
        patch_feature = patch_feature.view(
            -1,
            patch_feature.size(1) * patch_feature.size(2),
            patch_feature.size(-1))
        return patch_feature

    def forward(
        self,
        image_pixel_values: torch.FloatTensor,
        image_sizes: Optional[torch.Tensor] = None,
        image_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        image_pixel_values = image_pixel_values.to(
            self.img_processor.embeddings.patch_embedding.weight.dtype)

        target_device = self.img_projection.up.bias.device
        target_dtype = self.img_projection.up.bias.dtype

        batch_size = image_pixel_values.shape[0]

        img_features = self.get_img_features(
            image_pixel_values.flatten(0, 1),
            attention_mask=image_attention_mask.flatten(0, 1).to(
                dtype=bool, device=target_device),
        )
        base_feat_size = int(np.sqrt(img_features.shape[1]))
        img_features = img_features.view(batch_size, -1, base_feat_size**2,
                                         self.image_dim_out)
        image_sizes = image_sizes.view(-1, 2)

        output_imgs = []
        for idx in range(batch_size):
            height, width = image_sizes[idx]
            height_ratio = height // self.crop_size
            width_ratio = width // self.crop_size
            area_ratio = height_ratio * width_ratio

            global_img = img_features[idx, :1]
            global_img = global_img.reshape(1, base_feat_size, base_feat_size,
                                            self.image_dim_out).contiguous()
            temporary_extensor = self.sub_img_feature_extensor.repeat(
                1, base_feat_size, 1, 1)
            global_img = torch.cat([global_img, temporary_extensor],
                                   dim=2).reshape(1, -1, self.image_dim_out)

            sub_img = img_features[idx, 1:]
            sub_img = sub_img[:area_ratio]
            sub_img = (sub_img.reshape(
                height_ratio, width_ratio, base_feat_size, base_feat_size,
                self.image_dim_out).transpose(1, 2).reshape(
                    1, height_ratio * base_feat_size,
                    width_ratio * base_feat_size,
                    self.image_dim_out).contiguous())

            if image_attention_mask is not None:
                reshaped_image_attention_mask = (
                    image_attention_mask[idx, 1:area_ratio + 1,
                                         0::2, 0::2].reshape(
                                             height_ratio, width_ratio,
                                             base_feat_size,
                                             base_feat_size).transpose(
                                                 1, 2).reshape(
                                                     1, height_ratio *
                                                     base_feat_size,
                                                     width_ratio *
                                                     base_feat_size))
                useful_height = int(
                    reshaped_image_attention_mask[0, :, 0].sum().item())
                useful_width = int(
                    reshaped_image_attention_mask[0, 0, :].sum().item())
                sub_img = sub_img[:, :useful_height, :useful_width]
                temporary_extensor = self.sub_img_feature_extensor.repeat(
                    1, useful_height, 1, 1)
            else:
                temporary_extensor = self.sub_img_feature_extensor.repeat(
                    1, height_ratio * base_feat_size, 1, 1)

            sub_img = torch.cat([sub_img, temporary_extensor],
                                dim=2).reshape(1, -1, self.image_dim_out)

            # Merge global and sub
            output_imgs.append(
                torch.cat(
                    [sub_img, self.global_img_feature_extensor, global_img],
                    dim=1))

        img_set_tensor = []
        for output_img in output_imgs:
            output_img = output_img.to(device=target_device,
                                       dtype=target_dtype)
            img_feature_proj = self.img_projection(output_img)
            img_set_tensor.append(img_feature_proj.flatten(0, 1))

        return img_set_tensor


class Phi4MultimodalAudioMLP(nn.Module):

    def __init__(
        self,
        config: Phi4MultimodalAudioConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.act_fn = MulAndSilu()
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size, [config.intermediate_size] * 2,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
        self.down_proj = RowParallelLinear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=True,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.layer_norm(hidden_states)
        hidden_states, _ = self.gate_up_proj(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states, _ = self.down_proj(hidden_states)
        return hidden_states


class Phi4MultimodalAudioAttention(nn.Module):

    def __init__(
        self,
        config: Phi4MultimodalAudioConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.total_num_heads
        if self.head_dim * self.total_num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.num_heads = divide(self.total_num_heads, self.tp_size)

    def split_attn_mask(self, attention_mask: torch.Tensor) -> torch.Tensor:
        start_idx = self.num_heads * self.tp_rank
        end_idx = self.num_heads * (self.tp_rank + 1)
        return attention_mask[:, start_idx:end_idx]

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        qkv_states, _ = self.qkv_proj(hidden_states)
        query, key, value = qkv_states.chunk(3, dim=-1)

        bsz, seq_len, _ = query.size()
        query = query.view(bsz, seq_len, self.num_heads, self.head_dim)
        key = key.view(bsz, seq_len, self.num_heads, self.head_dim)
        value = value.view(bsz, seq_len, self.num_heads, self.head_dim)
        query, key, value = (x.transpose(1, 2) for x in (query, key, value))

        attention_mask = self.split_attn_mask(attention_mask)
        out = F.scaled_dot_product_attention(
            query,
            key,
            value,
            scale=self.scale,
            attn_mask=attention_mask,
        )
        out = out.transpose(1, 2).reshape(bsz, seq_len, -1)

        attn_output, _ = self.o_proj(out)

        return attn_output


class Phi4MultimodalAudioConformerEncoderLayer(nn.Module):

    def __init__(self, config: Phi4MultimodalAudioConfig):
        super().__init__()

        self.feed_forward_in = Phi4MultimodalAudioMLP(config)
        self.self_attn = Phi4MultimodalAudioAttention(config)
        self.conv = Phi4MultimodalAudioConvModule(config)
        self.feed_forward_out = Phi4MultimodalAudioMLP(config)
        self.layer_norm_att = nn.LayerNorm(config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states)
        hidden_states = self.layer_norm_att(residual)

        hidden_states = residual + self.self_attn(hidden_states,
                                                  attention_mask)
        hidden_states = hidden_states + self.conv(hidden_states)
        hidden_states = hidden_states + 0.5 * self.feed_forward_out(
            hidden_states)

        out = self.layer_norm(hidden_states)

        return out


class Phi4MMAudioMeanVarianceNormLayer(nn.Module):
    """Mean/variance normalization layer.

    Will subtract mean and multiply input by inverted standard deviation.
    Typically used as a very first layer in a model.

    Args:
        input_size: int
            layer input size.
    """

    def __init__(self, config: Phi4MultimodalAudioConfig):
        super().__init__()
        self.global_mean = nn.Parameter(torch.zeros(config.input_size))
        self.global_invstd = nn.Parameter(torch.ones(config.input_size))

    def forward(self, input_: torch.Tensor) -> torch.Tensor:
        """MeanVarianceNormLayer Forward

        Args:
            input_: torch.Tensor
                input tensor.
        """
        return (input_ - self.global_mean) * self.global_invstd


class Phi4MultimodalAudioModel(nn.Module):

    def __init__(self, config: Phi4MultimodalAudioConfig):
        super().__init__()
        self.config = config

        self.encoder_embedding = Phi4MMAudioMeanVarianceNormLayer(config)
        self.embed = Phi4MultimodalAudioNemoConvSubsampling(config)
        self.relative_attention_bias_layer = (
            Phi4MultimodalAudioRelativeAttentionBias(config))
        self.encoders = nn.ModuleList([
            Phi4MultimodalAudioConformerEncoderLayer(config)
            for _ in range(config.num_blocks)
        ])

    def _streaming_mask(
        self,
        seq_len: int,
        batch_size: int,
        chunk_size: int,
        left_chunk: int,
    ):
        # Create mask matrix for streaming
        # S stores start index. if chunksize is 18, s is [0,18,36,....]
        chunk_start_idx = np.arange(0, seq_len, chunk_size)

        enc_streaming_mask = (adaptive_enc_mask(
            seq_len, chunk_start_idx,
            left_window=left_chunk).unsqueeze(0).expand([batch_size, -1, -1]))
        return enc_streaming_mask

    def forward_embeddings(
        self,
        hidden_states: torch.Tensor,
        masks: torch.Tensor,
    ):
        """Forwarding the inputs through the top embedding layers"""
        seq_len = math.ceil(hidden_states.shape[1] /
                            self.config.time_reduction)
        if seq_len <= 0:
            raise ValueError(
                f"Sequence length after time reduction is invalid: {seq_len}."
                "Your input feature is too short.")

        batch_size = hidden_states.shape[0]

        enc_streaming_mask = self._streaming_mask(seq_len, batch_size,
                                                  self.config.chunk_size,
                                                  self.config.left_chunk)
        enc_streaming_mask = enc_streaming_mask.to(hidden_states.device)

        hidden_states, masks = self.embed(hidden_states, masks)

        streaming_mask = enc_streaming_mask
        if streaming_mask is not None and masks is not None:
            hs_mask = masks & streaming_mask
        elif masks is not None:
            hs_mask = masks
        else:
            hs_mask = streaming_mask

        return hidden_states, hs_mask, masks

    def calculate_hs_mask(self, hidden_states: torch.Tensor,
                          device: torch.device, mask: torch.Tensor):
        max_audio_length = hidden_states.shape[1]
        batch_size = hidden_states.shape[0]
        enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
                                                  self.config.chunk_size,
                                                  self.config.left_chunk)
        enc_streaming_mask = enc_streaming_mask.to(device)
        if mask is None:
            return enc_streaming_mask

        feature_lens = mask.sum(1)
        padding_length = feature_lens
        pad_mask = torch.arange(0, max_audio_length, device=device).expand(
            padding_length.size(0), -1) < padding_length.unsqueeze(1)
        pad_mask = pad_mask.unsqueeze(1)
        pad_mask = pad_mask & enc_streaming_mask
        return pad_mask

    def forward(self,
                hidden_states: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        hidden_states = self.encoder_embedding(hidden_states)
        hidden_states, hs_mask, mask = self.forward_embeddings(
            hidden_states, mask)

        unfolded = False
        bs, seq_len, _ = hidden_states.shape
        max_seq_len = 500  # maximum position for absolute positional encoding
        if seq_len > max_seq_len:
            # audio sequence is longer than max_seq_len,
            # unfold it into chunks of max_seq_len
            unfolded = True
            # the unfold op will drop residual frames,
            # pad it to the multiple of max_seq_len
            if seq_len % max_seq_len > 0:
                chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
            else:
                chunk_pad_size = 0
            if chunk_pad_size > 0:
                hidden_states_pad = F.pad(hidden_states,
                                          (0, 0, 0, chunk_pad_size),
                                          "constant", 0)
                hidden_states = hidden_states_pad.to(hidden_states.device)

            hidden_states = unfold_tensor(hidden_states, max_seq_len)
            masks_unfold = None
            if mask is not None:
                # revise hs_mask here because the previous calculated hs_mask
                # did not consider extra pad
                subsampled_pad_mask = mask.squeeze(
                    1)  # [bz, subsampled_unmask_seq_len]
                extra_padded_subsamlped_pad_mask = F.pad(
                    subsampled_pad_mask, (0, chunk_pad_size), "constant",
                    False)  # extra padding to the pad mask
                extra_padded_subsamlped_pad_mask = (
                    extra_padded_subsamlped_pad_mask.unsqueeze(-1).float())
                masks_unfold = unfold_tensor(
                    extra_padded_subsamlped_pad_mask, max_seq_len
                )  # unfold the pad mask like we did to the input tensor
                masks_unfold = masks_unfold.squeeze(
                    -1).bool()  # unfold op does not support bool tensor
            hs_mask = self.calculate_hs_mask(
                hidden_states, hidden_states.device, masks_unfold
            )  # calculate hs_mask based on the unfolded pad mask

        relative_attention_bias = self.relative_attention_bias_layer(
            hidden_states)
        attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias

        for layer in self.encoders:
            hidden_states = layer(hidden_states, attention_mask)

        if unfolded:
            embed_dim = hidden_states.shape[-1]
            hidden_states = hidden_states.reshape(bs, -1, embed_dim)
            # if we ever padded before unfolding, we need to remove the padding
            if chunk_pad_size > 0:
                hidden_states = hidden_states[:, :-chunk_pad_size, :]

        return hidden_states


class Phi4MMAudioEmbedding(nn.Module):

    def __init__(self, config: Phi4MultimodalConfig):
        super().__init__()
        self.config = config
        self.layer_idx = config.audio_config.feature_layer

        self.encoder = Phi4MultimodalAudioModel(config.audio_config)

        audio_config = config.audio_config
        proj_input_size = (audio_config.hidden_size *
                           audio_config.downsample_rate)
        self.vision_speech_projection = Phi4MMProjector(
            proj_input_size, config.hidden_size)
        self.speech_projection = Phi4MMProjector(proj_input_size,
                                                 config.hidden_size)

    def get_projection(
        self,
        audio_projection_mode: Literal["speech", "vision"],
    ) -> Phi4MMProjector:
        if audio_projection_mode == "speech":
            return self.speech_projection
        elif audio_projection_mode == "vision":
            return self.vision_speech_projection

    def forward(
        self,
        audio_input_features: torch.FloatTensor,
        audio_embed_sizes=None,
        audio_attention_mask=None,
        audio_projection_mode="speech",
    ) -> torch.FloatTensor:

        audio_projection = self.get_projection(audio_projection_mode)

        target_device = audio_projection.up.bias.device
        target_dtype = audio_projection.up.bias.dtype

        audio_input_features = audio_input_features.to(device=target_device,
                                                       dtype=target_dtype)

        audio_encoder_hidden_states = self.encoder(audio_input_features,
                                                   audio_attention_mask)
        audio_embeds = audio_projection(audio_encoder_hidden_states)

        return audio_embeds.flatten(0, 1)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Phi4MMImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`

    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    image_sizes: torch.Tensor
    """
    Shape: `(batch_size * num_images, 2)`

    This should be in `(height, width)` format.
    """

    num_img_tokens: list[int]
    """Shape: `(batch_size * num_images)`"""

    image_attention_mask: torch.Tensor
    """Shape: `(batch_size * num_images, H_mask, W_mask)`"""


class Phi4MMImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: Union[torch.Tensor, list[torch.Tensor]]
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


class Phi4MMAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
    data: Union[torch.Tensor, list[torch.Tensor]]
    """Shape: `(batch_size * num_audios, 80, M)"""


class Phi4MMAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
    data: NestedTensors
    """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""


Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]


def cat_with_pad(tensors, dim, padding_value=0):
    """
    cat along dim, while pad to max for all other dims
    """
    ndim = tensors[0].dim()
    assert all(
        t.dim() == ndim for t in
        tensors[1:]), "All tensors must have the same number of dimensions"

    out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
    out_size[dim] = sum(t.shape[dim] for t in tensors)
    output = tensors[0].new_full(out_size, padding_value)

    index = 0
    for t in tensors:
        # Create a slice list where every dimension except dim is full slice
        slices = [slice(0, t.shape[d]) for d in range(ndim)]
        # Update only the concat dimension slice
        slices[dim] = slice(index, index + t.shape[dim])

        output[slices] = t
        index += t.shape[dim]

    return output


class Phi4MMProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self) -> Phi4MultimodalConfig:
        return self.ctx.get_hf_config(Phi4MultimodalConfig)

    def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor:
        return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs)

    def get_feature_extractor(
            self, **kwargs: object) -> Phi4MultimodalFeatureExtractor:
        return self.get_hf_processor(**kwargs).audio_processor

    def get_image_processor(
        self,
        processor: Optional[Phi4MMProcessor] = None,
    ) -> Phi4MultimodalImageProcessorFast:
        if processor is None:
            processor = self.get_hf_processor()
        return processor.image_processor

    def get_dynamic_hd(
        self,
        processor: Optional[Phi4MMProcessor] = None,
    ) -> int:
        return self.get_image_processor(processor).dynamic_hd

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None, "image": None}

    def _find_target_aspect_ratio(
        self,
        orig_width: int,
        orig_height: int,
        image_size: int,
        max_num: int,
        min_num: int,
    ):
        w_crop_num = math.ceil(orig_width / float(image_size))
        h_crop_num = math.ceil(orig_height / float(image_size))
        if w_crop_num * h_crop_num > max_num:
            aspect_ratio = orig_width / orig_height

            # calculate the existing image aspect ratio
            target_ratios = set((i, j) for i in range(1, max_num + 1)
                                for j in range(1, max_num + 1)
                                if i * j <= max_num and i * j >= min_num)
            target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

            # find the closest aspect ratio to the target
            image_processor = self.get_image_processor()
            target_aspect_ratio = image_processor.find_closest_aspect_ratio(
                aspect_ratio,
                target_ratios,
                orig_width,
                orig_height,
            )

            # calculate the target width and height
            target_width = image_size * target_aspect_ratio[0]
            target_height = image_size * target_aspect_ratio[1]
        else:
            target_width = image_size * w_crop_num
            target_height = image_size * h_crop_num
            target_aspect_ratio = (w_crop_num, h_crop_num)
        return target_aspect_ratio, target_height, target_width

    def _compute_num_image_tokens(
        self,
        orig_width: int,
        orig_height: int,
        dynamic_hd_size: int,
        vit_image_size: int,
        vit_patch_size: int,
        token_compression_factor: int = 2,
    ):
        """
        compute the number of tokens an image is expected to take up considering
        the image encoder architecture and exclude output features containing 
        only padding pixels

        for siglip, vit_image_size=448, vit_patch_size=14, so output will be 
        32x32 feature map
        NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
        """
        assert vit_image_size % vit_patch_size == 0, (
            "vit_image_size must be divisible by vit_patch_size")
        assert (vit_image_size // vit_patch_size %
                token_compression_factor == 0), (
                    "vit_image_size // vit_patch_size must be divisible by "
                    "token_compression_factor")

        target_aspect_ratio, target_height, target_width = (
            self._find_target_aspect_ratio(orig_width,
                                           orig_height,
                                           vit_image_size,
                                           dynamic_hd_size,
                                           min_num=1))
        assert target_aspect_ratio[0] * vit_image_size == target_width, (
            f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}")
        assert target_aspect_ratio[1] * vit_image_size == target_height, (
            f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}")
        assert (target_height % vit_image_size == 0
                and target_width % vit_image_size == 0)

        padding_height, padding_width = _get_padding_size(
            orig_width, orig_height, target_height, target_width)
        assert padding_width == 0 or padding_height == 0, \
            "padding_width or padding_height must be 0"

        target_feat_width = target_width // vit_patch_size
        target_feat_height = target_height // vit_patch_size
        if padding_width >= vit_patch_size:
            assert padding_height == 0, "padding_height not 0"
            non_pad_feat_width = target_feat_width - math.floor(
                padding_width / vit_patch_size)
            non_pad_feat_height = target_feat_height
        elif padding_height >= vit_patch_size:
            assert padding_width == 0, "padding_width not 0"
            non_pad_feat_height = target_feat_height - math.floor(
                padding_height / vit_patch_size)
            non_pad_feat_width = target_feat_width
        else:
            # small padding shorter than a vit patch
            non_pad_feat_width = target_feat_width
            non_pad_feat_height = target_feat_height

        feat_width = non_pad_feat_width // token_compression_factor
        feat_height = non_pad_feat_height // token_compression_factor
        # NOTE it's possible that the non-padding feature is not divisible
        if non_pad_feat_width % token_compression_factor != 0:
            feat_width += 1
        if non_pad_feat_height % token_compression_factor != 0:
            feat_height += 1
        num_hd_patch_tokens = feat_width * feat_height
        num_hd_newline_tokens = feat_height
        vit_feature_size = vit_image_size // vit_patch_size
        num_global_image_tokens = (vit_feature_size //
                                   token_compression_factor)**2
        num_sep_tokens = 1
        num_global_image_newline_tokens = \
            vit_feature_size // token_compression_factor

        return (num_global_image_tokens + num_sep_tokens +
                num_hd_patch_tokens + num_hd_newline_tokens +
                num_global_image_newline_tokens)

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Phi4MMProcessor] = None,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        vit_image_size = vision_config.image_size
        vit_patch_size = vision_config.patch_size

        dynamic_hd_size = self.get_dynamic_hd(processor=processor)

        # we use default `token_compression_factor=2`,
        # since it's not in HF vision config.
        image_num_tokens = self._compute_num_image_tokens(
            image_width,
            image_height,
            dynamic_hd_size=dynamic_hd_size,
            vit_image_size=vit_image_size,
            vit_patch_size=vit_patch_size,
        )

        return image_num_tokens

    def get_image_size_with_most_features(
        self,
        processor: Optional[Phi4MMProcessor] = None,
    ) -> ImageSize:
        vit_image_size = self.get_hf_config().vision_config.image_size

        max_side = vit_image_size * self.get_dynamic_hd(processor=processor)
        return ImageSize(height=max_side, width=vit_image_size)

    def get_audio_num_frames(self, audio_len: int, sr: float) -> int:
        """
        Compute the output size of the `extract_features` method.

        Args:
            audio_len (int): Length of the input waveform in samples.
            sr (float): Sampling rate of the waveform, either 16000 or 8000.

        Returns:
            tuple (int, int): Output size as (T, D), where:
                T: Number of time frames.
                D: Number of Mel filterbank bins (80).
        """

        # Resample to 16000 or 8000 if needed
        if sr > 16000:
            audio_len //= sr // 16000
        elif 8000 <= sr < 16000:
            # We'll resample to 16K from 8K
            audio_len *= 2
        elif sr < 8000:
            raise RuntimeError(f"Unsupported sample rate {sr}")

        # Spectrogram parameters for 16 kHz
        win_length = 400  # Frame length in samples
        hop_length = 160  # Frame shift in samples

        # Calculate number of frames (T)
        num_frames = (audio_len - win_length) // hop_length + 1
        if num_frames < 1:
            raise ValueError("Waveform too short for given parameters.")

        # Return time frames (T)
        return num_frames

    def _compute_audio_embed_size(self, audio_frames: int) -> int:
        """
        Compute the size of audio embeddings from the number of audio frames.
        """
        # `_compute_audio_embed_size` in audio_processor use torch for
        # computation, therefore we re-implement it to use pythonic
        # numeric computation to avoid extra tensor conversion.
        audio_processor = self.get_feature_extractor()
        audio_compression_rate = audio_processor.audio_compression_rate
        audio_downsample_rate = audio_processor.audio_downsample_rate

        integer = audio_frames // audio_compression_rate
        remainder = audio_frames % audio_compression_rate
        result = integer + int(remainder > 0)

        integer = result // audio_downsample_rate
        remainder = result % audio_downsample_rate
        result = integer + int(remainder > 0)  # qformer compression

        return result


class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)
        num_images = mm_counts.get("image", 0)

        tokenizer = self.info.get_tokenizer()
        image_tokens: str = tokenizer.image_token * num_images
        audio_tokens: str = tokenizer.audio_token * num_audios

        return image_tokens + audio_tokens

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)
        num_images = mm_counts.get("image", 0)

        target_width, target_height = \
            self.info.get_image_size_with_most_features()

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "audio":
            self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE,
                                   num_audios=num_audios),
        }

        return mm_data


class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):

    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data:
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        audio_data = mm_data.pop("audios", [])
        if audio_data:
            mm_data['audio'] = audio_data

        processed_outputs = super()._call_hf_processor(prompt, mm_data,
                                                       mm_kwargs, tok_kwargs)

        if "image_pixel_values" in processed_outputs:
            num_img_tokens = [
                self.info.get_num_image_tokens(image_width=img_size[0],
                                               image_height=img_size[1])
                for img_size in processed_outputs["image_sizes"]
            ]
            processed_outputs["num_img_tokens"] = num_img_tokens

        if audio_data:
            audio_features = processed_outputs['audio_input_features']
            sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
            feature_sizes = [
                self.info.get_audio_num_frames(len(audio), sr)
                for audio in audio_data
            ]
            processed_outputs['audio_input_features'] = [
                audio_features[idx, :size]
                for idx, size in enumerate(feature_sizes)
            ]

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            image_pixel_values=MultiModalFieldConfig.batched("image"),
            image_attention_mask=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            num_img_tokens=MultiModalFieldConfig.batched("image"),
            audio_input_features=MultiModalFieldConfig.batched("audio"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargs,
    ) -> Sequence[PromptUpdate]:
        tokenizer = self.info.get_tokenizer()
        image_token_id = tokenizer.vocab[tokenizer.image_token]
        audio_token_id = tokenizer.vocab[tokenizer.audio_token]

        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        audio_processor = self.info.get_feature_extractor(
            **hf_processor_mm_kwargs)

        def get_image_replacement_phi4mm(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            image_tokens = [image_token_id] * num_image_tokens

            return image_tokens

        def get_audio_replacement_phi4mm(item_idx: int):
            audios = mm_items.get_items("audio", AudioProcessorItems)
            # TODO(Isotr0py): support embedding inputs
            audio_len = audios.get_audio_length(item_idx)
            audio_frames = self.info.get_audio_num_frames(
                audio_len, audio_processor.sampling_rate)
            audio_embed_size = self.info._compute_audio_embed_size(
                audio_frames)

            audio_tokens = [audio_token_id] * audio_embed_size

            return audio_tokens

        return [
            PromptReplacement(
                modality="audio",
                target=[audio_token_id],
                replacement=get_audio_replacement_phi4mm,
            ),
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_image_replacement_phi4mm,
            ),
        ]


@MULTIMODAL_REGISTRY.register_processor(
    Phi4MMMultiModalProcessor,
    info=Phi4MMProcessingInfo,
    dummy_inputs=Phi4MMDummyInputsBuilder,
)
class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
    """
    Implements the Phi-4-multimodal-instruct model in vLLM.
    """
    packed_modules_mapping = {
        "qkv_proj": [
            "qkv_proj",
        ],
        "gate_up_proj": [
            "gate_up_proj",
        ],
    }

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # Multimodal embedding
            "model.embed_tokens_extend.": "",
            # LLM backbone
            "model.": "language_model.model.",
        },
        orig_to_new_substr={
            # projection
            ".img_projection_": ".img_projection.",
            ".up_proj_for_speech.": ".speech_projection.up.",
            ".up_proj_for_vision_speech.": ".vision_speech_projection.up.",
            ".down_proj_for_speech.": ".speech_projection.down.",
            ".down_proj_for_vision_speech.": ".vision_speech_projection.down.",
        },
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|image|>"
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only image or audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        # TODO: Optionally initializes these for supporting input embeddings.
        self.image_embed = Phi4MMImageEmbedding(
            config,
            # prefix=maybe_prefix(prefix, "image_embed"),
        )
        self.audio_embed = Phi4MMAudioEmbedding(
            config,
            # prefix=maybe_prefix(prefix, "audio_embed"),
        )

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Phi3ForCausalLM"],
        )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[Phi4MMAudioInputs]:
        """
        Parse and validate the audio input to the model.  This handles both 
        audio features and audio embeddings, but only the former is used for
        now.

        Args:
            kwargs (object): Keyword arguments.

        Returns:
            Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
        """
        audio_features = kwargs.pop("audio_input_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            if not isinstance(audio_features, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio features. "
                                 f"Got type: {type(audio_features)}")

            return Phi4MMAudioFeatureInputs(type="audio_features",
                                            data=flatten_bn(audio_features))

        if audio_embeds is not None:
            if not isinstance(audio_embeds, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")

            return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
                                              data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(self, audio_input: Phi4MMAudioInputs,
                             audio_projection_mode: str) -> NestedTensors:
        """
        Create the audio embeddings from the audio input, where the audio input
        is pairs of audio features and audio embed lengths.  The audio input is
        created by `input_mapper_for_phi4mm_audio`.

        Args:
            audio_input (Phi4MMAudioInputs): Audio input.

        Returns:
            NestedTensors: Audio embeddings
        """
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

        audio_features = audio_input["data"]
        # (e.g. multiple examples) and the second dim is the multi-audio dim
        # (e.g. multiple audios in the same example)

        dtype = next(self.audio_embed.parameters()).dtype
        audio_embeds = [
            self.audio_embed(
                features.unsqueeze(0).to(dtype),
                audio_projection_mode=audio_projection_mode,
            ) for features in audio_features
        ]
        return audio_embeds

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]:
        image_pixel_values: NestedTensors = kwargs.get("image_pixel_values")
        if image_pixel_values is None:
            return None

        image_sizes = kwargs.get("image_sizes")
        image_attention_mask = kwargs.get("image_attention_mask")
        num_img_tokens = kwargs.get("num_img_tokens")
        assert image_sizes is not None and image_attention_mask is not None\
              and num_img_tokens is not None, "Missing image inputs"

        if is_list_of(image_pixel_values, torch.Tensor):
            assert all(p.dim() == 5
                       for p in image_pixel_values), "Incorrect image inputs"
            # list len is batch_size.
            # each tensor has dimension: num_img_per_example, num_hd_patches,
            # channels, height, width.
            # need to pad along num_hd_patches.
            # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
            image_pixel_values = cat_with_pad(image_pixel_values, dim=0)
        elif isinstance(image_pixel_values, torch.Tensor):
            # dimension: batch_size, num_img_per_example, num_hd_patches,
            # channels, height, width.
            # we flatten first 2 dims to make it a single large batch for
            # SigLIP Encoder.
            assert image_pixel_values.dim() == 6, "Incorrect image inputs"
            image_pixel_values = image_pixel_values.flatten(0, 1)
        else:
            raise ValueError("Incorrect image_pixel_values inputs")

        if isinstance(image_attention_mask, list):
            image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
        elif isinstance(image_attention_mask, torch.Tensor):
            image_attention_mask = image_attention_mask.flatten(0, 1)
        else:
            raise ValueError("Incorrect image_attention_mask inputs")

        if isinstance(image_sizes, list):
            image_sizes = torch.cat(image_sizes, dim=0)
        elif isinstance(image_sizes, torch.Tensor):
            image_sizes = image_sizes.flatten(0, 1)
        else:
            raise ValueError("Incorrect image_attention_mask inputs")

        if isinstance(num_img_tokens, list):
            num_img_tokens = [
                n for num_tensor in num_img_tokens
                for n in num_tensor.tolist()
            ]
        elif isinstance(num_img_tokens, torch.Tensor):
            num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
        else:
            raise ValueError("Incorrect image_attention_mask inputs")

        return Phi4MMImagePixelInputs(
            type="pixel_values",
            data=image_pixel_values,
            image_sizes=image_sizes,
            image_attention_mask=image_attention_mask,
            num_img_tokens=num_img_tokens,
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("image_pixel_values",
                             "image_embeds") and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if input_key in ("audio_input_features",
                             "audio_embeds") and "audios" not in modalities:
                modalities["audios"] = self._parse_and_validate_audio_input(
                    **kwargs)

        return modalities

    def _process_image_input(
            self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]:
        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            dtype = next(self.image_embed.parameters()).dtype
            pixel_values = image_input['data'].to(dtype)
            image_sizes = image_input['image_sizes']
            image_attention_mask = image_input['image_attention_mask']
            image_embeds = self.image_embed(pixel_values, image_sizes,
                                            image_attention_mask)
        return image_embeds

    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        audio_projection_mode = 'speech'
        for modality in modalities:
            # make sure process images first
            if modality == "images":
                audio_projection_mode = "vision"
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(vision_embeddings)
            if modality == "audios":
                audio_input = modalities["audios"]
                audio_embeddings = self._process_audio_input(
                    audio_input, audio_projection_mode=audio_projection_mode)
                multimodal_embeddings += tuple(audio_embeddings)

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
        return inputs_embeds

    def get_input_embeddings_v0(
        self,
        input_ids: torch.Tensor,
        image_input: Optional[Phi4MMImagePixelInputs] = None,
        audio_input: Optional[Phi4MMAudioFeatureInputs] = None,
    ) -> torch.Tensor:
        audio_projection_mode = 'speech'
        inputs_embeds = self.get_input_embeddings(input_ids)
        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                image_embeds,
                placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID,
            )
            audio_projection_mode = 'vision'

        if audio_input is not None:
            audio_embeds = self._process_audio_input(
                audio_input, audio_projection_mode=audio_projection_mode)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                audio_embeds,
                placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID,
            )
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
        elif inputs_embeds is None:
            image_input = self._parse_and_validate_image_input(**kwargs)
            audio_input = self._parse_and_validate_audio_input(**kwargs)

            if image_input is None and audio_input is None:
                inputs_embeds = None
            else:
                inputs_embeds = self.get_input_embeddings_v0(
                    input_ids,
                    image_input=image_input,
                    audio_input=audio_input)
                input_ids = None

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector=[
                "img_projection", "vision_speech_projection",
                "speech_projection"
            ],
            tower_model=["image_embed", "audio_embed"],
        )

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model
