# Owner(s): ["module: onnx"]
"""Test op correctness by comparing with PyTorch results.

Usage:

    pytest test_ops.py

    To run tests on a specific operator (e.g. torch.ceil):

    pytest test_ops.py -k ceil

    To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):

    pytest test_ops.py -k nn_functional_scaled_dot_product_attention

## Environment variables

1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
in onnxruntime by running the inference sessions in a separate process.

2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of
errors.
"""

from __future__ import annotations

import os
from typing import Callable, Optional, TYPE_CHECKING

import error_reproduction
import numpy as np

import onnx
import onnxruntime as ort
import onnxscript
import ops_test_common
import ops_test_data
import parameterized

import torch
from torch.testing._internal import common_device_type, common_utils
from torch.utils import _pytree as pytree


if TYPE_CHECKING:
    import unittest
    from collections.abc import Sequence

    from torch.testing._internal.opinfo import core as opinfo_core

# All dtypes will be tested on the generated symbolic functions.
# complex64 will be flattened to float32.
TESTED_DTYPES = (
    torch.float16,
    torch.float32,
    # Uncomment below item when we really need testing it
    # torch.bfloat16,
    # torch.float64,
    torch.bool,
    # torch.int8,
    # torch.int16,
    torch.int32,
    torch.int64,
    # torch.uint8,
)
# NOTE: torch.complex32 is experimental in torch
COMPLEX_TYPES = (torch.complex64,)


def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
    """Returns all dtypes except the ones specified."""
    return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes)


def _should_skip_xfail_test_sample(
    op_name: str, sample, dtype: torch.dtype, device_type: str
) -> tuple[Optional[str], Optional[str]]:
    """Returns a reason if a test sample should be skipped."""
    if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
        return None, None
    for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS:
        # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
        if decorator_meta.op_name == op_name:
            assert decorator_meta.matcher is not None, "Matcher must be defined"
            if not decorator_meta.enabled_if:
                # Do not skip the test if the decorator meta is not enabled
                continue
            if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
                # Not applicable for this dtype
                continue
            if (
                decorator_meta.device_type is not None
                and decorator_meta.device_type != device_type
            ):
                # Not applicable for this device_type
                continue
            if decorator_meta.matcher(sample):
                return decorator_meta.test_behavior, decorator_meta.reason
    return None, None


class TestFunctionValidity(common_utils.TestCase):
    @parameterized.parameterized.expand(
        [
            (info.op.name, info)
            for info in ops_test_data.TESTED_TORCHLIB_OPS
            if isinstance(info.op, onnxscript.OnnxFunction)
        ],
        skip_on_empty=True,
    )
    def test_script_function_passes_checker(
        self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo
    ):
        function_proto = torchlib_op_info.op.to_function_proto()
        onnx.checker.check_function(function_proto)  # type: ignore[attr-defined]


