참고
Click here to download the full example code
배포를 위한 비전 트랜스포머(Vision Transformer) 모델 최적화하기¶
Authors : Jeff Tang, Geeta Chauhan 번역 : 김태영
비전 트랜스포머(Vision Transformer)는 자연어 처리 분야에서 소개된 최고 수준의 결과를 달성한 최신의 어텐션 기반(attention-based) 트랜스포머 모델을 컴퓨터 비전 분야에 적용을 한 모델입니다. FaceBook에서 발표한 Data-efficient Image Transformers는 DeiT 이미지 분류를 위해 ImageNet 데이터셋을 통해 훈련된 비전 트랜스포머 모델입니다.
이번 튜토리얼에서는, DeiT가 무엇인지 그리고 어떻게 사용하는지 다룰 것입니다. 그 다음 스크립팅, 양자화, 최적화, 그리고 iOS와 안드로이드 앱 안에서 모델을 사용하는 전체적인 단계를 수행해 볼 것입니다. 또한, 양자화와 최적화가 된 모델과 양자화와 최적화가 되지 않은 모델을 비교해 볼 것이며, 단계를 수행해 가면서 양자화와 최적화를 적용한 모델이 얼마나 이점을 가지는지 볼 것입니다.
DeiT란 무엇인가¶
합성곱 신경망(CNNs)은 2012년 딥러닝이 시작된 이후 이미지 분류를 수행할 때 주요한 모델이였습니다. 그러나 합성곱 신경망은 일반적으로 최첨단의 결과를 달성하기 위해 훈련에 수억 개의 이미지가 필요했습니다. DeiT는 훈련에 더 적은 데이터와 컴퓨팅 자원을 필요로 하는 비전 트랜스포머 모델이며, 최신 CNN 모델과 이미지 분류를 수행하는데 경쟁을 합니다. 이는 DeiT의 두 가지 주요 구성 요소에 의해 가능하게 되었습니다.
훨씬 더 큰 데이터 세트에 대한 훈련을 시뮬레이션하는 데이터 증강(augmentation)
트랜스포머 네트워크에 CNN의 출력값을 그대로 증류(distillation)하여 학습할 수 있도록 하는 기법
DeiT는 제한된 데이터와 자원을 활용하여 컴퓨터 비전 태스크(task)에 트랜스포머 모델을 성공적으로 적용할 수 있음을 보여줍니다. DeiT의 좀 더 자세한 내용을 원한다면, 저장소 와 논문 을 참고하시길 바랍니다.
DeiT를 활용한 이미지 분류¶
DeiT를 사용하여 이미지를 분류하는 방법에 대한 자세한 정보는 DeiT 저장소에 README를 참고하시길 바랍니다. 빠른 테스트를 위해서, 먼저 필요한 패키지들을 설치합니다:
pip install torch torchvision timm pandas requests
Google Colab에서는 아래와 같이 실행합니다:
# !pip install timm pandas requests
그런 다음 아래 스크립트를 실행합니다:
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# Pytorch 버전은 1.8.0 이어야 합니다.
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
1.13.1+cu117
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
0%| | 0.00/330M [00:00<?, ?B/s]
0%| | 24.0k/330M [00:00<30:31, 189kB/s]
0%| | 56.0k/330M [00:00<25:31, 226kB/s]
0%| | 120k/330M [00:00<16:25, 351kB/s]
0%| | 256k/330M [00:00<09:03, 636kB/s]
0%| | 544k/330M [00:00<04:43, 1.22MB/s]
0%| | 1.09M/330M [00:00<02:28, 2.33MB/s]
1%| | 2.20M/330M [00:00<01:16, 4.50MB/s]
1%|1 | 4.42M/330M [00:01<00:39, 8.74MB/s]
3%|2 | 8.27M/330M [00:01<00:21, 15.4MB/s]
4%|3 | 12.0M/330M [00:01<00:15, 21.4MB/s]
4%|4 | 14.1M/330M [00:01<00:15, 21.3MB/s]
5%|5 | 17.4M/330M [00:01<00:14, 23.4MB/s]
6%|6 | 21.4M/330M [00:01<00:11, 28.4MB/s]
7%|7 | 24.2M/330M [00:01<00:11, 27.4MB/s]
8%|8 | 27.5M/330M [00:01<00:10, 29.1MB/s]
9%|9 | 30.3M/330M [00:01<00:10, 28.9MB/s]
10%|# | 33.5M/330M [00:02<00:10, 30.1MB/s]
11%|#1 | 36.4M/330M [00:02<00:10, 29.7MB/s]
12%|#1 | 39.3M/330M [00:02<00:10, 28.1MB/s]
13%|#3 | 43.0M/330M [00:02<00:10, 28.7MB/s]
14%|#4 | 46.8M/330M [00:02<00:10, 29.1MB/s]
15%|#5 | 50.7M/330M [00:02<00:09, 29.5MB/s]
17%|#6 | 54.6M/330M [00:02<00:09, 29.6MB/s]
18%|#7 | 58.5M/330M [00:02<00:09, 29.8MB/s]
19%|#8 | 62.3M/330M [00:03<00:09, 30.1MB/s]
20%|#9 | 65.9M/330M [00:03<00:09, 29.3MB/s]
21%|##1 | 69.8M/330M [00:03<00:09, 29.8MB/s]
22%|##2 | 73.7M/330M [00:03<00:08, 30.0MB/s]
23%|##3 | 77.6M/330M [00:03<00:08, 29.9MB/s]
25%|##4 | 81.5M/330M [00:03<00:08, 30.2MB/s]
26%|##5 | 85.4M/330M [00:03<00:08, 30.4MB/s]
27%|##7 | 89.3M/330M [00:04<00:08, 30.5MB/s]
28%|##8 | 93.2M/330M [00:04<00:08, 30.5MB/s]
29%|##9 | 96.8M/330M [00:04<00:08, 29.9MB/s]
30%|### | 101M/330M [00:04<00:08, 30.0MB/s]
32%|###1 | 105M/330M [00:04<00:07, 30.1MB/s]
33%|###2 | 108M/330M [00:04<00:07, 30.2MB/s]
34%|###4 | 112M/330M [00:04<00:07, 30.2MB/s]
35%|###5 | 116M/330M [00:04<00:07, 30.2MB/s]
36%|###6 | 120M/330M [00:05<00:07, 30.0MB/s]
38%|###7 | 124M/330M [00:05<00:07, 30.3MB/s]
39%|###8 | 128M/330M [00:05<00:07, 30.2MB/s]
40%|###9 | 132M/330M [00:05<00:06, 29.9MB/s]
41%|####1 | 136M/330M [00:05<00:06, 30.6MB/s]
42%|####2 | 139M/330M [00:05<00:06, 30.5MB/s]
43%|####3 | 143M/330M [00:05<00:06, 30.3MB/s]
45%|####4 | 147M/330M [00:06<00:06, 30.3MB/s]
46%|####5 | 151M/330M [00:06<00:06, 30.4MB/s]
47%|####6 | 155M/330M [00:06<00:06, 30.0MB/s]
48%|####8 | 159M/330M [00:06<00:05, 30.1MB/s]
49%|####9 | 163M/330M [00:06<00:05, 30.0MB/s]
50%|##### | 166M/330M [00:06<00:05, 31.0MB/s]
51%|#####1 | 170M/330M [00:06<00:04, 34.3MB/s]
52%|#####2 | 173M/330M [00:06<00:05, 32.3MB/s]
53%|#####3 | 176M/330M [00:07<00:05, 30.2MB/s]
54%|#####4 | 179M/330M [00:07<00:05, 29.1MB/s]
55%|#####5 | 182M/330M [00:07<00:05, 28.0MB/s]
56%|#####6 | 186M/330M [00:07<00:05, 28.4MB/s]
58%|#####7 | 190M/330M [00:07<00:05, 29.3MB/s]
59%|#####8 | 194M/330M [00:07<00:04, 29.5MB/s]
60%|#####9 | 198M/330M [00:07<00:04, 29.7MB/s]
61%|######1 | 202M/330M [00:07<00:04, 30.0MB/s]
62%|######2 | 206M/330M [00:08<00:04, 30.0MB/s]
63%|######3 | 209M/330M [00:08<00:04, 30.1MB/s]
65%|######4 | 213M/330M [00:08<00:04, 30.1MB/s]
66%|######5 | 217M/330M [00:08<00:03, 30.0MB/s]
67%|######6 | 221M/330M [00:08<00:03, 30.2MB/s]
68%|######7 | 225M/330M [00:08<00:03, 29.5MB/s]
69%|######9 | 228M/330M [00:08<00:03, 29.4MB/s]
70%|####### | 232M/330M [00:08<00:03, 29.7MB/s]
71%|#######1 | 236M/330M [00:09<00:03, 29.0MB/s]
73%|#######2 | 240M/330M [00:09<00:03, 29.0MB/s]
74%|#######3 | 244M/330M [00:09<00:03, 29.3MB/s]
75%|#######4 | 248M/330M [00:09<00:02, 29.7MB/s]
76%|#######6 | 251M/330M [00:09<00:02, 29.3MB/s]
77%|#######7 | 255M/330M [00:09<00:02, 29.1MB/s]
78%|#######8 | 259M/330M [00:09<00:02, 29.6MB/s]
80%|#######9 | 263M/330M [00:10<00:02, 29.7MB/s]
81%|######## | 267M/330M [00:10<00:02, 29.8MB/s]
82%|########2 | 271M/330M [00:10<00:02, 30.0MB/s]
83%|########3 | 275M/330M [00:10<00:01, 30.2MB/s]
84%|########4 | 279M/330M [00:10<00:01, 30.2MB/s]
86%|########5 | 283M/330M [00:10<00:01, 30.2MB/s]
87%|########6 | 286M/330M [00:10<00:01, 30.4MB/s]
88%|########7 | 290M/330M [00:11<00:01, 30.3MB/s]
89%|########9 | 294M/330M [00:11<00:01, 30.3MB/s]
90%|######### | 298M/330M [00:11<00:01, 29.8MB/s]
91%|#########1| 302M/330M [00:11<00:00, 30.0MB/s]
93%|#########2| 306M/330M [00:11<00:00, 29.7MB/s]
94%|#########3| 310M/330M [00:11<00:00, 29.8MB/s]
95%|#########4| 314M/330M [00:11<00:00, 29.9MB/s]
96%|#########6| 318M/330M [00:11<00:00, 29.8MB/s]
97%|#########7| 321M/330M [00:12<00:00, 29.9MB/s]
99%|#########8| 325M/330M [00:12<00:00, 29.8MB/s]
100%|#########9| 329M/330M [00:12<00:00, 30.0MB/s]
100%|##########| 330M/330M [00:12<00:00, 27.9MB/s]
/root/.local/lib/python3.9/site-packages/torchvision/transforms/transforms.py:329: UserWarning:
Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
269
ImageNet 목록에 따라 라벨(labels) 파일 클래스 인덱스의 출력은 269여야 하며, 이는 ‘timber wolf, grey wolf, gray wolf, Canis lupus’에 매핑됩니다.
이제 DeiT 모델을 사용하여 이미지들을 분류할 수 있음을 확인했습니다. iOS 및 Android 앱에서 실행할 수 있도록 모델을 수정하는 방법을 살펴보겠습니다.
DeiT 스크립팅¶
모바일에서 이 모델을 사용하려면, 우리는 첫번째로 모델 스크립팅이 필요합니다. 전체적인 개요는 스크립트 그리고 최적화 레시피 에서 확인할 수 있습니다. 아래 코드를 실행하여 이전 단계에서 사용한 DeiT 모델을 모바일에서 실행할 수 있는 TorchScript 형식으로 변환합니다.
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main
약 346MB 크기의 스크립팅된 모델 파일 fbdeit_scripted.pt가 생성됩니다.
DeiT 양자화¶
추론 정확도를 거의 동일하게 유지하면서 훈련된 모델 크기를 크게 줄이기 위해 모델에 양자화를 적용할 수 있습니다. DeiT에서 사용된 트랜스포머 모델 덕분에, 모델에 동적 양자화를 쉽게 적용할 수 있습니다. 왜나하면 동적 양자화는 LSTM 모델과 트랜스포머 모델에서 가장 잘 적용되기 때문입니다. (자세한 내용은 여기 를 참고하세요.)
아래의 코드를 실행시켜 봅시다.
# 서버 추론을 위해 'fbgemm'을, 모바일 추론을 위해 'qnnpack'을 사용해 봅시다.
backend = "fbgemm" # 이 주피터 노트북에서는 양자화된 모델의 더 느린 추론 속도를 일으키는 qnnpack으로 대체되었습니다.
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/root/.local/lib/python3.9/site-packages/torch/ao/quantization/observer.py:214: UserWarning:
Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
fbdeit_quantized_scripted.pt 모델의 스크립팅과 양자화가 적용된 버전이 만들어졌습니다. 모델의 크기는 단지 89MB 입니다. 양자화가 적용되지 않은 모델의 크기인 346MB보다 74%나 감소했습니다!
동일한 추론 결과를 만들기 위해 scripted_quantized_model
을
사용해 봅시다.
out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# 동일한 출력 결과인 269가 출력 되어야 합니다.
269
DeiT 최적화¶
모바일에 스크립트 되고 양자화된 모델을 사용하기 위한 마지막 단계는 최적화입니다.
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")
생성된 fbdeit_optimized_scripted_quantized.pt 파일은 양자화되고 스크립트되지만 최적화되지 않은 모델과 크기가 거의 같습니다. 추론 결과는 동일하게 유지됩니다.
out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# 다시 한번, 동일한 출력 결과인 269가 출력 되어야 합니다.
269
라이트 인터프리터(Lite interpreter) 사용¶
라이트 인터프리터를 사용하면 얼마나 모델의 사이즈가 작아지고, 추론 시간이 짧아지는지 결과를 확인해 봅시다. 이제 좀 더 가벼운 버전의 모델을 만들어 봅시다.
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")
가벼운 모델의 크기는 그렇지 않은 버전의 모델 크기와 비슷하지만, 모바일에서 가벼운 버전을 실행하면 추론 속도가 빨라질 것으로 예상됩니다.
추론 속도 비교¶
네 가지 모델(원본 모델, 스크립트된 모델, 스크립트와 양자화를 적용한 모델, 스크립트와 양자화를 적용한 후 최적화한 모델)의 추론 속도가 어떻게 다른지 확인해 봅시다.
아래의 코드를 실행해 봅시다.
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
out = ptl(img)
print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 146.71ms
scripted model: 115.71ms
scripted & quantized model: 122.75ms
scripted & quantized & optimized model: 134.22ms
lite model: 129.58ms
Google Colab에서 실행 시킨 결과는 다음과 같습니다.
original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms
다음 결과는 각 모델이 소요한 추론 시간과 원본 모델에 대한 각 모델의 감소율을 요약한 것입니다.
import pandas as pd
import numpy as np
df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
columns=['Inference Time', 'Reduction'])], axis=1)
print(df)
"""
Model Inference Time Reduction
0 original model 1236.69ms 0%
1 scripted model 1226.72ms 0.81%
2 scripted & quantized model 593.19ms 52.03%
3 scripted & quantized & optimized model 598.01ms 51.64%
4 lite model 600.72ms 51.43%
"""
Model Inference Time Reduction
0 original model 146.71ms 0%
1 scripted model 115.71ms 21.13%
2 scripted & quantized model 122.75ms 16.33%
3 scripted & quantized & optimized model 134.22ms 8.51%
4 lite model 129.58ms 11.68%
'\n Model Inference Time Reduction\n0\toriginal model 1236.69ms 0%\n1\tscripted model 1226.72ms 0.81%\n2\tscripted & quantized model 593.19ms 52.03%\n3\tscripted & quantized & optimized model 598.01ms 51.64%\n4\tlite model 600.72ms 51.43%\n'