{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how to compute an instance Norm using the cuDNN python frontend.\n",
    "\n",
    "$$\\text{InstanceNorm(x)} = \\frac{x - \\mathbb{E}(x)}{\\sqrt{Var(x)+\\epsilon}}\\cdot \\gamma + \\beta$$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites and Setup\n",
    "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup\n",
    "The cudnn handle is a per device handle used to initialize cudnn context.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "torch.manual_seed(0)\n",
    "handle = cudnn.create_handle()\n",
    "\n",
    "print(\"Running with cudnn backend version:\", cudnn.backend_version())\n",
    "\n",
    "assert torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### InstanceNorm Reference Computation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "#### Problem Sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N, C, H, W = 16, 32, 64, 64\n",
    "\n",
    "input_type = torch.float16\n",
    "\n",
    "# Epsilon is a small number to prevent division by 0.\n",
    "epsilon_value = 1e-5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create input tensor GPU buffers. We use PyTorch to allocate GPU tensors so we can reuse them easily when we calculate reference outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input tensor memory, initialize them to random numbers\n",
    "x_gpu = torch.randn(\n",
    "    (N, C, H, W), dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    ").to(memory_format=torch.channels_last)\n",
    "\n",
    "scale_gpu = torch.randn(\n",
    "    (1, C, 1, 1), dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    ").to(memory_format=torch.channels_last)\n",
    "\n",
    "bias_gpu = torch.randn(\n",
    "    (1, C, 1, 1), dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    ").to(memory_format=torch.channels_last)\n",
    "\n",
    "# set epsilon to epsilon_value, allocate on cpu.\n",
    "epsilon_cpu = torch.full(\n",
    "    (1, 1, 1, 1), epsilon_value, dtype=torch.float32, requires_grad=False, device=\"cpu\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute reference ouputs and allocate output tensor GPU buffers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we create the reference computation outputs here so we can use .empty_like() to create our output buffers\n",
    "out_expected = torch.nn.functional.instance_norm(\n",
    "    x_gpu, weight=scale_gpu.view(C), bias=bias_gpu.view(C)\n",
    ")\n",
    "\n",
    "mean_expected = x_gpu.to(torch.float32).mean(dim=(2, 3), keepdim=True)\n",
    "\n",
    "inv_var_expected = torch.rsqrt(\n",
    "    torch.var(x_gpu.to(torch.float32), dim=(2, 3), keepdim=True) + epsilon_value\n",
    ")\n",
    "\n",
    "\n",
    "# allocate output tensor memory using PyTorch\n",
    "# PyTorch has calculated their shapes already, so we can simply use .empty_like()\n",
    "\n",
    "# Comparing x_gpu and out_expected. The types are the same, the dimensions are the same.\n",
    "# The only difference is grad_fn=<ToCopyBackward0> (x_gpu) vs grad_fn=<ViewBackward0> (out_expected)\n",
    "out_gpu = torch.empty_like(\n",
    "    x_gpu\n",
    ") \n",
    "mean_gpu = torch.empty_like(mean_expected)\n",
    "inv_var_gpu = torch.empty_like(inv_var_expected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create cuDNN graph and tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create cuDNN graph\n",
    "graph = cudnn.pygraph(\n",
    "    handle=handle,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "# create tensor handles with the graph API\n",
    "x = graph.tensor_like(x_gpu.detach()).set_name(\"X\")\n",
    "scale = graph.tensor_like(scale_gpu.detach()).set_name(\"scale\")\n",
    "bias = graph.tensor_like(bias_gpu.detach()).set_name(\"bias\")\n",
    "epsilon = graph.tensor_like(epsilon_cpu).set_name(\"epsilon\")\n",
    "\n",
    "(out, mean, inv_var) = graph.instancenorm(\n",
    "    name=\"instancenorm\",\n",
    "    input=x,\n",
    "    norm_forward_phase=cudnn.norm_forward_phase.TRAINING,\n",
    "    scale=scale,\n",
    "    bias=bias,\n",
    "    epsilon=epsilon,\n",
    ")\n",
    "\n",
    "# enable all outputs\n",
    "out.set_name(\"output\").set_output(True).set_data_type(out_expected.dtype)\n",
    "mean.set_name(\"mean\").set_output(True).set_data_type(mean_expected.dtype)\n",
    "inv_var.set_name(\"inv_var\").set_output(True).set_data_type(inv_var_expected.dtype);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Build the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the graph\n",
    "graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "\n",
    "# To run this block more than once, we need to re-run the previous block to get a new graph.\n",
    "# The same instance of a graph should not be built twice."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Execute the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mapping of (handles -> memory)\n",
    "variant_pack = {\n",
    "    x: x_gpu.detach(),\n",
    "    scale: scale_gpu.detach(),\n",
    "    bias: bias_gpu.detach(),\n",
    "    epsilon: epsilon_cpu,\n",
    "    out: out_gpu,\n",
    "    mean: mean_gpu,\n",
    "    inv_var: inv_var_gpu,\n",
    "}\n",
    "\n",
    "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n",
    "\n",
    "graph.execute(variant_pack, workspace)\n",
    "\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test cuDNN's output against PyTorch's and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reference output\n",
    "torch.testing.assert_close(out_gpu, out_expected, rtol=1e-2, atol=1e-2)\n",
    "torch.testing.assert_close(inv_var_gpu, inv_var_expected, rtol=1e-2, atol=1e-2)\n",
    "torch.testing.assert_close(mean_gpu, mean_expected, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perform Cleanup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cudnn.destroy_handle(handle)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
