# Owner(s): ["module: onnx"]
"""Test torch.onnx.ops."""

from __future__ import annotations

import onnx_ir.passes.common as common_passes
from onnxscript import ir

import torch
from torch.onnx.ops import _impl, _symbolic_impl
from torch.testing._internal import common_utils


class SchemaTest(common_utils.TestCase):
    def test_symbolic_has_correct_schema(self):
        torch.library.opcheck(
            _symbolic_impl._symbolic,
            ([torch.tensor(1)], "CustomOp", 1),
            dict(
                shape=[
                    1,
                ],
                attr_keys=["key"],
                attr_types=["i"],
                attr_pos=[(0, 1)],
                attr_ints=[1],
                attr_floats=[1.0],
                attr_strs=["attr"],
                metadata_props_keys=["meta_key"],
                metadata_props_values=["meta_value"],
                domain="custom_domain",
                version=42,
            ),
        )

        # Empty inputs
        torch.library.opcheck(
            _symbolic_impl._symbolic,
            ([], "CustomOp", 1),
            dict(
                shape=[
                    1,
                ],
                attr_keys=[],
                attr_types=[],
                attr_pos=[],
                attr_ints=[],
                attr_floats=[],
                attr_strs=[],
                metadata_props_keys=[],
                metadata_props_values=[],
            ),
        )

    def test_symbolic_multi_out_has_correct_schema(self):
        torch.library.opcheck(
            _symbolic_impl._symbolic_multi_out,
            ([torch.tensor(1)], "CustomMultiOutOp", [1, 2, 10]),
            dict(
                shapes=[[1, 2], [42], []],
                attr_keys=["key"],
                attr_types=["i"],
                attr_pos=[(0, 1)],
                attr_ints=[1],
                attr_floats=[1.0],
                attr_strs=["attr"],
                metadata_props_keys=["meta_key"],
                metadata_props_values=["meta_value"],
                domain="",
                version=1,
            ),
        )

        # Empty inputs
        torch.library.opcheck(
            _symbolic_impl._symbolic_multi_out,
            ([], "CustomMultiOutOp", []),
            dict(
                shapes=[],
                attr_keys=[],
                attr_types=[],
                attr_pos=[],
                attr_ints=[],
                attr_floats=[],
                attr_strs=[],
                metadata_props_keys=[],
                metadata_props_values=[],
            ),
        )


