Skip to main content
Ctrl+K
Neuron 2.22.0 is released! check What's New and Announcements
Logo image
Ctrl+K
Search Engine: Default Google

Overview

  • Quick Links
  • Ask Q Developer
  • Get Started with Neuron
  • Samples and Tutorials
    • Training on Trn1
    • Inference on Inf2 and Trn1
    • Inference on Inf1
  • Performance
  • What’s New
  • Announcements

ML Frameworks

  • PyTorch Neuron
    • Pytorch Neuron Setup
    • Inference (Inf2, Trn1, Trn2)
      • Tutorials
        • Compiling and Deploying HuggingFace Pretrained BERT on Trn1 or Inf2
        • BERT TorchServe Tutorial
        • LibTorch C++ Tutorial
        • Compiling and Deploying ResNet50 on Trn1 or Inf2
        • T5 model inference on Trn1 or Inf2
      • Additional Examples
        • AWS Neuron Samples GitHub Repository
        • Transformers Neuron GitHub samples
      • API Reference Guide
        • PyTorch NeuronX Tracing API for Inference
        • PyTorch Neuron (torch-neuronx) Weight Replacement API for Inference
        • PyTorch NeuronX NeuronCore Placement APIs [Beta]
        • PyTorch NeuronX Analyze API for Inference
        • PyTorch NeuronX DataParallel API
      • Developer Guide
        • NeuronCore Allocation and Model Placement for Inference (torch-neuronx)
        • Comparison of Traced Inference versus XLA Lazy Tensor Inference (torch-neuronx)
        • Data Parallel Inference on torch_neuronx
      • Misc
        • PyTorch Neuron (torch-neuronx) release notes
    • Inference (Inf1)
      • Tutorials
        • Computer Vision Tutorials
        • Natural Language Processing (NLP) Tutorials
        • Utilizing Neuron Capabilities Tutorials
      • Additional Examples
        • AWS Neuron Samples GitHub Repository
      • API Reference Guide
        • PyTorch Neuron trace Python API
        • torch.neuron.DataParallel API
        • PyTorch Neuron (torch-neuron) Core Placement API [Beta]
      • Developer Guide
        • Running Inference on Variable Input Shapes with Bucketing
        • Data Parallel Inference on PyTorch Neuron
        • Developer Guide - PyTorch Neuron (torch-neuron) LSTM Support
        • PyTorch Neuron (torch-neuron) Core Placement
      • Misc
        • PyTorch Neuron (torch-neuron) Supported operators
        • Troubleshooting Guide for PyTorch Neuron (torch-neuron)
        • PyTorch Neuron (torch-neuron) release notes
    • Training
      • Tutorials
        • Hugging Face BERT Pretraining Tutorial (Data-Parallel)
        • Multi-Layer Perceptron Training Tutorial
        • PyTorch Neuron for Trainium Hugging Face BERT MRPC task finetuning using Hugging Face Trainer API
        • Fine-tune T5 model on Trn1
        • ZeRO-1 Tutorial
        • Analyze for Training Tutorial
        • Neuron Custom C++ Operators in MLP Training
        • Neuron Custom C++ Operators Performance Optimization
      • Additional Examples
        • AWS Neuron Reference for Nemo Megatron GitHub Repository
        • AWS Neuron Samples for EKS
        • AWS Neuron Samples for AWS ParallelCluster
        • AWS Neuron Samples GitHub Repository
      • API Reference Guide
        • PyTorch NeuronX neuron_parallel_compile CLI
        • PyTorch NeuronX Environment Variables
        • Neuron Persistent Cache
        • PyTorch NeuronX Profiling API
      • Developer Guide
        • Developer Guide for Training with PyTorch NeuronX
        • How to debug models in PyTorch NeuronX
        • Developer Guide for Profiling with PyTorch NeuronX
      • Misc
        • PyTorch Neuron (torch-neuronx) - Supported Operators
        • How to prepare trn1.32xlarge for multi-node execution
        • PyTorch Neuron (torch-neuronx) for Training Troubleshooting Guide
        • PyTorch Neuron (torch-neuronx) release notes
  • JAX Neuron
    • JAX Neuron plugin Setup
    • JAX NeuronX Known Issues
    • API Reference Guide for JAX Neuronx
      • JAX NeuronX Environment Variables
  • TensorFlow Neuron
    • Tensorflow Neuron Setup
    • Inference (Inf2 & Trn1)
      • Tutorials
        • HuggingFace Roberta-Base
        • Using NEURON_RT_VISIBLE_CORES with TensorFlow Serving
      • API Reference Guide
        • TensorFlow 2.x (tensorflow-neuronx) Tracing API
        • TensorFlow 2.x (tensorflow-neuronx) Auto Multicore Replication (Beta)
        • TensorFlow 2.x (tensorflow-neuronx) analyze_model API
      • Misc
        • TensorFlow 2.x (tensorflow-neuronx) Release Notes
    • Inference (Inf1)
      • Tutorials
        • Natural Language Processing (NLP) Tutorials
        • Utilizing Neuron Capabilities Tutorials
      • Additional Examples
        • AWS Neuron Samples GitHub Repository
      • API Reference Guide
        • TensorFlow 2.x (tensorflow-neuron) Tracing API
        • TensorFlow 2.x (tensorflow-neuron) analyze_model API
        • TensorFlow 2.x (tensorflow-neuron) Auto Multicore Replication (Beta)
      • Misc
        • TensorFlow 2.x (tensorflow-neuron) Release Notes
        • TensorFlow 2.x (tensorflow-neuron) Accelerated (torch-neuron) Python APIs and Graph Ops

