Tutorial: Multi-LoRA serving for Llama-3.1-8B on Trn2 instances#

NeuronX Distributed (NxD) Inference supports multi-LoRA serving. This tutorial provides a step-by-step guide for multi-LoRA serving with Llama-3.1-8B as the base model on a Trn2 instance. It describes two different ways of running multi-LoRA serving with NxD Inference directly and through vLLM (with NxD Inference) We will use LoRA adapters downloaded from HuggingFace as examples for serving.

Prerequisites#

Set up and connect to a Trn2.48xlarge instance#

As a prerequisite, this tutorial requires that you have a Trn2 instance created from a Deep Learning AMI that has the Neuron SDK pre-installed.

To set up a Trn2 instance using Deep Learning AMI with pre-installed Neuron SDK, see NxD Inference Setup Guide. To use Jupyter Notebook on the Neuron instance, you can use this guide.

After setting up an instance, use SSH to connect to the Trn2 instance using the key pair that you chose when you launched the instance.

After you are connected, activate the Python virtual environment that includes the Neuron SDK.

source ~/aws_neuronx_venv_pytorch_2_5_nxd_inference/bin/activate

Run pip list to verify that the Neuron SDK is installed.

pip list | grep neuron

You should see Neuron packages including neuronx-distributed-inference and neuronx-cc.

Install Packages#

NxD Inference supports running models with vLLM. This functionality is available in the AWS Neuron fork of the vLLM GitHub repository. Install the latest release branch of vLLM from the AWS Neuron fork following instructions in the vLLM User Guide for NxD Inference.

Download base model and LoRA adapters#

To use this sample, you must first download a Llama-3.1-8B-Instruct model checkpoint from Hugging Face to a local path on the Trn2 instance. Note that you may need access from Meta for model download. For more information, see Downloading models in the Hugging Face documentation.

You must download LoRA adapters from Hugging Face for multi-LoRA serving. As examples, you can download nvidia/llama-3.1-nemoguard-8b-topic-control, reissbaker/llama-3.1-8b-abliterated-lora, Stefano-M/aixpa_amicifamiglia_short_prompt, and GaetanMichelet/Llama-31-8B_task-2_180-samples_config-2. Suppose these LoRA adapters are saved in /home/ubuntu/lora_adapters/.

Using vLLM V1 for multi-LoRA serving on Trn2#

You will run multi-LoRA serving on Trn2 with vLLM V1 using Llama-3.1-8b-instruct and four LoRA adapters, two are preloaded in HBM during model initialization and the four adapters are loaded in host memory. The data type is bfloat16 precision. Please refer to vLLM User Guide for NxD Inference for more details on how to run model inference on TRN2 with vLLM V1.

Multi-LoRA Configurations#

You should specifically set the following configurations when enabling multi-LoRA serving with vLLM V1.

  • enable_lora - The flag to enable multi-LoRA serving in NxD Inference. Defaults to False.

  • max_loras - The maximum number of concurrent LoRA adapters in device memory.

  • max_cpu_loras - The maximum number of concurrent LoRA adapters in host memory.

  • max_lora_rank - The highest LoRA rank that needs to be supported. Defaults to 16. If it is not specified, the maximum LoRA rank of the LoRA adapter checkpoints will be used.

  • lora-ckpt-json - The the path of JSON file that describes the mappings for the adapter IDs and their checkpoint paths. It includes three fields:

    • lora-ckpt-dir - The directory of the LoRA adapters.

    • lora-ckpt-paths - The mapping between LoRA adapter IDs on HBM and their checkpoint paths at initialization. Note that they might be evicted at runtime.

    • lora-ckpt-paths-cpu - The mapping between LoRA adapter IDs and their checkpoints on CPU.

Here is an example of the JSON file:

{
    "lora-ckpt-dir": "/home/ubuntu/lora_adapters/",
    "lora-ckpt-paths": {
        "lora_id_1": "llama-3.1-nemoguard-8b-topic-control",
        "lora_id_2": "llama-3.1-8b-abliterated-lora"
    },
    "lora-ckpt-paths-cpu": {
        "lora_id_1": "llama-3.1-nemoguard-8b-topic-control",
        "lora_id_2": "llama-3.1-8b-abliterated-lora",
        "lora_id_3": "aixpa_amicifamiglia_short_prompt",
        "lora_id_4": "Llama-31-8B_task-2_180-samples_config-2"
    }
}

Offline inference example#

You can run multi-LoRA serving offline on TRN2 with vLLM V1.

[ ]:
import os
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

MODEL_PATH="/home/ubuntu/model_hf/llama-3.1-8b-instruct/"
# Replace this with the path where you saved the JSON file. Refer to the NxD Inference script for the JSON format.
LORA_CKPT_JSON="/home/ubuntu/lora_adapters/lora_adapters.json"
# This is where the compiled model will be saved.
COMPILED_MODEL_PATH="/home/ubuntu/traced_model/llama-3.1-8B-Lora/"
os.environ["NEURON_COMPILED_ARTIFACTS"] = (COMPILED_MODEL_PATH)
os.environ["VLLM_USE_V1"] = "1"

# Sample prompts.
prompts = [
    "The president of the United States is",
    "The capital of France is",
]

