NeuronX Distributed
This document is relevant for: Inf2
, Trn1
, Trn1n
NeuronX Distributed#
NeuronX Distributed is a package for supporting different distributed training/inference mechanism for Neuron devices. It would provide xla friendly implementations of some of the more popular distributed training/inference techniques. As the size of the model scales, fitting these models on a single device becomes impossible and hence we have to make use of model sharding techniques to partition the model across multiple devices. As part of this library, we enable support for Tensor Parallel sharding technique with other distributed library supported to be added in future.
Setup (neuronx-distributed
)
Install PyTorch Neuron on Trn1 to create a pytorch environment. It is recommended to work out of python virtual env so as to avoid package installation issues.
You can install the neuronx-distributed
package using the following command:
python -m pip install neuronx_distributed --extra-index-url https://pip.repos.neuron.amazonaws.com
Make sure the transformers version is set to 4.26.0
App Notes (neuronx-distributed
)
API Reference Guide (neuronx-distributed
)
Developer Guide (neuronx-distributed
)
Developer guide for Tensor Parallelism (neuronx-distributed )
Developer guide for Pipeline Parallelism (neuronx-distributed )
Developer guide for Activation Memory reduction (neuronx-distributed )
Developer guide for Neuron-PT-Lightning (neuronx-distributed )
Developer guide for model and optimizer wrapper (neuronx-distributed )
Developer guide for save/load checkpoint (neuronx-distributed )
Developer guide for Neuronx-Distributed Inference (neuronx-distributed )
Tutorials (neuronx-distributed
)
Training GPT-NeoX 6.9B with Tensor Parallelism and ZeRO-1 Optimizer (neuronx-distributed )
Training GPT-NeoX 20B with Tensor Parallelism and ZeRO-1 Optimizer (neuronx-distributed )
Training Llama2 7B with Tensor Parallelism and ZeRO-1 Optimizer (neuronx-distributed )
Training Llama-2-13B/70B with Tensor Parallelism and Pipeline Parallelism (neuronx-distributed )
T5 inference tutorial [html] [notebook]
Llama inference tutorial [html] [notebook]
Misc (neuronx-distributed
)
This document is relevant for: Inf2
, Trn1
, Trn1n