NeuronX Distributed (NxD)

  • NxD Training
    • Overview
    • Setup
    • App Notes
      • Introducing NxD Training
      • Tensor Parallelism Overview
      • Pipeline Parallelism Overview
      • Activation Memory Reduction
    • API Reference Guide
      • YAML Configuration Settings
    • Developer Guides
      • Integrating a new model
      • Integrating a new dataset/dataloader
      • Registering an optimizer and LR scheduler
      • Migrating from Neuron-NeMo-Megatron to Neuronx Distributed Training
      • NxD Training Compatibility with NeMo
    • Tutorials
      • Megatron GPT Pretraining
      • HuggingFace Llama3-8B Pretraining
      • HuggingFace LLama3-8B Supervised Fine-tuning
      • HuggingFace Llama3-8B Efficient Supervised Fine-tuning with LoRA (Beta)
      • HuggingFace Llama3-8B Direct Preference Optimization (DPO) based Fine-tuning
      • HuggingFace Llama3-70B, Llama3.1-70B Pretraining
      • Checkpoint Conversion
    • Misc
      • NxD Training Release Notes (neuronx-distributed-training)
      • Known Issues and Workarounds
  • NxD Inference (Beta)
    • Overview
      • NxD Inference Overview
    • Setup
    • API Reference Guide
      • NxD Inference API Reference
    • Developer Guides
      • NxD Inference Features Configuration Guide
      • NxD Inference - Production Ready Models
      • Onboarding models to run on NxD Inference
      • vLLM User Guide for NxD Inference
      • Testing modeling code with NxD Inference
      • Migrating from NxD Core inference examples to NxD Inference
      • Migrating from Transformers NeuronX to NeuronX Distributed(NxD) Inference
      • LLM Inference Benchmarking guide
      • Accuracy Evaluation of Models on Neuron Using Open Source Datasets
      • Custom Quantization
    • Tutorials
      • Tutorial: Deploying Llama3.1 405B (Trn2)
      • Tutorial: Deploying Llama3.2 Multimodal Models
      • Tutorial: Using Speculative Decoding to improve Llama-3.3-70B inference performance on Trn2 instances
      • Tutorial: Multi-LoRA serving for Llama-3.1-8B on Trn2 instances
      • Tutorial: Using Speculative Decoding and Quantization to improve Llama-3.1-405B inference performance on Trn2 instances
      • Tutorial: Evaluating Accuracy of Llama-3.1-70B on Neuron using open source datasets
    • App Notes
      • Introducing NeuronX Distributed (NxD) Inference
    • Misc
      • NxD Inference Release Notes (neuronx-distributed-inference)
  • NxD Core
    • Setup
    • App Notes
      • Tensor Parallelism Overview
      • Pipeline Parallelism Overview
      • Activation Memory Reduction
    • API Reference Guide
      • Distributed Strategies APIs
      • Training APIs
      • Inference APIs
    • Developer Guide
      • Training Developer Guides
        • Developer guide for Tensor Parallelism
        • Developer guide for Pipeline Parallelism
        • Developer guide for Activation Memory reduction
        • Developer guide for save/load checkpoint
        • Developer guide for Neuron-PT-Lightning
        • Developer guide for model and optimizer wrapper
        • Developer guide for LoRA finetuning
      • Inference Developer Guide
        • Developer guide for Neuronx-Distributed Inference
    • Tutorials
      • Training Tutorials
        • Training using Tensor Parallelism
        • Training GPT-NeoX 6.9B using TP and ZeRO-1
        • Training GPT-NeoX 20B using TP and ZeRO-1
        • Training Llama 3.1 8B/Llama 3 8B/Llama 2 7B using TP and ZeRO-1
        • Training Llama 3.1 70B/Llama 3 70B/Llama 2 13B/70B using TP and PP
        • Training Llama-2-7B/13B/70B using TP and PP with PyTorch-Lightning
      • Inference Tutorials
        • T5 inference with Tensor Parallelism
        • Llama-2-7b Inference
    • Misc
      • NxD Core Release Notes (neuronx-distributed)

