• Tutorials >
  • Flask를 사용하여 Python에서 PyTorch를 REST API로 배포하기
Shortcuts

Flask를 사용하여 Python에서 PyTorch를 REST API로 배포하기

Author: Avinash Sajjanshetty

번역: 박정환

이 튜토리얼에서는 Flask를 사용하여 PyTorch 모델을 배포하고 모델 추론(inference)을 할 수 있는 REST API를 만들어보겠습니다. 미리 훈련된 DenseNet 121 모델을 배포하여 이미지를 인식해보겠습니다.

여기서 사용한 모든 코드는 MIT 라이선스로 배포되며, GitHub 에서 확인하실 수 있습니다.

이것은 PyTorch 모델을 상용(production)으로 배포하는 튜토리얼 시리즈의 첫번째 편입니다. Flask를 여기에 소개된 것처럼 사용하는 것이 PyTorch 모델을 제공하는 가장 쉬운 방법이지만, 고성능을 요구하는 때에는 적합하지 않습니다. 그에 대해서는:

API 정의

먼저 API 엔드포인트(endpoint)의 요청(request)와 응답(response)을 정의하는 것부터 시작해보겠습니다. 새로 만들 API 엔드포인트는 이미지가 포함된 file 매개변수를 HTTP POST로 /predict 에 요청합니다. 응답은 JSON 형태이며 다음과 같은 예측 결과를 포함합니다:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

의존성(Dependencies)

아래 명령어를 실행하여 필요한 패키지들을 설치합니다:

$ pip install Flask==2.0.1 torchvision==0.10.0

간단한 웹 서버

Flask의 문서를 참고하여 아래와 같은 코드로 간단한 웹 서버를 구성합니다.

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

또한, ImageNet 분류 ID와 이름을 포함하는 JSON을 회신하도록 응답 형식을 변경하겠습니다. 이제 app.py 는 아래와 같이 변경되었습니다:

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

추론(Inference)

다음 섹션에서는 추론 코드 작성에 집중하겠습니다. 먼저 이미지를 DenseNet에 공급(feed)할 수 있도록 준비하는 방법을 살펴본 뒤, 모델로부터 예측 결과를 얻는 방법을 살펴보겠습니다.

이미지 준비하기

DenseNet 모델은 224 x 224의 3채널 RGB 이미지를 필요로 합니다. 또한 이미지 텐서를 평균 및 표준편차 값으로 정규화합니다. 자세한 내용은 여기 를 참고하세요.

torchvision 라이브러리의 transforms 를 사용하여 변환 파이프라인 (transform pipeline)을 구축합니다. Transforms와 관련한 더 자세한 내용은 여기 에서 읽어볼 수 있습니다.

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

위 메소드는 이미지를 byte 단위로 읽은 후, 일련의 변환을 적용하고 Tensor를 반환합니다. 위 메소드를 테스트하기 위해서는 이미지를 byte 모드로 읽은 후 Tensor를 반환하는지 확인하면 됩니다. (먼저 ../_static/img/sample_file.jpeg 을 컴퓨터 상의 실제 경로로 바꿔야 합니다.)

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

예측(Prediction)

미리 학습된 DenseNet 121 모델을 사용하여 이미지 분류(class)를 예측합니다. torchvision 라이브러리의 모델을 사용하여 모델을 읽어오고 추론을 합니다. 이 예제에서는 미리 학습된 모델을 사용하지만, 직접 만든 모델에 대해서도 이와 동일한 방법을 사용할 수 있습니다. 모델을 읽어오는 것은 이 튜토리얼 을 참고하세요.

from torchvision import models

# 이미 학습된 가중치를 사용하기 위해 `weights` 에 `IMAGENET1K_V1` 값을 전달합니다:
model = models.densenet121(weights='IMAGENET1K_V1')
# 모델을 추론에만 사용할 것이므로, `eval` 모드로 변경합니다:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

