# Owner(s): ["oncall: distributed"]

# To run:
# TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py


import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
from torch._inductor.runtime.triton_compat import tl, triton
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    skip_but_pass_in_sandcastle_if,
    skipIfRocm,
)
from torch.testing._internal.inductor_utils import requires_triton


# Decorator
def requires_nvshmem():
    return skip_but_pass_in_sandcastle_if(
        not symm_mem.is_nvshmem_available(),
        "test_nvshmem requires NVSHMEM, skipping tests",
    )


# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)


@instantiate_parametrized_tests
@requires_nvshmem()
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
    def _init_device(self) -> None:
        # TODO: relieve this (seems to hang if without)
        device_module.set_device(self.device)
        # NOTE: required for nvshmem allocation
        torch.empty(1, device=self.device)

    @property
    def device(self) -> torch.device:
        return torch.device(device_type, self.rank)

    @skipIfRocm
    def test_alloc(self) -> None:
        self._init_device()

        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)

        dtype = torch.float
        numel = 1024

        def foo():
            inp = symm_mem.empty(numel, dtype=dtype, device=self.device)
            symm_mem.rendezvous(inp, group=group_name)

        foo()

        out = symm_mem.empty(numel, dtype=dtype, device=self.device)
        symm_mem.rendezvous(out, group=group_name)

    @skipIfRocm
    def test_nvshmem_put(self) -> None:
        self._init_device()
        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)

        dtype = torch.float
        numel = 1024
        tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
        symm_mem.rendezvous(tensor, group=group_name)

        if self.rank == 0:
            torch.ops.symm_mem.nvshmem_put(tensor, 1)
            # TODO: remove after we have wait_signal
            dist.barrier()
        elif self.rank == 1:
            # handle.wait_signal(src_rank=0)
            # TODO: remove after we have wait_signal
            dist.barrier()
            torch.testing.assert_close(
                tensor, torch.zeros(numel, dtype=dtype, device=self.device)
            )
        else:
            dist.barrier()

    @skipIfRocm
    def test_nvshmem_all_to_all(self) -> None:
        self._init_device()

        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)

        dtype = torch.float
        numel_per_peer = 10
        numel = self.world_size * numel_per_peer
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)

        symm_mem.rendezvous(inp, group=group_name)
        symm_mem.rendezvous(out, group=group_name)
        torch.ops.symm_mem.nvshmem_all_to_all(inp, out, group_name)

        expected = torch.cat(
            [
                torch.empty(numel_per_peer, dtype=dtype, device=self.device).fill_(i)
                for i in range(self.world_size)
            ]
        )
        torch.testing.assert_close(out, expected)

    @skipIfRocm
    def test_all_to_all_vdev(self) -> None:
        self._init_device()

        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)

        dtype = torch.float
        # Number of elements for a peer is random between [0, k)
        k = 10
        inp_splits = torch.randint(k, (self.world_size,), device=self.device)
        inp_numel = inp_splits.sum().item()
        # Exchange input splits to get output splits
        out_splits = torch.zeros_like(inp_splits)
        dist.all_to_all_single(out_splits, inp_splits)
        out_numel = out_splits.sum().item()

        # Max number of input elements (must be a constant across ranks for symmetric memory allocation)
        max_inp_numel = k * self.world_size
        # Max number of output elements (must be a constant across ranks for symmetric memory allocation)
        overflow_factor = self.world_size  # worst case: one rank receives all data
        max_out_numel = max_inp_numel * overflow_factor

        inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_(
            self.rank
        )
        out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
        in_out_splits = symm_mem.empty(
            (3, self.world_size), dtype=torch.int64, device=self.device
        )
        # Row 0 is input splits
        in_out_splits[0].copy_(inp_splits)

        torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name)

        # Check input splits (row 0) -- should not change
        torch.testing.assert_close(in_out_splits[0], inp_splits)

        # Check output splits (row 1)
        torch.testing.assert_close(in_out_splits[1], out_splits)

        # Check output offsets (row 2)
        out_offsets = torch.cumsum(out_splits, dim=0)  # inclusive scan
        # output offsets from `all_to_all_vdev` is exclusive scan
        self.assertEqual(in_out_splits[2][0], 0)
        torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])

        # Check data
        expected = torch.empty(out_numel, dtype=dtype, device=self.device)
        dist.all_to_all_single(
            expected, inp[:inp_numel], out_splits.tolist(), inp_splits.tolist()
        )
        torch.testing.assert_close(out[:out_numel], expected)

    @skipIfRocm
    @parametrize("align", [1, 8, 16])  # `major_align` of output
    def test_all_to_all_vdev_2d(self, align: int) -> None:
        torch.manual_seed(42 + self.rank)
        self._init_device()

        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)

        dtype = torch.float
        # Number of experts per rank
        ne = 8
        nsplits = ne * self.world_size

        # Number of elements for an expert is random between [0, k)
        k = 10
        inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=self.device)

        # Exchange input splits to get output splits
        out_splits = torch.zeros_like(inp_splits)
        dist.all_to_all_single(out_splits, inp_splits)
        # We do a .t() here because there is a rank-major to expert-major shuffle
        out_splits_t = out_splits.reshape(self.world_size, ne).t()

        # Actual number of input elements
        inp_numel = inp_splits.sum().item()
        # Actual number of output elements
        out_numel = out_splits.sum().item()
        # Max number of input elements (must be a constant across ranks for symmetric memory allocation)
        max_inp_numel = k * nsplits
        # Max number of output elements (must be a constant across ranks for symmetric memory allocation)
        overflow_factor = self.world_size  # worst case: one rank receives all data
        max_out_numel = max_inp_numel * overflow_factor

        inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_(
            self.rank
        )
        out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
        # 3 rows: input splits, output splits, output offsets
        # Initiallizing all values to -1 to check if they are updated
        in_out_splits = symm_mem.empty(
            (3, nsplits), dtype=torch.int64, device=self.device
        ).fill_(-1)
        # Row 0 is input splits
        in_out_splits[0].copy_(inp_splits)

        torch.ops.symm_mem.all_to_all_vdev_2d(
            inp, out, in_out_splits, group_name, major_align=align
        )
        received_out_splits = in_out_splits[1]
        received_out_offsets = in_out_splits[2]

        # Check input splits (row 0) -- should not change
        torch.testing.assert_close(in_out_splits[0], inp_splits)

        # Check output splits (row 1)
        torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1))

        # Check output offsets (row 2)
        out_split_list = out_splits_t.tolist()
        for i in range(ne):
            expert_sum = 0
            for j in range(self.world_size):
                expert_sum += out_split_list[i][j]
            # Align up expert_sum
            expert_sum_aligned = (expert_sum + align - 1) // align * align
            # If 0, make it at least `align` (bc cutlass currently does not support empty bins)
            expert_sum_aligned = max(expert_sum_aligned, align)
            # last element absorbs the padding
            out_split_list[i][-1] += expert_sum_aligned - expert_sum

        out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
        out_offsets = torch.cumsum(out_splits_padded, dim=0)  # inclusive scan
        # Make it exclusive scan because that's what `all_to_all_vdev_2d` returns
        out_offsets = torch.cat(
            [torch.zeros(1, device=self.device), out_offsets[:-1]]
        ).to(torch.int64)
        torch.testing.assert_close(received_out_offsets, out_offsets)

        # Check data
        expected = torch.empty(out_numel, dtype=dtype, device=self.device)
        inp_splits_rank = inp_splits.reshape(self.world_size, ne).sum(1)
        out_splits_rank = out_splits.reshape(self.world_size, ne).sum(1)
        dist.all_to_all_single(
            expected,
            inp[:inp_numel],
            out_splits_rank.tolist(),
            inp_splits_rank.tolist(),
        )
        # We still need to shuffle `expected`
        out_offsets = torch.cumsum(out_splits, dim=0)  # inclusive scan
        result_list = []
        for j in range(ne):
            for i in range(self.world_size):
                chunk_id = i * ne + j
                offset = out_offsets[chunk_id]
                chunk = expected[offset - out_splits[chunk_id] : offset]
                result_list.append(chunk)

        # Do a chunk-wise comparison
        for c, chunk in enumerate(result_list):
            start = received_out_offsets[c].item()
            split = received_out_splits[c].item()
            received_chunk = out[start : start + split]
            torch.testing.assert_close(received_chunk, chunk)

    @skipIfRocm
    @requires_triton()
    def test_triton_put(self) -> None:
        # A Triton kernel that calls nvshmem device side API
        @triton.jit
        def put_kernel(
            dst_ptr,
            src_ptr,
            numel: tl.constexpr,
            peer: tl.constexpr,
        ):
            nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)

        torch.manual_seed(42 + self.rank)
        self._init_device()

        # Enable NVSHMEM for Triton
        nvshmem_lib = nvshmem.enable_triton()

        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank

        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize

        val = 5
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        inp_hdl = symm_mem.rendezvous(inp, group=group_name)
        out_hdl = symm_mem.rendezvous(out, group=group_name)

        peer = 1 - rank
        if rank == 0:
            dst_ptr = out_hdl.buffer_ptrs[rank]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            put_kernel[(1, 1, 1)](
                dst_ptr,
                src_ptr,
                numel=numel,
                peer=peer,
                extern_libs=nvshmem_lib,
            )

        dist.barrier()
        if rank == 1:
            torch.testing.assert_close(
                out, val * torch.ones(numel, dtype=dtype, device=self.device)
            )

    @skipIfRocm
    @requires_triton()
    def test_triton_get(self) -> None:
        # A Triton kernel that calls nvshmem device side API for GET
        @triton.jit
        def get_kernel(
            dst_ptr,
            src_ptr,
            numel: tl.constexpr,
            peer: tl.constexpr,
        ):
            nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)

        torch.manual_seed(42 + self.rank)
        self._init_device()

        nvshmem_lib = nvshmem.enable_triton()
        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank
        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize
        val = 7
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(
            val if rank == 0 else -1
        )
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        inp_hdl = symm_mem.rendezvous(inp, group=group_name)
        out_hdl = symm_mem.rendezvous(out, group=group_name)
        dist.barrier()
        peer = 1 - rank
        if rank == 1:
            # Rank 1 gets data from rank 0
            dst_ptr = out_hdl.buffer_ptrs[rank]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            get_kernel[(1, 1, 1)](
                dst_ptr,
                src_ptr,
                numel=numel,
                peer=peer,
                extern_libs=nvshmem_lib,
            )
        if rank == 1:
            torch.testing.assert_close(
                out, val * torch.ones(numel, dtype=dtype, device=self.device)
            )

    @skipIfRocm
    @requires_triton()
    def test_triton_get_ring(self) -> None:
        # A Triton kernel that calls nvshmem device side API for GET
        # with ring topology
        @triton.jit
        def get_kernel(
            dst_ptr,
            src_ptr,
            numel: tl.constexpr,
            peer: tl.constexpr,
        ):
            nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)

        torch.manual_seed(42 + self.rank)
        self._init_device()

        nvshmem_lib = nvshmem.enable_triton()
        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank
        world_size = dist.get_world_size()
        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize

        # Each rank fills its input buffer with its own rank value
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank)
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        inp_hdl = symm_mem.rendezvous(inp, group=group_name)
        out_hdl = symm_mem.rendezvous(out, group=group_name)
        dist.barrier()

        # Ring topology: each rank gets data from the rank to its left
        # rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc.
        peer = (rank - 1) % world_size

        # All ranks execute the get operation
        dst_ptr = out_hdl.buffer_ptrs[rank]
        src_ptr = inp_hdl.buffer_ptrs[rank]
        get_kernel[(1, 1, 1)](
            dst_ptr,
            src_ptr,
            numel=numel,
            peer=peer,
            extern_libs=nvshmem_lib,
        )

        expected_value = peer
        torch.testing.assert_close(
            out, expected_value * torch.ones(numel, dtype=dtype, device=self.device)
        )

    @skipIfRocm
    @requires_triton()
    def test_triton_put_signal_set(self) -> None:
        # A Triton kernel that calls nvshmem device side API for PUT with SIGNAL
        @triton.jit
        def put_signal_kernel(
            dst_ptr,
            src_ptr,
            numel: tl.constexpr,
            sig_ptr,
            signal_val: tl.constexpr,
            sig_op: tl.constexpr,
            peer: tl.constexpr,
        ):
            nvshmem.putmem_signal_block(
                dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
            )

        torch.manual_seed(42 + self.rank)
        self._init_device()

        nvshmem_lib = nvshmem.enable_triton()

        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank

        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize

        # Data buffers
        val = 11
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        inp_hdl = symm_mem.rendezvous(inp, group=group_name)
        out_hdl = symm_mem.rendezvous(out, group=group_name)

        # Use the signal pad attached to the output symmetric memory handle
        # as the flag buffer for signaling completion.
        flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)

        peer = 1 - rank
        NVSHMEM_SIGNAL_SET = 0  # value defined by NVSHMEM for atomic set
        SIGNAL_VAL = 1  # Signal completion value
        NVSHMEM_CMP_EQ = 0  # compare equal for signal wait until

        # Kernel for waiting on the signal locally (Rank 1).
        @triton.jit
        def signal_wait_until_kernel(
            sig_ptr, cmp_op: tl.constexpr, cmp_val: tl.constexpr
        ):
            nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)

        if rank == 0:
            # Rank 0 puts into Rank 1
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
            put_signal_kernel[(1, 1, 1)](
                dst_ptr,
                src_ptr,
                numel=numel,
                sig_ptr=sig_ptr,
                signal_val=SIGNAL_VAL,
                sig_op=NVSHMEM_SIGNAL_SET,
                peer=peer,
                extern_libs=nvshmem_lib,
            )

        if rank == 1:
            # Wait until signal flag is set by Rank 0
            sig_ptr_local = out_hdl.signal_pad_ptrs[rank]
            signal_wait_until_kernel[(1,)](
                sig_ptr_local,
                cmp_op=NVSHMEM_CMP_EQ,
                cmp_val=SIGNAL_VAL,
                extern_libs=nvshmem_lib,
            )
            # After wait completes, verify data and flag contents
            torch.testing.assert_close(
                out, val * torch.ones(numel, dtype=dtype, device=self.device)
            )
            torch.testing.assert_close(
                flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device)
            )

    @skipIfRocm
    @requires_triton()
    def test_triton_put_signal_add(self) -> None:
        # A Triton kernel that calls nvshmem device side API for PUT with SIGNAL
        @triton.jit
        def put_signal_kernel(
            dst_ptr,
            src_ptr,
            numel: tl.constexpr,
            sig_ptr,
            signal_val: tl.constexpr,
            sig_op: tl.constexpr,
            peer: tl.constexpr,
        ):
            nvshmem.putmem_signal_block(
                dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
            )

        torch.manual_seed(42 + self.rank)
        self._init_device()

        nvshmem_lib = nvshmem.enable_triton()

        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank

        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize

        # Data buffers
        val = 11
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        inp_hdl = symm_mem.rendezvous(inp, group=group_name)
        out_hdl = symm_mem.rendezvous(out, group=group_name)

        # Use the signal pad attached to the output symmetric memory handle
        # as the flag buffer for signaling completion.
        flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)

        peer = 1 - rank
        NVSHMEM_SIGNAL_ADD = 5  # atomic add operation
        SIGNAL_VAL = 16  # val + NVSHMEM_SIGNAL_ADD
        NVSHMEM_CMP_EQ = 0

        @triton.jit
        def signal_wait_until_kernel(
            sig_ptr, cmp_op: tl.constexpr, cmp_val: tl.constexpr
        ):
            nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)

        if rank == 0:
            # Rank 0 puts into Rank 1
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
            put_signal_kernel[(1, 1, 1)](
                dst_ptr,
                src_ptr,
                numel=numel,
                sig_ptr=sig_ptr,
                signal_val=SIGNAL_VAL,
                sig_op=NVSHMEM_SIGNAL_ADD,
                peer=peer,
                extern_libs=nvshmem_lib,
            )

        if rank == 1:
            sig_ptr_local = out_hdl.signal_pad_ptrs[rank]
            signal_wait_until_kernel[(1, 1, 1)](
                sig_ptr_local,
                cmp_op=NVSHMEM_CMP_EQ,
                cmp_val=SIGNAL_VAL,
                extern_libs=nvshmem_lib,
            )
            torch.testing.assert_close(
                out, val * torch.ones(numel, dtype=dtype, device=self.device)
            )
            torch.testing.assert_close(
                flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device)
            )

    @skipIfRocm
    @requires_triton()
    def test_triton_wait_until(self) -> None:
        # A Triton kernel that calls nvshmem device side API for PUT
        @triton.jit
        def put_kernel(
            dst_ptr,
            src_ptr,
            numel: tl.constexpr,
            peer: tl.constexpr,
        ):
            nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)

        # A Triton kernel that calls nvshmem device side API for WAIT_UNTIL
        @triton.jit
        def wait_until_kernel(
            ivar_ptr,
            cmp_op: tl.constexpr,
            cmp_val: tl.constexpr,
        ):
            nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)

        torch.manual_seed(42 + self.rank)
        self._init_device()
        nvshmem_lib = nvshmem.enable_triton()
        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank

        # Data buffers
        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize
        val = 13
        flag_val = 21
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        inp_hdl = symm_mem.rendezvous(inp, group=group_name)
        out_hdl = symm_mem.rendezvous(out, group=group_name)
        dist.barrier()

        peer = 1 - rank
        NVSHMEM_CMP_EQ = 0  # from nvshmem.h

        if rank == 0:
            # Rank 0 waits for the flag to be set by Rank 1, then checks the data
            ivar_ptr = out_hdl.signal_pad_ptrs[rank]
            wait_until_kernel[(1, 1, 1)](
                ivar_ptr,
                cmp_op=NVSHMEM_CMP_EQ,
                cmp_val=flag_val,
                extern_libs=nvshmem_lib,
            )
            torch.testing.assert_close(
                out, val * torch.ones(numel, dtype=dtype, device=self.device)
            )

        if rank == 1:
            # Rank 1 puts data into Rank 0's output buffer
            dst_ptr = out_hdl.buffer_ptrs[rank]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            put_kernel[(1, 1, 1)](
                dst_ptr,
                src_ptr,
                numel=numel,
                peer=peer,
                extern_libs=nvshmem_lib,
            )
            # Rank 1 sets the flag on Rank 0
            # We use a temporary tensor for the value to put.
            flag_update_val = torch.tensor(
                [flag_val], dtype=torch.int64, device=self.device
            )
            dst_ptr = out_hdl.signal_pad_ptrs[rank]
            src_ptr = flag_update_val.data_ptr()
            put_kernel[(1, 1, 1)](
                dst_ptr,
                src_ptr,
                numel=1,
                peer=peer,
                extern_libs=nvshmem_lib,
            )

    @skipIfRocm
    @requires_triton()
    def test_triton_signal_wait_until(self) -> None:
        # A Triton kernel that waits on a signal variable until it meets the compare condition.
        @triton.jit
        def signal_wait_until_kernel(
            sig_ptr,
            cmp_op: tl.constexpr,
            cmp_val: tl.constexpr,
        ):
            nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)

        # A Triton kernel for the producer that puts data and then signals completion.
        @triton.jit
        def put_and_signal_kernel(
            dst_ptr,
            src_ptr,
            numel: tl.constexpr,
            sig_ptr,
            signal_val: tl.constexpr,
            sig_op: tl.constexpr,
            peer: tl.constexpr,
        ):
            nvshmem.putmem_signal_block(
                dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
            )

        self._init_device()
        # Enable NVSHMEM for Triton
        nvshmem_lib = nvshmem.enable_triton()
        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank
        peer = 1 - rank

        # NVSHMEM constants from documentation
        NVSHMEM_CMP_EQ = 0  # equal comparison
        NVSHMEM_SIGNAL_SET = 0  # atomic set operation

        # Message configuration
        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize
        val_to_put = 123  # arbitrary test value
        COMPLETION_FLAG_VAL = 1

        # Producer (rank 0) prepares the data to send
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val_to_put)
        inp_hdl = symm_mem.rendezvous(inp, group=group_name)
        # Consumer (rank 1) prepares the destination buffer
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        out_hdl = symm_mem.rendezvous(out, group=group_name)
        # Use the signal pad for synchronization, as in previous tests
        flag_dtype = torch.int64
        flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0)
        # Ensure setup is complete on all ranks before proceeding
        dist.barrier()

        if rank == 0:
            # Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
            put_and_signal_kernel[(1, 1, 1)](
                dst_ptr,
                src_ptr,
                numel,
                sig_ptr,
                signal_val=COMPLETION_FLAG_VAL,
                sig_op=NVSHMEM_SIGNAL_SET,
                peer=peer,
                extern_libs=nvshmem_lib,
            )
        elif rank == 1:
            # Consumer (rank 1): Waits on the signal variable using `signal_wait_until`.
            sig_ptr = out_hdl.signal_pad_ptrs[rank]
            signal_wait_until_kernel[(1, 1, 1)](
                sig_ptr,
                cmp_op=NVSHMEM_CMP_EQ,
                cmp_val=COMPLETION_FLAG_VAL,
                extern_libs=nvshmem_lib,
            )
            # After the wait returns, verify data and flag
            torch.testing.assert_close(
                out, val_to_put * torch.ones(numel, dtype=dtype, device=self.device)
            )
            torch.testing.assert_close(
                flag,
                torch.tensor(
                    [COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device
                ),
            )
        # Final barrier to ensure the test does not exit before assertions complete
        dist.barrier()

    @skipIfRocm
    @requires_triton()
    def test_triton_fence(self) -> None:
        """
        Rank 0 performs two put operations into Rank 1's buffers with a fence
        between them, followed by another fence and a flag update. Rank 1 waits
        for the flag, then verifies that both destination buffers contain the
        expected values. The flag is transferred after the final fence, so
        its arrival implies that both preceding puts have been delivered in
        order.
        """

        # Triton kernel that issues two ordered puts separated by fences and
        # finally writes the completion flag.
        @triton.jit
        def put_with_fence_kernel(
            dst_ptr1,
            dst_ptr2,
            src_ptr1,
            src_ptr2,
            flag_ptr,
            flag_src_ptr,
            numel: tl.constexpr,
            peer: tl.constexpr,
        ):
            # First put
            nvshmem.putmem_block(dst_ptr1, src_ptr1, numel, peer)
            # Ensure the first put is ordered before the next.
            nvshmem.fence()
            # Second put
            nvshmem.putmem_block(dst_ptr2, src_ptr2, numel, peer)
            # Order the second put before flag update.
            nvshmem.fence()
            # Write the flag (single int64) to signal completion.
            nvshmem.putmem_block(flag_ptr, flag_src_ptr, 1, peer)

        # Kernel for Rank 1 to wait until the flag becomes the expected value.
        @triton.jit
        def wait_until_kernel(
            ivar_ptr,
            cmp_op: tl.constexpr,
            cmp_val: tl.constexpr,
        ):
            nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)

        torch.manual_seed(42 + self.rank)
        self._init_device()
        nvshmem_lib = nvshmem.enable_triton()
        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank
        peer = 1 - rank
        # Message configuration
        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize
        val1 = 10
        val2 = 20
        flag_val = 1
        # Symmetric buffers
        inp1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val1)
        inp2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val2)
        out1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        out2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        inp1_hdl = symm_mem.rendezvous(inp1, group=group_name)
        inp2_hdl = symm_mem.rendezvous(inp2, group=group_name)
        out1_hdl = symm_mem.rendezvous(out1, group=group_name)
        out2_hdl = symm_mem.rendezvous(out2, group=group_name)

        # Flag buffer resides in the signal pad of out2.
        flag = out2_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
        flag_update_val = torch.tensor(
            [flag_val], dtype=torch.int64, device=self.device
        )
        NVSHMEM_CMP_EQ = 0  # compare equal
        dist.barrier()

        if rank == 0:
            dst_ptr1 = out1_hdl.buffer_ptrs[rank]
            dst_ptr2 = out2_hdl.buffer_ptrs[rank]
            src_ptr1 = inp1_hdl.buffer_ptrs[rank]
            src_ptr2 = inp2_hdl.buffer_ptrs[rank]
            flag_ptr = out2_hdl.signal_pad_ptrs[rank]
            flag_src_ptr = flag_update_val.data_ptr()

            put_with_fence_kernel[(1, 1, 1)](
                dst_ptr1,
                dst_ptr2,
                src_ptr1,
                src_ptr2,
                flag_ptr,
                flag_src_ptr,
                numel,
                peer=peer,
                extern_libs=nvshmem_lib,
            )
        elif rank == 1:
            # Wait until flag is set by Rank 0.
            ivar_ptr = out2_hdl.signal_pad_ptrs[rank]
            wait_until_kernel[(1, 1, 1)](
                ivar_ptr,
                cmp_op=NVSHMEM_CMP_EQ,
                cmp_val=flag_val,
                extern_libs=nvshmem_lib,
            )

            # Verify ordered data arrival.
            torch.testing.assert_close(
                out1, val1 * torch.ones(numel, dtype=dtype, device=self.device)
            )
            torch.testing.assert_close(
                out2, val2 * torch.ones(numel, dtype=dtype, device=self.device)
            )
            torch.testing.assert_close(
                flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device)
            )
        dist.barrier()

    @skipIfRocm
    @requires_triton()
    def test_triton_quiet(self) -> None:
        # A Triton kernel that uses nvshmem_quiet to ensure completion
        @triton.jit
        def put_with_quiet_kernel(
            dst_ptr,
            src_ptr,
            flag_dst_ptr,
            flag_src_ptr,
            numel: tl.constexpr,
            peer: tl.constexpr,
        ):
            # Put data
            nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
            # Call quiet to ensure put is complete
            nvshmem.quiet()
            # Only after quiet, set the completion flag
            # This ensures the data put is complete before flag is set
            nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer)

        torch.manual_seed(42 + self.rank)
        self._init_device()
        # Enable NVSHMEM for Triton
        nvshmem_lib = nvshmem.enable_triton()
        group_name = dist.group.WORLD.group_name
        symm_mem.enable_symm_mem_for_group(group_name)
        rank = self.rank
        msg_size_bytes = 8
        dtype = torch.int8
        numel = msg_size_bytes // dtype.itemsize
        # Data buffers
        val = 15
        inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
        out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
        inp_hdl = symm_mem.rendezvous(inp, group=group_name)
        out_hdl = symm_mem.rendezvous(out, group=group_name)
        # Use signal pad as completion flag
        flag_val = 42
        peer = 1 - rank
        NVSHMEM_CMP_EQ = 0

        @triton.jit
        def wait_until_kernel(
            ivar_ptr,
            cmp_op: tl.constexpr,
            cmp_val: tl.constexpr,
        ):
            nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)

        dist.barrier()
        if rank == 0:
            # Rank 0 waits for flag from Rank 1
            ivar_ptr = out_hdl.signal_pad_ptrs[rank]
            wait_until_kernel[(1, 1, 1)](
                ivar_ptr,
                cmp_op=NVSHMEM_CMP_EQ,
                cmp_val=flag_val,
                extern_libs=nvshmem_lib,
            )
            # After flag is set, data should be complete due to quiet
            torch.testing.assert_close(
                out, val * torch.ones(numel, dtype=dtype, device=self.device)
            )
        if rank == 1:
            # Rank 1 puts data and flag with quiet in between
            dst_ptr = out_hdl.buffer_ptrs[rank]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            flag_dst_ptr = out_hdl.signal_pad_ptrs[rank]
            # Create a tensor for the flag value
            flag_update_val = torch.tensor(
                [flag_val], dtype=torch.int64, device=self.device
            )
            flag_src_ptr = flag_update_val.data_ptr()
            put_with_quiet_kernel[(1, 1, 1)](
                dst_ptr,
                src_ptr,
                flag_dst_ptr,
                flag_src_ptr,
                numel=numel,
                peer=peer,
                extern_libs=nvshmem_lib,
            )


if __name__ == "__main__":
    run_tests()
