This document is relevant for: Inf2, Trn1, Trn1n

Developer guide for LoRA finetuning#

This document will introduce how to enable model finetuning with LoRA.

For a complete api guide, refer to API.

Enable LoRA finetuning:#

We first set up LoRA-related configurations:

lora_config = nxd.modules.lora.LoraConfig(
    enable_lora=True,
    lora_rank=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    lora_verbose=True,
    target_modules=["q_proj", "v_proj", "k_proj"],
    save_lora_base=False,
    merge_lora=False,
)

The default target modules for different model architectures can be found in model.py.

We then initialize NxD model with LoRA enabled:

nxd_config = nxd.neuronx_distributed_config(
    ...
    lora_config=lora_config,
)
model = nxd.initialize_parallel_model(nxd_config, ...)

Save LoRA checkpoint#

Users can save the LoRA adapter with

nxd.save_checkpoint(
    checkpoint_dir_str=checkpoint_dir, # checkpoint path
    tag=tag,     # sub-directory under checkpoint path
    model=model
)

Because save_lora_base=False and merge_lora=False, only the LoRA adapter is saved under checkpoint_dir/tag/. We can also set merge_lora=True to save the merged model, i.e., merging LoRA adapter into the base model.

Load LoRA checkpoint:#

A sample usage:

lora_config = LoraConfig(
    enable_lora=True,
    load_lora_from_ckpt=True,
    lora_save_dir=checkpoint_dir,  # checkpoint path
    lora_load_tag=tag,  # sub-directory under checkpoint path
)
nxd_config = nxd.neuronx_distributed_config(
    ...
    lora_config=lora_config,
)
model = nxd.initialize_parallel_model(nxd_config, ...)

The NxD model with be initialized with LoRA enabled and LoRA weights loaded. LoRA-related configurations are the same as the LoRA adapter checkpoint.

This document is relevant for: Inf2, Trn1, Trn1n