This document is relevant for: Trn1

Run training in Pytorch Neuron container#

Table of Contents


This tutorial demonstrates how to run a pytorch container on an trainium instance.

By the end of this tutorial you will be able to run simple mlp training using the container

You will use an trn1.2xlarge to test your Docker configuration for Trainium.

To find out the available neuron devices on your instance, use the command ls /dev/neuron*.

Setup Environment#

  1. Launch an Trn1 Instance
    • Please follow the instructions at launch an Amazon EC2 Instance to Launch an Trn1 instance, when choosing the instance type at the EC2 console. Please make sure to select the correct instance type. To get more information about Trn1 instances sizes and pricing see Trn1 web page.

    • Select your Amazon Machine Image (AMI) of choice, please note that Neuron support Amazon Linux 2 AMI(HVM) - Kernel 5.10.

    • When launching a Trn1, please adjust your primary EBS volume size to a minimum of 512GB.

    • After launching the instance, follow the instructions in Connect to your instance to connect to the instance


    AutoScalingGroups is currently not supported on Trn1 and will be added soon.

    To launch a Trn1 cluster you can use AWS ParallelCluster, please see example.


    Neuron Driver installed on Deep Learning AMI (DLAMI) with Conda does not support Trn1.

    If you want to use DLAMI with Conda, please make sure to uninstall aws-neuron-dkms and install aws-neuronx-dkms before using Neuron on DLAMI with Conda.

  2. Set up docker environment according to Tutorial Docker environment setup

3. A sample Dockerfile for for torch-neuron can be found here Dockerfile for Application Container. This dockerfile needs the mlp train script found here Simple MLP train script

With the files in a dir, build the image with the following command:

docker build . -f -t neuron-container:pytorch

Run the following command to start the container

docker run -it --name pt-cont --net=host --device=/dev/neuron0 neuron-container:pytorch python3 /opt/ml/

This document is relevant for: Trn1