{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Google Colab\uc5d0\uc11c \ub178\ud2b8\ubd81\uc744 \uc2e4\ud589\ud558\uc2e4 \ub54c\uc5d0\ub294 \n# https://tutorials.pytorch.kr/beginner/colab \ub97c \ucc38\uace0\ud558\uc138\uc694.\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Neural Tangent Kernels\n\nThe neural tangent kernel (NTK) is a kernel that describes\n[how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel).\nThere has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572).\nThis tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents)\n(see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details),\ndemonstrates how to easily compute this quantity using torch.func,\ncomposable function transforms for PyTorch.\n\n

#### Note

This tutorial requires PyTorch 2.0.0 or later.

\n\n## Setup\n\nFirst, some setup. Let's define a simple CNN that we wish to compute the NTK of.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\nfrom torch.func import functional_call, vmap, vjp, jvp, jacrev\ndevice = 'cuda'\n\nclass CNN(nn.Module):\n def __init__(self):\n super(CNN, self).__init__()\n self.conv1 = nn.Conv2d(3, 32, (3, 3))\n self.conv2 = nn.Conv2d(32, 32, (3, 3))\n self.conv3 = nn.Conv2d(32, 32, (3, 3))\n self.fc = nn.Linear(21632, 10)\n\n def forward(self, x):\n x = self.conv1(x)\n x = x.relu()\n x = self.conv2(x)\n x = x.relu()\n x = self.conv3(x)\n x = x.flatten(1)\n x = self.fc(x)\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And let's generate some random data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x_train = torch.randn(20, 3, 32, 32, device=device)\nx_test = torch.randn(5, 3, 32, 32, device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a function version of the model\n\ntorch.func transforms operate on functions. In particular, to compute the NTK,\nwe will need a function that accepts the parameters of the model and a single\ninput (as opposed to a batch of inputs!) and returns a single output.\n\nWe'll use torch.func.functional_call, which allows us to call an nn.Module\nusing different parameters/buffers, to help accomplish the first step.\n\nKeep in mind that the model was originally written to accept a batch of input\ndata points. In our CNN example, there are no inter-batch operations. That\nis, each data point in the batch is independent of other data points. With\nthis assumption in mind, we can easily generate a function that evaluates the\nmodel on a single data point:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "net = CNN().to(device)\n\n# Detaching the parameters because we won't be calling Tensor.backward().\nparams = {k: v.detach() for k, v in net.named_parameters()}\n\ndef fnet_single(params, x):\n return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute the NTK: method 1 (Jacobian contraction)\nWe're ready to compute the empirical NTK. The empirical NTK for two data\npoints $x_1$ and $x_2$ is defined as the matrix product between the Jacobian\nof the model evaluated at $x_1$ and the Jacobian of the model evaluated at\n$x_2$:\n\n\\begin{align}J_{net}(x_1) J_{net}^T(x_2)\\end{align}\n\nIn the batched case where $x_1$ is a batch of data points and $x_2$ is a\nbatch of data points, then we want the matrix product between the Jacobians\nof all combinations of data points from $x_1$ and $x_2$.\n\nThe first method consists of doing just that - computing the two Jacobians,\nand contracting them. Here's how to compute the NTK in the batched case:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n # Compute J(x1)\n jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n jac1 = jac1.values()\n jac1 = [j.flatten(2) for j in jac1]\n\n # Compute J(x2)\n jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n jac2 = jac2.values()\n jac2 = [j.flatten(2) for j in jac2]\n\n # Compute J(x1) @ J(x2).T\n result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])\n result = result.sum(0)\n return result\n\nresult = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\nprint(result.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In some cases, you may only want the diagonal or the trace of this quantity,\nespecially if you know beforehand that the network architecture results in an\nNTK where the non-diagonal elements can be approximated by zero. It's easy to\nadjust the above function to do that:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n # Compute J(x1)\n jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n jac1 = jac1.values()\n jac1 = [j.flatten(2) for j in jac1]\n\n # Compute J(x2)\n jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n jac2 = jac2.values()\n jac2 = [j.flatten(2) for j in jac2]\n\n # Compute J(x1) @ J(x2).T\n einsum_expr = None\n if compute == 'full':\n einsum_expr = 'Naf,Mbf->NMab'\n elif compute == 'trace':\n einsum_expr = 'Naf,Maf->NM'\n elif compute == 'diagonal':\n einsum_expr = 'Naf,Maf->NMa'\n else:\n assert False\n\n result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])\n result = result.sum(0)\n return result\n\nresult = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\nprint(result.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The asymptotic time complexity of this method is $N O [FP]$ (time to\ncompute the Jacobians) + $N^2 O^2 P$ (time to contract the Jacobians),\nwhere $N$ is the batch size of $x_1$ and $x_2$, $O$\nis the model's output size, $P$ is the total number of parameters, and\n$[FP]$ is the cost of a single forward pass through the model. See\nsection 3.2 in\n[Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720)\nfor details.\n\n## Compute the NTK: method 2 (NTK-vector products)\n\nThe next method we will discuss is a way to compute the NTK using NTK-vector\nproducts.\n\nThis method reformulates NTK as a stack of NTK-vector products applied to\ncolumns of an identity matrix $I_O$ of size $O\\times O$\n(where $O$ is the output size of the model):\n\n\\begin{align}J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},\\end{align}\n\nwhere $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix\n$I_O$.\n\n- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use\n a vector-Jacobian product to compute this.\n- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a\n Jacobian-vector product!\n- Finally, we can run the above computation in parallel over all\n columns $e_o$ of $I_O$ using vmap.\n\nThis suggests that we can use a combination of reverse-mode AD (to compute\nthe vector-Jacobian product) and forward-mode AD (to compute the\nJacobian-vector product) to compute the NTK.\n\nLet's code that up:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n def get_ntk(x1, x2):\n def func_x1(params):\n return func(params, x1)\n\n def func_x2(params):\n return func(params, x2)\n\n output, vjp_fn = vjp(func_x1, params)\n\n def get_ntk_slice(vec):\n # This computes vec @ J(x2).T\n # vec is some unit vector (a single slice of the Identity matrix)\n vjps = vjp_fn(vec)\n # This computes J(X1) @ vjps\n _, jvps = jvp(func_x2, (params,), vjps)\n return jvps\n\n # Here's our identity matrix\n basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)\n return vmap(get_ntk_slice)(basis)\n\n # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n # we actually wish to compute the NTK between every pair of data points\n # between {x1} and {x2}. That's what the vmaps here do.\n result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n\n if compute == 'full':\n return result\n if compute == 'trace':\n return torch.einsum('NMKK->NM', result)\n if compute == 'diagonal':\n return torch.einsum('NMKK->NMK', result)\n\nresult_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\nresult_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\nassert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our code for empirical_ntk_ntk_vps looks like a direct translation from\nthe math above! This showcases the power of function transforms: good luck\ntrying to write an efficient version of the above by only using\ntorch.autograd.grad.\n\nThe asymptotic time complexity of this method is $N^2 O [FP]$, where\n$N$ is the batch size of $x_1$ and $x_2$, $O$ is the\nmodel's output size, and $[FP]$ is the cost of a single forward pass\nthrough the model. Hence this method performs more forward passes through the\nnetwork than method 1, Jacobian contraction ($N^2 O$ instead of\n$N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$\nterm, where $P$ is the total number of model's parameters). Therefore,\nthis method is preferable when $O P$ is large relative to $[FP]$,\nsuch as fully-connected (not convolutional) models with many outputs $O$.\nMemory-wise, both methods should be comparable. See section 3.3 in\n[Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720)\nfor details.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }