This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

ML framework support on AWS Neuron SDK#

AWS Neuron provides integration with popular machine learning frameworks, enabling you to accelerate your existing models on AWS Inferentia and Trainium with minimal code changes. Choose from our comprehensive framework support to optimize your inference and training workloads.

Frameworks#

PyTorch on AWS Neuron

Complete PyTorch integration for both inference and training on all Neuron hardware.

  • TorchNeuron Native - Native PyTorch backend with eager execution and torch.compile

  • PyTorch NeuronX (torch-neuronx) - Inf2, Trn1, Trn2 (inference & training)

  • See: Native PyTorch for AWS Trainium

JAX on AWS Neuron

Beta release

Experimental JAX support with Neuron Kernel Interface (NKI) integration.

  • JAX NeuronX - Neuron hardware support

  • Research and development focus

  • Status: Beta - active

Note

Looking for TensorFlow, MXNet, or torch-neuron (Inf1) documentation? These frameworks have been archived. See Archived AWS Neuron SDK documentation for legacy framework documentation.

Hardware compatibility matrix#

Framework

Inf2

Trn1/Trn1n

Trn2

Inference

Training

torch-neuronx

JAX NeuronX

N/A

N/A

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3