# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import tempfile
from pathlib import Path

import numpy as np
import numpy.typing as npt
import pytest
from PIL import Image

from vllm.assets.base import get_vllm_public_assets
from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list
from vllm.multimodal.image import ImageMediaIO
from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader,
                                   VideoMediaIO)

from .utils import cosine_similarity, create_video_from_image, normalize_image

NUM_FRAMES = 10
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)


@VIDEO_LOADER_REGISTRY.register("test_video_loader_1")
class TestVideoLoader1(VideoLoader):

    @classmethod
    def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
        return FAKE_OUTPUT_1


@VIDEO_LOADER_REGISTRY.register("test_video_loader_2")
class TestVideoLoader2(VideoLoader):

    @classmethod
    def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
        return FAKE_OUTPUT_2


def test_video_loader_registry():
    custom_loader_1 = VIDEO_LOADER_REGISTRY.load("test_video_loader_1")
    output_1 = custom_loader_1.load_bytes(b"test")
    np.testing.assert_array_equal(output_1, FAKE_OUTPUT_1)

    custom_loader_2 = VIDEO_LOADER_REGISTRY.load("test_video_loader_2")
    output_2 = custom_loader_2.load_bytes(b"test")
    np.testing.assert_array_equal(output_2, FAKE_OUTPUT_2)


def test_video_loader_type_doesnt_exist():
    with pytest.raises(AssertionError):
        VIDEO_LOADER_REGISTRY.load("non_existing_video_loader")


@VIDEO_LOADER_REGISTRY.register("assert_10_frames_1_fps")
class Assert10Frames1FPSVideoLoader(VideoLoader):

    @classmethod
    def load_bytes(cls,
                   data: bytes,
                   num_frames: int = -1,
                   fps: float = -1.0,
                   **kwargs) -> npt.NDArray:
        assert num_frames == 10, "bad num_frames"
        assert fps == 1.0, "bad fps"
        return FAKE_OUTPUT_2


def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:
        m.setenv("VLLM_VIDEO_LOADER_BACKEND", "assert_10_frames_1_fps")
        imageio = ImageMediaIO()

        # Verify that different args pass/fail assertions as expected.
        videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
        _ = videoio.load_bytes(b"test")

        videoio = VideoMediaIO(
            imageio, **{
                "num_frames": 10,
                "fps": 1.0,
                "not_used": "not_used"
            })
        _ = videoio.load_bytes(b"test")

        with pytest.raises(AssertionError, match="bad num_frames"):
            videoio = VideoMediaIO(imageio, **{})
            _ = videoio.load_bytes(b"test")

        with pytest.raises(AssertionError, match="bad num_frames"):
            videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
            _ = videoio.load_bytes(b"test")

        with pytest.raises(AssertionError, match="bad fps"):
            videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
            _ = videoio.load_bytes(b"test")


@pytest.mark.parametrize("is_color", [True, False])
@pytest.mark.parametrize("fourcc, ext", [("mp4v", "mp4"), ("XVID", "avi")])
def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str):
    """
    Test all functions that use OpenCV for video I/O return RGB format.
    Both RGB and grayscale videos are tested.
    """
    image_path = get_vllm_public_assets(filename="stop_sign.jpg",
                                        s3_prefix="vision_model_images")
    image = Image.open(image_path)
    with tempfile.TemporaryDirectory() as tmpdir:
        if not is_color:
            image_path = f"{tmpdir}/test_grayscale_image.png"
            image = image.convert("L")
            image.save(image_path)
            # Convert to gray RGB for comparison
            image = image.convert("RGB")
        video_path = f"{tmpdir}/test_RGB_video.{ext}"
        create_video_from_image(
            image_path,
            video_path,
            num_frames=2,
            is_color=is_color,
            fourcc=fourcc,
        )

        frames = video_to_ndarrays(video_path)
        for frame in frames:
            sim = cosine_similarity(normalize_image(np.array(frame)),
                                    normalize_image(np.array(image)))
            assert np.sum(np.isnan(sim)) / sim.size < 0.001
            assert np.nanmean(sim) > 0.99

        pil_frames = video_to_pil_images_list(video_path)
        for frame in pil_frames:
            sim = cosine_similarity(normalize_image(np.array(frame)),
                                    normalize_image(np.array(image)))
            assert np.sum(np.isnan(sim)) / sim.size < 0.001
            assert np.nanmean(sim) > 0.99

        io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path))
        for frame in io_frames:
            sim = cosine_similarity(normalize_image(np.array(frame)),
                                    normalize_image(np.array(image)))
            assert np.sum(np.isnan(sim)) / sim.size < 0.001
            assert np.nanmean(sim) > 0.99
