{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from functools import partial\n",
    "\n",
    "import cutlass\n",
    "import cutlass.cute as cute\n",
    "from cutlass.cute.runtime import from_dlpack"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial: Elementwise Add Kernel in CuTe DSL\n",
    "\n",
    "This tutorial demonstrates how to implement a simple elementwise\n",
    "addition kernel using the CuTe DSL (Domain Specific Language).\n",
    "\n",
    "\n",
    "\n",
    "Elementwise Addition\n",
    "---------------------\n",
    "\n",
    "Elementwise addition is a fundamental operation in linear algebra.\n",
    "Given two tensors of the same shape, the operation performs element-wise\n",
    "addition to produce a result tensor of the same shape.\n",
    "\n",
    "For two 2D tensors :math:`A` and :math:`B` of shape :math:`(M, N)`,\n",
    "the elementwise addition operation :math:`C = A + B` is defined as:\n",
    "\n",
    "$\n",
    "   C_{i,j} = A_{i,j} + B_{i,j}\n",
    "$\n",
    "\n",
    "where:\n",
    "\n",
    "- $i \\in [0, M-1]$ represents the row index\n",
    "- $j \\in [0, N-1]$ represents the column index\n",
    "- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ \n",
    "  in tensors $A$, $B$, and $C$ respectively\n",
    "\n",
    "This operation is performed independently for each element position,\n",
    "making it highly parallelizable and well-suited for GPU implementation.\n",
    "\n",
    "Naive Elementwise Add Kernel\n",
    "-----------------------------\n",
    "\n",
    "Let's start with a naive implementation that loads each element from\n",
    "$A$ and $B$, adds them, and stores the result back to $C$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cute.kernel\n",
    "def naive_elementwise_add_kernel(\n",
    "    gA: cute.Tensor,\n",
    "    gB: cute.Tensor,\n",
    "    gC: cute.Tensor,\n",
    "):\n",
    "    tidx, _, _ = cute.arch.thread_idx()\n",
    "    bidx, _, _ = cute.arch.block_idx()\n",
    "    bdim, _, _ = cute.arch.block_dim()\n",
    "\n",
    "    thread_idx = bidx * bdim + tidx\n",
    "\n",
    "    # Map thread index to logical index of input tensor\n",
    "    m, n = gA.shape\n",
    "    ni = thread_idx % n\n",
    "    mi = thread_idx // n\n",
    "\n",
    "    # Map logical index to physical address via tensor layout\n",
    "    a_val = gA[mi, ni]\n",
    "    b_val = gB[mi, ni]\n",
    "\n",
    "    # Perform element-wise addition\n",
    "    gC[mi, ni] = a_val + b_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Structure of the Kernel\n",
    "\n",
    "The naive kernel simply maps each thread to one element with a 1-to-1 mapping.\n",
    "In this kernel, we don't use CuTe layout algebra but only use basic\n",
    "addressing to index the tensor.\n",
    "\n",
    "We can launch the kernel with the following JIT function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cute.jit\n",
    "def naive_elementwise_add(\n",
    "    mA: cute.Tensor,\n",
    "    mB: cute.Tensor,\n",
    "    mC: cute.Tensor\n",
    "):\n",
    "    num_threads_per_block = 256\n",
    "\n",
    "    m, n = mA.shape\n",
    "    kernel = naive_elementwise_add_kernel(mA, mB, mC)\n",
    "    kernel.launch(grid=((m * n) // num_threads_per_block, 1, 1),\n",
    "                  block=(num_threads_per_block, 1, 1))\n",
    "\n",
    "M, N = 2048, 2048\n",
    "\n",
    "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "\n",
    "a_ = from_dlpack(a, assumed_align=16)\n",
    "b_ = from_dlpack(b, assumed_align=16)\n",
    "c_ = from_dlpack(c, assumed_align=16)\n",
    "\n",
    "# Compile kernel\n",
    "naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)\n",
    "naive_elementwise_add_(a_, b_, c_)\n",
    "\n",
    "# verify correctness\n",
    "torch.testing.assert_close(c, a + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Benchmark performance\n",
    "\n",
    "Here's a utility function to benchmark our kernel implementations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark(callable, *, num_warmups, num_iterations):\n",
    "    start_event = torch.cuda.Event(enable_timing=True)\n",
    "    end_event = torch.cuda.Event(enable_timing=True)\n",
    "\n",
    "    torch.cuda.synchronize()\n",
    "\n",
    "    for _ in range(num_warmups):\n",
    "        callable()\n",
    "\n",
    "    start_event.record(stream=torch.cuda.current_stream())\n",
    "    for _ in range(num_iterations):\n",
    "        callable()\n",
    "    end_event.record(stream=torch.cuda.current_stream())\n",
    "    torch.cuda.synchronize()\n",
    "\n",
    "    elapsed_time = start_event.elapsed_time(end_event)\n",
    "    avg_time = elapsed_time / num_iterations\n",
    "\n",
    "    print(f\"Average execution time: {avg_time:.4f} ms\")\n",
    "    print(f\"Throughput: {(3 * a.numel() * 2) / (avg_time / 1000) / 1e9:.2f} GB/s\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average execution time: 0.0385 ms\n",
      "Throughput: 653.44 GB/s\n"
     ]
    }
   ],
   "source": [
    "benchmark(partial(naive_elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Performance Analysis\n",
    "\n",
    "While our naive implementation maps thread indices to contiguous tensor\n",
    "dimensions for coalesced memory access, it doesn't have enough\n",
    "in-flight load & store operations to hide memory latency.\n",
    "\n",
    "According to Little's Law:\n",
    "\n",
    "$ L = \\lambda \\times W $\n",
    "\n",
    "Where:\n",
    "- $L$ is the average number of items in a system\n",
    "- $\\lambda$ is the average arrival rate of items (bandwidth)\n",
    "- $W$ is the average time an item spends in the system (latency)\n",
    "\n",
    "For our elementwise addition kernel:\n",
    "\n",
    "1. $L$: The number of load & store operations in-flight\n",
    "2. $\\lambda$ (Bandwidth): Data transfer rate between memory and compute units\n",
    "3. $W$ (Latency): Round-trip delay of memory requests\n",
    "\n",
    "For memory-bound operations like elementwise addition, performance is\n",
    "limited by the number of in-flight load & store operations.\n",
    "\n",
    "## Vectorized Load and Store\n",
    "\n",
    "To improve performance according to Little's Law, we need to increase the number\n",
    "of in-flight requests. We can do this by increasing the number of bytes handled\n",
    "in each load & store operation per thread through vectorized memory access.\n",
    "\n",
    "Since Ampere GPUs support up to 128-bit per load/store and each element is 32-bit,\n",
    "we can load 4 elements per vectorized operation on contiguous rows.\n",
    "CuTe tiling operations make this vectorization straightforward.\n",
    "\n",
    "Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n",
    "``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\n",
    "as the block of data each thread accesses (4 contiguous elements in the same row, or ``(1,4)``).\n",
    "Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n",
    "\n",
    "```python\n",
    "mA : cute.Tensor                           # (2048,2048):(2048,1)\n",
    "gA = cute.zipped_divide(a, tiler=(1, 4))   # tiled/vectorized => ((1,4),(2048,512)):((0,1),(2048,4))\n",
    "```\n",
    "\n",
    "$\n",
    "    \\begin{array}{ccccc}\n",
    "    & ((1,4) & , & (2048,512)) & : ((0,1),(2048,4)) \\\\\n",
    "    & \\underbrace{\\phantom{(1,4)}}_{tiler} & & \\underbrace{\\phantom{(2048,512)}}_{threads} & \\\\\n",
    "    & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n",
    "    \\end{array}\n",
    "$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cute.kernel\n",
    "def vectorized_elementwise_add_kernel(\n",
    "    gA: cute.Tensor,\n",
    "    gB: cute.Tensor,\n",
    "    gC: cute.Tensor,\n",
    "):\n",
    "    tidx, _, _ = cute.arch.thread_idx()\n",
    "    bidx, _, _ = cute.arch.block_idx()\n",
    "    bdim, _, _ = cute.arch.block_dim()\n",
    "\n",
    "    thread_idx = bidx * bdim + tidx\n",
    "\n",
    "    # Map thread index to logical index of input tensor\n",
    "    m, n = gA.shape[1]       # thread-domain\n",
    "    ni = thread_idx % n\n",
    "    mi = thread_idx // n\n",
    "\n",
    "    # Map logical index to physical address via tensor layout\n",
    "    a_val = gA[(None, (mi, ni))].load()\n",
    "    b_val = gB[(None, (mi, ni))].load()\n",
    "    print(f\"[DSL INFO] sliced gA = {gA[(None, (mi, ni))]}\")\n",
    "    print(f\"[DSL INFO] sliced gB = {gB[(None, (mi, ni))]}\")\n",
    "\n",
    "    # Perform element-wise addition\n",
    "    gC[(None, (mi, ni))] = a_val + b_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\n",
    "with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n",
    "we can extract a `(1,4)` sub-tensor from `gA`, `gB` and `gC` like \n",
    "\n",
    "```python\n",
    "gA[(None, (mi, ni))]\n",
    "\n",
    "```\n",
    "\n",
    "Then tensor data can be loaded into vector via the `.load()` method.\n",
    "\n",
    "\n",
    "```\n",
    "                                         slice\n",
    "    ((1,4),(2048,512)):((0,1),(2048,4))   ==>  ((1,4)):((0,1))\n",
    "       ^     ^    ^\n",
    "       |     |    |\n",
    "     (None, (mi,  ni))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[DSL INFO] Tiled Tensors:\n",
      "[DSL INFO]   gA = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
      "[DSL INFO]   gB = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
      "[DSL INFO]   gC = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
      "[DSL INFO] sliced gA = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>\n",
      "[DSL INFO] sliced gB = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def vectorized_elementwise_add(\n",
    "    mA: cute.Tensor,\n",
    "    mB: cute.Tensor,\n",
    "    mC: cute.Tensor\n",
    "):\n",
    "    threads_per_block = 256\n",
    "\n",
    "    gA = cute.zipped_divide(mA, (1, 4))\n",
    "    gB = cute.zipped_divide(mB, (1, 4))\n",
    "    gC = cute.zipped_divide(mC, (1, 4))\n",
    "\n",
    "    print(f\"[DSL INFO] Tiled Tensors:\")\n",
    "    print(f\"[DSL INFO]   gA = {gA}\")\n",
    "    print(f\"[DSL INFO]   gB = {gB}\")\n",
    "    print(f\"[DSL INFO]   gC = {gC}\")\n",
    "\n",
    "    vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n",
    "        grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n",
    "        block=(threads_per_block, 1, 1),\n",
    "    )\n",
    "\n",
    "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "\n",
    "a_ = from_dlpack(a, assumed_align=16)\n",
    "b_ = from_dlpack(b, assumed_align=16)\n",
    "c_ = from_dlpack(c, assumed_align=16)\n",
    "\n",
    "compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
    "compiled_func(a_, b_, c_)\n",
    "\n",
    "# verify correctness\n",
    "torch.testing.assert_close(c, a + b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average execution time: 0.0202 ms\n",
      "Throughput: 1244.98 GB/s\n"
     ]
    }
   ],
   "source": [
    "benchmark(partial(compiled_func, a_, b_, c_), num_warmups=5, num_iterations=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TV Layout\n",
    "\n",
    "Both the naive and vectorized kernels follow a common pattern to map thread indices\n",
    "to physical addresses:\n",
    "\n",
    "Step 1: Map thread index to logical M/N coordinates\n",
    "\n",
    "```python\n",
    "    mi = thread_idx // n\n",
    "    ni = thread_idx % n\n",
    "```\n",
    "\n",
    "Step 2: Map logical M/N coordinates to physical addresses using the tensor layout\n",
    "\n",
    "```python\n",
    "    a[(None, (mi, ni))].load()\n",
    "```\n",
    "\n",
    "CuTe uses TV layout to represent this mapping from thread index and value index\n",
    "(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n",
    "By configuring different TV layouts, we can experiment with different memory access\n",
    "patterns with minimal code changes.\n",
    "\n",
    "The following example demonstrates two levels of tiling: at the thread-block level\n",
    "and at the thread level.\n",
    "\n",
    "For thread-block level tiling, each input & output tensor is first divided\n",
    "into a group of ``(TileM, TileN)`` sub-tensors at the host side.\n",
    "\n",
    "Inside the GPU kernel, we provide the thread-block index to the 2nd mode of the tiled tensor\n",
    "(``gA[((None, None), bidx)]``), which returns a thread-block local view of\n",
    "a single ``(TileM, TileN)`` sub-tensor.\n",
    "\n",
    "For thread level tiling, we compose the sub-tensor (which maps from logical coordinates\n",
    "to physical addresses) with the TV layout (which maps from thread & value indices to\n",
    "logical coordinates). This gives us a tiled sub-tensor that maps from thread & value\n",
    "indices directly to physical addresses.\n",
    "\n",
    "We then provide the thread index to the tiled sub-tensor (``tidfrgA[(tidx, None)]``)\n",
    "to get a thread-local view of the data each thread accesses. Note that the thread index\n",
    "is now in the 1st mode, as the tiled sub-tensor puts the thread mode before the value mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cute.kernel\n",
    "def elementwise_add_kernel(\n",
    "    gA: cute.Tensor,\n",
    "    gB: cute.Tensor,\n",
    "    gC: cute.Tensor,\n",
    "    tv_layout: cute.Layout\n",
    "):\n",
    "    tidx, _, _ = cute.arch.thread_idx()\n",
    "    bidx, _, _ = cute.arch.block_idx()\n",
    "\n",
    "    #--------------------------------\n",
    "    # slice for thread-block level view\n",
    "    #--------------------------------\n",
    "    blk_coord = ((None, None), bidx)\n",
    "\n",
    "    # logical coord -> address\n",
    "    blkA = gA[blk_coord]  # (TileM, TileN) -> physical address\n",
    "    blkB = gB[blk_coord]  # (TileM, TileN) -> physical address\n",
    "    blkC = gC[blk_coord]  # (TileM, TileN) -> physical address\n",
    "\n",
    "    #--------------------------------\n",
    "    # compose for thread-index & value-index to physical mapping\n",
    "    #--------------------------------\n",
    "    # blockA:    (TileM, TileN) -> physical address\n",
    "    # tv_layout: (tid, vid)     -> (TileM, TileN)\n",
    "    # tidfrgA = blkA o tv_layout\n",
    "    # tidfrgA:   (tid, vid) -> physical address\n",
    "    tidfrgA = cute.composition(blkA, tv_layout)\n",
    "    tidfrgB = cute.composition(blkB, tv_layout)\n",
    "    tidfrgC = cute.composition(blkC, tv_layout)\n",
    "\n",
    "    print(f\"Composed with TV layout:\")\n",
    "    print(f\"  tidfrgA: {tidfrgA.type}\")\n",
    "\n",
    "    #--------------------------------\n",
    "    # slice for thread-level view\n",
    "    #--------------------------------\n",
    "    # `None` represent slice of the entire per-thread data\n",
    "    thr_coord = (tidx, None)\n",
    "\n",
    "    # slice for threads: vid -> address\n",
    "    thrA = tidfrgA[thr_coord]  # (V) -> physical address\n",
    "    thrB = tidfrgB[thr_coord]  # (V) -> physical address\n",
    "    thrC = tidfrgC[thr_coord]  # (V) -> physical address\n",
    "\n",
    "    thrC[None] = thrA.load() + thrB.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we take a closer look at the layout of zipped divided input tensor `gA`:\n",
    "\n",
    "```\n",
    "Tiled to Thread Block:\n",
    "\n",
    "    ((16,256),(128,8))  : ((2048,1),(32768,256))\n",
    "     ~~~~~~~~  ~~~~~~      ~~~~~~~~\n",
    "        |        |            |\n",
    "        |        |            |\n",
    "        |        `------------------------> Number of Thread Blocks\n",
    "        |                     |\n",
    "        |                     |\n",
    "        `--------------------'\n",
    "                  |\n",
    "                  V\n",
    "             Thread Block\n",
    "                 Tile\n",
    "\n",
    "Sliced to Thread-Block local sub-tensor (a (16, 256) tile):  gA[((None, None), bidx)]\n",
    "\n",
    "    (16,256)   :  (2048,1)\n",
    "     ~~~~~~        ~~~~~~\n",
    "        |             |        Tiled/Composed with TV Layout\n",
    "        |             |    \n",
    "        |             |    o   ((32,4),(8,4)):((128,4),(16,1))\n",
    "        V             V         \n",
    "~~~~~~~~~~~~~~~     ~~~~~~~~~~~~~~~~~~~ \n",
    "((32,4), (8,4))  :  ((4,8192),(1,2048))\n",
    "    |      |\n",
    "    |      `--------> per thread fragment\n",
    "    |\n",
    "Thread Block\n",
    "  Shape\n",
    "\n",
    "Sliced to Thread local sub-tensor (a (4,8) tile):  tidfrgA[(tidx, None)]\n",
    "\n",
    "```\n",
    "\n",
    "The host code below shows the construction of the TV layout. By composing\n",
    "a thread layout of ``(4,32):(32,1)`` (32 threads read contiguous elements on the row dimension,\n",
    "then 4 warps read different rows) with a value layout of ``(4,8):(8,1)`` (each thread reads\n",
    "8 contiguous elements on the row dimension across 4 contiguous rows),\n",
    "we obtain the TV layout shown in the figure above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tiler: (16, 256)\n",
      "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
      "Tiled Input Tensors:\n",
      "  gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "  gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "  gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "Composed with TV layout:\n",
      "  tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def elementwise_add(\n",
    "    mA: cute.Tensor,\n",
    "    mB: cute.Tensor,\n",
    "    mC: cute.Tensor,\n",
    "):\n",
    "    # mA layout: (M, N):(N, 1)\n",
    "    # TV layout map thread & value index to (16, 256) logical tile\n",
    "    #  - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
    "    #     mode-1 for coalesced load-store\n",
    "    #  - each thread load 8 contiguous element each row and load 4 rows\n",
    "    thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n",
    "    val_layout = cute.make_layout((4, 8), stride=(8, 1))\n",
    "    tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
    "    print(f\"Tiler: {tiler_mn}\")\n",
    "    print(f\"TV Layout: {tv_layout}\")\n",
    "\n",
    "    gA = cute.zipped_divide(mA, tiler_mn)  # ((TileM, TileN), (RestM, RestN))\n",
    "    gB = cute.zipped_divide(mB, tiler_mn)  # ((TileM, TileN), (RestM, RestN))\n",
    "    gC = cute.zipped_divide(mC, tiler_mn)  # ((TileM, TileN), (RestM, RestN))\n",
    "\n",
    "    print(f\"Tiled Input Tensors:\")\n",
    "    print(f\"  gA: {gA.type}\")\n",
    "    print(f\"  gB: {gB.type}\")\n",
    "    print(f\"  gC: {gC.type}\")\n",
    "\n",
    "    # Launch the kernel asynchronously\n",
    "    # Async token(s) can also be specified as dependencies\n",
    "    elementwise_add_kernel(\n",
    "        gA, gB, gC, tv_layout\n",
    "    ).launch(\n",
    "        grid=[cute.size(gC, mode=[1]), 1, 1],\n",
    "        block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
    "    )\n",
    "\n",
    "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "\n",
    "a_ = from_dlpack(a, assumed_align=16)\n",
    "b_ = from_dlpack(b, assumed_align=16)\n",
    "c_ = from_dlpack(c, assumed_align=16)\n",
    "\n",
    "elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)\n",
    "elementwise_add_(a_, b_, c_)\n",
    "\n",
    "# verify correctness\n",
    "torch.testing.assert_close(c, a + b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average execution time: 0.0222 ms\n",
      "Throughput: 1133.58 GB/s\n"
     ]
    }
   ],
   "source": [
    "benchmark(partial(elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using Lambda Function\n",
    "\n",
    "CuTe DSL is built on top of Python. It can leverage Python to implement meta-programming to generate flexible kernels.\n",
    "E.g. we can write kernel template that take custom binary operations to generate kernels for arbitrary binary operations.\n",
    "\n",
    "\n",
    "```python\n",
    "@cute.jit\n",
    "def elementwise_apply(\n",
    "    op: cutlass.Constexpr,\n",
    "    mA: cute.Tensor,\n",
    "    mB: cute.Tensor,\n",
    "    mC: cute.Tensor\n",
    "):\n",
    "    ...\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tiler: (16, 256)\n",
      "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
      "Tiled Input Tensors:\n",
      "  gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "  gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "  gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "Composed with TV layout:\n",
      "  tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
     ]
    }
   ],
   "source": [
    "@cute.kernel\n",
    "def elementwise_apply_kernel(\n",
    "    op: cutlass.Constexpr,    # lambda function must be const expr to generate code at compile time\n",
    "    gA: cute.Tensor,\n",
    "    gB: cute.Tensor,\n",
    "    gC: cute.Tensor,\n",
    "    tv_layout: cute.Layout\n",
    "):\n",
    "    tidx, _, _ = cute.arch.thread_idx()\n",
    "    bidx, _, _ = cute.arch.block_idx()\n",
    "\n",
    "    blk_coord = ((None, None), bidx)\n",
    "\n",
    "    # logical coord -> address\n",
    "    blkA = gA[blk_coord]  # (TileM, TileN) -> physical address\n",
    "    blkB = gB[blk_coord]  # (TileM, TileN) -> physical address\n",
    "    blkC = gC[blk_coord]  # (TileM, TileN) -> physical address\n",
    "\n",
    "    tidfrgA = cute.composition(blkA, tv_layout)\n",
    "    tidfrgB = cute.composition(blkB, tv_layout)\n",
    "    tidfrgC = cute.composition(blkC, tv_layout)\n",
    "\n",
    "    print(f\"Composed with TV layout:\")\n",
    "    print(f\"  tidfrgA: {tidfrgA.type}\")\n",
    "\n",
    "    thr_coord = (tidx, None)\n",
    "\n",
    "    # slice for threads: vid -> address\n",
    "    thrA = tidfrgA[thr_coord]  # (V) -> physical address\n",
    "    thrB = tidfrgB[thr_coord]  # (V) -> physical address\n",
    "    thrC = tidfrgC[thr_coord]  # (V) -> physical address\n",
    "\n",
    "    #--------------------------------\n",
    "    # apply custom operation\n",
    "    #--------------------------------\n",
    "    thrC[None] = op(thrA.load(), thrB.load())\n",
    "\n",
    "\n",
    "@cute.jit\n",
    "def elementwise_op(\n",
    "    op: cutlass.Constexpr,\n",
    "    mA: cute.Tensor,\n",
    "    mB: cute.Tensor,\n",
    "    mC: cute.Tensor,\n",
    "):\n",
    "    # mA layout: (M, N):(N, 1)\n",
    "    # TV layout map thread & value index to (16, 256) logical tile\n",
    "    #  - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
    "    #     mode-1 for coalesced load-store\n",
    "    #  - each thread load 8 contiguous element each row and load 4 rows\n",
    "    thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n",
    "    val_layout = cute.make_layout((4, 8), stride=(8, 1))\n",
    "    tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
    "    print(f\"Tiler: {tiler_mn}\")\n",
    "    print(f\"TV Layout: {tv_layout}\")\n",
    "\n",
    "    gA = cute.zipped_divide(mA, tiler_mn)  # ((TileM, TileN), (RestM, RestN))\n",
    "    gB = cute.zipped_divide(mB, tiler_mn)  # ((TileM, TileN), (RestM, RestN))\n",
    "    gC = cute.zipped_divide(mC, tiler_mn)  # ((TileM, TileN), (RestM, RestN))\n",
    "\n",
    "    print(f\"Tiled Input Tensors:\")\n",
    "    print(f\"  gA: {gA.type}\")\n",
    "    print(f\"  gB: {gB.type}\")\n",
    "    print(f\"  gC: {gC.type}\")\n",
    "\n",
    "    # Launch the kernel asynchronously\n",
    "    # Async token(s) can also be specified as dependencies\n",
    "    elementwise_apply_kernel(\n",
    "        op, gA, gB, gC, tv_layout\n",
    "    ).launch(\n",
    "        grid=[cute.size(gC, mode=[1]), 1, 1],\n",
    "        block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
    "    )\n",
    "\n",
    "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
    "\n",
    "a_ = from_dlpack(a, assumed_align=16)\n",
    "b_ = from_dlpack(b, assumed_align=16)\n",
    "c_ = from_dlpack(c, assumed_align=16)\n",
    "\n",
    "from operator import mul\n",
    "\n",
    "elementwise_op(mul, a_, b_, c_)\n",
    "\n",
    "# verify correctness\n",
    "torch.testing.assert_close(c, mul(a, b))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Custom operators can be more complex. For example, here's a function that performs\n",
    "multiplication followed by ReLU:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tiler: (16, 256)\n",
      "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
      "Tiled Input Tensors:\n",
      "  gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "  gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "  gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
      "Composed with TV layout:\n",
      "  tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
     ]
    }
   ],
   "source": [
    "def mul_relu(a, b):\n",
    "    tmp = a * b\n",
    "    return cute.where(tmp > 0, tmp, cute.full_like(tmp, 0))\n",
    "\n",
    "\n",
    "# As we uses cute.where in customized operation, we need to create another relu function\n",
    "def mul_relu_ref(a, b):\n",
    "    tmp = a * b\n",
    "    return torch.relu(tmp)\n",
    "\n",
    "\n",
    "elementwise_op(mul_relu, a_, b_, c_)\n",
    "\n",
    "# verify correctness\n",
    "torch.testing.assert_close(c, mul_relu_ref(a, b))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
