{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how to compute a batchnorm forward operation using cuDNN.\n",
    "\n",
    "$$\\text{BatchNorm}(x) = \\frac{x-\\mu}{\\sqrt{\\sigma^2 + \\epsilon}}\\cdot\\gamma+\\beta$$\n",
    "\n",
    "Where $\\mu = E[x]$ and $\\sigma^2 = Var[x]$ are taken over all inputs in a channel."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites and Setup\n",
    "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup\n",
    "Create a cudnn handle, which is a per device handle used to initialize cudnn context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running with cudnn backend version: 90400\n"
     ]
    }
   ],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "torch.manual_seed(1)\n",
    "handle = cudnn.create_handle()\n",
    "\n",
    "print(\"Running with cudnn backend version:\", cudnn.backend_version())\n",
    "\n",
    "assert torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batchnorm Training Forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch size, channel size, height, width\n",
    "n, c, h, w = 4, 16, 56, 56\n",
    "input_type = torch.float16\n",
    "\n",
    "# Epsilon is a small number to prevent division by 0.\n",
    "epsilon_value = 1e-3\n",
    "# Momentum value is used in computing running stats during training where\n",
    "# running_mean_next = (1 - momentum) * running_mean + momentum * local_mean\n",
    "momentum_value = 1e-1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create input and output tensor buffers in PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input tensors\n",
    "x_gpu = torch.randn(n, c, h, w, dtype=input_type, device=\"cuda\")\n",
    "x_gpu = x_gpu.to(memory_format=torch.channels_last)\n",
    "scale_gpu = torch.randn(1, c, 1, 1, device=\"cuda\")\n",
    "bias_gpu = torch.randn_like(scale_gpu)\n",
    "running_mean_gpu = torch.randn_like(scale_gpu)\n",
    "running_var_gpu = torch.randn_like(scale_gpu)\n",
    "\n",
    "comparison_gpu = torch.zeros_like(x_gpu, dtype=input_type, device=\"cuda\")\n",
    "\n",
    "epsilon_cpu = torch.full((1, 1, 1, 1), epsilon_value)\n",
    "momentum_cpu = torch.full((1, 1, 1, 1), momentum_value)\n",
    "\n",
    "# output tensors\n",
    "saved_mean_gpu = torch.empty_like(running_mean_gpu, device=\"cuda\")\n",
    "saved_inv_var_gpu = torch.empty_like(running_var_gpu, device=\"cuda\")\n",
    "y_gpu = torch.empty_like(x_gpu, dtype=input_type, device=\"cuda\")\n",
    "mask_gpu = torch.empty_like(x_gpu, dtype=torch.bool, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create cuDNN graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = cudnn.pygraph(\n",
    "    io_data_type=cudnn.data_type.HALF,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    "    handle=handle,\n",
    ")\n",
    "\n",
    "x = graph.tensor_like(x_gpu)\n",
    "scale = graph.tensor_like(scale_gpu)\n",
    "bias = graph.tensor_like(bias_gpu)\n",
    "\n",
    "in_running_mean = graph.tensor_like(running_mean_gpu)\n",
    "in_running_var = graph.tensor_like(running_var_gpu)\n",
    "epsilon = graph.tensor_like(epsilon_cpu)\n",
    "momentum = graph.tensor_like(momentum_cpu)\n",
    "comparison = graph.tensor_like(x_gpu)\n",
    "\n",
    "y_before_relu, saved_mean, saved_inv_var, out_running_mean, out_running_var = (\n",
    "    graph.batchnorm(\n",
    "        name=\"BN\",\n",
    "        input=x,\n",
    "        scale=scale,\n",
    "        bias=bias,\n",
    "        in_running_mean=in_running_mean,\n",
    "        in_running_var=in_running_var,\n",
    "        epsilon=epsilon,\n",
    "        momentum=momentum,\n",
    "    )\n",
    ")\n",
    "y = graph.relu(name=\"relu\", input=y_before_relu)\n",
    "mask = graph.cmp_gt(name=\"cmp\", input=y, comparison=comparison)\n",
    "\n",
    "y.set_output(True)\n",
    "saved_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n",
    "saved_inv_var.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n",
    "out_running_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n",
    "out_running_var.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n",
    "mask.set_output(True).set_data_type(cudnn.data_type.BOOLEAN)\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.validate()\n",
    "graph.build_operation_graph()\n",
    "graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "graph.check_support()\n",
    "graph.build_plans()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_pack = {\n",
    "    x: x_gpu,\n",
    "    scale: scale_gpu,\n",
    "    bias: bias_gpu,\n",
    "    in_running_mean: running_mean_gpu,\n",
    "    in_running_var: running_var_gpu,\n",
    "    epsilon: epsilon_cpu,\n",
    "    momentum: momentum_cpu,\n",
    "    out_running_mean: running_mean_gpu,\n",
    "    out_running_var: running_var_gpu,\n",
    "    saved_mean: saved_mean_gpu,\n",
    "    saved_inv_var: saved_inv_var_gpu,\n",
    "    y: y_gpu,\n",
    "    comparison: comparison_gpu,\n",
    "    mask: mask_gpu,\n",
    "}\n",
    "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n",
    "graph.execute(\n",
    "    variant_pack,\n",
    "    workspace,\n",
    "    handle=handle,\n",
    ")\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test cuDNN's output against PyTorch's and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_ref = x_gpu.clone().float()\n",
    "running_mean_ref = running_mean_gpu.clone().float()\n",
    "running_var_ref = running_var_gpu.clone().float()\n",
    "\n",
    "y_before_relu_ref = torch.nn.functional.batch_norm(\n",
    "    x_ref,\n",
    "    running_mean_ref,  # running_mean is both input and output\n",
    "    running_var_ref,  # running_var is both input and output\n",
    "    weight=scale_gpu,\n",
    "    bias=bias_gpu,\n",
    "    training=True,\n",
    "    momentum=momentum_cpu.item(),\n",
    "    eps=epsilon_cpu.item(),\n",
    ")\n",
    "\n",
    "mean_ref = torch.mean(x_ref, dim=(0, 2, 3), keepdim=True)\n",
    "inv_var_ref = torch.var(x_ref, dim=(0, 2, 3), keepdim=True)\n",
    "inv_var_ref = torch.rsqrt(inv_var_ref + epsilon_value)\n",
    "y_ref = torch.relu(y_before_relu_ref)\n",
    "mask_ref = y_ref > 0\n",
    "\n",
    "torch.testing.assert_close(y_ref, y_gpu.float(), atol=1e-3, rtol=1e-3)\n",
    "torch.testing.assert_close(mean_ref, saved_mean_gpu.float(), atol=1e-3, rtol=1e-3)\n",
    "torch.testing.assert_close(inv_var_ref, saved_inv_var_gpu.float(), atol=1e-3, rtol=1e-3)\n",
    "# torch.testing.assert_close(mask_ref, mask_gpu.float(), atol=1e-3, rtol=1e-3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
