Shortcuts

Model Freezing in TorchScript

In this tutorial, we introduce the syntax for model freezing in TorchScript. Freezing is the process of inlining Pytorch module parameters and attributes values into the TorchScript internal representation. Parameter and attribute values are treated as final values and they cannot be modified in the resulting Frozen module.

Basic Syntax

Model freezing can be invoked using API below:

torch.jit.freeze(mod : ScriptModule, names : str[]) -> SciptModule

Note the input module can either be the result of scripting or tracing. See https://tutorials.pytorch.kr/beginner/Intro_to_TorchScript_tutorial.html

Next, we demonstrate how freezing works using an example:

import torch, time

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = torch.nn.Dropout2d(0.25)
        self.dropout2 = torch.nn.Dropout2d(0.5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = torch.nn.functional.log_softmax(x, dim=1)
        return output

    @torch.jit.export
    def version(self):
        return 1.0

net = torch.jit.script(Net())
fnet = torch.jit.freeze(net)

print(net.conv1.weight.size())
print(net.conv1.bias)

try:
    print(fnet.conv1.bias)
    # without exception handling, prints:
    # RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
    # with name 'conv1'
except RuntimeError:
    print("field 'conv1' is inlined. It does not exist in 'fnet'")

try:
    fnet.version()
    # without exception handling, prints:
    # RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
    # with name 'version'
except RuntimeError:
    print("method 'version' is not deleted in fnet. Only 'forward' is preserved")

fnet2 = torch.jit.freeze(net, ["version"])

print(fnet2.version())

B=1
warmup = 1
iter = 1000
input = torch.rand(B, 1,28, 28)

start = time.time()
for i in range(warmup):
    net(input)
end = time.time()
print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)

start = time.time()
for i in range(warmup):
    fnet(input)
end = time.time()
print("Frozen   - Warm up time: {0:7.4f}".format(end-start), flush=True)

start = time.time()
for i in range(iter):
    input = torch.rand(B, 1,28, 28)
    net(input)
end = time.time()
print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)

start = time.time()
for i in range(iter):
    input = torch.rand(B, 1,28, 28)
    fnet2(input)
end = time.time()
print("Frozen    - Inference time: {0:5.2f}".format(end-start), flush =True)

On my machine, I measured the time:

  • Scripted - Warm up time: 0.0107

  • Frozen - Warm up time: 0.0048

  • Scripted - Inference: 1.35

  • Frozen - Inference time: 1.17

In our example, warm up time measures the first two runs. The frozen model is 50% faster than the scripted model. On some more complex models, we observed even higher speed up of warm up time. freezing achieves this speed up because it is doing some the work TorchScript has to do when the first couple runs are initiated.

Inference time measures inference execution time after the model is warmed up. Although we observed significant variation in execution time, the frozen model is often about 15% faster than the scripted model. When input is larger, we observe a smaller speed up because the execution is dominated by tensor operations.

Conclusion

In this tutorial, we learned about model freezing. Freezing is a useful technique to optimize models for inference and it also can significantly reduce TorchScript warmup time.

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2023, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동