y_hat Tensor는 예측된 분류 ID의 인덱스를 포함합니다. 하지만 사람이 읽을 수 있는 분류명이 있어야 하기 때문에, 이를 위해 이름과 분류 ID를 매핑하는 것이 필요합니다. 이 파일 을 다운로드 받아 imagenet_class_index.json 이라는 이름으로 저장 후, 저장한 곳의 위치를 기억해두세요. (또는, 이 튜토리얼과 똑같이 진행하는 경우에는 tutorials/_static 에 저장하세요.) 이 파일은 ImageNet 분류 ID와 ImageNet 분류명의 쌍을 포함하고 있습니다. 이제 이 JSON 파일을 불러와 예측 결과의 인덱스에 해당하는 분류명을 가져오겠습니다.

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

imagenet_class_index 사전(dictionary)을 사용하기 전에, imagenet_class_index 의 키 값이 문자열이므로 Tensor의 값도 문자열로 변환해야 합니다. 위 메소드를 테스트해보겠습니다:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

다음과 같은 응답을 받게 될 것입니다:

['n02124075', 'Egyptian_cat']

배열의 첫번째 항목은 ImageNet 분류 ID이고, 두번째 항목은 사람이 읽을 수 있는 이름입니다.

모델을 API 서버에 통합하기

마지막으로 위에서 만든 Flask API 서버에 모델을 추가하겠습니다. API 서버는 이미지 파일을 받는 것을 가정하고 있으므로, 요청으로부터 파일을 읽도록 predict 메소드를 수정해야 합니다:

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()
FLASK_ENV=development FLASK_APP=app.py flask run
import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

resp.json() 을 호출하면 다음과 같은 결과를 출력합니다:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

다음 단계

지금까지 만든 서버는 매우 간단하여 상용 프로그램(production application)으로써 갖춰야할 것들을 모두 갖추지 못했습니다. 따라서, 다음과 같이 개선해볼 수 있습니다:

  • /predict 엔드포인트는 요청 시에 반드시 이미지 파일이 전달되는 것을 가정하고 있습니다. 하지만 모든 요청이 그렇지는 않습니다. 사용자는 다른 매개변수로 이미지를 보내거나, 이미지를 아예 보내지 않을수도 있습니다.

  • 사용자가 이미지가 아닌 유형의 파일을 보낼수도 있습니다. 여기서는 에러를 처리하지 않고 있으므로, 이러한 경우에 서버는 다운(break)됩니다. 예외 전달(exception throe)을 위한 명시적인 에러 핸들링 경로를 추가하면 잘못된 입력을 더 잘 처리할 수 있습니다.

  • 모델은 많은 종류의 이미지 분류를 인식할 수 있지만, 모든 이미지를 인식할 수 있는 것은 아닙니다. 모델이 이미지에서 아무것도 인식하지 못하는 경우를 처리하도록 개선합니다.

  • 위에서는 Flask 서버를 개발 모드에서 실행하였지만, 이는 상용으로 배포하기에는 적당하지 않습니다. Flask 서버를 상용으로 배포하는 것은 이 튜토리얼 을 참고해보세요.

  • 또한 이미지를 가져오는 양식(form)과 예측 결과를 표시하는 페이지를 만들어 UI를 추가할 수도 있습니다. 비슷한 프로젝트의 데모 와 이 데모의 소스 코드 를 참고해보세요.

  • 이 튜토리얼에서는 한 번에 하나의 이미지에 대한 예측 결과를 반환하는 서비스를 만드는 방법만 살펴보았는데요, 한 번에 여러 이미지에 대한 예측 결과를 반환하도록 수정해볼 수 있습니다. 추가로, service-streamer 라이브러리는 자동으로 요청을 큐에 넣은 뒤 모델에 공급(feed)할 수 있는 미니-배치로 샘플링합니다. 이 튜토리얼 을 참고해보세요.

  • 마지막으로 이 문서 상단에 링크된, PyTorch 모델을 배포하는 다른 튜토리얼들을 읽어보는 것을 권장합니다.

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동