This document is relevant for: Inf2
, Trn1
, Trn1n
Developer guide for Tensor Parallelism#
Training#
For training models with tensor-parallelism, one would have to make few changes to their model/training script. Below we walk through the different changes one would have to make to shard the models across devices.
Creating DataLoader:#
When we shard the model across devices using tensor parallelism, all the
tensor parallel workers are operating on the same batch of data. Hence,
to ensure that each tensor parallel worker is getting the same data, we
make use of DistributedSampler
as shown in the snippet below
def create_pretraining_dataset(
input_file, max_pred_length, mini_batch_size, worker_init
):
train_data = pretraining_dataset(
input_file=input_file, max_pred_length=max_pred_length
)
# To distribute the data across different workers in the world,
# we use the DistributedSampler. The num_replicas should be equal
# to the data_parallel_world_size. Note: data_parallel_rank=0 can have
# multiple tensor parallel ranks and each of these should get the same
# data.
train_sampler = DistributedSampler(
train_data,
num_replicas=parallel_state.get_data_parallel_world_size(),
rank=parallel_state.get_data_parallel_rank(),
)
train_dataloader = DataLoader(
train_data,
sampler=train_sampler,
batch_size=mini_batch_size,
num_workers=0,
worker_init_fn=worker_init,
drop_last=True,
pin_memory=True,
)
return train_dataloader
Creating Model:#
One can create models by replacing the large linear layers with
ColumnParallel
and RowParallel
Linear layers. In case of
transformers, we have a good structure where the Attention block usually
have linear projections for QKV and this is followed by a fully
connected layer. Let’s take a look at the example for the BERT model. We
make the attention module of BERT model to use tensor parallel layers,
thereby adding the ability to shard the model across devices.
class ParallelSelfAttention(transformers.models.bert.modeling_bert.BertSelfAttention):
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type)
self.query = ColumnParallelLinear(config.hidden_size,
self.all_head_size,
gather_output=False)
self.key = ColumnParallelLinear(config.hidden_size,
self.all_head_size,
gather_output=False)
self.value = ColumnParallelLinear(config.hidden_size,
self.all_head_size,
gather_output=False)
# Since we shard the number of attention heads across tensor parallel
# ranks, each rank would have a subset of heads, hence, we update
# the num_attention_heads here.
tp_size = parallel_state.get_tensor_parallel_size()
self.num_attention_heads = self.num_attention_heads // tp_size
self.all_head_size = self.all_head_size // tp_size
As seen we just had to swap out the linear layers with ColumnParallel Linear layers and the rest of the forward method of the attention layer can work as is. Note: In the above ColumnParallelLinear layer we are not gathering output from each rank, in other words, each ranks is working on its own shard. We can make gather_output=True and that would gather output and you would get a full dim output. However, gathering output from all ranks would introduce an all-gather operation which can be expensive depending on the size of the tensor. In the case of attention module, we know that the SelfAttention block is followed by MLP block. Hence, we replace the linear layer there with a RowParallelLinear as shown below:
class ParallelSelfOutput(transformers.models.bert.modeling_bert.BertSelfOutput):
def __init__(self, config):
super().__init__(config)
self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True)
As seen we just had to replace the dense layer here, and pass the
input_is_parallel
argument. This way, the RowParallelLinear
should operator on partitions and get a collective result.
Making just the above two changes can help you partition good chunk of your model across multiple workers, thereby allowing models of larger size to be trained on a single instance. Note: Majority of the parameters of a transformer model are in these linear layers and hence partitioning these layers can help you scale.
Final Training script:#
Once the dataloader and model changes are done, we are ready to build the training script. Good news, you can use the same training loop as before for data-parallel training, and would need just the minor tweaks to get it all started.
from neuronx_distributed.parallel_layers import parallel_state, clip_grad_norm
neuronx_distributed.parallel_state.initialize_model_parallel(tensor_model_parallel_size=2)
dataloader = create_pretraining_dataset(
input_file, max_pred_length, mini_batch_size, worker_init)
model = YourNewlyBuiltParallelModel(config)
# We have to move the model to device using this API, because when
# we move model to device using .to(device), the model parameter's
# attributes aren't preserved. This causes some of the tensor parallel
# attributes to be lost. Hence, this API takes care of preserving the
# tensor parallel attributes.
parallel_layers.move_model_to_device(model, device)
for inputs, labels in dataloader:
output = model(*inputs)
loss = loss_fn(output, labels)
loss.backward()
# Here we use clip_grad_norm from neuronx_distributed as that
# can handle tensor parallel ranks
clip_grad_norm(model.parameters(), max_norm)
# For the optimzer step, we have to pass the data_parallel group
xm.optimizer_step(
optimzer,
groups=parallel_state.get_data_parallel_group(as_list=True)
)
optimizer.zero_grad()
scheduler.step()
Few things to take note of in the above code snippet: 1. We are
initializing the model parallel with tensor parallel size of 2. This
will shard the model across 2 devices. 2. We use the
move_model_to_device
API to move model to device. This is equivalent
to doing model.to(device)
. We need to explicity call this API since
some of the tensor-parallel attributes do not get copied over when we
move the model to device using model.to(device)
. 3. We are calling
the clip_grad_norm
from parallel_layers
. This clip_grad_norm
should take care of accumulating the max_norm from the tensor_parallel
ranks and producing the correct output. 4. We pass the
data_parallel_group
to the optimizer_step
. If we don’t pass the
group, default would be all the workers in the world.
Saving Model:#
Once training is done, we want to save the model. This can be done
easily by calling the save api from
neuronx_distributed.parallel_layers
. Here is an example:
neuronx_distributed.parallel_layers.save({
'epoch': epoch,
'model': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
Note the model
key used here, we need to provide the same key during
model load.
This document is relevant for: Inf2
, Trn1
, Trn1n