{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b687b169-ec83-493d-a7c5-f8c6cd402ea3",
   "metadata": {},
   "source": [
    "# Neural Tangent Kernels\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/neural_tangent_kernels.ipynb\">\n",
    "  <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>\n",
    "\n",
    "The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77f41c65-f070-4b60-b3d0-1c8f56ed4f64",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "First, some setup. Let's define a simple CNN that we wish to compute the NTK of."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "855fa70b-5b63-4973-94df-41be57ab6ecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from functorch import make_functional, vmap, vjp, jvp, jacrev\n",
    "device = 'cuda'\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 32, (3, 3))\n",
    "        self.conv2 = nn.Conv2d(32, 32, (3, 3))\n",
    "        self.conv3 = nn.Conv2d(32, 32, (3, 3))\n",
    "        self.fc = nn.Linear(21632, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = x.relu()\n",
    "        x = self.conv2(x)\n",
    "        x = x.relu()\n",
    "        x = self.conv3(x)\n",
    "        x = x.flatten(1)\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52c600e9-207a-41ec-93b4-5d940827bda0",
   "metadata": {},
   "source": [
    "And let's generate some random data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0001a907-f5c9-4532-9ee9-2e94b8487d08",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = torch.randn(20, 3, 32, 32, device=device)\n",
    "x_test = torch.randn(5, 3, 32, 32, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8af210fe-9613-48ee-a96c-d0836458b0f1",
   "metadata": {},
   "source": [
    "## Create a function version of the model\n",
    "\n",
    "functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n",
    "\n",
    "We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e6b4bb59-bdde-46cd-8a28-7fd00a37a387",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = CNN().to(device)\n",
    "fnet, params = make_functional(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "319276a4-da45-499a-af47-0677107559b6",
   "metadata": {},
   "source": [
    "Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0b8b4021-eb10-4a50-9d99-3817cb0ce4cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fnet_single(params, x):\n",
    "    return fnet(params, x.unsqueeze(0)).squeeze(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248",
   "metadata": {},
   "source": [
    "## Compute the NTK: method 1 (Jacobian contraction)\n",
    "\n",
    "We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n",
    "\n",
    "$$J_{net}(x_1) J_{net}^T(x_2)$$\n",
    "\n",
    "In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n",
    "\n",
    "The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "99a38a4b-64d3-4e13-bd63-2d71e8dd6840",
   "metadata": {},
   "outputs": [],
   "source": [
    "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n",
    "    # Compute J(x1)\n",
    "    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
    "    jac1 = [j.flatten(2) for j in jac1]\n",
    "    \n",
    "    # Compute J(x2)\n",
    "    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n",
    "    jac2 = [j.flatten(2) for j in jac2]\n",
    "    \n",
    "    # Compute J(x1) @ J(x2).T\n",
    "    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])\n",
    "    result = result.sum(0)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cbf54d2b-c4bc-46bd-9e55-e1471d639a4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20, 5, 10, 10])\n"
     ]
    }
   ],
   "source": [
    "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n",
    "print(result.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea844f45-98fb-4cba-8056-644292b968ab",
   "metadata": {},
   "source": [
    "In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It's easy to adjust the above function to do that:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "aae760c9-e906-4fda-b490-1126a86b7e96",
   "metadata": {},
   "outputs": [],
   "source": [
    "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n",
    "    # Compute J(x1)\n",
    "    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
    "    jac1 = [j.flatten(2) for j in jac1]\n",
    "    \n",
    "    # Compute J(x2)\n",
    "    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n",
    "    jac2 = [j.flatten(2) for j in jac2]\n",
    "    \n",
    "    # Compute J(x1) @ J(x2).T\n",
    "    einsum_expr = None\n",
    "    if compute == 'full':\n",
    "        einsum_expr = 'Naf,Mbf->NMab'\n",
    "    elif compute == 'trace':\n",
    "        einsum_expr = 'Naf,Maf->NM'\n",
    "    elif compute == 'diagonal':\n",
    "        einsum_expr = 'Naf,Maf->NMa'\n",
    "    else:\n",
    "        assert False\n",
    "        \n",
    "    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])\n",
    "    result = result.sum(0)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "42d974f3-1f9d-4953-8677-5ee22cfc67eb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20, 5])\n"
     ]
    }
   ],
   "source": [
    "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n",
    "print(result.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa",
   "metadata": {},
   "source": [
    "The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa",
   "metadata": {},
   "source": [
    "## Compute the NTK: method 2 (NTK-vector products)\n",
    "\n",
    "The next method we will discuss is a way to compute the NTK using NTK-vector products.\n",
    "\n",
    "This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n",
    "\n",
    "$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n",
    "where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n",
    "\n",
    "- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n",
    "- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n",
    "- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n",
    "\n",
    "This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n",
    "\n",
    "Let's code that up:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dc4b49d7-3096-45d5-a7a1-7032309a2613",
   "metadata": {},
   "outputs": [],
   "source": [
    "def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n",
    "    def get_ntk(x1, x2):\n",
    "        def func_x1(params):\n",
    "            return func(params, x1)\n",
    "\n",
    "        def func_x2(params):\n",
    "            return func(params, x2)\n",
    "\n",
    "        output, vjp_fn = vjp(func_x1, params)\n",
    "\n",
    "        def get_ntk_slice(vec):\n",
    "            # This computes vec @ J(x2).T\n",
    "            # `vec` is some unit vector (a single slice of the Identity matrix)\n",
    "            vjps = vjp_fn(vec)\n",
    "            # This computes J(X1) @ vjps\n",
    "            _, jvps = jvp(func_x2, (params,), vjps)\n",
    "            return jvps\n",
    "\n",
    "        # Here's our identity matrix\n",
    "        basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)\n",
    "        return vmap(get_ntk_slice)(basis)\n",
    "        \n",
    "    # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n",
    "    # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n",
    "    # we actually wish to compute the NTK between every pair of data points\n",
    "    # between {x1} and {x2}. That's what the vmaps here do.\n",
    "    result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n",
    "    \n",
    "    if compute == 'full':\n",
    "        return result\n",
    "    if compute == 'trace':\n",
    "        return torch.einsum('NMKK->NM', result)\n",
    "    if compute == 'diagonal':\n",
    "        return torch.einsum('NMKK->NMK', result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f750544f-9e48-47fe-9f9b-e1b8ae49b245",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n",
    "result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n",
    "assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84253466-971d-4475-999c-fe3de6bd25b5",
   "metadata": {},
   "source": [
    "Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n",
    "\n",
    "The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
