{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how to compute a layernorm forward operation using cuDNN.\n",
    "\n",
    "$$\\text{LayerNorm}(x) = \\frac{x-\\mu}{\\sqrt{\\sigma^2 + \\epsilon}}\\cdot\\gamma+\\beta$$\n",
    "\n",
    "Where $\\mu = E[x]$ and $\\sigma^2 = Var[x]$ are taken over all inputs in a batch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites and Setup\n",
    "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup\n",
    "Create a cudnn handle, which is a per device handle used to initialize cudnn context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "torch.manual_seed(1)\n",
    "handle = cudnn.create_handle()\n",
    "\n",
    "print(\"Running with cudnn backend version:\", cudnn.backend_version())\n",
    "\n",
    "assert torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LayerNorm Training\n",
    "Problem Sizes:\n",
    "- Batch Size: 4\n",
    "- Sequence Size: 1024\n",
    "- Embedding Dimension: 768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch, seq_size, embedding_dim = 4, 1024, 768\n",
    "\n",
    "input_type = torch.float16\n",
    "\n",
    "# Epsilon is a small number to prevent division by 0.\n",
    "epsilon_value = 1e-3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create input tensor GPU buffers. We use PyTorch to allocate GPU tensors so we can reuse them easily when we calculate reference outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# allocate input tensor memory, initialize them to random numbers\n",
    "x_gpu = torch.randn(\n",
    "    batch * seq_size,\n",
    "    embedding_dim,\n",
    "    1,\n",
    "    1,\n",
    "    dtype=input_type,\n",
    "    requires_grad=True,\n",
    "    device=\"cuda\",\n",
    ").to(memory_format=torch.channels_last)\n",
    "scale_gpu = torch.randn(\n",
    "    1, embedding_dim, 1, 1, dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    ").to(memory_format=torch.channels_last)\n",
    "bias_gpu = torch.randn(\n",
    "    1, embedding_dim, 1, 1, dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    ").to(memory_format=torch.channels_last)\n",
    "\n",
    "# Epsilon must be a scalar value on the cpu. \n",
    "epsilon_cpu = torch.full(\n",
    "    (1, 1, 1, 1), epsilon_value, dtype=torch.float32, requires_grad=False, device=\"cpu\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute reference ouputs and allocate output tensor GPU buffers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the reference computation outputs here before the cuDNN computation, in order to use .empty_like() to create our output buffers\n",
    "out_expected = torch.nn.functional.layer_norm(\n",
    "    x_gpu,\n",
    "    [embedding_dim, 1, 1],\n",
    "    weight=scale_gpu.squeeze(0),\n",
    "    bias=bias_gpu.squeeze(0),\n",
    "    eps=epsilon_value,\n",
    ")\n",
    "\n",
    "mean_expected = x_gpu.to(torch.float32).mean(dim=(1, 2, 3), keepdim=True)\n",
    "\n",
    "inv_var_expected = torch.rsqrt(\n",
    "    torch.var(x_gpu.to(torch.float32), dim=(1, 2, 3), keepdim=True) + epsilon_value\n",
    ")\n",
    "\n",
    "# Allocate output tensor memory using PyTorch\n",
    "# PyTorch has calculated their shapes already, so we can simply use .empty_like()\n",
    "out_gpu = torch.empty_like(out_expected)\n",
    "mean_gpu = torch.empty_like(mean_expected)\n",
    "inv_var_gpu = torch.empty_like(inv_var_expected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create cuDNN graph\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we assign UIDs for tensors. UIDs are a unique identifier that will allow us to provide a mapping from tensors created from cuDNN graph api calls, such as `graph.tensor_like()`, to the underlying device memory that will be used to store these tensors. Virtual tensors don't require explicit memory allocated for them, but non-vritual tensors like inputs or outputs will need to have UIDs assigned to them. \n",
    "\n",
    "Alternatively, one can use handles directly in the mapping, however using UIDs can be more convinient for caching of cuDNN graphs.\n",
    "\n",
    "For each of our inputs {X, Scale, Bias, Epsilon} and our outputs {Out, Mean, Inverse Variance}, we allocate a UID. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from enum import Enum\n",
    "\n",
    "class UID(Enum):\n",
    "    X = 0\n",
    "    SCALE = 1\n",
    "    BIAS = 2\n",
    "    EPSILON = 3\n",
    "    OUT = 4\n",
    "    MEAN = 5\n",
    "    INV_VAR = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the cuDNN graph.\n",
    "graph = cudnn.pygraph(\n",
    "    handle=handle,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "# Create tensor handles with the graph API, assign UIDs.\n",
    "x = graph.tensor_like(x_gpu.detach()).set_name(\"X\").set_uid(UID.X.value)\n",
    "scale = graph.tensor_like(scale_gpu.detach()).set_name(\"scale\").set_uid(UID.SCALE.value)\n",
    "bias = graph.tensor_like(bias_gpu.detach()).set_name(\"bias\").set_uid(UID.BIAS.value)\n",
    "epsilon = graph.tensor_like(epsilon_cpu).set_name(\"epsilon\").set_uid(UID.EPSILON.value)\n",
    "\n",
    "# Add a layernorm operation\n",
    "(out, mean, inv_var) = graph.layernorm(\n",
    "    name=\"layernorm\",\n",
    "    input=x,\n",
    "    norm_forward_phase=cudnn.norm_forward_phase.TRAINING,\n",
    "    scale=scale,\n",
    "    bias=bias,\n",
    "    epsilon=epsilon, \n",
    ")\n",
    "\n",
    "# Enable all outputs, by default outputs are disabled\n",
    "out.set_name(\"output\").set_output(True).set_data_type(out_expected.dtype).set_uid(\n",
    "    UID.OUT.value\n",
    ")\n",
    "mean.set_name(\"mean\").set_output(True).set_data_type(mean_expected.dtype).set_uid(\n",
    "    UID.MEAN.value\n",
    ")\n",
    "inv_var.set_name(\"inv_var\").set_output(True).set_data_type(\n",
    "    inv_var_expected.dtype\n",
    ").set_uid(UID.INV_VAR.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Build the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the graph\n",
    "graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "\n",
    "# To run this block more than once, we need to re-run the previous block to get a new graph.\n",
    "# The same instance of a graph can not be built twice."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Execute the graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After validating and building a cuDNN graph,  we can now execute it. To do this, we have to provide input and output buffers. We do this by using the previously allocated UIDs to associate between tensor handles generated from the graph API, and their underlying memory. \n",
    "\n",
    "The desired input values need to be stored in these buffers before the `graph.execute` call. Because we have done a reference computation, we can simply reuse the buffers we have allocated via PyTorch.\n",
    "\n",
    "Note that the EPISLON UID expects a cpu buffer, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mapping of (UIDs -> memory)\n",
    "variant_pack = {\n",
    "    UID.X.value: x_gpu,\n",
    "    UID.SCALE.value: scale_gpu,\n",
    "    UID.BIAS.value: bias_gpu,\n",
    "    UID.EPSILON.value: epsilon_cpu,\n",
    "    UID.OUT.value: out_gpu,\n",
    "    UID.MEAN.value: mean_gpu,\n",
    "    UID.INV_VAR.value: inv_var_gpu,\n",
    "}\n",
    "\n",
    "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n",
    "graph.execute(variant_pack, workspace)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test cuDNN's output against PyTorch's and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.synchronize()\n",
    "\n",
    "# compare to reference output\n",
    "torch.testing.assert_close(out_gpu, out_expected, rtol=5e-3, atol=5e-3)\n",
    "torch.testing.assert_close(inv_var_gpu, inv_var_expected, rtol=5e-3, atol=5e-3)\n",
    "torch.testing.assert_close(mean_gpu, mean_expected, rtol=5e-3, atol=5e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perform Cleanup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cudnn.destroy_handle(handle)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
