Integrating a new dataset/dataloader#

In this section, we showcase how to integrate a new dataset/dataloader with the library.

Building Dataset module#

One can use the guide on PyTorch docs to create a Dataset class.

Building DataModule#

To configure the dataloader, one needs to create a DataModule class. Neuronx Distributed Training library provides a BaseDataModule which one can use to implement their new DataModule. Create a new file called and add the following content.

from import BaseDataModule

class NewDataModule(BaseDataModule):
    def __init__(self, cfg, trainer):
        DataModule class for configuring the dataset/dataloader

            cfg: `data` cfg in the yaml file.
            trainer: PyTorch-Lightning trainer.
        super().__init__(cfg, trainer)
        # Users can use the cfg argument to pass down
        # arguments from the yaml file to the DataModule.

    def get_batch_length(self, batch):
        Returns the length of the batch.
        return len(batch["input_ids"])

    def process_global_batch(self, global_batch, global_batch_size=None):
        """ Any custom processing of batches can be done here.

            global_batch: list of inputs, eg.[tokens, labels]
            global_batch_size: Length of tokens and labels
        return global_batch

    def train_dataloader(self):
        This API should return a object

    def val_dataloader(self):
        This API should return a object

    def test_dataloader(self):
        This API should return a object

Plug into

Once the new data module is created, we can then plug this into the script under examples folder. We can modify the script as follows:

# Assuming we are using the same ModelModule we used for LLama example.
from new_data_module import NewDataModule
data_module = NewDataModule(cfg, trainer)
model = HFLLamaModule(cfg, trainer), datamodule=data_module)

The rest of the code can remain the same. The trainer will now use the NewDataModule for fetching the dataloader and run e2e training.

Create config file#

Next, we can create a config file under conf to be used for this new dataloader. We can start with a copy of hf_llama_7B_config.yaml. Let’s call this config file my_new_config.yaml. We can edit the data key to configure the DataModule


For the model, we are using the same model that the llama example is using. To configure a new model, please check the Integrating a New Model section.

Launching e2e training#

We can now launch training using the new data_module. This can be done using the following command:

CONF=my_new_config.yaml ./