# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections import OrderedDict
from typing import NamedTuple, Optional
from unittest.mock import patch

import pytest
from huggingface_hub.utils import HfHubHTTPError
from torch import nn

from vllm.lora.utils import (get_adapter_absolute_path,
                             parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.utils import WeightsMapper


class LoRANameParserTestConfig(NamedTuple):
    name: str
    module_name: str
    is_lora_a: bool
    is_bias: bool
    weights_mapper: Optional[WeightsMapper] = None


def test_parse_fine_tuned_lora_name_valid():
    fixture = [
        LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight",
                                 "lm_head", True, False),
        LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight",
                                 "lm_head", False, False),
        LoRANameParserTestConfig(
            "base_model.model.model.embed_tokens.lora_embedding_A",
            "model.embed_tokens",
            True,
            False,
        ),
        LoRANameParserTestConfig(
            "base_model.model.model.embed_tokens.lora_embedding_B",
            "model.embed_tokens",
            False,
            False,
        ),
        LoRANameParserTestConfig(
            "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
            "model.layers.9.mlp.down_proj",
            True,
            False,
        ),
        LoRANameParserTestConfig(
            "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
            "model.layers.9.mlp.down_proj",
            False,
            False,
        ),
        LoRANameParserTestConfig(
            "language_model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.layers.9.mlp.down_proj",
            True,
            False,
        ),
        LoRANameParserTestConfig(
            "language_model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.layers.9.mlp.down_proj",
            False,
            False,
        ),
        # Test with WeightsMapper
        LoRANameParserTestConfig(
            "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.model.layers.9.mlp.down_proj",
            True,
            False,
            weights_mapper=WeightsMapper(
                orig_to_new_prefix={"model.": "language_model.model."}),
        ),
        LoRANameParserTestConfig(
            "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.model.layers.9.mlp.down_proj",
            False,
            False,
            weights_mapper=WeightsMapper(
                orig_to_new_prefix={"model.": "language_model.model."}),
        ),
        LoRANameParserTestConfig(
            "model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.model.layers.9.mlp.down_proj",
            True,
            False,
            weights_mapper=WeightsMapper(
                orig_to_new_prefix={"model.": "language_model.model."}),
        ),
        LoRANameParserTestConfig(
            "model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.model.layers.9.mlp.down_proj",
            False,
            False,
            weights_mapper=WeightsMapper(
                orig_to_new_prefix={"model.": "language_model.model."}),
        ),
    ]
    for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
        assert (module_name, is_lora_a,
                is_bias) == parse_fine_tuned_lora_name(name, weights_mapper)


def test_parse_fine_tuned_lora_name_invalid():
    fixture = {
        "base_model.weight",
        "base_model.model.weight",
    }
    for name in fixture:
        with pytest.raises(ValueError, match="unsupported LoRA weight"):
            parse_fine_tuned_lora_name(name)


def test_replace_submodule():
    model = nn.Sequential(
        OrderedDict([
            ("dense1", nn.Linear(764, 100)),
            ("act1", nn.ReLU()),
            ("dense2", nn.Linear(100, 50)),
            (
                "seq1",
                nn.Sequential(
                    OrderedDict([
                        ("dense1", nn.Linear(100, 10)),
                        ("dense2", nn.Linear(10, 50)),
                    ])),
            ),
            ("act2", nn.ReLU()),
            ("output", nn.Linear(50, 10)),
            ("outact", nn.Sigmoid()),
        ]))

    sigmoid = nn.Sigmoid()

    replace_submodule(model, "act1", sigmoid)
    assert dict(model.named_modules())["act1"] == sigmoid

    dense2 = nn.Linear(1, 5)
    replace_submodule(model, "seq1.dense2", dense2)
    assert dict(model.named_modules())["seq1.dense2"] == dense2


# Unit tests for get_adapter_absolute_path
@patch('os.path.isabs')
def test_get_adapter_absolute_path_absolute(mock_isabs):
    path = '/absolute/path/to/lora'
    mock_isabs.return_value = True
    assert get_adapter_absolute_path(path) == path


@patch('os.path.expanduser')
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
    # Path with ~ that needs to be expanded
    path = '~/relative/path/to/lora'
    absolute_path = '/home/user/relative/path/to/lora'
    mock_expanduser.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


@patch('os.path.exists')
@patch('os.path.abspath')
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
    # Relative path that exists locally
    path = 'relative/path/to/lora'
    absolute_path = '/absolute/path/to/lora'
    mock_exist.return_value = True
    mock_abspath.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_adapter_absolute_path_huggingface(mock_exist,
                                               mock_snapshot_download):
    # Hugging Face model identifier
    path = 'org/repo'
    absolute_path = '/mock/snapshot/path'
    mock_exist.return_value = False
    mock_snapshot_download.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
                                                     mock_snapshot_download):
    # Hugging Face model identifier with download error
    path = 'org/repo'
    mock_exist.return_value = False
    mock_snapshot_download.side_effect = HfHubHTTPError(
        "failed to query model info")
    assert get_adapter_absolute_path(path) == path