Additional ML Libraries

  • Third Party Libraries
  • Transformers Neuron
    • Setup
    • Developer Guide
      • Transformers NeuronX (transformers-neuronx) Developer Guide
      • Transformers NeuronX (transformers-neuronx) Developer Guide for Continuous Batching
    • Tutorials
      • Hugging Face meta-llama/Llama-2-13b autoregressive sampling on Inf2 & Trn1
      • Hugging Face facebook/opt-13b autoregressive sampling on Inf2 & Trn1
      • Hugging Face facebook/opt-30b autoregressive sampling on Inf2 & Trn1
      • Hugging Face facebook/opt-66b autoregressive sampling on Inf2
    • Misc
      • Transformers Neuron (transformers-neuronx) release notes
  • AWS Neuron reference for NeMo Megatron

Developer Flows

  • Neuron DLAMI
  • Neuron Containers
    • Getting started with Neuron DLC using Docker
    • Neuron Deep Learning Containers
    • Customize Neuron DLC
    • Neuron Plugins for Containerized Environments
    • Neuron Containers FAQ
  • AWS Workload Orchestration
    • Amazon EKS
      • Using Neuron with Amazon EKS
      • Deploy Neuron Container on Elastic Kubernetes Service (EKS) for Inference
      • Deploy a simple mlp training script as a Kubernetes job
    • Amazon ECS
      • Neuron Problem Detector And Recovery
      • Deploy Neuron Container on Elastic Container Service (ECS) for Inference
      • Deploy Neuron Container on Elastic Container Service (ECS) for Training
    • AWS ParallelCluster
      • Parallel Cluster Flows- Training
        • Train your model on ParallelCluster
    • AWS Batch
      • Train your model on AWS Batch
  • Amazon SageMaker
  • Third-party Solutions
  • Setup Guide
    • Launching Inf/Trn instances on Amazon EC2
      • Inference
        • Compile with Framework API and Deploy on EC2 Inf1
        • Compile with Framework API and Deploy on EC2 Inf2
      • Training
        • Train your model on EC2
    • PyTorch NeuronX (torch-neuronx)
    • PyTorch Neuron (torch-neuron)
    • JAX NeuronX
      • JAX Neuron plugin Setup
      • JAX NeuronX Known Issues
      • API Reference Guide for JAX Neuronx
        • JAX NeuronX Environment Variables
    • Tensorflow NeuronX (tensorflow-neuronx)
    • Tensorflow Neuron (tensorflow-neuron)
    • MxNet Neuron (mxnet-neuron)

Runtime & Tools

  • Neuron Runtime
    • API Reference Guide
      • Runtime API
    • Configuration Guide
      • Runtime Configuration
    • Misc
      • Troubleshooting on Inf1 and Trn1
      • FAQ
      • Neuron Runtime Release Notes
      • Neuron Driver Release Notes
      • Neuron Collectives Release Notes
  • Monitoring Tools
    • Neuron-Monitor User Guide
    • Neuron-Top User Guide
    • Neuron-LS User Guide
    • Neuron-Sysfs User Guide
    • NCCOM-TEST User Guide
    • What's New
  • Profiling Tools
    • Neuron Profiler User Guide
    • Neuron Profiler 2.0 (Beta) User Guide
    • What's New
  • Third-party Solutions
  • Other Tools
    • Check Model
    • GatherInfo
    • NeuronPerf
    • Neuron Calculator
    • TensorBoard Plugin for Neuron (Trn1)
    • TensorBoard Plugin for Neuron (Inf1)
    • Track Training Progress in TensorBoard using PyTorch Neuron
    • TensorBoard What's New

