This document is relevant for: Trn1
, Trn1n
Simple MLP train script#
Save the following contents as mlp_train.py
1import os
2import time
3import torch
4from model import MLP
5
6from torchvision.datasets import mnist
7from torch.utils.data import DataLoader
8from torchvision.transforms import ToTensor
9
10# XLA imports
11import torch_xla.core.xla_model as xm
12
13# Global constants
14EPOCHS = 4
15WARMUP_STEPS = 2
16BATCH_SIZE = 32
17
18# Load MNIST train dataset
19train_dataset = mnist.MNIST(root='./MNIST_DATA_train',
20 train=True, download=True, transform=ToTensor())
21
22def main():
23 # Prepare data loader
24 train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
25
26 # Fix the random number generator seeds for reproducibility
27 torch.manual_seed(0)
28
29 # XLA: Specify XLA device (defaults to a NeuronCore on Trn1 instance)
30 device = 'xla'
31
32 # Move model to device and declare optimizer and loss function
33 model = MLP().to(device)
34 optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
35 loss_fn = torch.nn.NLLLoss()
36
37 # Run the training loop
38 print('----------Training ---------------')
39 model.train()
40 for epoch in range(EPOCHS):
41 start = time.time()
42 for idx, (train_x, train_label) in enumerate(train_loader):
43 optimizer.zero_grad()
44 train_x = train_x.view(train_x.size(0), -1)
45 train_x = train_x.to(device)
46 train_label = train_label.to(device)
47 output = model(train_x)
48 loss = loss_fn(output, train_label)
49 loss.backward()
50 optimizer.step()
51 xm.mark_step() # XLA: collect ops and run them in XLA runtime
52 if idx < WARMUP_STEPS: # skip warmup iterations
53 start = time.time()
54
55 # Compute statistics for the last epoch
56 interval = idx - WARMUP_STEPS # skip warmup iterations
57 throughput = interval / (time.time() - start)
58 print("Train throughput (iter/sec): {}".format(throughput))
59 print("Final loss is {:0.4f}".format(loss.detach().to('cpu')))
60
61 # Save checkpoint for evaluation
62 os.makedirs("checkpoints", exist_ok=True)
63 checkpoint = {'state_dict': model.state_dict()}
64 # XLA: use xm.save instead of torch.save to ensure states are moved back to cpu
65 # This can prevent "XRT memory handle not found" at end of test.py execution
66 xm.save(checkpoint,'checkpoints/checkpoint.pt')
67
68 print('----------End Training ---------------')
69
70if __name__ == '__main__':
71 main()
Save the following contents as model.py
1import torch.nn as nn
2import torch.nn.functional as F
3
4# Declare 3-layer MLP for MNIST dataset
5class MLP(nn.Module):
6 def __init__(self, input_size = 28 * 28, output_size = 10, layers = [120, 84]):
7 super(MLP, self).__init__()
8 self.fc1 = nn.Linear(input_size, layers[0])
9 self.fc2 = nn.Linear(layers[0], layers[1])
10 self.fc3 = nn.Linear(layers[1], output_size)
11
12 def forward(self, x):
13 x = F.relu(self.fc1(x))
14 x = F.relu(self.fc2(x))
15 x = self.fc3(x)
16 return F.log_softmax(x, dim=1)
This document is relevant for: Trn1
, Trn1n