{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how to compute an RMS norm using the cuDNN python frontend.\n",
    "\n",
    "$$\\text{RMSNorm}(x) = \\frac{x}{\\sqrt{\\mathbb{E}(x^2) + \\epsilon}}\\cdot\\gamma+\\beta$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites and Setup\n",
    "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup\n",
    "The cudnn handle is a per device handle used to initialize cudnn context.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "torch.manual_seed(1)\n",
    "handle = cudnn.create_handle()\n",
    "\n",
    "print(\"Running with cudnn backend version:\", cudnn.backend_version())\n",
    "\n",
    "assert torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RMSNorm Reference Computation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reference Model:\n",
    "class RMSNorm(torch.nn.Module):\n",
    "    \"\"\"Root Mean Square Layer Normalization.\n",
    "\n",
    "    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:\n",
    "    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, dim: int = -1, eps: float = 1e-5) -> None:\n",
    "        super().__init__()\n",
    "        self.eps = eps\n",
    "        self.dim = dim\n",
    "\n",
    "    def forward(\n",
    "        self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None\n",
    "    ) -> torch.Tensor:\n",
    "        # NOTE: the original RMSNorm paper implementation is not equivalent\n",
    "        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)\n",
    "        print(norm_x.shape)\n",
    "        inv_var = torch.rsqrt(norm_x + self.eps)\n",
    "        x_normed = x * inv_var\n",
    "        x_scaled = weight * x_normed\n",
    "        if bias is not None:\n",
    "            x_scaled += bias\n",
    "        return x_scaled, inv_var"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Problem Sizes\n",
    "- Batch Size: 4 \n",
    "- Sequence Length: 1024\n",
    "- Hidden Size: 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch, seq_length, hidden_size = 4, 1024, 128\n",
    "\n",
    "input_type = torch.float16\n",
    "\n",
    "# Epsilon is a small number to prevent division by 0.\n",
    "epsilon_value = 1e-3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create input tensor GPU buffers. We use PyTorch to allocate GPU tensors so we can reuse them easily when we calculate reference outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input tensor memory, initialize them to random numbers\n",
    "x_gpu = (\n",
    "    2\n",
    "    * torch.randn(\n",
    "        batch * seq_length,\n",
    "        hidden_size,\n",
    "        1,\n",
    "        1,\n",
    "        dtype=input_type,\n",
    "        requires_grad=True,\n",
    "        device=\"cuda\",\n",
    "    ).to(memory_format=torch.channels_last)\n",
    "    - 1.25\n",
    ")\n",
    "\n",
    "scale_gpu = (\n",
    "    3\n",
    "    * torch.randn(\n",
    "        1, hidden_size, 1, 1, dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    "    ).to(memory_format=torch.channels_last)\n",
    "    - 2.75\n",
    ")\n",
    "bias_gpu = torch.randn(\n",
    "    1, hidden_size, 1, 1, dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    ").to(memory_format=torch.channels_last)\n",
    "\n",
    "# set epsilon to epsilon_value, allocate on cpu.\n",
    "epsilon_cpu = torch.full(\n",
    "    (1, 1, 1, 1), epsilon_value, dtype=torch.float32, requires_grad=False, device=\"cpu\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute reference ouputs and allocate output tensor GPU buffers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we create the reference computation outputs here so we can use .empty_like() to create our output buffers\n",
    "model = RMSNorm(eps=epsilon_value, dim=(1, 2, 3)).float()\n",
    "out_expected, inv_var_expected = model(x_gpu, scale_gpu, bias_gpu)\n",
    "\n",
    "# allocate output tensor memory using PyTorch\n",
    "# PyTorch has calculated their shapes already, so we can simply use .empty_like()\n",
    "out_gpu = torch.empty_like(out_expected)\n",
    "inv_var_gpu = torch.empty_like(inv_var_expected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create cuDNN graph and tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cudnn.backend_version_string() >= \"9.1.0\":\n",
    "    handle = cudnn.create_handle()\n",
    "\n",
    "    # create cuDNN graph\n",
    "    graph = cudnn.pygraph(\n",
    "        handle=handle,\n",
    "        intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "        compute_data_type=cudnn.data_type.FLOAT,\n",
    "    )\n",
    "\n",
    "    # create tensor handles with the graph API\n",
    "    x = graph.tensor_like(x_gpu.detach()).set_name(\"X\")\n",
    "    scale = graph.tensor_like(scale_gpu.detach()).set_name(\"scale\")\n",
    "    bias = graph.tensor_like(bias_gpu.detach()).set_name(\"bias\")\n",
    "    epsilon = graph.tensor_like(epsilon_cpu).set_name(\"epsilon\")\n",
    "\n",
    "    (out, inv_var) = graph.rmsnorm(\n",
    "        name=\"rmsnorm\",\n",
    "        input=x,\n",
    "        norm_forward_phase=cudnn.norm_forward_phase.TRAINING,\n",
    "        scale=scale,\n",
    "        bias=bias,\n",
    "        epsilon=epsilon,\n",
    "    )\n",
    "\n",
    "    # enable all outputs\n",
    "    out.set_name(\"output\").set_output(True).set_data_type(out_expected.dtype)\n",
    "    inv_var.set_name(\"inv_var\").set_output(True).set_data_type(inv_var_expected.dtype);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Build the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cudnn.backend_version_string() >= \"9.1.0\":\n",
    "    # Build the graph\n",
    "    graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "\n",
    "# To run this block more than once, we need to re-run the previous block to get a new graph.\n",
    "# The same instance of a graph should not be built twice."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Execute the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cudnn.backend_version_string() >= \"9.1.0\":\n",
    "    # Mapping of (handles -> memory)\n",
    "    variant_pack = {\n",
    "        x: x_gpu.detach(),\n",
    "        scale: scale_gpu.detach(),\n",
    "        bias: bias_gpu.detach(),\n",
    "        epsilon: epsilon_cpu,\n",
    "        out: out_gpu,\n",
    "        inv_var: inv_var_gpu,\n",
    "    }\n",
    "\n",
    "    workspace = torch.empty(\n",
    "        graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8\n",
    "    )\n",
    "    graph.execute(variant_pack, workspace)\n",
    "    torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test cuDNN's output against PyTorch's and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reference output\n",
    "if cudnn.backend_version_string() >= \"9.1.0\":\n",
    "\n",
    "    torch.testing.assert_close(out_gpu, out_expected, rtol=5e-3, atol=5e-3)\n",
    "    torch.testing.assert_close(inv_var_gpu, inv_var_expected, rtol=5e-3, atol=5e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### RMSNorm Backwards Pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = torch.randn_like(out_expected)\n",
    "criterion = torch.nn.MSELoss()  # TODO: What is this?\n",
    "loss = criterion(out_expected, target)\n",
    "\n",
    "out_expected.retain_grad()\n",
    "x_gpu.retain_grad()\n",
    "scale_gpu.retain_grad()\n",
    "bias_gpu.retain_grad()\n",
    "\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cudnn.backend_version_string() >= \"9.1.0\":\n",
    "\n",
    "    bwd_graph = cudnn.pygraph(\n",
    "        handle=handle,\n",
    "        intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "        compute_data_type=cudnn.data_type.FLOAT,\n",
    "    )\n",
    "\n",
    "    d_out = bwd_graph.tensor_like(out_expected.grad)\n",
    "\n",
    "    x_bwd = bwd_graph.tensor_like(x, name=\"x\")\n",
    "    scale_bwd = bwd_graph.tensor_like(scale, name=\"scale\")\n",
    "    inv_var_bwd = bwd_graph.tensor_like(inv_var, name=\"inv_var\")\n",
    "\n",
    "    (d_x, d_scale, d_bias) = bwd_graph.rmsnorm_backward(\n",
    "        name=\"d_rmsnorm\",\n",
    "        grad=d_out,\n",
    "        input=x_bwd,\n",
    "        scale=scale_bwd,\n",
    "        inv_variance=inv_var_bwd,\n",
    "        has_dbias=True,\n",
    "    )\n",
    "\n",
    "    d_x.set_output(True).set_data_type(x_gpu.dtype)\n",
    "    d_scale.set_output(True).set_data_type(x_gpu.dtype)\n",
    "    d_bias.set_output(True).set_data_type(x_gpu.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cudnn.backend_version_string() >= \"9.1.0\":\n",
    "    # Build the bwd_graph\n",
    "    bwd_graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "\n",
    "    # Create output buffers for gradients\n",
    "    d_x_gpu = torch.empty_like(x_gpu)\n",
    "    d_scale_gpu = torch.empty_like(scale_gpu)\n",
    "    d_bias_gpu = torch.empty_like(bias_gpu)\n",
    "\n",
    "    workspace = torch.empty(\n",
    "        bwd_graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8\n",
    "    )\n",
    "\n",
    "    bwd_graph.execute(\n",
    "        {\n",
    "            x_bwd: x_gpu.detach(),\n",
    "            scale_bwd: scale_gpu.detach(),\n",
    "            d_out: out_expected.grad,\n",
    "            inv_var_bwd: inv_var_gpu.detach(),\n",
    "            d_x: d_x_gpu,\n",
    "            d_scale: d_scale_gpu,\n",
    "            d_bias: d_bias_gpu,\n",
    "        },\n",
    "        workspace,\n",
    "        handle=handle,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare results and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cudnn.backend_version_string() >= \"9.1.0\":\n",
    "    torch.cuda.synchronize()\n",
    "\n",
    "    torch.testing.assert_close(x_gpu.grad, d_x_gpu, atol=2e-4, rtol=2e-4)\n",
    "    torch.testing.assert_close(scale_gpu.grad, d_scale_gpu, atol=2e-4, rtol=2e-4)\n",
    "    torch.testing.assert_close(bias_gpu.grad, d_bias_gpu, atol=2e-4, rtol=2e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perform Cleanup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cudnn.backend_version_string() >= \"9.1.0\":\n",
    "    cudnn.destroy_handle(handle)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
