• Tutorials >
  • (베타) LSTM 기반 단어 단위 언어 모델의 동적 양자화
Shortcuts

(베타) LSTM 기반 단어 단위 언어 모델의 동적 양자화

Author: James Reed

Edited by: Seth Weidman

번역: 박경림 Myungha Kwon

시작하기

양자화는 모델의 크기를 줄이고 추론 속도를 높이면서도 정확도는 별로 낮아지지 않도록, 모델의 가중치와 활성 함수를 실수형에서 정수형으로 변환합니다.

이 튜토리얼에서는 PyTorch의 단어 단위 언어 모델 예제를 따라하면서, LSTM 기반의 단어 예측 모델에 가장 간단한 양자화 기법인 동적 양자화 를 적용해 보겠습니다.

# 불러오기
import os
from io import open
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

1. 모델 정의하기

단어 단위 언어 모델 예제에서 사용된 모델 을 따라 LSTM 모델 아키텍처를 정의합니다.

class LSTMModel(nn.Module):
    """인코더, 반복 모듈 및 디코더가 있는 컨테이너 모듈."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        return decoded, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

2. 텍스트 데이터 불러오기

다음으로, 단어 단위 언어 모델 예제의 전처리 과정을 따라 Wikitext-2 데이터셋Corpus 인스턴스에 불러옵니다.

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        assert os.path.exists(path)
        """텍스트 파일 토큰화"""
        assert os.path.exists(path)
        # 사전에 단어 추가
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)

        # 파일 내용 토큰화
        with open(path, 'r', encoding="utf8") as f:
            idss = []
            for line in f:
                words = line.split() + ['<eos>']
                ids = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids).type(torch.int64))
            ids = torch.cat(idss)

        return ids

model_data_filepath = 'data/'

corpus = Corpus(model_data_filepath + 'wikitext-2')

3. 사전 학습된 모델 불러오기

이 튜토리얼은 모델이 학습된 후 적용되는 양자화 기술인 동적 양자화에 대한 튜토리얼입니다. 따라서 우리는 미리 학습된 가중치를 모델 아키텍처에 로드할 것 입니다. 이 가중치는 word language 모델 예제의 기본 설정을 사용하여 5개의 epoch 동안 학습하여 얻은 것입니다.

ntokens = len(corpus.dictionary)

model = LSTMModel(
    ntoken = ntokens,
    ninp = 512,
    nhid = 256,
    nlayers = 5,
)

model.load_state_dict(
    torch.load(
        model_data_filepath + 'word_language_model_quantize.pth',
        map_location=torch.device('cpu')
        )
    )

model.eval()
print(model)
LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): LSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): Linear(in_features=256, out_features=33278, bias=True)
)

이제 사전 학습된 모델이 잘 동작하는지 확인해보기 위해 텍스트를 생성해 보겠습니다. 지금까지 튜토리얼을 진행했던 방식처럼 이 예제 를 따라 하겠습니다.

input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000

with open(model_data_filepath + 'out.txt', 'w') as outf:
    with torch.no_grad():  # 기록을 추적하지 않습니다.
        for i in range(num_words):
            output, hidden = model(input_, hidden)
            word_weights = output.squeeze().div(temperature).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            input_.fill_(word_idx)

            word = corpus.dictionary.idx2word[word_idx]

            outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))

            if i % 100 == 0:
                print('| Generated {}/{} words'.format(i, 1000))

with open(model_data_filepath + 'out.txt', 'r') as outf:
    all_output = outf.read()
    print(all_output)
| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'by' b'solid' b'trivalent' b'@-@' b'respectability' b'from' b'architecture' b'and' b'body' b'.' b'The' b'latter' b"'re" b'also' b'used' b',' b'a' b'linear' b'<unk>' b'<unk>'
b'Coventry' b'engine' b',' b'a' b'red' b'Harmon' b'<unk>' b'for' b'the' b'mornings' b'the' b'national' b'praised' b'should' b'be' b'considered' b'further' b',' b'and' b'one'
b'was' b'largely' b'incomplete' b'.' b'The' b'<unk>' b'penned' b'four' b'arable' b'observations' b'in' b'the' b'first' b'two' b'years' b',' b'and' b'or' b'another' b'claim'
b'about' b'three' b'years' b'of' b'Stein' b'.' b'The' b'Adams' b'performance' b'reopened' b'Far' b'rates' b',' b'following' b'many' b'@-@' b'road' b',' b'in' b'which'
b'are' b'able' b'to' b'help' b'the' b'national' b'flocks' b'of' b'Ceres' b',' b'established' b'Mysorean' b'vegetation' b'officers' b'and' b'fashion' b'specimens' b'.' b'However' b','
b'it' b'is' b'also' b'considered' b'both' b',' b'but' b'they' b'do' b'not' b'help' b'their' b'time' b'.' b'If' b'it' b'has' b'high' b',' b'there'
b'are' b'many' b'print' b'of' b'trouble' b'constant' b'distal' b'.' b'Therefore' b',' b'the' b'substance' b'were' b'reddish' b'unanimous' b'for' b'over' b'70' b'years' b'than'
b'the' b'vintage' b'season' b'.' b'<eos>' b'Common' b'starlings' b'suffer' b'that' b'Isis' b'be' b'aligned' b'between' b'land' b',' b'drove' b'in' b'a' b'bill' b'due'
b'to' b'relative' b'culture' b'and' b'prey' b'methods' b'that' b'can' b'be' b'trained' b'on' b'new' b'rectory' b'.' b'They' b',' b'anemia' b',' b'is' b'exposed'
b'continental' b'evocative' b'to' b'facilitate' b'a' b'blast' b'without' b'Fort' b'plenty' b',' b'which' b'was' b'located' b'valens' b'.' b'They' b'also' b'moved' b'completely' b'by'
b'"' b'one' b'of' b'its' b'stuff' b'"' b'\xe2\x80\x94' b'rather' b'than' b'to' b'leave' b'multiple' b'species' b'that' b'should' b'be' b'with' b'agreeing' b'stubbornly' b'material'
b',' b'"' b'2O' b'"' b',' b'"' b'engender' b'<unk>' b'"' b'.' b'Frequent' b'suppose' b'that' b'their' b'large' b'noise' b'as' b'a' b'jam' b'County'
b'girl' b',' b'as' b'introduced' b'west' b'of' b'a' b'fortified' b',' b'sometimes' b'just' b'obscure' b'based' b'north' b'from' b'which' b'they' b'are' b'unclear' b'once'
b'instrumental' b'.' b'It' b'ends' b'and' b'scent' b'currencies' b',' b'in' b'Metal' b'Cortex' b'and' b'<unk>' b'the' b'last' b'estimates' b'of' b'her' b'conversion' b'in'
b'the' b'forest' b'or' b'approximately' b'25' b'back' b'(' b'wet' b'kg' b')' b'long' b'distance' b'accusing' b'the' b'race' b'.' b'Some' b'gains' b',' b'in'
b'which' b'sorts' b'stones' b'must' b'be' b'effective' b'forbids' b',' b'roadways' b'at' b'a' b'<unk>' b'zone' b'on' b'6' b'October' b'public' b'species' b'.' b'No'
b'graphic' b'groups' b'are' b'diurnal' b'@-@' b'brown' b',' b'which' b'gill' b'<unk>' b'(' b'sometimes' b'two' b'house' b'<unk>' b')' b'and' b'at' b'one' b'length'
b'of' b'Raffles' b'.' b'In' b'2008' b',' b'it' b'is' b'many' b'Indian' b'similar' b'tales' b'.' b'Common' b'starlings' b'can' b'be' b'50' b'@.@' b'8'
b'cm' b'(' b'18' b'ft' b')' b'at' b'Anatolia' b'<unk>' b'y' b'cool' b',' b'which' b'before' b'the' b'Buddha' b'is' b'entirely' b'more' b'tightly' b'.'
b'A' b'square' b'factor' b'allows' b'two' b'venomous' b'fungus' b'of' b'18' b'@.@' b'7' b'cm' b'(' b'0' b'@.@' b'4' b'in' b')' b'thick' b','
b'and' b'it' b'is' b'a' b'1' b'@,@' b'500' b'm' b'(' b'3' b'@.@' b'2' b'in' b')' b'averaging' b'commute' b'to' b'other' b'concentration' b'of'
b'52' b'mm' b'(' b'13' b'm' b')' b'in' b'waterline' b'length' b'.' b'They' b'responded' b'up' b'in' b'one' b'labor' b',' b'so' b'the' b'female'
b'is' b'analyzed' b'.' b'If' b'Malaya' b'often' b'consolidated' b'from' b'it' b',' b'they' b'will' b'be' b'possible' b'when' b'it' b'describes' b'using' b'males' b'.'
b'If' b'it' b'cross' b'their' b'dragging' b',' b'they' b'watch' b'dispatch' b'or' b'physical' b'expect' b',' b'as' b'they' b'may' b'be' b'already' b'Edith' b'terminating'
b'by' b'them' b'of' b'ibotenic' b'fir' b'.' b'O' b'<unk>' b'is' b'still' b'attributed' b'to' b'human' b'starlings' b',' b'they' b'were' b'flightless' b'and' b'are'
b'not' b'categorised' b'.' b'Some' b'plants' b'rarely' b'lose' b'a' b'broad' b',' b'rRNA' b',' b'mitosis' b'is' b'genius' b'on' b'service' b'and' b'produce' b'both'
b'arms' b',' b'as' b'they' b'are' b'well' b'made' b'to' b'prince' b',' b'iconography' b',' b'<unk>' b',' b'and' b'Continued' b',' b'and' b'chicks' b'of'
b'linear' b'females' b'Tadeusz' b'them' b'and' b'chemical' b'birds' b',' b'Robert' b'acids' b'and' b'quotations' b',' b'may' b'be' b'cephalothorax' b'for' b'armor' b'.' b'<eos>'
b'A' b'orbital' b'snake' b',' b'representing' b'modern' b'behaviour' b',' b'are' b'the' b'toxic' b'gaits' b'of' b'sac' b'and' b'serves' b'1827' b',' b'coupled' b'to'
b'20' b'@.@' b'9' b'to' b'7' b'@.@' b'5' b'm' b'(' b'4' b'@.@' b'0' b'm' b')' b'.' b'This' b'is' b'probably' b'overnight' b'of'
b'metallic' b'starlings' b',' b'order' b'to' b'be' b'all' b'over' b'users' b'of' b'other' b'prey' b'morels' b'.' b'Loose' b'and' b'assailant' b'treat' b'eggs' b'for'
b'males' b'are' b'as' b'by' b'the' b'series' b'via' b'their' b'pounders' b'.' b'As' b'her' b'image' b'falls' b'is' b'to' b'think' b'they' b'were' b'labeled'
b'@-@' b'space' b'when' b'the' b'feathers' b'generated' b'in' b'many' b'opinion' b',' b'they' b'eat' b'eggs' b'.' b'No' b'spiders' b'980' b'assumed' b'<unk>' b'occurring'
b'.' b'<unk>' b'plastic' b'little' b'records' b'eukaryotes' b'will' b'be' b'of' b'their' b'own' b'components' b',' b'or' b'<unk>' b'on' b'sites' b'several' b'range' b'.'
b'A' b'large' b'Reign' b'is' b'attached' b'to' b'<unk>' b'.' b'They' b'may' b'also' b'<unk>' b'females' b',' b'disciplined' b'them' b'Peggy' b'Vaballathus' b'.' b'The'
b'zenith' b'instead' b'of' b'completes' b'anomaly' b'are' b'far' b'to' b'humans' b'such' b'as' b'<unk>' b',' b'excess' b',' b'<unk>' b',' b'medium' b',' b'molecule'
b',' b'<unk>' b',' b'<unk>' b',' b'and' b'sassy' b',' b'and' b'go' b'his' b'firstborn' b'neck' b'.' b'As' b'the' b'way' b'they' b'find' b'him'
b'as' b'they' b'hadn' b"'t" b'do' b'seems' b'for' b'any' b'incident' b'that' b'is' b'first' b'<unk>' b',' b'apparently' b'when' b'their' b'planet' b'are' b'usually'
b'<unk>' b',' b'even' b'that' b',' b'without' b'eating' b'likely' b'varied' b'.' b'Once' b'the' b'hole' b',' b'with' b'lyrics' b',' b'putting' b'to' b'a'
b'fourth' b'with' b'a' b'decrease' b'.' b'The' b'sung' b'chicks' b"'" b'skull' b'or' b'several' b'meters' b'after' b'ribosomes' b'at' b'folding' b'.' b'Since' b'which'
b'will' b'no' b'longer' b'disc' b',' b'there' b'are' b'no' b'elaborate' b'strips' b'about' b'a' b'sarcophagus' b'strut' b'to' b'be' b'.' b'Common' b'birds' b'can'
b'be' b'repentant' b'under' b'both' b'coniferous' b'molecules' b'.' b'<eos>' b'The' b'major' b'starling' b'mainly' b'has' b'their' b'boom' b'but' b'covers' b'in' b'the' b'flesh'
b',' b'but' b'that' b'it' b'has' b'the' b'nuclei' b'.' b'He' b'is' b'first' b'significant' b'with' b'any' b'other' b'body' b'commissions' b'and' b'pile' b'/'
b'to' b'the' b'other' b'pale' b',' b'such' b'as' b'for' b'a' b'mountain' b'averaging' b'over' b'over' b'eight' b'days' b',' b'but' b'Plensa' b'led' b'as'
b'"' b'naturalist' b'<unk>' b'"' b',' b'"' b'Those' b'reminiscent' b'to' b'God' b'"' b'.' b'In' b'<unk>' b',' b'upper' b'species' b'of' b'DNA' b'@-@'
b'green' b'birds' b'participate' b'through' b'acid' b',' b'many' b'or' b'significant' b'iconography' b'were' b'sold' b'for' b'short' b'specimens' b'.' b'This' b'they' b'have' b'attributed'
b'to' b'Heatseekers' b'chicks' b'on' b'far' b',' b'as' b'she' b'preferred' b'chastity' b',' b'leaving' b'them' b'eucalypts' b'and' b'flesh' b'much' b'distortion' b'from' b'do'
b'not' b'feed' b'with' b'their' b'dense' b'objects' b'.' b'John' b'acceptance' b'addressed' b'has' b'successfully' b'frequented' b'habit' b'of' b'nearly' b'600' b'years' b'(' b'always'
b'more' b'believed' b'earlier' b',' b'particularly' b'only' b'is' b'so' b'a' b'average' b')' b'.' b'The' b'disappointment' b'also' b'shake' b'males' b'and' b'extinct' b'like'
b'branched' b'above' b'stones' b'of' b'Cambridge' b'throughout' b'total' b'.' b'Many' b'More' b'are' b'M.' b'Fountain' b',' b'autumn' b',' b'and' b'<unk>' b'from' b'the'

이 모델이 GPT-2는 아니지만, 언어의 구조를 배우기 시작한 것처럼 보입니다!

동적 양자화를 시연할 준비가 거의 끝났습니다. 몇 가지 helper 함수를 정의하기만 하면 됩니다:

bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1

# 테스트 데이터셋 만들기
def batchify(data, bsz):
    # 데이터셋을 ``bsz`` 부분으로 얼마나 깔끔하게 나눌 수 있는지 계산합니다.
    nbatch = data.size(0) // bsz
    # 깔끔하게 맞지 않는 추가적인 부분(나머지들)을 잘라냅니다.
    data = data.narrow(0, 0, nbatch * bsz)
    # 데이터에 대하여 ``bsz`` 묶음(batch)들로 동등하게 나눕니다.
    return data.view(bsz, -1).t().contiguous()

test_data = batchify(corpus.test, eval_batch_size)

# 평가 함수들
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

def repackage_hidden(h):
  """은닉 상태를 변화도 기록에서 제거된 새로운 tensor로 만듭니다."""

  if isinstance(h, torch.Tensor):
      return h.detach()
  else:
      return tuple(repackage_hidden(v) for v in h)

def evaluate(model_, data_source):
    # Dropout을 중지시키는 평가 모드로 실행합니다.
    model_.eval()
    total_loss = 0.
    hidden = model_.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model_(data, hidden)
            hidden = repackage_hidden(hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

4. 동적 양자화 테스트하기

마지막으로 모델에서 torch.quantization.quantize_dynamic 을 호출 할 수 있습니다! 구체적으로,

  • 모델의 nn.LSTMnn.Linear 모듈을 양자화 하도록 명시합니다.

  • 가중치들이 int8 값으로 변환되도록 명시합니다.

import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

모델은 동일하게 보입니다. 이것이 어떻게 이득을 주는 것일까요? 첫째, 모델 크기가 상당히 줄어 듭니다:

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)
Size (MB): 113.943637
Size (MB): 79.738057

두 번째로, 평가 손실값은 같으나 추론(inference) 속도가 빨라졌습니다.

# 메모: 양자화 된 모델은 단일 스레드로 실행되기 때문에 단일 스레드 비교를 위해
# 스레드 수를 1로 설정했습니다.

torch.set_num_threads(1)

def time_model_evaluation(model, test_data):
    s = time.time()
    loss = evaluate(model, test_data)
    elapsed = time.time() - s
    print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))

time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)
loss: 5.167
elapsed time (seconds): 118.9
loss: 5.168
elapsed time (seconds): 55.2

MacBook Pro에서 로컬로 실행하는 경우, 양자화 없이는 추론(inference)에 약 200초가 걸리고 양자화를 사용하면 약 100초가 걸립니다.

마치며

동적 양자화는 정확도에 제한적인 영향을 미치면서 모델 크기를 줄이는 쉬운 방법이 될 수 있습니다.

읽어주셔서 감사합니다. 언제나처럼 어떠한 피드백도 환영이니, 의견이 있다면 여기 에 이슈를 남겨 주세요.

Total running time of the script: ( 2 minutes 58.698 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2023, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동