Integrating a new dataset/dataloader#

In this section, we showcase how to integrate a new dataset/dataloader with the library.

Building Dataset module#

One can use the guide on PyTorch docs to create a Dataset class.

Building DataModule#

To configure the dataloader, one needs to create a DataModule class. Neuronx Distributed Training library provides a BaseDataModule which one can use to implement their new DataModule. Create a new file called new_data_module.py and add the following content.

from neuronx_distributed_training.lightning_modules.data.base import BaseDataModule

class NewDataModule(BaseDataModule):
    def __init__(self, cfg, trainer):
        """
        DataModule class for configuring the dataset/dataloader

        Args:
            cfg: `data` cfg in the yaml file.
            trainer: PyTorch-Lightning trainer.
        """
        super().__init__(cfg, trainer)
        # Users can use the cfg argument to pass down
        # arguments from the yaml file to the DataModule.


    def get_batch_length(self, batch):
        """
        Returns the length of the batch.
        """
        return len(batch["input_ids"])

    def process_global_batch(self, global_batch, global_batch_size=None):
        """ Any custom processing of batches can be done here.

        Args:
            global_batch: list of inputs, eg.[tokens, labels]
            global_batch_size: Length of tokens and labels
        """
        return global_batch

    def train_dataloader(self):
        """
        This API should return a torch.utils.data.dataloader.DataLoader object
        """
        ...

    def val_dataloader(self):
        """
        This API should return a torch.utils.data.dataloader.DataLoader object
        """
        ...

    def test_dataloader(self):
        """
        This API should return a torch.utils.data.dataloader.DataLoader object
        """
        ...

Plug into training.py#

Once the new data module is created, we can then plug this into the training.py script under examples folder. We can modify the training.py script as follows:

...
# Assuming we are using the same ModelModule we used for LLama example.
from new_data_module import NewDataModule
data_module = NewDataModule(cfg, trainer)
model = HFLLamaModule(cfg, trainer)

trainer.fit(model, datamodule=data_module)

The rest of the code can remain the same. The trainer will now use the NewDataModule for fetching the dataloader and run e2e training.

Create config file#

Next, we can create a config file under conf to be used for this new dataloader. We can start with a copy of hf_llama_7B_config.yaml. Let’s call this config file my_new_config.yaml. We can edit the data key to configure the DataModule

Note

For the model, we are using the same model that the llama example is using. To configure a new model, please check the Integrating a New Model section.

Launching e2e training#

We can now launch training using the new data_module. This can be done using the following command:

CONF=my_new_config.yaml ./train.sh