This document is relevant for: Trn1

Simple MLP train script#

Save the following contents as mlp_train.py

 1import os
 2import time
 3import torch
 4from model import MLP
 5
 6from torchvision.datasets import mnist
 7from torch.utils.data import DataLoader
 8from torchvision.transforms import ToTensor
 9
10# XLA imports
11import torch_xla.core.xla_model as xm
12
13# Global constants
14EPOCHS = 4
15WARMUP_STEPS = 2
16BATCH_SIZE = 32
17
18# Load MNIST train dataset
19train_dataset = mnist.MNIST(root='./MNIST_DATA_train',
20                            train=True, download=True, transform=ToTensor())
21
22def main():
23    # Prepare data loader
24    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
25
26    # Fix the random number generator seeds for reproducibility
27    torch.manual_seed(0)
28
29    # XLA: Specify XLA device (defaults to a NeuronCore on Trn1 instance)
30    device = 'xla'
31
32    # Move model to device and declare optimizer and loss function
33    model = MLP().to(device)
34    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
35    loss_fn = torch.nn.NLLLoss()
36
37    # Run the training loop
38    print('----------Training ---------------')
39    model.train()
40    for epoch in range(EPOCHS):
41        start = time.time()
42        for idx, (train_x, train_label) in enumerate(train_loader):
43            optimizer.zero_grad()
44            train_x = train_x.view(train_x.size(0), -1)
45            train_x = train_x.to(device)
46            train_label = train_label.to(device)
47            output = model(train_x)
48            loss = loss_fn(output, train_label)
49            loss.backward()
50            optimizer.step()
51            xm.mark_step() # XLA: collect ops and run them in XLA runtime
52            if idx < WARMUP_STEPS: # skip warmup iterations
53                start = time.time()
54
55    # Compute statistics for the last epoch
56    interval = idx - WARMUP_STEPS # skip warmup iterations
57    throughput = interval / (time.time() - start)
58    print("Train throughput (iter/sec): {}".format(throughput))
59    print("Final loss is {:0.4f}".format(loss.detach().to('cpu')))
60
61    # Save checkpoint for evaluation
62    os.makedirs("checkpoints", exist_ok=True)
63    checkpoint = {'state_dict': model.state_dict()}
64    # XLA: use xm.save instead of torch.save to ensure states are moved back to cpu
65    # This can prevent "XRT memory handle not found" at end of test.py execution
66    xm.save(checkpoint,'checkpoints/checkpoint.pt')
67
68    print('----------End Training ---------------')
69
70if __name__ == '__main__':
71    main()

Save the following contents as model.py

 1import torch.nn as nn
 2import torch.nn.functional as F
 3
 4# Declare 3-layer MLP for MNIST dataset
 5class MLP(nn.Module):
 6  def __init__(self, input_size = 28 * 28, output_size = 10, layers = [120, 84]):
 7      super(MLP, self).__init__()
 8      self.fc1 = nn.Linear(input_size, layers[0])
 9      self.fc2 = nn.Linear(layers[0], layers[1])
10      self.fc3 = nn.Linear(layers[1], output_size)
11
12  def forward(self, x):
13      x = F.relu(self.fc1(x))
14      x = F.relu(self.fc2(x))
15      x = self.fc3(x)
16      return F.log_softmax(x, dim=1)

This document is relevant for: Trn1