# Create a sampling params object.
sampling_params = SamplingParams(top_k=1)
override_neuron_config = {
    "skip_warmup": True,
    "lora_ckpt_json": LORA_CKPT_JSON,
}

# Create an LLM with multi-LoRA serving.
llm = LLM(
    model=MODEL_PATH,
    max_num_seqs=2,
    max_model_len=64,
    tensor_parallel_size=32,
    additional_config={
        "override_neuron_config": override_neuron_config
    },
    enable_lora=True,
    max_loras=2,
    max_cpu_loras=4,
    enable_prefix_caching=False,
    enable_chunked_prefill=False,
)
"""
Only the lora_name needs to be specified.
The lora_id and lora_path are supplied at the LLM class/server initialization, after which the paths are
handled by NxD Inference.
"""
# lora_id_1 is in HBM
lora_req_1 = LoRARequest("lora_id_1", 1, " ")
# lora_id_3 is in host memory and it will be dynamically swapped to HBM at runtime
lora_req_2 = LoRARequest("lora_id_3", 2, " ")
outputs = llm.generate(prompts, sampling_params, lora_request=[lora_req_1, lora_req_2])

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Run multi-LoRA serving with model quantization#

To enable multi-LoRA serving with the base model quantized, you must pass some quantization-related arguments to vLLM. For example, you can add the following arguments to override_neuron_config.

[ ]:
quantization_config = {
    "quantized": True,
    "quantized_checkpoints_path": os.path.join(COMPILED_MODEL_PATH, "model_quant.pt"),
    "quantization_type": "per_channel_symmetric",
}
# Add quantization config to override_neuron_config
override_neuron_config.update(quantization_config)

Online Server Example#

You can also run online multi-LoRA serving on TRN2 with vLLM V1. Save the contents of the below script to another shell script file, for example, start_vllm.sh and then run it.

[ ]:
%%writefile start_vllm.sh
#!/bin/bash

echo "Running vLLM server in the background..."
rm -f ./vllm_server.log

# These should be the same paths used when compiling the model.
MODEL_PATH="/home/ubuntu/model_hf/llama-3.1-8b-instruct/"
# Replace this with the path where you saved the JSON file. Refer to the NxD Inference script for the JSON format.
LORA_CKPT_JSON="/home/ubuntu/lora_adapters/lora_adapters.json"
# This is where the compiled model will be saved.
COMPILED_MODEL_PATH="/home/ubuntu/traced_model/llama-3.1-8B-Lora/"
# Replace this with the path where you saved the LoRA adapters
LORA_ADAPTER_DIR="/home/ubuntu/lora_adapters"
# Set lora_modules to register LoRA adapters during multi-LoRA serving
LORA_MODULES="lora_id_1=${LORA_ADAPTER_DIR}/llama-3.1-nemoguard-8b-topic-control "
LORA_MODULES+="lora_id_2=${LORA_ADAPTER_DIR}/llama-3.1-8b-abliterated-lora "
LORA_MODULES+="lora_id_3=${LORA_ADAPTER_DIR}/aixpa_amicifamiglia_short_prompt "
LORA_MODULES+="lora_id_4=${LORA_ADAPTER_DIR}/Llama-31-8B_task-2_180-samples_config-2 "

export NEURON_COMPILED_ARTIFACTS=$COMPILED_MODEL_PATH
VLLM_RPC_TIMEOUT=100000
nohup python -m vllm.entrypoints.openai.api_server \
    --model $MODEL_PATH \
    --max-num-seqs 2 \
    --max-model-len 64 \
    --tensor-parallel-size 32 \
    --disable-log-requests \
    --no-enable-chunked-prefill \
    --no-enable-prefix-caching \
    --enable-lora \
    --max-loras 2 \
    --override-neuron-config "{"sequence_parallel_enabled": false, "lora_modules": {"lora_id_1": "${LORA_PATH_1}", "lora_id_2": "${LORA_PATH_2}"}}" \
    --lora-modules lora_id_1=${LORA_PATH_1} lora_id_2=${LORA_PATH_2} \
    --port 8000 > ./vllm_server.log 2>&1 &

SERVER_PID=$!

echo "Server started in the background with the following id: $SERVER_PID. Waiting until server is ready to serve..."

until grep -q "Server is ready to serve" ./vllm_server.log 2>/dev/null || ! kill -0 $SERVER_PID 2>/dev/null; do sleep 0.5; done
grep -q "Server is ready to serve" ./vllm_server.log 2>/dev/null && echo "vLLM Server is ready!" || (echo "vLLM Server failed, check the ./vllm_server.log file" && exit 1)
[ ]:
!chmod +x ./start_vllm.sh
!./start_vllm.sh

After the vLLM server is launched, you can check the registered LoRA adapters in the vLLM server.

[ ]:
%%bash
curl http://localhost:8000/v1/models | jq

You can send requests to the server for serving with the model field as one of the registered LoRA adapter IDs. Two sample requests are:

[ ]:
%%bash
# request LoRA adapter in HBM
curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "lora_id_1",
        "prompt": "The president of the United States is",
        "max_tokens": 32,
        "temperature": 0
    }' | jq

# request LoRA adapter in host memory with dynamic swap
curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "lora_id_3",
        "prompt": "The capital of France is",
        "max_tokens": 32,
        "temperature": 0
    }' | jq