{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5d24a692",
   "metadata": {},
   "source": [
    "# Example of using epilogue visitor in the CUTLASS Python interface\n",
    "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues through CUTLASS Epilogue Visitor.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/04_epilogue_visitor.ipynb)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a800e79",
   "metadata": {},
   "source": [
    "## Prerequisites for running on Colab\n",
    "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cfff2c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "!#nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06706f00",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the CUTLASS Python interface. To do so, uncomment the following line and run the cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "491a7314",
   "metadata": {},
   "outputs": [],
   "source": [
    "!#pip install nvidia-cutlass"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "962324fd",
   "metadata": {},
   "source": [
    "## General setup\n",
    "We first import various packages needed for the example, construct the input and output tensors that will be used in our example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63a70a3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import cutlass\n",
    "from cutlass.epilogue import relu\n",
    "from cutlass import Tensor as FakeTensor\n",
    "from cutlass.utils.profiler import CUDAEventProfiler\n",
    "\n",
    "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
    "# omit this information.\n",
    "print_module = True\n",
    "\n",
    "# The Epilogue Visitor feature currently only works for SM80 and 90\n",
    "from cutlass.backend.utils.device import device_cc\n",
    "if device_cc() not in [80, 90]:\n",
    "    import sys\n",
    "    sys.exit()\n",
    "\n",
    "m = 16384\n",
    "n = m\n",
    "k = 512\n",
    "\n",
    "type_A = torch.float16\n",
    "type_B = torch.float16\n",
    "type_C = torch.float16\n",
    "type_D = torch.float16\n",
    "\n",
    "torch.manual_seed(2023)\n",
    "scope_min = -4\n",
    "scope_max = 4\n",
    "tensor_A = torch.ceil(torch.empty(size=(m, k), dtype=type_A, device=\"cuda\").uniform_(scope_min, scope_max))\n",
    "tensor_B = torch.ceil(torch.empty(size=(k, n), dtype=type_B, device=\"cuda\").uniform_(scope_min, scope_max))\n",
    "tensor_C = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n",
    "tensor_D = torch.zeros_like(tensor_C)\n",
    "\n",
    "plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor, element_accumulator=torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1eb0d95b",
   "metadata": {},
   "source": [
    "## Define the epilogue visitor functor\n",
    "The epilogue functor can be defined as a simple Python function and a set of example tensors for inputs and outputs. The example below illustrates a complex epilogue under the directed acyclic graph structure (`F` is used twice). The epilogue takes source tensors in different ranks: `alpha`, `beta` are scalars, `bias` is a column vector to broadcast, and `C`, `aux` are matrices. It contains various math operations from basic arithmatic operations and built-in callable functions like `relu`. It also accomodates multiple outputs `D` and `F`. Note that there are some restrictions on syntax.\n",
    "* Each named variable must be assigned exactly once and defined before it used.\n",
    "* Reserved names: `accum`, `C`, and `D` are reserved for accumulator, tensor_C, and tensor_D.\n",
    "* Return values must be a named variable.\n",
    "\n",
    "The example tensors is a dictionary with tensor names as keys and reference tensors as values. The reference tensors can be `float`, `torch.Tensor`, `numpy.ndarray`, or our `FakeTensor`. They provides the shape and data type information of the inputs and outputs of the epilogue.\n",
    "\n",
    "The epilogue can be generated simply through `cutlass.evt.trace(<epilogue function>, <example_tensors>)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d257833",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define epilogue visitor\n",
    "def example_epilogue(accum, alpha, C, beta, aux, bias):\n",
    "    F = alpha * accum + (beta * C + aux)\n",
    "    E = relu(F + 1) + bias\n",
    "    D = E + F\n",
    "    return D, F\n",
    "\n",
    "# Construct inputs and outputs\n",
    "alpha = 0.5\n",
    "beta = 0.5\n",
    "aux = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n",
    "bias = torch.ceil(torch.empty(size=(m, 1), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n",
    "tensor_F = torch.zeros_like(tensor_D)\n",
    "examples_tensors = {\n",
    "    \"accum\": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),\n",
    "    \"alpha\": alpha,\n",
    "    \"C\": tensor_C,\n",
    "    \"beta\": beta,\n",
    "    \"aux\": aux,\n",
    "    \"bias\": bias,\n",
    "    \"D\": tensor_D,\n",
    "    \"F\": tensor_F\n",
    "}\n",
    "\n",
    "# Trace the epilogue visitor\n",
    "epilogue_visitor = cutlass.epilogue.trace(example_epilogue, examples_tensors)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54961694",
   "metadata": {},
   "source": [
    "## Run a GEMM with the epilogue visitor functor\n",
    "The `epilogue_visitor` can be used by setting the plan's `epilogue_visitor` field. The arguments for the epilogue visitor are provided as a `dict` through the `visitor_args` keyword argument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fe49443",
   "metadata": {},
   "outputs": [],
   "source": [
    "visitor_args = {\n",
    "    \"alpha\": alpha, \"C\": tensor_C, \"beta\": beta, \n",
    "    \"aux\": aux, \"bias\": bias, \"D\": tensor_D, \"F\": tensor_F\n",
    "}\n",
    "\n",
    "plan.epilogue_visitor = epilogue_visitor\n",
    "plan.run(\n",
    "    tensor_A, tensor_B, tensor_C, tensor_D, \n",
    "    visitor_args=visitor_args, print_module=print_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "455d0a37",
   "metadata": {},
   "source": [
    "The epilogue function `example_epilogue` can be used as a reference function. We can now verify the results simply with"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e32e7798",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchReference(torch.nn.Module):\n",
    "    def forward(self, A, B, alpha, C, beta, aux, bias):\n",
    "        accum = torch.matmul(A, B)\n",
    "        return example_epilogue(accum, alpha, C, beta, aux, bias)\n",
    "\n",
    "torch_reference = TorchReference()\n",
    "tensor_D_ref, tensor_F_ref = torch_reference(tensor_A, tensor_B, alpha, tensor_C, beta, aux, bias)\n",
    "\n",
    "assert torch.equal(tensor_D, tensor_D_ref)\n",
    "assert torch.equal(tensor_F, tensor_F_ref)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b69e441f",
   "metadata": {},
   "source": [
    "The performance of CUTLASS fused kernel can be profiled with"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db92150",
   "metadata": {},
   "outputs": [],
   "source": [
    "warmup_iterations = 10\n",
    "profile_iterations = 50\n",
    "# Profile CUTLASS fused kernel\n",
    "duration = CUDAEventProfiler(\n",
    "    plan, warmup_iterations, profile_iterations,\n",
    "    tensor_A, tensor_B, tensor_C, tensor_D, \n",
    "    visitor_args=visitor_args)()\n",
    "\n",
    "print(f\"CUTLASS duration: {duration:.2f} ms\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