Compiler

  • Neuron Compiler
    • NeuronX Compiler for Trn1 & Inf2
      • API Reference Guide
        • Neuron Compiler CLI Reference Guide
      • Developer Guide
        • Mixed Precision and Performance-accuracy Tuning (neuronx-cc)
      • Misc
        • FAQ
        • What's New
    • Neuron Compiler for Inf1
      • API Reference Guide
        • Neuron compiler CLI Reference Guide (neuron-cc)
      • Developer Guide
        • Mixed precision and performance-accuracy tuning (neuron-cc)
      • Misc
        • FAQ
        • What's New
        • Neuron Supported operators
  • Neuron Kernel Interface (Beta)
    • API Reference Manual
      • nki
        • nki.jit
        • nki.benchmark
        • nki.profile
        • nki.baremetal
        • nki.simulate_kernel
        • nki.tensor
      • nki.language
        • nki.language.load
        • nki.language.store
        • nki.language.load_transpose2d
        • nki.language.atomic_rmw
        • nki.language.copy
        • nki.language.broadcast_to
        • nki.language.ndarray
        • nki.language.empty_like
        • nki.language.zeros
        • nki.language.zeros_like
        • nki.language.ones
        • nki.language.full
        • nki.language.rand
        • nki.language.random_seed
        • nki.language.shared_constant
        • nki.language.shared_identity_matrix
        • nki.language.add
        • nki.language.subtract
        • nki.language.multiply
        • nki.language.divide
        • nki.language.power
        • nki.language.maximum
        • nki.language.minimum
        • nki.language.max
        • nki.language.min
        • nki.language.mean
        • nki.language.var
        • nki.language.sum
        • nki.language.prod
        • nki.language.all
        • nki.language.abs
        • nki.language.negative
        • nki.language.sign
        • nki.language.trunc
        • nki.language.floor
        • nki.language.ceil
        • nki.language.mod
        • nki.language.fmod
        • nki.language.exp
        • nki.language.log
        • nki.language.cos
        • nki.language.sin
        • nki.language.tan
        • nki.language.tanh
        • nki.language.arctan
        • nki.language.sqrt
        • nki.language.rsqrt
        • nki.language.sigmoid
        • nki.language.relu
        • nki.language.gelu
        • nki.language.gelu_dx
        • nki.language.gelu_apprx_tanh
        • nki.language.silu
        • nki.language.silu_dx
        • nki.language.erf
        • nki.language.erf_dx
        • nki.language.softplus
        • nki.language.mish
        • nki.language.square
        • nki.language.softmax
        • nki.language.rms_norm
        • nki.language.dropout
        • nki.language.matmul
        • nki.language.transpose
        • nki.language.reciprocal
        • nki.language.bitwise_and
        • nki.language.bitwise_or
        • nki.language.bitwise_xor
        • nki.language.invert
        • nki.language.left_shift
        • nki.language.right_shift
        • nki.language.equal
        • nki.language.not_equal
        • nki.language.greater
        • nki.language.greater_equal
        • nki.language.less
        • nki.language.less_equal
        • nki.language.logical_and
        • nki.language.logical_or
        • nki.language.logical_xor
        • nki.language.logical_not
        • nki.language.ds
        • nki.language.arange
        • nki.language.mgrid
        • nki.language.expand_dims
        • nki.language.where
        • nki.language.all_reduce
        • nki.language.static_range
        • nki.language.affine_range
        • nki.language.sequential_range
        • nki.language.par_dim
        • nki.language.psum
        • nki.language.sbuf
        • nki.language.hbm
        • nki.language.private_hbm
        • nki.language.shared_hbm
        • nki.language.program_id
        • nki.language.num_programs
        • nki.language.program_ndim
        • nki.language.spmd_dim
        • nki.language.nc
        • nki.language.device_print
        • nki.language.loop_reduce
        • nki.language.tfloat32
        • nki.language.bfloat16
        • nki.language.float8_e4m3
        • nki.language.float8_e5m2
        • nki.language.tile_size
      • nki.isa
        • nki.isa.nc_matmul
        • nki.isa.nc_transpose
        • nki.isa.activation
        • nki.isa.activation_reduce
        • nki.isa.tensor_reduce
        • nki.isa.tensor_partition_reduce
        • nki.isa.tensor_tensor
        • nki.isa.tensor_tensor_scan
        • nki.isa.scalar_tensor_tensor
        • nki.isa.tensor_scalar
        • nki.isa.tensor_scalar_reduce
        • nki.isa.tensor_copy
        • nki.isa.tensor_copy_dynamic_src
        • nki.isa.tensor_copy_dynamic_dst
        • nki.isa.tensor_copy_predicated
        • nki.isa.reciprocal
        • nki.isa.iota
        • nki.isa.dropout
        • nki.isa.affine_select
        • nki.isa.memset
        • nki.isa.bn_stats
        • nki.isa.bn_aggr
        • nki.isa.local_gather
        • nki.isa.dma_copy
        • nki.isa.max8
        • nki.isa.nc_find_index8
        • nki.isa.nc_match_replace8
        • nki.isa.nc_stream_shuffle
        • nki.isa.reduce_cmd
        • nki.isa.tensor_engine
        • nki.isa.vector_engine
        • nki.isa.scalar_engine
        • nki.isa.gpsimd_engine
        • nki.isa.dma_engine
        • nki.isa.unknown_engine
        • nki.isa.engine
        • nki.isa.nc_version
        • nki.isa.get_nc_version
      • nki.compiler
        • nki.compiler.sbuf.alloc
        • nki.compiler.sbuf.mod_alloc
        • nki.compiler.sbuf.auto_alloc
        • nki.compiler.psum.alloc
        • nki.compiler.psum.mod_alloc
        • nki.compiler.psum.auto_alloc
        • nki.compiler.skip_middle_end_transformations
        • nki.compiler.enable_stack_allocator
        • nki.compiler.force_auto_alloc
      • NKI API Common Fields
      • NKI API Errors
    • Developer Guide
      • Getting Started with NKI
      • NKI Programming Model
      • NKI Kernel as a Framework Custom Operator
      • NeuronDevice Architecture Guide for NKI
        • Trainium/Inferentia2 Architecture Guide for NKI
        • Trainium2 Architecture Guide for NKI
      • Profiling NKI kernels with Neuron Profile
      • NKI Performance Guide
      • NKI Direct Allocation Developer Guide
    • Tutorials
      • Single program, multiple data tensor addition
      • Single program, multiple data tensor addition using multiple Neuron Cores
      • Transpose2D
      • AveragePool2D
      • Matrix multiplication
      • RMSNorm
      • LayerNorm
      • Fused Self Attention
      • Fused Mamba
    • Kernels
    • Misc
      • NKI FAQ
      • What's New
      • NKI Known Issues
  • Neuron C++ Custom Operators
    • API Reference Guide
      • Custom Operators API Reference Guide [Beta]
    • Developer Guide
      • Neuron Custom C++ Operators Developer Guide [Beta]
    • Tutorials
      • Neuron Custom C++ Operators in MLP Training
      • Neuron Custom C++ Operators Performance Optimization
    • Misc (Neuron Custom C++ Operators)
      • Neuron Custom C++ Tools Release Notes
      • Neuron Custom C++ Library Release Notes

