Fuse Modules Recipe

This recipe demonstrates how to fuse a list of PyTorch modules into a single module and how to do the performance test to compare the fused model with its non-fused version.


Before quantization is applied to a model to reduce its size and memory footprint (see Quantization Recipe for details on quantization), the list of modules in the model may be fused first into a single module. Fusion is optional, but it may save on memory access, make the model run faster, and improve its accuracy.


PyTorch 1.6.0 or 1.7.0


Follow the steps below to fuse an example model, quantize it, script it, optimize it for mobile, save it and test it with the Android benchmark tool.

1. Define the Example Model

Use the same example model defined in the PyTorch Mobile Performance Recipes:

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

class AnnotatedConvBnReLUModel(torch.nn.Module):
    def __init__(self):
        super(AnnotatedConvBnReLUModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
        self.relu = torch.nn.ReLU(inplace=True)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)
        x = self.conv(x)
        x =
        x = self.relu(x)
        x = self.dequant(x)
        return x

2. Generate Two Models with and without fuse_modules

Add the following code below the model definition above and run the script:

model = AnnotatedConvBnReLUModel()

def prepare_save(model, fused):
    model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)
    torchscript_model = torch.jit.script(model)
    torchscript_model_optimized = optimize_for_mobile(torchscript_model), "" if not fused else "")

prepare_save(model, False)

model = AnnotatedConvBnReLUModel()
model_fused = torch.quantization.fuse_modules(model, [['bn', 'relu']], inplace=False)

prepare_save(model_fused, True)

The graphs of the original model and its fused version will be printed as follows:

  (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()

  (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn): BNReLU2d(
    (0): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU(inplace=True)
  (relu): Identity()
  (quant): QuantStub()
  (dequant): DeQuantStub()

In the second fused model output, the first item bn in the list is replaced with the fused module, and the rest of the modules (relu in this example) is replaced with identity. In addition, the non-fused and fused versions of the model and are generated.

3. Build the Android benchmark Tool

Get the PyTorch source and build the Android benchmark tool as follows:

git clone --recursive
cd pytorch
git submodule update --init --recursive

This will generate the Android benchmark binary speed_benchmark_torch in the build_android/bin folder.

4. Test Compare the Fused and Non-Fused Models

Connect your Android device, then copy speed_benchmark_torch and the model files and run the benchmark tool on them:

adb push build_android/bin/speed_benchmark_torch /data/local/tmp
adb push /data/local/tmp
adb push /data/local/tmp
adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/" --input_dims="1,3,224,224" --input_type="float"
adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/" --input_dims="1,3,224,224" --input_type="float"

The results from the last two commands should be like:

Main run finished. Microseconds per iter: 6189.07. Iters per second: 161.575


Main run finished. Microseconds per iter: 6216.65. Iters per second: 160.858

For this example model, there is no much performance difference between the fused and non-fused models. But the similar steps can be used to fuse and prepare a real deep model and test to see the performance improvement. Keep in mind that currently torch.quantization.fuse_modules only fuses the following sequence of modules:

  • conv, bn

  • conv, bn, relu

  • conv, relu

  • linear, relu

  • bn, relu

If any other sequence list is provided to the fuse_modules call, it will simply be ignored.

Learn More

See here for the official documentation of torch.quantization.fuse_modules.

더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!

이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동


다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동