This document is relevant for: Inf2, Trn1, Trn1n

Developer guide for Neuron-PT-Lightning (neuronx-distributed )#

Training#

For training models with Neuron-PT-Lightning, user needs to make few changes to their model/training script. In this document we explain how we can train a model using Tensor Parallelism (TP), Data Parallelism (DP) and Zero-1.

First, let’s start with the model changes. Please follow the guidelines here (tensor parallel guidance) for building the model with tensor-parallelism enabled and setting up training dataset.

Next, let’s walkthrough how we can build the training loop with Neuron-PT-Lightning APIs

Configure NeuronLTModule#

NeuronxDistributed overrides LightningModule with built-in support for Neuron device. User needs to inherit from NeuronLTModule

class NeuronLlamaLTModule(NeuronLTModule):
    def training_step(self, batch, batch_idx):
        ...
    ...

Within LTModule, user needs to override the following methods training_step At this moment NeuronLTModule only support manual optimization, so user needs to define forward, backward and optimization steps

def training_step(self, batch, batch_idx):
    xm.mark_step() # Isolate forward+backward graph
    for logger in self.trainer.loggers:
        logger.print_step = -1
    self.should_print = False
    outputs = self.model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=batch["labels"],
    )
    loss = outputs.loss / self.grad_accum_steps
    loss.backward()
    self.averaged_loss += loss.detach()
    xm.mark_step() # Isolate forward+backward graph
    if not self.automatic_optimization and (batch_idx +1) % self.grad_accum_steps == 0:
        self.should_print = True
        loss_div = self.averaged_loss / self.trainer.strategy.data_parallel_size
        loss_reduced = xm.all_reduce(
            xm.REDUCE_SUM,
            loss_div,
            groups=parallel_state.get_data_parallel_group(as_list=True),
        )
        loss_reduced_detached = loss_reduced.detach()
        self.averaged_loss.zero_()
        optimizer = self.optimizers()
        scheduler = self.lr_schedulers()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        xm.mark_step() # Isolate Optimization step graph

        # Setup items for logging
        self.loss = loss_reduced_detached
    return loss

configure_optimizers Configure optimizer and lr_scheduler

def configure_optimizers(self):
    param_groups = self.get_param_groups_by_weight_decay()
    optimizer = initialize_parallel_optimizer(
        self.nxd_config, self.opt_cls, param_groups, **self.opt_kwargs
    )
    optimizer.zero_grad()
    scheduler = self.scheduler_cls(optimizer, *self.scheduler_args, **self.scheduler_kwargs)
    return (
        [optimizer],
        [
            {
                "scheduler": scheduler,
            }
        ],
    )

on_train_batch_end Customized behaviour at the end of each training batch, like logging

def on_train_batch_end(self, *args, **kwargs):
    if self.should_print:
        if not self.automatic_optimization:
            self.log(
                "loss",
                self.loss.detach().cpu().item() if self.loss is not None else torch.zeros(1, device="cpu", requires_grad=False),
                prog_bar=True,
            )
            self.log(
                "global_step",
                self.global_step,
                prog_bar=True,
                on_step=True,
                on_epoch=True,
            )
            for logger in self.trainer.loggers:
                logger.print_step = self.global_step

Note that NeuronLTModule has a built-in function of get_param_groups_by_weight_decay for common use case as shown in snippet below, users can also override with their own param_groups generation.

def get_param_groups_by_weight_decay(self):
    """Get param groups. Customers can override this to have their own way of weight_decay"""
    param_optimizer = list(self.model.named_parameters())
    no_decay = ["bias", "LayerNorm"]  # gamma/beta are in LayerNorm.weight

    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    return optimizer_grouped_parameters

Configure DataModule#

Create a LightningDataModule for data loading/sampling

class NeuronLightningDataModule(LightningDataModule):
    def __init__(
        self,
        dataloader_fn: Callable,
        data_dir: str,
        batch_size: int,
        data_args: Tuple = (),
        data_kwargs: Dict = {},
    ):
        super().__init__()
        self.dataloader_fn = dataloader_fn
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.data_args = data_args,
        self.data_kwargs = data_kwargs


    def setup(self, stage: str):
        pass

    def train_dataloader(self):
        return self.dataloader_fn(
            self.data_dir,
            self.batch_size,
            self.trainer.strategy.data_parallel_size,
            self.trainer.strategy.data_parallel_rank,
            *self.data_args,
            **self.data_kwargs
        )

Update Training Script#

For detailed introduction to each api/class, check api guide

Create NeuronLTModule and DataModule#

model = NeuronLlamaLTModule(
    model_fn = LlamaForCausalLM,
    nxd_config = nxd_config,
    model_args = (model_config,),
    opt_cls = optimizer_cls,
    scheduler_cls = configure_scheduler,
    opt_kwargs = {
        "lr": flags.lr,
    },
    scheduler_args = (flags.warmup_steps, flags.max_steps),
    grad_accum_steps = flags.grad_accum_usteps,
    manual_opt = True,
)

dm = NeuronLightningDataModule(
    create_llama_pretraining_dataset,
    flags.data_dir,
    flags.batch_size,
    data_args = (flags.seed,),
)

Add Strategy, Plugins, Callbacks#

strategy = NeuronXLAStrategy(
    nxd_config = nxd_config
)
plugins = []
plugins.append(NeuronXLAPrecisionPlugin())
callbacks = []
callbacks.append(NeuronTQDMProgressBar())

Create Trainer and Start Training#

trainer = Trainer(
    strategy = strategy,
    max_steps = flags.steps_this_run,
    plugins = plugins,
    enable_checkpointing = flags.save_checkpoint,
    logger = NeuronTensorBoardLogger(save_dir=flags.log_dir),
    log_every_n_steps = 1,
    callbacks = callbacks,
)
trainer.fit(model=model, datamodule=dm)

Checkpointing#

To enable checkpoint saving, add ModelCheckpoint to the callbacks

callbacks.append(
    ModelCheckpoint(
        save_top_k = flags.num_kept_checkpoint,
        monitor="global_step",
        mode="max",
        every_n_train_steps = flags.checkpoint_freq,
        dirpath = flags.checkpoint_dir,
    )
)

To load from specific checkpoint, add ckpt_path=ckpt_path to trainer.fit

trainer.fit(model=model, datamodule=dm, ckpt_path=ckpt_path)

This document is relevant for: Inf2, Trn1, Trn1n