This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
About PyTorch on AWS Neuron#
This topic provides an overview of PyTorch support in Neuron for AWS Inf* (Inferentia-based) and Trn* (Trainium-based) ML platforms.
Throughout the past 5 years, AWS Neuron has evolved its PyTorch support to match the capabilities and architectures of successive generations of AWS ML accelerators, delivering three distinct PyTorch implementations optimized for different hardware platforms and use cases:
torch-neuron (2019): Graph-based inference for Inferentia (Inf1)
torch-neuronx (2022): XLA-based training and inference for Inferentia2 (Inf2) and Trainium (Trn1/Trn2)
TorchNeuron (2025): Native PyTorch backend for Trainium (Trn2/Trn3) with eager mode and
torch.compile
Overview#
AWS Neuron’s PyTorch support has evolved through three major implementations, each designed to leverage the unique capabilities of AWS ML accelerators:
torch-neuron (2019-2026): The original PyTorch integration for AWS Inferentia (Inf1), focused on inference workloads with a graph-based compilation approach
torch-neuronx (2022-): An XLA-based PyTorch implementation for AWS Inferentia2 (Inf2) and Trainium (Trn1/Trn2/Trn3), supporting both training and inference with distributed computing capabilities
TorchNeuron (2025-): A native PyTorch backend for Trainium that provides eager mode execution,
torch.compilesupport, and standard PyTorch distributed APIs without requiring XLA
Each implementation represents a significant architectural evolution, reflecting advances in both AWS ML accelerator hardware and PyTorch framework capabilities.
torch-neuron for Inf1#
The first Neuron library supporting PyTorch, torch-neuron, was initially released in December 2019 alongside the launch of AWS Inferentia. This implementation introduced PyTorch developers to AWS’s purpose-built ML inference accelerators.
torch-neuron uses a graph-based compilation approach where PyTorch models are traced and compiled into optimized Neuron Executable File Format (NEFF) binaries. The library integrates with PyTorch through custom operators and provides APIs for model compilation (torch.neuron.trace) and execution on Inferentia NeuronCores.
Key characteristics of torch-neuron:
Target Platform: AWS Inferentia (Inf1 instances)
Primary Use Case: Inference workloads
Compilation Approach: Ahead-of-time (AOT) graph compilation via
torch.neuron.traceSupported Models: Computer vision models (ResNet, VGG, EfficientNet, YOLO variants), NLP models (BERT, RoBERTa, DistilBERT, MarianMT), and other inference-optimized architectures
Integration Method: Custom PyTorch operators and tracing API
When to choose torch-neuron#
Choose torch-neuron when:
Deploying inference workloads on AWS Inferentia (Inf1) instances
Working with models that can be traced and compiled ahead of time
Optimizing for inference latency and throughput on first-generation Inferentia hardware
Requiring compatibility with existing Inf1-based infrastructure
torch-neuronx for Inf2 and Trn1#
In October 2022, AWS introduced Inferentia2 and Trainium, second-generation ML accelerators with enhanced capabilities for both training and inference. To support these platforms, Neuron delivered torch-neuronx, a new PyTorch implementation built on PyTorch/XLA.
torch-neuronx represents a significant architectural shift from torch-neuron, leveraging the XLA (Accelerated Linear Algebra) compiler infrastructure to enable both training and inference workloads. This XLA-based approach provides support for dynamic shapes, control flow, distributed training primitives, and advanced parallelism strategies.
Key characteristics of torch-neuronx:
Target Platforms: AWS Inferentia2 (Inf2 instances) and AWS Trainium (Trn1, Trn1n, Trn2, Trn3 instances)
Primary Use Cases: Both training and inference workloads
Compilation Approach: XLA-based compilation with support for dynamic shapes and control flow
Distributed Computing: Native support for data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, and Zero Redundancy Optimizer (ZeRO)
Training Capabilities: Full support for large-scale model training including LLMs (Llama, GPT, BERT families), with gradient accumulation, mixed precision training, and distributed checkpointing
Inference Capabilities: Support for large language model inference with features like continuous batching, speculative decoding, and quantization
Integration Method: PyTorch/XLA device backend (
xladevice type)
The XLA-based architecture enables torch-neuronx to support advanced training techniques and distributed strategies that were not possible with the original torch-neuron implementation. This includes support for frameworks like NeuronX Distributed (NxD) for training and inference, Transformers NeuronX for LLM inference, and integration with popular ML libraries like HuggingFace Transformers and PyTorch Lightning.
When to choose torch-neuronx#
Choose torch-neuronx when:
Training models on AWS Trainium (Trn1, Trn1n, Trn2) instances
Running inference on AWS Inferentia2 (Inf2) instances
Requiring distributed training capabilities with tensor parallelism, pipeline parallelism, or data parallelism
Working with large language models or other models requiring multi-device training
Needing dynamic shape support or control flow in your models
Using PyTorch versions 2.5 through 2.9 (XLA-based implementation)
Note: PyTorch 2.9 is the last version of torch-neuronx based on PyTorch/XLA. Starting with PyTorch 2.10 support (planned for a future Neuron release), torch-neuronx will transition to the native PyTorch implementation (TorchNeuron).
TorchNeuron (Native PyTorch integration)#
TorchNeuron, the latest evolution of PyTorch support for Neuron, was announced in December 2025 at AWS re:Invent and shipped its initial version as part of Neuron release 2.27.0. While it retains the same Python package name as its predecessor (torch-neuronx), TorchNeuron is an entirely new native PyTorch backend developed specifically for Trainium platforms.
TorchNeuron represents a fundamental architectural shift from XLA-based compilation to native PyTorch integration through the PrivateUse1 device backend mechanism. This native integration enables PyTorch code to run on Trainium with minimal modifications, supporting both eager mode execution for rapid iteration and torch.compile for production optimization.
Key characteristics of TorchNeuron:
Target Platforms: AWS Trainium (Trn2, Trn3 instances)
Primary Use Cases: Training and inference workloads with native PyTorch workflows
Execution Modes:
Eager Mode: Immediate operation execution for interactive development and debugging
torch.compile: Just-in-time (JIT) compilation via TorchDynamo for optimized performance
Distributed APIs: Native support for standard PyTorch distributed primitives:
Fully Sharded Data Parallel (FSDP)
Distributed Tensor (DTensor)
Distributed Data Parallel (DDP)
Tensor Parallelism (TP)
Integration Method: Native PyTorch backend via PrivateUse1 mechanism (
neurondevice type)Ecosystem Compatibility: Works with TorchTitan, HuggingFace Transformers, and other PyTorch ecosystem tools with minimal code changes
Custom Kernels: Integration with Neuron Kernel Interface (NKI) for performance-critical operations
Open Source: Available on GitHub under Apache 2.0 license
TorchNeuron’s native integration eliminates the need for XLA-specific APIs and enables researchers and ML developers to use familiar PyTorch patterns. The eager mode support provides immediate feedback during development, while torch.compile delivers production-grade performance through hardware-specific optimizations.
The implementation includes Adaptive Eager Execution, which applies optimizations like operator fusion while maintaining functional accuracy and debuggability. This approach provides a balance between development velocity and runtime performance.
When to choose TorchNeuron#
Choose TorchNeuron (native PyTorch) when:
Training models on AWS Trainium (Trn2, Trn3) instances with PyTorch 2.10 or later
Requiring eager mode execution for interactive development and debugging
Using standard PyTorch distributed training APIs (FSDP, DTensor, DDP)
Working with PyTorch ecosystem tools like TorchTitan or HuggingFace Transformers
Needing minimal code changes to run existing PyTorch code on Trainium
Leveraging
torch.compilefor production optimizationDeveloping custom kernels with Neuron Kernel Interface (NKI)
Migration Note: Starting with PyTorch 2.10 support (planned for a future Neuron release), AWS Neuron will transition from PyTorch/XLA to native PyTorch support via TorchNeuron. Users on PyTorch 2.9 or earlier will need to update their scripts when upgrading to PyTorch 2.10 or later. See Native PyTorch for AWS Trainium for complete migration guidance.
Read More#
Training Resources
Training with torch-neuronx - Training guides and tutorials for Trainium
PyTorch Neuron Programming Guide - Core concepts for training on Neuron
NeuronX Distributed (NxD) Training - Distributed training library for large-scale models
PyTorch Training Tutorials - Step-by-step training examples
Inference Resources
Inference with torch-neuronx - Inference guides for Inf2 and Trn1/Trn2
Inference with torch-neuron - Inference guides for Inf1
NeuronX Distributed Inference (NxDI) - Inference library for large language models
torch-neuron vs torch-neuronx Comparison - Detailed comparison for inference workloads
Architecture and Hardware
AWS Inferentia Architecture - Inf1 hardware architecture
AWS Inferentia2 Architecture - Inf2 hardware architecture
AWS Trainium Architecture - Trn1 hardware architecture
AWS Trainium2 Architecture - Trn2 hardware architecture
AWS Trainium3 Architecture - Trn3 hardware architecture
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3