This document is relevant for: Inf1, Inf2, Trn1, Trn2

ML framework support on AWS Neuron SDK#

AWS Neuron provides integration with popular machine learning frameworks, enabling you to accelerate your existing models on AWS Inferentia and Trainium with minimal code changes. Choose from our comprehensive framework support to optimize your inference and training workloads.

PyTorch on AWS Neuron

Supports PyTorch 2.8

Complete PyTorch integration for both inference and training on all Neuron hardware.

  • PyTorch NeuronX - Inf2, Trn1, Trn2 (inference & training)

  • PyTorch Neuron - Inf1 (inference only)

  • Native PyTorch API compatibility

  • Distributed training support

  • Advanced profiling and debugging tools

JAX on AWS Neuron

Beta release

Experimental JAX support with Neuron Kernel Interface (NKI) integration.

  • JAX NeuronX - Neuron hardware support

  • NKI JAX interface

  • Research and development focus

  • Custom kernel development

  • Status: Beta - active development

Hardware compatibility matrix#

Framework

Inf1

Inf2

Trn1/Trn1n

Trn2

Inference

Training

PyTorch NeuronX

PyTorch Neuron

TensorFlow NeuronX

TensorFlow Neuron

MXNet Neuron

JAX NeuronX

🚧

This document is relevant for: Inf1, Inf2, Trn1, Trn2