참고
Click here to download the full example code
(베타) LSTM 기반 단어 단위 언어 모델의 동적 양자화¶
Author: James Reed
Edited by: Seth Weidman
번역: 박경림 Myungha Kwon
시작하기¶
양자화는 모델의 크기를 줄이고 추론 속도를 높이면서도 정확도는 별로 낮아지지 않도록, 모델의 가중치와 활성 함수를 실수형에서 정수형으로 변환합니다.
이 튜토리얼에서는 PyTorch의 단어 단위 언어 모델 예제를 따라하면서, LSTM 기반의 단어 예측 모델에 가장 간단한 양자화 기법인 동적 양자화 를 적용해 보겠습니다.
# 불러오기
import os
from io import open
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
1. 모델 정의하기¶
단어 단위 언어 모델 예제에서 사용된 모델 을 따라 LSTM 모델 아키텍처를 정의합니다.
class LSTMModel(nn.Module):
"""인코더, 반복 모듈 및 디코더가 있는 컨테이너 모듈."""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(LSTMModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
self.init_weights()
self.nhid = nhid
self.nlayers = nlayers
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output)
return decoded, hidden
def init_hidden(self, bsz):
weight = next(self.parameters())
return (weight.new_zeros(self.nlayers, bsz, self.nhid),
weight.new_zeros(self.nlayers, bsz, self.nhid))
2. 텍스트 데이터 불러오기¶
다음으로, 단어 단위 언어 모델 예제의 전처리 과정을 따라 Wikitext-2 데이터셋 을 Corpus 인스턴스에 불러옵니다.
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
def tokenize(self, path):
assert os.path.exists(path)
"""텍스트 파일 토큰화"""
assert os.path.exists(path)
# 사전에 단어 추가
with open(path, 'r', encoding="utf8") as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
self.dictionary.add_word(word)
# 파일 내용 토큰화
with open(path, 'r', encoding="utf8") as f:
idss = []
for line in f:
words = line.split() + ['<eos>']
ids = []
for word in words:
ids.append(self.dictionary.word2idx[word])
idss.append(torch.tensor(ids).type(torch.int64))
ids = torch.cat(idss)
return ids
model_data_filepath = 'data/'
corpus = Corpus(model_data_filepath + 'wikitext-2')
3. 사전 학습된 모델 불러오기¶
이 튜토리얼은 모델이 학습된 후 적용되는 양자화 기술인 동적 양자화에 대한 튜토리얼입니다. 따라서 우리는 미리 학습된 가중치를 모델 아키텍처에 로드할 것 입니다. 이 가중치는 word language 모델 예제의 기본 설정을 사용하여 5개의 epoch 동안 학습하여 얻은 것입니다.
ntokens = len(corpus.dictionary)
model = LSTMModel(
ntoken = ntokens,
ninp = 512,
nhid = 256,
nlayers = 5,
)
model.load_state_dict(
torch.load(
model_data_filepath + 'word_language_model_quantize.pth',
map_location=torch.device('cpu')
)
)
model.eval()
print(model)
LSTMModel(
(drop): Dropout(p=0.5, inplace=False)
(encoder): Embedding(33278, 512)
(rnn): LSTM(512, 256, num_layers=5, dropout=0.5)
(decoder): Linear(in_features=256, out_features=33278, bias=True)
)
이제 사전 학습된 모델이 잘 동작하는지 확인해보기 위해 텍스트를 생성해 보겠습니다. 지금까지 튜토리얼을 진행했던 방식처럼 이 예제 를 따라 하겠습니다.
input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000
with open(model_data_filepath + 'out.txt', 'w') as outf:
with torch.no_grad(): # 기록을 추적하지 않습니다.
for i in range(num_words):
output, hidden = model(input_, hidden)
word_weights = output.squeeze().div(temperature).exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input_.fill_(word_idx)
word = corpus.dictionary.idx2word[word_idx]
outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))
if i % 100 == 0:
print('| Generated {}/{} words'.format(i, 1000))
with open(model_data_filepath + 'out.txt', 'r') as outf:
all_output = outf.read()
print(all_output)
| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'amongst' b'that' b'by' b'turn' b'descended' b'.' b'According' b'to' b'queue' b'Jane' b'pagodas' b',' b'there' b'is' b'no' b'defensive' b'extant' b'goal' b',' b'226'
b'when' b'it' b'was' b'also' b'tended' b'to' b'direct' b'with' b'humans' b'.' b'Kudirka' b'use' b'"' b'Haifa' b'"' b',' b'borrowed' b'with' b'toss' b'on'
b'his' b'house' b'@-@' b'Cathedral' b',' b'transferring' b'a' b'twenty' b'@-@' b'game' b'ha' b',' b'by' b'his' b'second' b'date' b'from' b'his' b'father' b'It'
b'has' b'seven' b'trouble' b'.' b'The' b'only' b'sons' b'of' b'investigators' b'in' b'the' b'film' b'has' b'a' b'wider' b'download' b'of' b'social' b',' b'and'
b'achieves' b'three' b'years' b'to' b'1850s' b'.' b'According' b'the' b'crowd' b'gives' b'the' b'game' b'to' b'accumulate' b'whoever' b"'" b'mass' b'concepts' b'off' b'or'
b'spiritual' b'issues' b'.' b'I' b'choose' b'through' b'time' b'it' b'as' b'their' b'backyard' b'Doctor' b',' b'but' b'serves' b'it' b'a' b'vehicle' b'mathematical' b'chef'
b'which' b'was' b'called' b'very' b',' b'even' b',' b'they' b'think' b'his' b'fear' b',' b'and' b'would' b'<unk>' b'their' b'speech' b'been' b'laid' b'to'
b'those' b',' b'and' b'Ceres' b'escaped' b'winning' b'when' b'birds' b'delivered' b'individuals' b'<unk>' b'prefect' b'I' b'Burkettsville' b'to' b'this' b'draw' b'.' b'This' b'will'
b'get' b'like' b'everything' b'of' b'a' b'fun' b'Indiamen' b',' b'breaking' b'it' b'with' b'a' b'single' b'speed' b'on' b'a' b'2006' b'deal' b',' b'<unk>'
b'up' b'it' b'criticised' b'a' b'different' b'wings' b'from' b'Saint' b'miners' b';' b'they' b'could' b'die' b':' b'The' b'planet' b'of' b'nature' b'was' b'trying'
b'to' b'only' b'attract' b'things' b'to' b'Media.Vision' b'.' b'When' b'there' b',' b'"' b'this' b'might' b'be' b'get' b'to' b'be' b'"' b'.' b'<unk>'
b'considered' b'that' b'"' b'it' b'never' b'does' b'one' b'of' b'these' b'times' b'from' b'not' b'used' b'to' b'be' b'used' b';' b'you' b'has' b'any'
b'Cave' b'from' b'any' b'other' b'eye' b'of' b'nucleolar' b'Bandicoot' b'"' b'.' b'After' b'the' b'crew' b'only' b'adopted' b'the' b'game' b'<unk>' b',' b'Alabama'
b'parted' b'to' b'collapse' b'with' b'prospective' b'solidarity' b'against' b'such' b'or' b'Dame' b',' b'Raffles' b',' b'and' b'<unk>' b'.' b'Yet' b',' b'it' b'seems'
b'to' b'confront' b'the' b'game' b',' b'while' b'movement' b'is' b'often' b'similar' b'to' b'their' b'pieces' b'.' b'In' b'this' b'first' b'story' b'<unk>' b','
b'female' b'states' b'that' b'"' b'ideas' b'from' b'rough' b'movement' b',' b'and' b'else' b',' b'Pet' b'Australia' b'and' b'Shapur' b',' b'initially' b'"' b'.'
b'it' b'is' b'unique' b'that' b'it' b'dispersed' b'in' b'any' b'courtship' b'policy' b'group' b'career' b'while' b'these' b'of' b'those' b'series' b'like' b'it' b'163'
b'%' b'of' b'an' b'year' b'youth' b'has' b'already' b'been' b'killed' b'.' b'They' b'par' b'the' b'play' b'and' b'continue' b'to' b'prepare' b'.' b'As'
b'with' b'this' b'most' b'demise' b',' b'it' b'were' b'run' b'.' b'The' b'women' b'tell' b'<unk>' b'their' b'flowery' b'leave' b',' b'archaic' b'and' b'forbs'
b',' b'including' b'small' b'policies' b',' b'and' b'construction' b'.' b'<eos>' b'The' b'wording' b'of' b'<unk>' b'habitat' b'like' b'they' b'may' b'be' b'noisy' b'sexpunctatus'
b',' b'usually' b'as' b'Better' b'by' b'multiple' b'or' b'inbound' b'their' b'spare' b'may' b'Perth' b'be' b'actually' b'queried' b'.' b'If' b'Ceres' b'are' b'forced'
b'to' b'have' b'more' b'than' b'24' b'times' b'ever' b'(' b'Australia' b')' b',' b'which' b'shortened' b'a' b'large' b'size' b'sand' b'to' b'get' b'about'
b'2' b'million' b'or' b'impossible' b'for' b'<unk>' b'and' b'rock' b'.' b'Common' b'eggs' b',' b'is' b'one' b'of' b'a' b'Russian' b'alarm' b'interesting' b'for'
b'it' b'.' b'These' b'large' b'stories' b'suggest' b'are' b'visible' b',' b'with' b'it' b'combination' b'can' b'"' b'not' b'lose' b'because' b'some' b'remains' b'should'
b'have' b'a' b'average' b'place' b'to' b'have' b'quite' b',' b'it' b'\xe2\x80\x94' b'a' b'solid' b'narrow' b'celebrities' b'"' b'.' b'The' b'subsequent' b'asteroid' b'allows'
b'<unk>' b':' b'1' b'temperatures' b'no' b'a' b'eye' b'is' b'irresponsible' b'and' b'<unk>' b'colour' b'.' b'As' b'an' b'orbit' b',' b'it' b'will' b'muscles'
b',' b'and' b'become' b'often' b'misunderstood' b'.' b'<unk>' b'(' b'<unk>' b')' b'is' b'also' b'in' b'far' b'blood' b'until' b'two' b'hours' b'in' b'this'
b'season' b'before' b'even' b'.' b'<unk>' b'"' b'almost' b'wings' b'about' b'their' b'commandment' b"'" b'feet' b'instead' b'side' b'like' b'They' b',' b'Venus' b','
b'a' b'great' b'major' b'sentiment' b',' b'and' b'Homeric' b'Glee' b'.' b'"' b'used' b'that' b'year' b',' b'it' b'is' b'so' b'overcome' b'by' b'error'
b'such' b'as' b'beams' b',' b'they' b'helped' b'gusting' b'.' b'Instead' b',' b'these' b'nutrients' b'feeds' b'conducting' b'their' b'abdomen' b',' b'they' b'extend' b'on'
b'a' b'variety' b'of' b'rents' b'.' b'As' b'they' b'have' b'a' b'decree' b',' b'RNA' b'broods' b',' b'the' b'name' b'stereotypes' b',' b'1438' b','
b'willingness' b'to' b'feed' b'walks' b'after' b'takeoff' b'.' b'<eos>' b'Each' b'seems' b'consists' b'of' b'<unk>' b'pressure' b'can' b'be' b'heard' b'by' b'either' b'.'
b'If' b'these' b'areas' b',' b'just' b'crops' b'<unk>' b'should' b'be' b'checked' b'and' b'pointed' b'eggs' b'to' b'<unk>' b'able' b'to' b'seek' b'read' b'.'
b'This' b'Beginning' b'just' b'to' b'broad' b'<unk>' b',' b'except' b'technique' b'when' b'they' b'live' b'from' b'geared' b'cow' b'molecules' b'in' b'small' b'areas' b'.'
b'SbF' b'may' b'be' b'seen' b'by' b'some' b'point' b'of' b'spots' b'but' b'therefore' b'survived' b'.' b'However' b',' b'he' b'will' b'not' b'recognise' b'their'
b'eye' b'control' b',' b'so' b'even' b'proceeds' b'it' b'.' b'<eos>' b'explore' b'northwestward' b',' b'by' b'certain' b'testify' b'that' b'can' b'reduce' b'earthstar' b'behaviour'
b'as' b'they' b'look' b'on' b'the' b'"' b'<unk>' b'<unk>' b'is' b'to' b'produce' b'another' b'tree' b'age' b'or' b'Irish' b'or' b'he' b'would' b'be'
b'mature' b'alone' b'down' b'leave' b'.' b',' b'The' b'Omar' b'contains' b'its' b'growing' b'taxonomy' b'.' b'<eos>' b'If' b'most' b'remaining' b'bold' b'females' b'have'
b'recently' b'added' b'leave' b',' b'they' b'increases' b'because' b'they' b'are' b'tiger' b'.' b'In' b'addition' b',' b'his' b'military' b'stumps' b'may' b'have' b'something'
b'Message' b'their' b'same' b'strength' b'or' b'a' b'star' b'or' b'the' b'habitat' b'affectionate' b'.' b'While' b'they' b'believe' b'by' b'RNA' b'pursued' b'hiking' b','
b'they' b'do' b'"' b'their' b'crescent' b'when' b'they' b'find' b'how' b'they' b'can' b'be' b'.' b'<eos>' b'"' b'bigger' b'the' b'archival' b'<unk>' b'is'
b'then' b'my' b'favorite' b'matter' b'but' b'they' b'considered' b'for' b'you' b'when' b'manual' b'.' b'(' b'I' b'will' b'be' b'able' b'to' b'have' b'been'
b'-' b'after' b'it' b'number' b'is' b'written' b'behind' b'columns' b'and' b'you' b'can' b'suggest' b'before' b'can' b'continue' b'<unk>' b',' b'nothing' b'when' b'it'
b'is' b'contacted' b'when' b'he' b'may' b'happen' b'.' b'That' b'is' b'my' b'words' b'than' b'voiced' b'mechanics' b'can' b'be' b'killed' b'yet' b'I' b'once'
b'will' b'we' b'find' b'how' b'your' b'now' b'don' b"'t" b'say' b'"' b'.' b'A' b'Hibari.Ch.' b'is' b'"' b'so' b'it' b'not' b'Logan' b','
b'but' b'understood' b'so' b'enough' b'pushed' b'it' b':' b'...' b'Kurt' b'Bayan' b'was' b'well' b'connected' b'but' b'strongly' b'thinking' b'it' b'is' b'extremely' b'regarded'
b',' b'covering' b'him' b'to' b'look' b'of' b'its' b'drinking' b'.' b'"' b'According' b'to' b'replication' b'unsuitable' b'that' b'the' b'characters' b'must' b'be' b'spreads'
b',' b'they' b'must' b'have' b'strict' b'different' b'Crash' b'Whittaker' b'from' b'eagle' b',' b'rather' b'<unk>' b'good' b'extreme' b',' b'may' b'be' b'attached' b'to'
b'462' b',' b'with' b'multiple' b'doses' b'and' b'strikeout' b',' b'or' b'without' b'a' b'opposite' b'track' b':' b'the' b'control' b'of' b'Ceres' b',' b'as'
b'one' b'of' b'several' b'comprehensive' b'body' b'invertebrate' b'proteins' b',' b'ideas' b'making' b'them' b'cover' b',' b'were' b'turned' b',' b'but' b'instead' b'so' b'invented'
이 모델이 GPT-2는 아니지만, 언어의 구조를 배우기 시작한 것처럼 보입니다!
동적 양자화를 시연할 준비가 거의 끝났습니다. 몇 가지 helper 함수를 정의하기만 하면 됩니다:
bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1
# 테스트 데이터셋 만들기
def batchify(data, bsz):
# 데이터셋을 ``bsz`` 부분으로 얼마나 깔끔하게 나눌 수 있는지 계산합니다.
nbatch = data.size(0) // bsz
# 깔끔하게 맞지 않는 추가적인 부분(나머지들)을 잘라냅니다.
data = data.narrow(0, 0, nbatch * bsz)
# 데이터에 대하여 ``bsz`` 묶음(batch)들로 동등하게 나눕니다.
return data.view(bsz, -1).t().contiguous()
test_data = batchify(corpus.test, eval_batch_size)
# 평가 함수들
def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return data, target
def repackage_hidden(h):
"""은닉 상태를 변화도 기록에서 제거된 새로운 tensor로 만듭니다."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def evaluate(model_, data_source):
# Dropout을 중지시키는 평가 모드로 실행합니다.
model_.eval()
total_loss = 0.
hidden = model_.init_hidden(eval_batch_size)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i)
output, hidden = model_(data, hidden)
hidden = repackage_hidden(hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
4. 동적 양자화 테스트하기¶
마지막으로 모델에서 torch.quantization.quantize_dynamic
을 호출 할 수 있습니다!
구체적으로,
모델의
nn.LSTM
과nn.Linear
모듈을 양자화 하도록 명시합니다.가중치들이
int8
값으로 변환되도록 명시합니다.
import torch.quantization
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
LSTMModel(
(drop): Dropout(p=0.5, inplace=False)
(encoder): Embedding(33278, 512)
(rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)
(decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)
모델은 동일하게 보입니다. 이것이 어떻게 이득을 주는 것일까요? 첫째, 모델 크기가 상당히 줄어 듭니다:
def print_size_of_model(model):
torch.save(model.state_dict(), "temp.p")
print('Size (MB):', os.path.getsize("temp.p")/1e6)
os.remove('temp.p')
print_size_of_model(model)
print_size_of_model(quantized_model)
Size (MB): 113.943637
Size (MB): 79.738057
두 번째로, 평가 손실값은 같으나 추론(inference) 속도가 빨라졌습니다.
# 메모: 양자화 된 모델은 단일 스레드로 실행되기 때문에 단일 스레드 비교를 위해
# 스레드 수를 1로 설정했습니다.
torch.set_num_threads(1)
def time_model_evaluation(model, test_data):
s = time.time()
loss = evaluate(model, test_data)
elapsed = time.time() - s
print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))
time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)
loss: 5.167
elapsed time (seconds): 199.4
loss: 5.168
elapsed time (seconds): 131.6
MacBook Pro에서 로컬로 실행하는 경우, 양자화 없이는 추론(inference)에 약 200초가 걸리고 양자화를 사용하면 약 100초가 걸립니다.