Example training code for

import argparse
import logging
import os
import sys
import time
import warnings

import torch
from IPython.utils import io
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import loggers as pl_loggers
from torch import nn
from torch.nn import functional as F
from import DataLoader
from torchmetrics.functional.classification.accuracy import multiclass_accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

warnings.filterwarnings("ignore")  # Disable data logger warnings
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)  # Disable GPU/TPU prints

def parse_args():
    parser = argparse.ArgumentParser(description="train mnist")
        "--log_path", type=str, required=True, help="dir to place tensorboard logs from all trials"
        "--hidden_size_1", type=int, required=True, help="hidden size layer 1"
        "--hidden_size_2", type=int, required=True, help="hidden size layer 2"
    parser.add_argument("--learning_rate", type=float, required=True, help="learning rate")
    parser.add_argument("--epochs", type=int, required=True, help="number of epochs")
    parser.add_argument("--dropout", type=float, required=True, help="dropout probability")
    parser.add_argument("--batch_size", type=int, required=True, help="batch size")
    return parser.parse_args()

args = parse_args()

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")

class MnistModel(LightningModule):
    def __init__(self):

        # Tunable parameters
        self.hidden_size_1 = args.hidden_size_1
        self.hidden_size_2 = args.hidden_size_2
        self.learning_rate = args.learning_rate
        self.dropout = args.dropout
        self.batch_size = args.batch_size

        # Set class attributes
        self.data_dir = PATH_DATASETS

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
                transforms.Normalize((0.1307,), (0.3081,)),

        # Create a PyTorch model
        layers = [nn.Flatten()]
        width = channels * width * height
        hidden_layers = [self.hidden_size_1, self.hidden_size_2]
        num_params = 0
        for hidden_size in hidden_layers:
            if hidden_size > 0:
                layers.append(nn.Linear(width, hidden_size))
                num_params += width * hidden_size
                width = hidden_size
        layers.append(nn.Linear(width, self.num_classes))
        num_params += width * self.num_classes

        # Save the model and parameter counts
        self.num_params = num_params
        self.model = nn.Sequential(*layers)  # No need to use Relu for the last layer

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = multiclass_accuracy(preds, y, num_classes=self.num_classes)
        self.log("val_acc", acc, prog_bar=False)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        self.mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
        self.mnist_val = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

def run_training_job():

    mnist_model = MnistModel()

    # Initialize a trainer (don't log anything since things get so slow...)
    trainer = Trainer(
        deterministic=True,  # Do we want a bit of noise?

    logger = pl_loggers.TensorBoardLogger(args.log_path)

    print(f"Logging to path: {args.log_path}.")

    # Train the model and log time ⚡
    start = time.time()
    end = time.time()
    train_time = end - start
    logger.log_metrics({"train_time": end - start})

    # Compute the validation accuracy once and log the score
    with io.capture_output() as captured:
        val_accuracy = trainer.validate()[0]["val_acc"]
    logger.log_metrics({"val_acc": val_accuracy})

    # Log the number of model parameters
    num_params = trainer.model.num_params
    logger.log_metrics({"num_params": num_params})

    # Print outputs
    print(f"train time: {train_time}, val acc: {val_accuracy}, num_params: {num_params}")

if __name__ == "__main__":

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!

이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동


다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동