This document is relevant for: Inf2
, Trn1
, Trn1n
Developer guide for save/load checkpoint#
This document will introduce how to use nxd.save_checkpoint and nxd.load_checkpoint to save and load checkpoint for distributed model training. This two methods handle all checkpoint in a single method: model, optimize, learning rate scheduler and any user contents.
Model states are saved on data parallel rank-0 only. When ZeRO-1 optimizer is not turned on, optimizer states are also saved like this; while when ZeRO-1 optimizer is turned on, states are saved on all ranks. Scheduler and user contents are saved on master rank only.
For a complete api guide, refer to API GUIDE.
Save checkpoint:#
A sample usage:
nxd.save_checkpoint(
args.checkpoint_dir, # checkpoint path
tag=f"step_{total_steps}", # tag, sub-directory under checkpoint path
model=model,
optimizer=optimizer,
scheduler=lr_scheduler,
user_content={"total_steps": total_steps, "batch_idx": batch_idx, "cli_args": args.__dict__},
use_xser=True,
async_save=True,
)
Users can choose to not save every thing. For example, model states only:
nxd.save_checkpoint(
args.checkpoint_dir, # checkpoint path
tag=f"step_{total_steps}", # tag, sub-directory under checkpoint path
model=model,
use_xser=True,
async_save=True,
)
To only keep several checkpoints (e.g. 5), just use num_kept_ckpts=5
.
Load checkpoint:#
A sample usage, note that if no user contents detected, it will return None
:
user_content = nxd.load_checkpoint(
args.checkpoint_dir, # checkpoint path
tag=f"step_{args.loading_step}", # tag
model=model,
optimizer=optimizer,
scheduler=lr_scheduler,
)
Leave tag
not provided, this loading method will try to automatically resume from the
latest checkpoint.
user_content = nxd.load_checkpoint(
args.checkpoint_dir, # checkpoint path
model=model,
optimizer=optimizer,
scheduler=lr_scheduler,
)
ZeRO-1 Optimizer State Offline Conversion:#
ZeRO-1 optimizer checkpoint are sharded states stored for each rank. When user want to load ZeRO-1 optimizer states with different cluster setting (e.g. with DP degree changed), they can run the offline ZeRO-1 optimizer checkpoint conversion tool. This tool supports conversion from sharded states to full states, from full to sharded, and from sharded to sharded.
This document is relevant for: Inf2
, Trn1
, Trn1n