{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Google Colab\uc5d0\uc11c \ub178\ud2b8\ubd81\uc744 \uc2e4\ud589\ud558\uc2e4 \ub54c\uc5d0\ub294 \n# https://tutorials.pytorch.kr/beginner/colab \ub97c \ucc38\uace0\ud558\uc138\uc694.\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# (\ubca0\ud0c0) FX\uc5d0\uc11c \ud569\uc131\uacf1/\ubc30\uce58 \uc815\uaddc\ud654(Convolution/Batch Norm) \uacb0\ud569\uae30(Fuser) \ub9cc\ub4e4\uae30\n\n**\uc800\uc790**: [Horace He](https://github.com/chillee)\n\n**\ubc88\uc5ed:** [\uc624\ucc2c\ud76c](https://github.com/kozeldark)\n\n\uc774 \ud29c\ud1a0\ub9ac\uc5bc\uc5d0\uc11c\ub294 PyTorch\uc758 \uad6c\uc131 \uac00\ub2a5\ud55c \ud568\uc218\uc758 \ubcc0\ud658\uc744 \uc704\ud55c \ud234\ud0b7\uc778 FX\ub97c \uc0ac\uc6a9\ud558\uc5ec \ub2e4\uc74c\uc744 \uc218\ud589\ud558\uace0\uc790 \ud569\ub2c8\ub2e4.\n\n1) \ub370\uc774\ud130 \uc758\uc874\uc131\uc5d0\uc11c \ud569\uc131\uacf1/\ubc30\uce58 \uc815\uaddc\ud654 \ud328\ud134\uc744 \ucc3e\uc2b5\ub2c8\ub2e4.\n2) 1\ubc88\uc5d0\uc11c \ubc1c\uacac\ub41c \ud328\ud134\uc758 \uacbd\uc6b0 \ubc30\uce58 \uc815\uaddc\ud654 \ud1b5\uacc4\ub97c \ud569\uc131\uacf1 \uac00\uc911\uce58\ub85c \uacb0\ud569\ud569\ub2c8\ub2e4(folding).\n\n\uc774 \ucd5c\uc801\ud654\ub294 \ucd94\ub860 \ubaa8\ub4dc(\uc989, `mode.eval()`)\uc758 \ubaa8\ub378\uc5d0\ub9cc \uc801\uc6a9\ub41c\ub2e4\ub294 \uc810\uc5d0 \uc720\uc758\ud558\uc138\uc694.\n\n\ub2e4\uc74c \ub9c1\ud06c\uc5d0 \uc788\ub294 \uacb0\ud569\uae30\ub97c \ub9cc\ub4e4 \uac83\uc785\ub2c8\ub2e4.\nhttps://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\uba87 \uac00\uc9c0\uc758 import \uacfc\uc815\uc744 \uba3c\uc800 \ucc98\ub9ac\ud574\uc90d\uc2dc\ub2e4(\ub098\uc911\uc5d0 \ucf54\ub4dc\uc5d0\uc11c \ubaa8\ub450 \uc0ac\uc6a9\ud560 \uac83\uc785\ub2c8\ub2e4).\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Type, Dict, Any, Tuple, Iterable\nimport copy\nimport torch.fx as fx\nimport torch\nimport torch.nn as nn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\uc774 \ud29c\ud1a0\ub9ac\uc5bc\uc5d0\uc11c\ub294 \ud569\uc131\uacf1\uacfc \ubc30\uce58 \uc815\uaddc\ud654\ub85c \uad6c\uc131\ub41c \ubaa8\ub378\uc744 \ub9cc\ub4e4 \uac83\uc785\ub2c8\ub2e4.\n\uc774 \ubaa8\ub378\uc5d0\ub294 \uc544\ub798\uc640 \uac19\uc740 \uae4c\ub2e4\ub85c\uc6b4 \uc694\uc18c\uac00 \uc788\uc2b5\ub2c8\ub2e4.\n\ud569\uc131\uacf1/\ubc30\uce58 \uc815\uaddc\ud654 \ud328\ud134 \uc911\uc758 \uc77c\ubd80\ub294 \uc2dc\ud000\uc2a4\uc5d0 \uc228\uaca8\uc838 \uc788\uace0\n\ubc30\uce58 \uc815\uaddc\ud654 \uc911 \ud558\ub098\ub294 \ub2e4\ub978 \ubaa8\ub4c8\ub85c \uac10\uc2f8\uc838 \uc788\uc2b5\ub2c8\ub2e4.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class WrappedBatchNorm(nn.Module):\n def __init__(self):\n super().__init__()\n self.mod = nn.BatchNorm2d(1)\n def forward(self, x):\n return self.mod(x)\n\nclass M(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(1, 1, 1)\n self.bn1 = nn.BatchNorm2d(1)\n self.conv2 = nn.Conv2d(1, 1, 1)\n self.nested = nn.Sequential(\n nn.BatchNorm2d(1),\n nn.Conv2d(1, 1, 1),\n )\n self.wrapped = WrappedBatchNorm()\n\n def forward(self, x):\n x = self.conv1(x)\n x = self.bn1(x)\n x = self.conv2(x)\n x = self.nested(x)\n x = self.wrapped(x)\n return x\n\nmodel = M()\n\nmodel.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## \ud569\uc131\uacf1\uacfc \ubc30\uce58 \uc815\uaddc\ud654 \uacb0\ud569\ud558\uae30\nPyTorch\uc5d0\uc11c \ud569\uc131\uacf1\uacfc \ubc30\uce58 \uc815\uaddc\ud654\ub97c \uc790\ub3d9\uc73c\ub85c \uacb0\ud569\ud558\ub824\uace0 \ud560 \ub54c \uac00\uc7a5 \ud070 \uc5b4\ub824\uc6c0 \uc911 \ud558\ub098\ub294\nPyTorch\uac00 \uacc4\uc0b0 \uadf8\ub798\ud504\uc5d0 \uc27d\uac8c \uc811\uadfc\ud560 \uc218 \uc788\ub294 \ubc29\ubc95\uc744 \uc81c\uacf5\ud558\uc9c0 \uc54a\ub294\ub2e4\ub294 \uac83\uc785\ub2c8\ub2e4.\nFX\ub294 \ud638\ucd9c\ub41c \uc2e4\uc81c \uc5f0\uc0b0\uc744 \uae30\ud638\uc801(symbolically)\uc73c\ub85c \ucd94\uc801\ud558\uc5ec \uc774 \ubb38\uc81c\ub97c \ud574\uacb0\ud558\ubbc0\ub85c \uc21c\ucc28\uc801 \ubaa8\ub4c8 \ub0b4\uc5d0 \uc911\ucca9\ub418\uac70\ub098\n\uc0ac\uc6a9\uc790 \uc815\uc758 \ubaa8\ub4c8\ub85c \uac10\uc2f8\uc9c4 `forward` \ud638\ucd9c\uc744 \ud1b5\ud574 \uacc4\uc0b0\uc744 \ucd94\uc801\ud560 \uc218 \uc788\uc2b5\ub2c8\ub2e4.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "traced_model = torch.fx.symbolic_trace(model)\nprint(traced_model.graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\uc774\ub807\uac8c \ud558\uba74 \ubaa8\ub378\uc744 \uadf8\ub798\ud504\ub85c \ub098\ud0c0\ub0bc \uc218 \uc788\uc2b5\ub2c8\ub2e4.\n\uc21c\ucc28\uc801 \ubaa8\ub4c8 \ubc0f \uac10\uc2f8\uc9c4 \ubaa8\ub4c8 \ub0b4\uc5d0 \uc228\uaca8\uc9c4 \ub450 \ubaa8\ub4c8\uc774 \ubaa8\ub450 \uadf8\ub798\ud504\uc5d0 \uc0bd\uc785\ub418\uc5b4 \uc788\uc2b5\ub2c8\ub2e4.\n\uc774\ub294 \uae30\ubcf8 \ucd94\uc0c1\ud654 \uc218\uc900\uc774\uc9c0\ub9cc \uc804\ub2ec \uae30\ub85d\uae30(pass writer)\uc5d0\uc11c \uad6c\uc131\ud560 \uc218 \uc788\uc2b5\ub2c8\ub2e4.\n\uc790\uc138\ud55c \ub0b4\uc6a9\uc740 \ub2e4\uc74c \ub9c1\ud06c\uc758 FX \uac1c\uc694\uc5d0\uc11c \ud655\uc778\ud560 \uc218 \uc788\uc2b5\ub2c8\ub2e4.\nhttps://pytorch.org/docs/master/fx.html#module-torch.fx\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## \ud569\uc131\uacf1\uacfc \ubc30\uce58 \uc815\uaddc\ud654 \uacb0\ud569\ud558\uae30\n\uc77c\ubd80 \ub2e4\ub978 \uacb0\ud569\uacfc \ub2ec\ub9ac, \ud569\uc131\uacf1\uacfc \ubc30\uce58 \uc815\uaddc\ud654\uc758 \uacb0\ud569\uc740 \uc0c8\ub85c\uc6b4 \uc5f0\uc0b0\uc790\ub97c \ud544\uc694\ub85c \ud558\uc9c0 \uc54a\uc2b5\ub2c8\ub2e4.\n\ub300\uc2e0, \ucd94\ub860 \uc911 \ubc30\uce58 \uc815\uaddc\ud654\ub294 \uc810\ubcc4 \ub367\uc148\uacfc \uacf1\uc148\uc73c\ub85c \uad6c\uc131\ub418\ubbc0\ub85c,\n\uc774\ub7ec\ud55c \uc5f0\uc0b0\uc740 \uc774\uc804 \ud569\uc131\uacf1\uc758 \uac00\uc911\uce58\ub85c \"\ubbf8\ub9ac \uacc4\uc0b0\ub418\uc5b4 \uc800\uc7a5(baked)\" \ub420 \uc218 \uc788\uc2b5\ub2c8\ub2e4.\n\uc774\ub97c \ud1b5\ud574 \ubc30\uce58 \uc815\uaddc\ud654\ub97c \ubaa8\ub378\uc5d0\uc11c \uc644\uc804\ud788 \uc81c\uac70\ud560 \uc218 \uc788\uc2b5\ub2c8\ub2e4!\n\uc790\uc138\ud55c \ub0b4\uc6a9\uc740 \ub2e4\uc74c \ub9c1\ud06c\uc5d0\uc11c \ud655\uc778 \ud560 \uc218 \uc788\uc2b5\ub2c8\ub2e4.\nhttps://nenadmarkus.com/p/fusing-batchnorm-and-conv/\n\uc774 \ucf54\ub4dc\ub294 \uba85\ud655\uc131\uc744 \uc704\ud574 \ub2e4\uc74c \ub9c1\ud06c\uc5d0\uc11c \ubcf5\uc0ac\ud55c \uac83\uc785\ub2c8\ub2e4.\nhttps://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def fuse_conv_bn_eval(conv, bn):\n \"\"\"\n \ud569\uc131\uacf1 \ubaa8\ub4c8 'A'\uc640 \ubc30\uce58 \uc815\uaddc\ud654 \ubaa8\ub4c8 'B'\uac00 \uc8fc\uc5b4\uc9c0\uba74\n C(x) == B(A(x))\ub97c \ub9cc\uc871\ud558\ub294 \ud569\uc131\uacf1 \ubaa8\ub4c8 'C'\ub97c \ucd94\ub860 \ubaa8\ub4dc\ub85c \ubc18\ud658\ud569\ub2c8\ub2e4.\n \"\"\"\n assert(not (conv.training or bn.training)), \"Fusion only for eval!\"\n fused_conv = copy.deepcopy(conv)\n\n fused_conv.weight, fused_conv.bias = \\\n fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,\n bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)\n\n return fused_conv\n\ndef fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):\n if conv_b is None:\n conv_b = torch.zeros_like(bn_rm)\n if bn_w is None:\n bn_w = torch.ones_like(bn_rm)\n if bn_b is None:\n bn_b = torch.zeros_like(bn_rm)\n bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)\n\n conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))\n conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b\n\n return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## FX \uacb0\ud569 \uc804\ub2ec(pass)\n\uc774\uc81c \ud569\uc131\uacf1\uacfc \ubc30\uce58 \uc815\uaddc\ud654\ub97c \uacb0\ud569\ud558\ub294 \ubc29\ubc95\ubfd0\ub9cc \uc544\ub2c8\ub77c \uacc4\uc0b0 \uadf8\ub798\ud504\ub3c4 \uc5bb\uc5c8\uc73c\ubbc0\ub85c\n\ub0a8\uc740 \uac83\uc740 FX \uadf8\ub798\ud504\uc5d0 \uc808\ucc28\ub97c \ubc18\ubcf5\ud558\uace0 \uc6d0\ud558\ub294 \uacb0\ud569\uc744 \uc801\uc6a9\ud558\ub294 \uac83\uc785\ub2c8\ub2e4.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def _parent_name(target : str) -> Tuple[str, str]:\n \"\"\"\n \uc815\uaddc\ud654\ub41c \uc774\ub984( ``qualname`` )\uc744 \ubd80\ubaa8 \uacbd\ub85c(parent path)\uc640 \ub9c8\uc9c0\ub9c9 \uc694\uc18c(last atom)\ub85c \ub098\ub220\uc90d\ub2c8\ub2e4.\n \uc608\ub97c \ub4e4\uc5b4, `foo.bar.baz` -> (`foo.bar`, `baz`)\n \"\"\"\n *parent, name = target.rsplit('.', 1)\n return parent[0] if parent else '', name\n\ndef replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):\n assert(isinstance(node.target, str))\n parent_name, name = _parent_name(node.target)\n setattr(modules[parent_name], name, new_module)\n\n\ndef fuse(model: torch.nn.Module) -> torch.nn.Module:\n model = copy.deepcopy(model)\n # \ub300\ubd80\ubd84\uc758 FX \uc804\ub2ec\uc758 \uccab \ubc88\uc9f8 \ub2e8\uacc4\ub294 `GraphModule` \uc744 \uc5bb\uae30 \uc704\ud574\n # \ubaa8\ub378\uc744 \uae30\ud638\uc801\uc73c\ub85c \ucd94\uc801\ud558\ub294 \uac83\uc785\ub2c8\ub2e4.\n # \uc774\uac83\uc740 \uc6d0\ub798 \ubaa8\ub378\uacfc \uae30\ub2a5\uc801\uc73c\ub85c \ub3d9\uc77c\ud55c \uc6d0\ub798 \ubaa8\ub378\uc758 \ud45c\ud604\uc785\ub2c8\ub2e4.\n # \ub2e8, \uc774\uc81c\ub294 \uc21c\uc804\ud30c \ub2e8\uacc4(forward pass)\uc5d0 \ub300\ud55c \uadf8\ub798\ud504 \ud45c\ud604\ub3c4 \uac00\uc9c0\uace0 \uc788\uc2b5\ub2c8\ub2e4.\n fx_model: fx.GraphModule = fx.symbolic_trace(model)\n modules = dict(fx_model.named_modules())\n\n # FX \uc791\uc5c5\uc744 \uc704\ud55c \uae30\ubcf8 \ud45c\ud604\uc740 `\uadf8\ub798\ud504(Graph)` \uc640 `\ub178\ub4dc(Node)` \uc785\ub2c8\ub2e4.\n # \uac01 `GraphModule` \uc5d0\ub294 \uc5f0\uad00\ub41c `\uadf8\ub798\ud504` \uac00 \uc788\uc2b5\ub2c8\ub2e4.\n # \uc774 `\uadf8\ub798\ud504` \ub294 `GraphModule.code` \ub97c \uc0dd\uc131\ud558\ub294 \uac83\uc774\uae30\ub3c4 \ud569\ub2c8\ub2e4.\n # `\uadf8\ub798\ud504` \uc790\uccb4\ub294 `\ub178\ub4dc` \uac1d\uccb4\uc758 \ubaa9\ub85d\uc73c\ub85c \ud45c\uc2dc\ub429\ub2c8\ub2e4.\n # \ub530\ub77c\uc11c \uadf8\ub798\ud504\uc758 \ubaa8\ub4e0 \uc791\uc5c5\uc744 \ubc18\ubcf5\ud558\uae30 \uc704\ud574 `\uadf8\ub798\ud504` \uc5d0\uc11c \uac01 `\ub178\ub4dc` \uc5d0 \ub300\ud574 \ubc18\ubcf5\ud569\ub2c8\ub2e4.\n for node in fx_model.graph.nodes:\n # FX IR \uc5d0\ub294 \uc77c\ubc18\uc801\uc73c\ub85c \ubaa8\ub4c8, \ud568\uc218 \ub610\ub294 \uba54\uc18c\ub4dc\uc5d0 \ub300\ud55c\n # \ud638\ucd9c \uc0ac\uc774\ud2b8\ub97c \ub098\ud0c0\ub0b4\ub294 \uc5ec\ub7ec \uc720\ud615\uc758 \ub178\ub4dc\uac00 \uc788\uc2b5\ub2c8\ub2e4.\n # \ub178\ub4dc\uc758 \uc720\ud615\uc740 `Node.op` \uc5d0 \uc758\ud574 \uacb0\uc815\ub429\ub2c8\ub2e4.\n if node.op != 'call_module': # \ud604\uc7ac \ub178\ub4dc\uac00 \ubaa8\ub4c8\uc744 \ud638\ucd9c\ud558\uc9c0 \uc54a\uc73c\uba74 \ubb34\uc2dc\ud560 \uc218 \uc788\uc2b5\ub2c8\ub2e4.\n continue\n # \ud638\ucd9c \uc0ac\uc774\ud2b8\uc758 \uacbd\uc6b0, `Node.target` \uc740 \ud638\ucd9c\ub418\ub294 \ubaa8\ub4c8/\ud568\uc218/\ubc29\ubc95\uc744 \ub098\ud0c0\ub0c5\ub2c8\ub2e4.\n # \uc5ec\uae30\uc11c\ub294 'Node.target' \uc744 \ud655\uc778\ud558\uc5ec \ubc30\uce58 \uc815\uaddc\ud654 \ubaa8\ub4c8\uc778\uc9c0 \ud655\uc778\ud55c \ub2e4\uc74c\n # `Node.args[0].target` \uc744 \ud655\uc778\ud558\uc5ec \uc785\ub825 `\ub178\ub4dc` \uac00 \ud569\uc131\uacf1\uc778\uc9c0 \ud655\uc778\ud569\ub2c8\ub2e4.\n if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:\n if len(node.args[0].users) > 1: # \ud569\uc131\uacf1 \ucd9c\ub825\uc740 \ub2e4\ub978 \ub178\ub4dc\uc5d0\uc11c \uc0ac\uc6a9\ub429\ub2c8\ub2e4.\n continue\n conv = modules[node.args[0].target]\n bn = modules[node.target]\n fused_conv = fuse_conv_bn_eval(conv, bn)\n replace_node_module(node.args[0], modules, fused_conv)\n # \ubc30\uce58 \uc815\uaddc\ud654\ub97c \ud569\uc131\uacf1\uc73c\ub85c \uacb0\ud569\ud588\uae30 \ub54c\ubb38\uc5d0\n # \ubc30\uce58 \uc815\uaddc\ud654\uc758 \uc0ac\uc6a9\uc744 \ubaa8\ub450 \ud569\uc131\uacf1\uc73c\ub85c \uad50\uccb4\ud574\uc57c \ud569\ub2c8\ub2e4.\n node.replace_all_uses_with(node.args[0])\n # \ubc30\uce58 \uc815\uaddc\ud654 \uc0ac\uc6a9\uc744 \ubaa8\ub450 \uad50\uccb4\ud588\uc73c\ubbc0\ub85c\n # \uc548\uc804\ud558\uac8c \ubc30\uce58 \uc815\uaddc\ud654\ub97c \uc81c\uac70\ud560 \uc218 \uc788\uc2b5\ub2c8\ub2e4.\n fx_model.graph.erase_node(node)\n fx_model.graph.lint()\n # \uadf8\ub798\ud504\ub97c \uc218\uc815\ud55c \ud6c4\uc5d0\ub294 \uc0dd\uc131\ub41c \ucf54\ub4dc\ub97c \ub3d9\uae30\ud654\ud558\uae30 \uc704\ud574 \uadf8\ub798\ud504\ub97c \ub2e4\uc2dc \ucef4\ud30c\uc77c\ud574\uc57c \ud569\ub2c8\ub2e4.\n fx_model.recompile()\n return fx_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\uc5ec\uae30\uc11c\ub294 2D \ud569\uc131\uacf1\ub9cc \uc77c\uce58\uc2dc\ud0a4\ub294 \ub4f1 \uc2dc\uc5f0 \ubaa9\uc801\uc73c\ub85c \uc57d\uac04\uc758 \ub2e8\uc21c\ud654\ub97c \ud558\uc600\uc2b5\ub2c8\ub2e4.\n \ub354 \uc720\uc6a9\ud55c \uc804\ub2ec\uc740 \ub2e4\uc74c \ub9c1\ud06c\ub97c \ucc38\uc870\ud558\uc2ed\uc2dc\uc624.\n https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py