class SymbolicOpsTest(common_utils.TestCase):
    def test_symbolic_accepts_valid_inputs(self):
        output = torch.onnx.ops.symbolic(
            "custom_domain::CustomOp",
            (torch.tensor(1),),
            dict(
                int_key=1,
                float_key=1.0,
                str_key="attr",
                bool_key=True,
                list_int_key=[1, 2],
                list_float_key=[1.0, 2.0],
                list_str_key=["attr1", "attr2"],
                list_bool_key=[True, False],
            ),
            dtype=torch.float32,
            shape=[1, 2, 3],
            version=1,
            metadata_props={"meta_key": "meta_value"},
        )
        self.assertEqual(output.shape, torch.Size([1, 2, 3]))
        self.assertEqual(output.dtype, torch.float32)
        self.assertEqual(output.device, torch.device("cpu"))

    def test_symbolic_accepts_valid_inputs_empty_shape(self):
        output = torch.onnx.ops.symbolic(
            "custom_domain::CustomOp",
            (torch.tensor(1),),
            dtype=torch.float32,
            shape=[],
        )
        self.assertEqual(output.shape, torch.Size([]))

    def test_symbolic_accepts_valid_inputs_integer_types(self):
        output = torch.onnx.ops.symbolic(
            "custom_domain::CustomOp",
            (torch.tensor(1),),
            dtype=1,  # 1 is float32 in ONNX
            shape=[42],
        )
        self.assertEqual(output.dtype, torch.float32)

    def test_symbolic_accepts_valid_inputs_int4_type(self):
        output = torch.onnx.ops.symbolic(
            "custom_domain::CustomOp",
            (torch.tensor(1),),
            dtype=22,  # 22 is INT4 in ONNX
            shape=[42],
        )
        # We use torch uint8 for int4
        self.assertEqual(output.dtype, torch.uint8)

    def test_symbolic_is_exportable(self):
        class Model(torch.nn.Module):
            def forward(self, x: torch.Tensor):
                return torch.onnx.ops.symbolic(
                    "custom_domain::CustomOp",
                    (x, None),
                    dict(
                        int_key=1,
                        float_key=1.0,
                        str_key="attr",
                        bool_key=True,
                        list_int_key=[1, 2],
                        list_float_key=[1.0, 2.0],
                        list_str_key=["attr1", "attr2"],
                        list_bool_key=[True, False],
                    ),
                    dtype=x.dtype,
                    shape=[1, 2, 3],
                    version=1,
                    metadata_props={"meta_key": "meta_value"},
                )

        onnx_program = torch.onnx.export(
            Model(), (torch.tensor(1),), dynamo=True, verbose=False
        )
        assert onnx_program is not None
        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "CustomOp")
        self.assertEqual(node.domain, "custom_domain")
        attributes = node.attributes
        self.assertEqual(
            attributes,
            dict(
                int_key=ir.AttrInt64("int_key", 1),
                float_key=ir.AttrFloat32("float_key", 1.0),
                str_key=ir.AttrString("str_key", "attr"),
                bool_key=ir.AttrInt64("bool_key", 1),
                list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
                list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
                list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
                list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
            ),
        )
        self.assertEqual(node.metadata_props["meta_key"], "meta_value")
        outputs = node.outputs
        self.assertEqual(list(outputs[0].shape), [1, 2, 3])
        self.assertEqual(outputs[0].dtype, ir.DataType.INT64)

    def test_symbolic_preserves_dynamic_shapes(self):
        class Model(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor):
                return torch.onnx.ops.symbolic(
                    "custom_domain::CustomOp",
                    (x, y),
                    dtype=x.dtype,
                    shape=[*x.shape, *y.shape],
                    version=1,
                )

        onnx_program = torch.onnx.export(
            Model(),
            (torch.zeros(2, 3), torch.zeros(1, 2)),
            dynamic_shapes=({0: "batch"}, {1: "something_else"}),
            dynamo=True,
            verbose=False,
        )
        assert onnx_program is not None
        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "CustomOp")
        self.assertEqual(node.domain, "custom_domain")
        inputs = onnx_program.model.graph.inputs
        self.assertEqual(str(inputs[0].shape[0]), "batch")
        self.assertEqual(inputs[0].shape[1], 3)
        self.assertEqual(inputs[1].shape[0], 1)
        self.assertEqual(str(inputs[1].shape[1]), "something_else")
        outputs = node.outputs
        self.assertEqual(str(outputs[0].shape[0]), "batch")
        self.assertEqual(outputs[0].shape[1], 3)
        self.assertEqual(outputs[0].shape[2], 1)
        self.assertEqual(str(outputs[0].shape[3]), "something_else")
        self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)

    def test_symbolic_multi_out_accepts_valid_inputs(self):
        outputs = torch.onnx.ops.symbolic_multi_out(
            "custom_domain::CustomMultiOutOp",
            (torch.tensor(1),),
            dict(
                int_key=1,
                float_key=1.0,
                str_key="attr",
                bool_key=True,
                list_int_key=[1, 2],
                list_float_key=[1.0, 2.0],
                list_str_key=["attr1", "attr2"],
                list_bool_key=[True, False],
            ),
            dtypes=(
                1,  # 1 is float32 in ONNX
                torch.int32,
                torch.float8_e4m3fn,
            ),
            shapes=([1, 2], [42], []),
            version=1,
            metadata_props={"meta_key": "meta_value"},
        )
        self.assertEqual(len(outputs), 3)
        self.assertEqual(outputs[0].shape, torch.Size([1, 2]))
        self.assertEqual(outputs[0].dtype, torch.float32)
        self.assertEqual(outputs[1].shape, torch.Size([42]))
        self.assertEqual(outputs[1].dtype, torch.int32)
        self.assertEqual(outputs[2].shape, torch.Size([]))
        self.assertEqual(outputs[2].dtype, torch.float8_e4m3fn)
        self.assertEqual(outputs[0].device, torch.device("cpu"))
        self.assertEqual(outputs[1].device, torch.device("cpu"))
        self.assertEqual(outputs[2].device, torch.device("cpu"))

    def test_symbolic_multi_out_accepts_valid_inputs_empty_shape(self):
        outputs = torch.onnx.ops.symbolic_multi_out(
            "custom_domain::CustomOp",
            (torch.tensor(1),),
            dtypes=(torch.float32,),
            shapes=[[]],
        )
        self.assertEqual(outputs[0].shape, torch.Size([]))

    def test_symbolic_multi_out_accepts_valid_inputs_integer_types(self):
        outputs = torch.onnx.ops.symbolic_multi_out(
            "custom_domain::CustomOp",
            (torch.tensor(1),),
            dtypes=(1,),  # 1 is float32 in ONNX
            shapes=[[42]],
        )
        self.assertEqual(outputs[0].dtype, torch.float32)

    def test_symbolic_multi_out_accepts_valid_inputs_int4_type(self):
        outputs = torch.onnx.ops.symbolic_multi_out(
            "custom_domain::CustomOp",
            (torch.tensor(1),),
            dtypes=(22,),  # 22 is INT4 in ONNX
            shapes=[[42]],
        )
        # We use torch uint8 for int4
        self.assertEqual(outputs[0].dtype, torch.uint8)

    def test_symbolic_multi_out_is_exportable(self):
        class Model(torch.nn.Module):
            def forward(self, x: torch.Tensor):
                return torch.onnx.ops.symbolic_multi_out(
                    "custom_domain::CustomOp",
                    (x, None),
                    dict(
                        int_key=1,
                        float_key=1.0,
                        str_key="attr",
                        bool_key=True,
                        list_int_key=[1, 2],
                        list_float_key=[1.0, 2.0],
                        list_str_key=["attr1", "attr2"],
                        list_bool_key=[True, False],
                    ),
                    dtypes=(torch.float32, torch.int32, torch.float8_e4m3fn),
                    shapes=([1, 2], [42], []),
                    version=1,
                    metadata_props={"meta_key": "meta_value"},
                )

        onnx_program = torch.onnx.export(
            Model(), (torch.tensor(1),), dynamo=True, verbose=False
        )
        assert onnx_program is not None
        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "CustomOp")
        self.assertEqual(node.domain, "custom_domain")
        attributes = node.attributes
        self.assertEqual(
            attributes,
            dict(
                int_key=ir.AttrInt64("int_key", 1),
                float_key=ir.AttrFloat32("float_key", 1.0),
                str_key=ir.AttrString("str_key", "attr"),
                bool_key=ir.AttrInt64("bool_key", 1),
                list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
                list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
                list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
                list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
            ),
        )
        self.assertEqual(node.metadata_props["meta_key"], "meta_value")
        outputs = node.outputs
        self.assertEqual(list(outputs[0].shape), [1, 2])
        self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
        self.assertEqual(list(outputs[1].shape), [42])
        self.assertEqual(outputs[1].dtype, ir.DataType.INT32)
        self.assertEqual(list(outputs[2].shape), [])
        self.assertEqual(outputs[2].dtype, ir.DataType.FLOAT8E4M3FN)

    def test_symbolic_multi_out_preserves_dynamic_shapes(self):
        class Model(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor):
                return torch.onnx.ops.symbolic_multi_out(
                    "custom_domain::CustomOp",
                    (x, y),
                    dtypes=(x.dtype, 22),  # 22 is INT4
                    shapes=[[*x.shape, *y.shape], [42]],
                    version=1,
                )

        onnx_program = torch.onnx.export(
            Model(),
            (torch.zeros(2, 3), torch.zeros(1, 2)),
            dynamic_shapes=({0: "batch"}, {1: "something_else"}),
            dynamo=True,
            verbose=False,
        )
        assert onnx_program is not None
        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "CustomOp")
        self.assertEqual(node.domain, "custom_domain")
        inputs = onnx_program.model.graph.inputs
        self.assertEqual(str(inputs[0].shape[0]), "batch")
        self.assertEqual(inputs[0].shape[1], 3)
        self.assertEqual(inputs[1].shape[0], 1)
        self.assertEqual(str(inputs[1].shape[1]), "something_else")
        outputs = node.outputs
        self.assertEqual(str(outputs[0].shape[0]), "batch")
        self.assertEqual(outputs[0].shape[1], 3)
        self.assertEqual(outputs[0].shape[2], 1)
        self.assertEqual(str(outputs[0].shape[3]), "something_else")
        self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
        self.assertEqual(list(outputs[1].shape), [42])
        self.assertEqual(outputs[1].dtype, ir.DataType.INT4)

    def test_symbolic_multi_out_raises_when_dtypes_and_shapes_differ(self):
        with self.assertRaises(RuntimeError):
            torch.onnx.ops.symbolic_multi_out(
                "custom_domain::CustomMultiOutOp",
                (torch.tensor(1),),
                dict(
                    int_key=1,
                    float_key=1.0,
                    str_key="attr",
                    bool_key=True,
                    list_int_key=[1, 2],
                    list_float_key=[1.0, 2.0],
                    list_str_key=["attr1", "attr2"],
                    list_bool_key=[True, False],
                ),
                dtypes=(torch.float32, torch.int32),
                shapes=([1, 2], [42], []),
                version=1,
                metadata_props={"meta_key": "meta_value"},
            )

        with self.assertRaises(RuntimeError):
            torch.onnx.ops.symbolic_multi_out(
                "custom_domain::CustomMultiOutOp",
                (torch.tensor(1),),
                dict(
                    int_key=1,
                    float_key=1.0,
                    str_key="attr",
                    bool_key=True,
                    list_int_key=[1, 2],
                    list_float_key=[1.0, 2.0],
                    list_str_key=["attr1", "attr2"],
                    list_bool_key=[True, False],
                ),
                dtypes=(torch.float32,),
                shapes=([1, 2], [42]),
                version=1,
                metadata_props={"meta_key": "meta_value"},
            )


class NativeOnnxOpsTest(common_utils.TestCase):
    def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
        onnx_program = torch.onnx.export(
            model,
            args,
            kwargs=kwargs,
            dynamo=True,
            fallback=False,
            verbose=False,
            **options,
        )
        assert onnx_program is not None
        common_passes.CheckerPass()(onnx_program.model)
        return onnx_program

    def test_onnx_ops_can_be_decomposed_to_aten(self):
        input_data = torch.rand(2, 3, 4, 8)
        position_ids_data = torch.randint(0, 50, (2, 3)).long()
        sin_cache_data = torch.rand(50, 4)
        cos_cache_data = torch.rand(50, 4)

        class Model(torch.nn.Module):
            def forward(
                self, input_data, cos_cache_data, sin_cache_data, position_ids_data
            ):
                return torch.onnx.ops.rotary_embedding(
                    input_data,
                    cos_cache_data,
                    sin_cache_data,
                    position_ids_data,
                    interleaved=True,
                )

        model = Model()

        ep = torch.export.export(
            model,
            (input_data, cos_cache_data, sin_cache_data, position_ids_data),
        )
        self.assertIn(
            "onnx.RotaryEmbedding.opset23",
            [str(node.target) for node in ep.graph.nodes],
        )
        # The program can be decomposed into aten ops so it is fully compatible with the PyTorch ecosystem
        aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
        self.assertNotIn(
            "onnx.RotaryEmbedding.opset23",
            [str(node.target) for node in aten_decomped.graph.nodes],
        )
        torch.testing.assert_close(
            aten_decomped.module()(
                input_data, cos_cache_data, sin_cache_data, position_ids_data
            ),
            model(input_data, cos_cache_data, sin_cache_data, position_ids_data),
        )

    def test_rotary_embedding_opcheck(self):
        input_data = torch.rand(2, 3, 4, 8)
        position_ids_data = torch.randint(0, 50, (2, 3)).long()
        sin_cache_data = torch.rand(50, 4)
        cos_cache_data = torch.rand(50, 4)

        torch.library.opcheck(
            _impl.rotary_embedding_23,
            (input_data, cos_cache_data, sin_cache_data, position_ids_data),
        )

    def test_rotary_embedding(self):
        input_data = torch.rand(2, 3, 4, 8)
        position_ids_data = torch.randint(0, 50, (2, 3)).long()
        sin_cache_data = torch.rand(50, 4)
        cos_cache_data = torch.rand(50, 4)

        # Eager mode is supported. Autograd is also supported so users can choose to use the op
        # in development and production
        result = torch.onnx.ops.rotary_embedding(
            input_data, cos_cache_data, sin_cache_data, position_ids_data
        )
        self.assertEqual(result.shape, input_data.shape)

        class Model(torch.nn.Module):
            def forward(
                self, input_data, cos_cache_data, sin_cache_data, position_ids_data
            ):
                return torch.onnx.ops.rotary_embedding(
                    input_data,
                    cos_cache_data,
                    sin_cache_data,
                    position_ids_data,
                    interleaved=True,
                )

        model = Model()

        # Dynamic shapes are supported
        dynamic_shapes = {
            "input_data": {0: torch.export.Dim.DYNAMIC},
            "cos_cache_data": None,
            "sin_cache_data": None,
            "position_ids_data": {0: torch.export.Dim.DYNAMIC},
        }

        onnx_program = self.export(
            model,
            (input_data, cos_cache_data, sin_cache_data, position_ids_data),
            dynamic_shapes=dynamic_shapes,
            opset_version=23,
        )
        self.assertEqual(onnx_program.model.opset_imports[""], 23)
        self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)

    def test_attention_basic(self):
        """Test basic attention functionality."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        # Test eager mode
        torch.library.opcheck(_impl.attention_23, (Q, K, V))
        output, present_key, present_value, qk_output = torch.onnx.ops.attention(
            Q, K, V
        )

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
        self.assertEqual(present_key.shape, K.shape)
        self.assertEqual(present_value.shape, V.shape)
        self.assertEqual(
            qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
        )

    def test_attention_3d_inputs(self):
        """Test attention with 3D inputs (requires num_heads parameters)."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
        K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
        V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)

        torch.library.opcheck(
            _impl.attention_23,
            (Q, K, V),
            dict(q_num_heads=q_num_heads, kv_num_heads=kv_num_heads),
        )
        output, present_key, present_value, qk_output = torch.onnx.ops.attention(
            Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
        )

        # Output should be reshaped back to 3D
        self.assertEqual(output.shape, (batch_size, q_seq_len, q_num_heads * head_size))
        self.assertEqual(
            present_key.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
        )
        self.assertEqual(
            present_value.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
        )

    def test_attention_gqa(self):
        """Test Group Query Attention (GQA)."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 4  # GQA: q_num_heads % kv_num_heads = 0
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        torch.library.opcheck(_impl.attention_23, (Q, K, V))
        output, present_key, present_value, qk_output = torch.onnx.ops.attention(
            Q, K, V
        )
        expected = torch.nn.functional.scaled_dot_product_attention(
            Q, K, V, None, enable_gqa=True
        )

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
        self.assertEqual(present_key.shape, K.shape)
        self.assertEqual(present_value.shape, V.shape)
        torch.testing.assert_close(output, expected)

    def test_attention_mqa(self):
        """Test Multi-Query Attention (MQA)."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 1  # MQA: kv_num_heads = 1
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        torch.library.opcheck(_impl.attention_23, (Q, K, V))
        output, present_key, present_value, qk_output = torch.onnx.ops.attention(
            Q, K, V
        )
        expected = torch.nn.functional.scaled_dot_product_attention(
            Q, K, V, None, enable_gqa=True
        )

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
        torch.testing.assert_close(output, expected)

    def test_attention_with_2d_mask(self):
        """Test attention with 2D attention mask (q_seq_len, kv_seq_len)."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        # Test with boolean mask
        bool_mask = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
        output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)

        # Test with float mask
        float_mask = torch.randn(q_seq_len, kv_seq_len)
        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
        output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)

        self.assertEqual(
            output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
        )
        self.assertEqual(
            output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
        )

    def test_attention_with_4d_mask(self):
        """Test attention with 4D attention mask (batch_size, num_heads, q_seq_len, kv_seq_len)."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        # Test with boolean mask
        bool_mask = torch.randint(
            0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
        )
        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
        output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)

        # Test with float mask
        float_mask = torch.randn(batch_size, q_num_heads, q_seq_len, kv_seq_len)
        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
        output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)

        self.assertEqual(
            output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
        )
        self.assertEqual(
            output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
        )

    def test_attention_with_zero_float_mask(self):
        """Test attention with zero float mask."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        zero_mask = torch.zeros(q_seq_len, kv_seq_len)
        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=zero_mask))
        output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=zero_mask)

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))

    def test_attention_with_causal_mask_pattern(self):
        """Test attention with lower triangular causal mask pattern."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 4  # Square for causal
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        # Create a lower triangular causal mask
        causal_mask = torch.tril(torch.ones(q_seq_len, kv_seq_len, dtype=torch.bool))
        torch.library.opcheck(
            _impl.attention_23, (Q, K, V), dict(attn_mask=causal_mask)
        )
        output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=causal_mask)

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))

    def test_attention_with_gqa_and_mask(self):
        """Test attention with GQA and different mask shapes."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 4  # GQA
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        # Test 2D mask with GQA
        mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_2d))
        output_2d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_2d)

        # Test 4D mask with GQA (note: using q_num_heads for mask heads)
        mask_4d = torch.randint(
            0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
        )
        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_4d))
        output_4d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_4d)

        self.assertEqual(
            output_2d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
        )
        self.assertEqual(
            output_4d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
        )

    def test_attention_with_large_negative_float_mask(self):
        """Test attention with large negative values in float mask."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        # Create mask with large negative values (similar to -inf masking)
        float_mask = torch.full((q_seq_len, kv_seq_len), -1e9)
        # Allow some positions
        float_mask[:, :3] = 0.0

        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
        output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))

    def test_attention_causal(self):
        """Test causal attention."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 4  # Square for causal
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(is_causal=True))
        output, _, _, _ = torch.onnx.ops.attention(Q, K, V, is_causal=True)

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))

    def test_attention_with_past_kv(self):
        """Test attention with past key/value caches."""
        batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
        past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)

        torch.library.opcheck(
            _impl.attention_23,
            (Q, K, V),
            dict(past_key=past_key, past_value=past_value),
        )
        output, present_key, present_value, _ = torch.onnx.ops.attention(
            Q, K, V, past_key=past_key, past_value=past_value
        )

        # Present key/value should include past + current
        expected_total_seq_len = past_seq_len + kv_seq_len
        self.assertEqual(
            present_key.shape,
            (batch_size, kv_num_heads, expected_total_seq_len, head_size),
        )
        self.assertEqual(
            present_value.shape,
            (batch_size, kv_num_heads, expected_total_seq_len, head_size),
        )

    def test_attention_with_softcap(self):
        """Test attention with softcap."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(softcap=30.0))
        output, _, _, _ = torch.onnx.ops.attention(Q, K, V, softcap=30.0)

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))

    def test_attention_qk_output_modes(self):
        """Test different QK matmul output modes."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        for mode in [0, 1, 2, 3]:
            torch.library.opcheck(
                _impl.attention_23,
                (Q, K, V),
                dict(qk_matmul_output_mode=mode),
            )
            output, _, _, qk_output = torch.onnx.ops.attention(
                Q, K, V, qk_matmul_output_mode=mode
            )

            self.assertEqual(
                output.shape, (batch_size, q_num_heads, q_seq_len, head_size)
            )
            self.assertEqual(
                qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
            )

    def test_attention_custom_scale(self):
        """Test attention with custom scale factor."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        custom_scale = 0.25
        torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(scale=custom_scale))
        output, _, _, _ = torch.onnx.ops.attention(Q, K, V, scale=custom_scale)

        self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))

    def test_attention_export(self):
        """Test that attention can be exported to ONNX."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        class AttentionModel(torch.nn.Module):
            def forward(self, Q, K, V):
                output, present_key, present_value, qk_output = (
                    torch.onnx.ops.attention(Q, K, V)
                )
                return output

        model = AttentionModel()

        onnx_program = self.export(
            model,
            (Q, K, V),
            opset_version=23,
        )

        self.assertEqual(onnx_program.model.opset_imports[""], 23)
        self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)

    def test_attention_export_with_dynamic_shapes(self):
        """Test attention export with dynamic shapes."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        class AttentionModel(torch.nn.Module):
            def forward(self, Q, K, V):
                output, present_key, present_value, qk_output = (
                    torch.onnx.ops.attention(Q, K, V)
                )
                return output

        model = AttentionModel()

        dynamic_shapes = {
            "Q": {0: "batch", 2: "q_seq_len"},
            "K": {0: "batch", 2: "kv_seq_len"},
            "V": {0: "batch", 2: "kv_seq_len"},
        }

        onnx_program = self.export(
            model,
            (Q, K, V),
            dynamic_shapes=dynamic_shapes,
            opset_version=23,
        )

        self.assertEqual(onnx_program.model.opset_imports[""], 23)
        self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
        node = onnx_program.model.graph.node(0)
        # Verify inputs
        self.assertEqual(len(node.inputs), 3)  # Q, K, V (no optional inputs)
        self.assertEqual(
            node.inputs[0].shape, ["batch", q_num_heads, "q_seq_len", head_size]
        )
        self.assertEqual(
            node.inputs[1].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
        )
        self.assertEqual(
            node.inputs[2].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
        )

        # Verify default attributes (should be minimal)
        self.assertEqual(len(node.attributes), 0)

    def test_attention_3d_export(self):
        """Test attention export with 3D inputs."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
        K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
        V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)

        class AttentionModel(torch.nn.Module):
            def forward(self, Q, K, V):
                output, _, _, _ = torch.onnx.ops.attention(
                    Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
                )
                return output

        model = AttentionModel()

        onnx_program = self.export(
            model,
            (Q, K, V),
            opset_version=23,
        )

        self.assertEqual(onnx_program.model.opset_imports[""], 23)
        self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)

    def test_attention_decomposition(self):
        """Test that attention can be decomposed to aten ops."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        class AttentionModel(torch.nn.Module):
            def forward(self, Q, K, V):
                output, present_key, present_value, qk_output = (
                    torch.onnx.ops.attention(Q, K, V)
                )
                return output

        model = AttentionModel()

        ep = torch.export.export(model, (Q, K, V))
        self.assertIn(
            "onnx.Attention.opset23",
            [str(node.target) for node in ep.graph.nodes],
        )

        # The program can be decomposed into aten ops
        aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
        self.assertNotIn(
            "onnx.Attention.opset23",
            [str(node.target) for node in aten_decomped.graph.nodes],
        )

        # Results should match
        torch.testing.assert_close(
            aten_decomped.module()(Q, K, V),
            model(Q, K, V),
        )

    def test_attention_export_with_past_key_value(self):
        """Test export with past_key, past_value to ensure the optional input order is correct."""
        batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
        past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)

        class Model(torch.nn.Module):
            def forward(self, Q, K, V, past_key, past_value):
                output, _, _, _ = torch.onnx.ops.attention(
                    Q,
                    K,
                    V,
                    past_key=past_key,
                    attn_mask=None,
                    # Switched argument order
                    past_value=past_value,
                )
                return output

        model = Model()
        onnx_program = self.export(
            model, (Q, K, V, past_key, past_value), opset_version=23
        )

        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "Attention")

        # Verify all 6 inputs are present
        self.assertEqual(
            len(node.inputs), 6
        )  # Q, K, V, attn_mask, past_key, past_value
        self.assertEqual(
            node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
        )
        self.assertEqual(
            node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
        )
        self.assertEqual(
            node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
        )
        self.assertIsNone(node.inputs[3])
        self.assertEqual(
            node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
        )
        self.assertEqual(
            node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
        )

    def test_attention_export_with_all_optional_inputs(self):
        """Test export with all optional inputs: mask, past_key, past_value."""
        batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        attn_mask = torch.randint(
            0, 2, (1, 1, q_seq_len, kv_seq_len + past_seq_len), dtype=torch.bool
        )
        past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
        past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)

        class FullAttentionModel(torch.nn.Module):
            def forward(self, Q, K, V, attn_mask, past_key, past_value):
                output, _, _, _ = torch.onnx.ops.attention(
                    Q,
                    K,
                    V,
                    attn_mask=attn_mask,
                    past_key=past_key,
                    past_value=past_value,
                )
                return output

        model = FullAttentionModel()
        onnx_program = self.export(
            model, (Q, K, V, attn_mask, past_key, past_value), opset_version=23
        )

        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "Attention")

        # Verify all 6 inputs are present
        self.assertEqual(
            len(node.inputs), 6
        )  # Q, K, V, attn_mask, past_key, past_value
        self.assertEqual(
            node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
        )
        self.assertEqual(
            node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
        )
        self.assertEqual(
            node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
        )
        self.assertEqual(
            node.inputs[3].shape, [1, 1, q_seq_len, kv_seq_len + past_seq_len]
        )
        self.assertEqual(
            node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
        )
        self.assertEqual(
            node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
        )

    def test_attention_export_3d_with_num_heads_attributes(self):
        """Test export with 3D inputs and explicit num_heads attributes."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 4  # GQA
        head_size = 64

        Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
        K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
        V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)

        class Attention3DModel(torch.nn.Module):
            def forward(self, Q, K, V):
                output, _, _, _ = torch.onnx.ops.attention(
                    Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
                )
                return output

        model = Attention3DModel()
        onnx_program = self.export(model, (Q, K, V), opset_version=23)

        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "Attention")

        # Verify 3D input shapes
        self.assertEqual(
            node.inputs[0].shape, [batch_size, q_seq_len, q_num_heads * head_size]
        )
        self.assertEqual(
            node.inputs[1].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
        )
        self.assertEqual(
            node.inputs[2].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
        )

        # Verify num_heads attributes are set
        attrs = node.attributes
        self.assertIn("q_num_heads", attrs)
        self.assertIn("kv_num_heads", attrs)
        self.assertEqual(attrs["q_num_heads"].value, q_num_heads)
        self.assertEqual(attrs["kv_num_heads"].value, kv_num_heads)

    def test_attention_export_with_all_attributes(self):
        """Test export with all possible attributes set."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        class FullAttributesModel(torch.nn.Module):
            def forward(self, Q, K, V):
                output, _, _, _ = torch.onnx.ops.attention(
                    Q,
                    K,
                    V,
                    is_causal=True,
                    qk_matmul_output_mode=2,
                    scale=0.25,
                    softcap=30.0,
                    softmax_precision=1,  # FLOAT
                )
                return output

        model = FullAttributesModel()
        onnx_program = self.export(model, (Q, K, V), opset_version=23)

        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "Attention")

        # Verify all attributes are set correctly
        attrs = node.attributes
        self.assertIn("is_causal", attrs)
        self.assertIn("qk_matmul_output_mode", attrs)
        self.assertIn("scale", attrs)
        self.assertIn("softcap", attrs)
        self.assertIn("softmax_precision", attrs)

        self.assertEqual(attrs["is_causal"].value, 1)  # True as int
        self.assertEqual(attrs["qk_matmul_output_mode"].value, 2)
        self.assertAlmostEqual(attrs["scale"].value, 0.25, places=6)
        self.assertAlmostEqual(attrs["softcap"].value, 30.0, places=6)
        self.assertEqual(attrs["softmax_precision"].value, 1)

    def test_attention_export_with_different_mask_shapes(self):
        """Test export with different attention mask shapes."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        # Test 2D mask
        mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)

        class Mask2DModel(torch.nn.Module):
            def forward(self, Q, K, V, mask):
                output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
                return output

        model_2d = Mask2DModel()
        onnx_program_2d = self.export(model_2d, (Q, K, V, mask_2d), opset_version=23)

        node_2d = onnx_program_2d.model.graph.node(0)
        self.assertEqual(node_2d.inputs[3].shape, [q_seq_len, kv_seq_len])

        # Test 3D mask
        mask_3d = torch.randint(
            0, 2, (batch_size, 1, q_seq_len, kv_seq_len), dtype=torch.bool
        )

        class Mask3DModel(torch.nn.Module):
            def forward(self, Q, K, V, mask):
                output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
                return output

        model_3d = Mask3DModel()
        onnx_program_3d = self.export(model_3d, (Q, K, V, mask_3d), opset_version=23)

        node_3d = onnx_program_3d.model.graph.node(0)
        self.assertEqual(
            node_3d.inputs[3].shape, [batch_size, 1, q_seq_len, kv_seq_len]
        )

        # Test 4D mask
        mask_4d = torch.randint(
            0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
        )

        class Mask4DModel(torch.nn.Module):
            def forward(self, Q, K, V, mask):
                output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
                return output

        model_4d = Mask4DModel()
        onnx_program_4d = self.export(model_4d, (Q, K, V, mask_4d), opset_version=23)

        node_4d = onnx_program_4d.model.graph.node(0)
        self.assertEqual(
            node_4d.inputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
        )

    def test_attention_export_with_float_mask(self):
        """Test export with float attention mask."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        float_mask = torch.randn(q_seq_len, kv_seq_len)

        class FloatMaskModel(torch.nn.Module):
            def forward(self, Q, K, V, mask):
                output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
                return output

        model = FloatMaskModel()
        onnx_program = self.export(model, (Q, K, V, float_mask), opset_version=23)

        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "Attention")
        self.assertEqual(node.inputs[3].shape, [q_seq_len, kv_seq_len])
        # Verify the mask input has float dtype in the ONNX model
        self.assertEqual(node.inputs[3].dtype, ir.DataType.FLOAT)

    def test_attention_export_qk_output_modes(self):
        """Test export with different QK output modes."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        for mode in [0, 1, 2, 3]:

            class QKOutputModel(torch.nn.Module):
                def __init__(self, qk_mode):
                    super().__init__()
                    self.qk_mode = qk_mode

                def forward(self, Q, K, V):
                    output, _, _, qk_output = torch.onnx.ops.attention(
                        Q, K, V, qk_matmul_output_mode=self.qk_mode
                    )
                    return output, qk_output

            model = QKOutputModel(mode)
            onnx_program = self.export(model, (Q, K, V), opset_version=23)

            node = onnx_program.model.graph.node(0)
            self.assertEqual(node.op_type, "Attention")

            # Verify qk_matmul_output_mode attribute
            attrs = node.attributes
            if mode != 0:
                self.assertIn("qk_matmul_output_mode", attrs)
                self.assertEqual(attrs["qk_matmul_output_mode"].value, mode)

            # Verify 4 outputs (output, present_key, present_value, qk_output)
            self.assertEqual(len(node.outputs), 4)

    def test_attention_export_mqa(self):
        """Test export with Multi-Query Attention (MQA)."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 1  # MQA
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        class MQAModel(torch.nn.Module):
            def forward(self, Q, K, V):
                output, _, _, _ = torch.onnx.ops.attention(Q, K, V)
                return output

        model = MQAModel()
        onnx_program = self.export(model, (Q, K, V), opset_version=23)

        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "Attention")

        # Verify MQA tensor shapes
        self.assertEqual(
            node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
        )
        self.assertEqual(
            node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
        )  # kv_num_heads = 1
        self.assertEqual(
            node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
        )

    def test_attention_export_with_softmax_precision(self):
        """Test export with different softmax precision values."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 8
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        # Test different ONNX precision types
        precision_types = [
            (1, "FLOAT"),
            (10, "FLOAT16"),
            (11, "DOUBLE"),
            (16, "BFLOAT16"),
        ]

        for precision_val, precision_name in precision_types:

            class SoftmaxPrecisionModel(torch.nn.Module):
                def __init__(self, precision):
                    super().__init__()
                    self.precision = precision

                def forward(self, Q, K, V):
                    output, _, _, _ = torch.onnx.ops.attention(
                        Q, K, V, softmax_precision=self.precision
                    )
                    return output

            model = SoftmaxPrecisionModel(precision_val)
            onnx_program = self.export(model, (Q, K, V), opset_version=23)

            node = onnx_program.model.graph.node(0)
            self.assertEqual(node.op_type, "Attention")

            # Verify softmax_precision attribute
            attrs = node.attributes
            self.assertIn("softmax_precision", attrs)
            self.assertEqual(attrs["softmax_precision"].value, precision_val)

    def test_attention_export_gqa(self):
        """Test export and verify output tensor shapes."""
        batch_size, q_seq_len, kv_seq_len = 2, 4, 6
        q_num_heads, kv_num_heads = 8, 4  # GQA
        head_size = 64

        Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
        K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
        V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)

        class AttentionOutputsModel(torch.nn.Module):
            def forward(self, Q, K, V):
                return torch.onnx.ops.attention(Q, K, V)

        model = AttentionOutputsModel()
        onnx_program = self.export(model, (Q, K, V), opset_version=23)

        node = onnx_program.model.graph.node(0)
        self.assertEqual(node.op_type, "Attention")

        # Verify all 4 outputs have correct shapes
        outputs = node.outputs
        self.assertEqual(len(outputs), 4)

        # output: (batch_size, q_num_heads, q_seq_len, head_size)
        self.assertEqual(
            outputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
        )

        # present_key: (batch_size, kv_num_heads, kv_seq_len, head_size)
        self.assertEqual(
            outputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
        )

        # present_value: (batch_size, kv_num_heads, kv_seq_len, head_size)
        self.assertEqual(
            outputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
        )

        # qk_output: (batch_size, q_num_heads, q_seq_len, kv_seq_len)
        self.assertEqual(
            outputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
        )


if __name__ == "__main__":
    common_utils.run_tests()