Learning Neuron

  • Architecture
    • Trn/Inf Instances
    • Amazon EC2 AI Chips
    • NeuronCores
  • Features
    • Data types
    • Rounding modes
    • Neuron batching
    • NeuronCore pipeline
    • Neuron persistent cache
    • Collective communication
    • Logical NeuronCore configuration
    • Custom C++ operators
  • Application notes
    • Introducing the first release of Neuron 2.x enabling EC2 Trn1 General Availability (GA)
    • Introducing Neuron Runtime 2.x (libnrt.so)
    • Performance Tuning
    • Parallel Execution using NEURON_RT_NUM_CORES
    • Running R-CNNs on Inf1
    • Graph Partitioner on torch_neuronx
    • Generative LLM inference with Neuron
    • Introducing PyTorch 2.5 Support
  • FAQ
  • Troubleshooting
  • Neuron Glossary

Legacy Software

  • Apache MXNet
    • MXNet Neuron Setup
    • Inference (Inf1)
      • Tutorials
        • Computer Vision Tutorials
        • Natural Language Processing (NLP) Tutorials
        • Utilizing Neuron Capabilities Tutorials
      • API Reference Guide
        • Neuron Apache MXNet Compilation Python API
      • Developer Guide
        • Flexible Execution Group (FlexEG) in Neuron-MXNet
      • Misc
        • Troubleshooting Guide for Neuron Apache MXNet
        • What's New
        • Neuron Apache MXNet Supported operators

About Neuron

  • Release Details
  • Roadmap
    • Neuron Public Roadmap
  • Support
    • SDK Maintenance Policy
    • Security Disclosures
    • Contact Us
  • Repository
  • Suggest edit
  • Open issue
  • .rst

API Reference Guide for JAX Neuronx

This document is relevant for: Inf2, Trn1, Trn2

API Reference Guide for JAX Neuronx#

  • JAX NeuronX Environment Variables

This document is relevant for: Inf2, Trn1, Trn2

previous

JAX NeuronX Known Issues

next

JAX NeuronX Environment Variables

By AWS

© Copyright 2025, Amazon.com.