{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to cudnn frontend python API\n",
    "This notebook is an introduction to cudnn FE graph python API and how to perform a single fprop convolution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/00_introduction.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites for running on Colab\n",
    "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(cudnn.backend_version())\n",
    "\n",
    "handle = cudnn.create_handle()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`backend_version()` prints the cudnn backend version. Eg. 90000\n",
    "\n",
    "`handle` is a pointer to an opaque structure holding the cuDNN library context. The cuDNN library context must be created using `create_handle()` and the returned handle must be passed to all subsequent library function calls as needed. The context should be destroyed at the end using `destroy_handle()`. \n",
    "\n",
    "The context is associated with only one GPU device, the current device at the time of the call to `create_handle()`. However, multiple contexts can be created on the same GPU device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = cudnn.pygraph(\n",
    "    handle=handle,\n",
    "    name=\"cudnn_graph_0\",\n",
    "    io_data_type=cudnn.data_type.HALF,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pygraph` is the subgraph that is provided to the cudnn for execution.\n",
    "\n",
    "Each component in the graph has an optional `name` for future reference. (Optional)\n",
    "\n",
    "The `io_data_type` provides the data type of the input and output tensors of the graph. This can be overridden by actual tensor data type. (Optional)\n",
    "\n",
    "The `compute_data_type` provides the data type in which computation will happen. This can be overridden by actual compute data type of the individual operation. (Optional)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = graph.tensor(\n",
    "    name=\"X\",\n",
    "    dim=[8, 64, 56, 56],\n",
    "    stride=[56 * 56 * 64, 1, 56 * 64, 64],\n",
    "    data_type=cudnn.data_type.HALF,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "W = graph.tensor(name=\"W\", dim=[32, 64, 3, 3], stride=[3 * 3 * 64, 1, 3 * 64, 64])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`graph.tensor` creates an entry edge to the graph. The main attributes of the tensor class are `dim`, `stride` and `data_type`. \n",
    "Some other attributes are `is_virtual` (mainly used for interior nodes in graph), `is_pass_by_value` for scalar tensors.\n",
    "\n",
    "Note that the \"W\" tensor above did not have data_type. Its data type is deduced from the `pygraph.io_data_type` that was specified above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = graph.conv_fprop(\n",
    "    X,\n",
    "    W,\n",
    "    padding=[1, 1],\n",
    "    stride=[1, 1],\n",
    "    dilation=[1, 1],\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perform a `convolution forward` operation with padding as [1,1] on the input X tensor.\n",
    "\n",
    "Other parameters are `compute_data_type`, `stride`, `dilation`. See `help (cudnn.pygraph.conv_fprop)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y.set_output(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default the output of any operation is virtual (does not have device pointer associated). This is because the output can be fed as input to the next operation in graph. In order to terminate the graph, or to mark the tensor non-virtual we need to set the output. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.build([cudnn.heur_mode.A])\n",
    "# print(graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Following things happen in the above call\n",
    "- validation of inputs, outputs and output shape deduction.\n",
    "- Lowering pass into the cudnn dialect.\n",
    "- Heuristics query to determine which execution plan to run.\n",
    "- Runtime compilation of the plan if needed\n",
    "\n",
    "In following notebooks, we will see that this function gets split into its constituents to have a better control over each phase.\n",
    "\n",
    "Use the `print` function to inspect the graph after the shape and datatype deduction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "X_gpu = torch.randn(\n",
    "    8, 64, 56, 56, requires_grad=False, device=\"cuda\", dtype=torch.float16\n",
    ").to(memory_format=torch.channels_last)\n",
    "W_gpu = torch.randn(\n",
    "    32, 64, 3, 3, requires_grad=False, device=\"cuda\", dtype=torch.float16\n",
    ").to(memory_format=torch.channels_last)\n",
    "Y_gpu = torch.zeros(\n",
    "    8, 32, 3, 3, requires_grad=False, device=\"cuda\", dtype=torch.float16\n",
    ").to(memory_format=torch.channels_last)\n",
    "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we are using torch to create GPU tensor. Note that, cudnn FE supports any DLPack interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.execute({X: X_gpu, W: W_gpu, Y: Y_gpu}, workspace, handle=handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The execute call launches the kernel for execution on the GPU device."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fe_develop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