def run_test_output_match(
    test_suite: unittest.TestCase,
    device: str,
    dtype: torch.dtype,
    op: opinfo_core.OpInfo,
    function_executor: Callable,
    tested_op_mapping: dict[
        str,
        ops_test_data.TorchLibOpInfo,
    ],
):
    """Base test method for testing each opset, used by instantiate_device_type_tests.

    Args:
        test_suite: The test class instance.
        device: The PyTorch device. instantiate_device_type_tests provides this.
        dtype: The PyTorch dtype. instantiate_device_type_tests provides this.
        op: The OpInfo instance. instantiate_device_type_tests provides this.
        function_executor: The function executor. This is a function that takes
            a function and its arguments and returns the output of the function.
        tested_op_mapping: The mapping of op name to the tested op.
    """
    samples = op.sample_inputs(
        device,
        dtype,
        requires_grad=False,
    )

    torchlib_op_info = tested_op_mapping[op.name]
    # Obtain the input_wrangler that manipulates the OpInfo inputs
    # to match the aten operator signature
    # An example is nn.functional.upsample_nearest2d, which has a different signature
    # than the aten operator upsample_nearest2d
    onnx_function = torchlib_op_info.op
    input_wrangler = torchlib_op_info.input_wrangler
    if (
        not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema)
        and dtype not in COMPLEX_TYPES
    ):
        test_suite.skipTest(
            f"dtype '{dtype}' is not supported by the op '{op.name}'. "
            f"Type constraints: {onnx_function.op_schema.type_constraints}"
        )

    # Obtain the tolerance for the op
    rtol, atol = torchlib_op_info.get_tolerance(dtype)
    for i, cpu_sample in enumerate(samples):
        inputs = (cpu_sample.input, *cpu_sample.args)
        # Provide the repr to subtest because tensors are not serializable in parallel test runs
        with test_suite.subTest(
            sample_num=i,
            inputs=repr(
                [
                    f"Tensor<{inp.shape}, dtype={inp.dtype}>"
                    if isinstance(inp, torch.Tensor)
                    else inp
                    for inp in inputs
                ]
            ),
            kwargs=repr(cpu_sample.kwargs),
        ):
            try:
                device_type = cpu_sample.args[0].device.type
            except (AttributeError, IndexError):
                device_type = "cpu"
            test_behavior, reason = _should_skip_xfail_test_sample(
                op.name, cpu_sample, dtype, device_type
            )

            with ops_test_common.normal_xfail_skip_test_behaviors(
                test_behavior, reason
            ):
                input_onnx = [
                    ops_test_common.convert_tensor_to_numpy(x) for x in inputs
                ]
                kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs)
                if input_wrangler:
                    input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
                torch_output = op(*inputs, **cpu_sample.kwargs)

                if isinstance(torch_output, torch.Tensor) and torch.is_complex(
                    torch_output
                ):
                    torch_output = torch.view_as_real(torch_output.resolve_conj())

                reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
                if (
                    op.name.startswith("split")
                    or op.name.startswith("chunk")
                    or op.name.startswith("unbind")
                    or op.name
                    in {
                        "atleast_1d_Sequence",
                        "atleast_2d_Sequence",
                        "atleast_3d_Sequence",
                    }
                ):
                    # Hack for handling split, chunk and unbind which relies on SplitToSequence op.
                    # Split returns a Sequence that should be treats as a single
                    # value. So we wrap it into a tuple.
                    # TODO(justinchuby): Find a more general solution
                    reference_torch_outputs = [reference_torch_outputs]

                test_name = test_suite.id()
                function_output, model_proto = function_executor(
                    test_name,
                    reference_torch_outputs,
                    opset_version=torchlib_op_info.opset_introduced,
                )(onnx_function, input_onnx, kwargs_onnx)
                # Finally we re-flatten everything
                # TODO: add pytree structure comparison.
                flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
                flattened_function_outputs, _ = pytree.tree_flatten(function_output)

                assert flattened_torch_outputs
                assert len(flattened_torch_outputs) == len(flattened_function_outputs)

                for j, (torch_output, function_output) in enumerate(
                    zip(flattened_torch_outputs, flattened_function_outputs)
                ):
                    actual = torch.tensor(function_output)
                    expected = (
                        torch_output
                        if isinstance(torch_output, torch.Tensor)
                        else torch.tensor(torch_output)
                    )

                    if (
                        op.name in ops_test_data.NONDETERMINISTIC_OPS
                        or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name]
                    ):
                        # Check shape and dtype only for ops that are known to be
                        # nondeterministic
                        test_suite.assertEqual(actual.shape, expected.shape)
                        test_suite.assertEqual(actual.dtype, expected.dtype)
                        continue

                    # Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
                    try:
                        torch.testing.assert_close(
                            actual,
                            expected,
                            rtol=rtol,
                            atol=atol,
                            equal_nan=True,
                            check_device=False,
                        )
                    except AssertionError as e:
                        if (
                            os.environ.get("CREATE_REPRODUCTION_REPORT") == "1"
                            and test_behavior is None
                        ):
                            error_reproduction.create_mismatch_report(
                                test_name,
                                i,
                                model_proto,
                                inputs,
                                cpu_sample.kwargs,
                                actual,
                                expected,
                                e,
                                __file__,
                            )
                        if len(flattened_torch_outputs) > 1:
                            raise AssertionError(f"Output {j} mismatch") from e
                        raise


class TestOutputConsistencyFullGraph(common_utils.TestCase):
    """Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.

    This is a parameterized test suite.
    """

    def setUp(self) -> None:
        torch.manual_seed(42)
        np.random.seed(42)
        ort.set_seed(42)

    @ops_test_common.add_decorate_info(
        ops_test_data.OPS_DB,
        "TestOutputConsistencyFullGraph",
        "test_output_match_opinfo_",
        skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
    )
    @common_device_type.ops(  # type: ignore[misc]
        [
            info
            for info in ops_test_data.OPS_DB
            if info.name in ops_test_data.TESTED_OPS
        ],
        allowed_dtypes=TESTED_DTYPES,
    )
    def test_output_match_opinfo_(
        self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
    ):
        # Base test method for testing each op by running the full ONNX graph.
        run_test_output_match(
            self,
            device,
            dtype,
            op,
            ops_test_common.graph_executor,
            ops_test_data.TORCHLIB_OPINFO_MAPPING,
        )

    @ops_test_common.add_decorate_info(
        ops_test_data.OPS_DB,
        "TestOutputConsistencyFullGraph",
        "test_complex_output_match_opinfo_",
        skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
    )
    @common_device_type.ops(  # type: ignore[misc]
        [
            info
            for info in ops_test_data.OPS_DB
            if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
        ],
        allowed_dtypes=COMPLEX_TYPES,
    )
    def test_complex_output_match_opinfo_(
        self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
    ):
        """Base test method for testing each op by running the full ONNX graph."""
        run_test_output_match(
            self,
            device,
            dtype,
            op,
            ops_test_common.graph_executor,
            ops_test_data.COMPLEX_FUNCTION_MAPPING,
        )


common_device_type.instantiate_device_type_tests(
    TestOutputConsistencyFullGraph, globals(), only_for=["cpu"]
)

if __name__ == "__main__":
    common_utils.run_tests()
