{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Matrix multiplication operation with fused bias using cudnn FE\n",
    "This notebook shows how a matmul operation with fused bias can be done using cudnn."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites for running on Colab\n",
    "This notebook requires an NVIDIA GPU H100 or newer. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup\n",
    "We are going to call the cudnn through torch in this example. In general any dlpack tensor should work.\n",
    "cudnn handle is a per device handle used to initialize cudnn context.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "StopExecution",
     "evalue": "",
     "output_type": "error",
     "traceback": []
    }
   ],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "handle = cudnn.create_handle()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create input tensors and calculate reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch, m, n, k = 16, 128, 128, 512\n",
    "\n",
    "input_type = torch.float16\n",
    "\n",
    "# input tensors\n",
    "a = torch.randn(batch, m, k, dtype=input_type, device=\"cuda\")\n",
    "b = torch.randn(batch, k, n, dtype=input_type, device=\"cuda\")\n",
    "B = torch.randn(1, m, n, dtype=torch.float16, device=\"cuda\")\n",
    "\n",
    "# reference output\n",
    "c_ref = torch.matmul(a, b) + B\n",
    "\n",
    "# place holder for cudnn output\n",
    "c = torch.randn_like(c_ref, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create cudnn graph and tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = cudnn.pygraph(\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "a_cudnn_tensor = graph.tensor_like(a)\n",
    "b_cudnn_tensor = graph.tensor_like(b)\n",
    "bias_cudnn_tensor = graph.tensor_like(B)\n",
    "\n",
    "c_intermediate = graph.matmul(name=\"matmul\", A=a_cudnn_tensor, B=b_cudnn_tensor)\n",
    "\n",
    "c_cudnn_tensor = graph.bias(name=\"bias\", input=c_intermediate, bias=bias_cudnn_tensor)\n",
    "\n",
    "c_cudnn_tensor.set_name(\"c\").set_output(True).set_data_type(cudnn.data_type.HALF)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Build the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.validate()\n",
    "graph.build_operation_graph()\n",
    "graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "graph.check_support()\n",
    "graph.build_plans()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Execute the code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_pack = {\n",
    "    a_cudnn_tensor: a,\n",
    "    b_cudnn_tensor: b,\n",
    "    c_cudnn_tensor: c,\n",
    "    bias_cudnn_tensor: B,\n",
    "}\n",
    "\n",
    "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n",
    "graph.execute(variant_pack, workspace)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.testing.assert_close(c, c_ref, rtol=5e-3, atol=5e-3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "build_thunder",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
