This document is relevant for: Inf2, Trn1, Trn1n

Developer guide for save/load checkpoint#

This document will introduce how to use nxd.save_checkpoint and nxd.load_checkpoint to save and load checkpoint for distributed model training. This two methods handle all checkpoint in a single method: model, optimize, learning rate scheduler and any user contents.

Model states are saved on data parallel rank-0 only. When ZeRO-1 optimizer is not turned on, optimizer states are also saved like this; while when ZeRO-1 optimizer is turned on, states are saved on all ranks. Scheduler and user contents are saved on master rank only.

For a complete api guide, refer to API GUIDE.

Save checkpoint:#

A sample usage:

nxd.save_checkpoint(
    args.checkpoint_dir,  # checkpoint path
    tag=f"step_{total_steps}",  # tag, sub-directory under checkpoint path
    model=model,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    user_content={"total_steps": total_steps, "batch_idx": batch_idx, "cli_args": args.__dict__},
    use_xser=True,
    async_save=True,
)

Users can choose to not save every thing. For example, model states only:

nxd.save_checkpoint(
    args.checkpoint_dir,  # checkpoint path
    tag=f"step_{total_steps}",  # tag, sub-directory under checkpoint path
    model=model,
    use_xser=True,
    async_save=True,
)

To only keep several checkpoints (e.g. 5), just use num_kept_ckpts=5.

Load checkpoint:#

A sample usage, note that if no user contents detected, it will return None:

user_content = nxd.load_checkpoint(
    args.checkpoint_dir,  # checkpoint path
    tag=f"step_{args.loading_step}",  # tag
    model=model,
    optimizer=optimizer,
    scheduler=lr_scheduler,
)

Leave tag not provided, this loading method will try to automatically resume from the latest checkpoint.

user_content = nxd.load_checkpoint(
    args.checkpoint_dir,  # checkpoint path
    model=model,
    optimizer=optimizer,
    scheduler=lr_scheduler,
)

ZeRO-1 Optimizer State Offline Conversion:#

ZeRO-1 optimizer checkpoint are sharded states stored for each rank. When user want to load ZeRO-1 optimizer states with different cluster setting (e.g. with DP degree changed), they can run the offline ZeRO-1 optimizer checkpoint conversion tool. This tool supports conversion from sharded states to full states, from full to sharded, and from sharded to sharded.

This document is relevant for: Inf2, Trn1, Trn1n