Tutorial: Deploying Llama4 Multimodal Models#

This guide shows how to deploy Llama4 on an AWS Neuron Trainium2 (Trn2) instance. This model supports both text and images. It uses Llama4 Scout (meta-llama/Llama-4-Scout-17B-16E) as the example model in this tutorial; however, Maverick (meta-llama/Llama-4-Maverick-17B-128E-Instruct) can also be used.

Examples#

Step 1: Set up your development environment#

As a prerequisite, this tutorial requires that you have a Trn2 instance created from a Deep Learning AMI that has the Neuron SDK pre-installed.

To set up a Trn2 instance using Deep Learning AMI with pre-installed Neuron SDK, see NxD Inference Setup Guide. To use a Jupyter Notebook (.ipynb) on a Neuron-enabled instance, see this guide.

After setting up an instance, use SSH to connect to the Trn2 instance using the key pair that you chose when you launched the instance.

After you are connected, activate the Python virtual environment that includes the Neuron SDK.

Step 2: Compile your model and save it as artifacts#

The code snippet below is required to compile Llama4 as artifacts to load for vLLM serving. There is no need to download a Llama4 checkpoint from HuggingFace explicitly, but you may need access from Meta to download it as part of compilation script. For more information, see Downloading models in the Hugging Face documentation.

Inneuron_config, to support multimodal architecture, you can define text_config and vision_config separately for text decoder and vision encoder.

The image input can be represented in 1, 4, or 16 chunks based on its resolution and aspect ratio. Additionally, there is one chunk to describe the entire image, resulting in the total number of chunks. Due to the use of data parallelism (DP) together with tensor parallelism (TP), the vision model input batch size is padded to the next value divisible by the DP degree, which in this case is 4. The final padded batch size will be:

  • 1+1 = 2 → 4: Each rank has the batch size = 4/4 = 1

  • 4+1 = 5 → 8: Each rank has the batch size = 8/4 = 2

  • 16+1 = 17 → 20: Each rank has the batch size = 20/4 = 5

There are a few fields you can configure to improve performance:

  • cp_degree: degree of context parallelism at the attention layer for prefill.

  • blockwise_matmul_config: the configuration of the blockwise MoE kernel for prefill.

  • attn_block_tkg_nki_kernel_enabled and attn_block_tkg_nki_kernel_cache_update to enable a NKI kernel for attention and a kernel KV cache update for decode operations.

[ ]:
scout_neuron_config = {
    "text_config": {
        "batch_size": 1,
        "is_continuous_batching": true,
        "seq_len": 16384,
        "enable_bucketing": true,
        "context_encoding_buckets": [256, 512, 1024, 2048, 4096, 8192, 10240, 16384],
        "token_generation_buckets": [256, 512, 1024, 2048, 4096, 8192, 10240, 16384],
        "torch_dtype": "float16",
        "async_mode": true,
        "world_size": 64,
        "tp_degree": 64,
        "cp_degree": 16,
        "on_device_sampling_config": {
            "dynamic": true,
            "top_k_kernel_enabled": true,
            "top_k": 1
        },
        "cast_type": "as-declared",
        "logical_neuron_cores": 2,
        "cc_pipeline_tiling_factor": 1,
        "sequence_parallel_enabled": true,
        "fused_qkv": true,
        "qkv_kernel_enabled": true,
        "attn_kernel_enabled": true,
        "attn_block_tkg_nki_kernel_enabled": true,
        "attn_block_tkg_nki_kernel_cache_update": true,
        "blockwise_matmul_config": {
            "block_size": 256,
            "use_block_parallel": true,
            "block_sharding_strategy": "HI_LO",
            "skip_dma_token": true,
            "skip_dma_weight": true,
            "parallelize_token_to_block_mapping": true
        }
    },
    "vision_config": {
        "batch_size": 1,
        "seq_len": 8192,
        "torch_dtype": "float16",
        "tp_degree": 16,
        "cp_degree": 1,
        "dp_degree": 4,
        "world_size": 64,
        "fused_qkv": true,
        "qkv_kernel_enabled": true,
        "attn_kernel_enabled": true,
        "mlp_kernel_enabled": true,
        "enable_bucketing": true,
        "buckets": [8, 28, 88],
        "logical_neuron_cores": 2,
        "save_sharded_checkpoint": true
    }
}

import argparse
import json

import torch
from neuronx_distributed_inference.models.config import OnDeviceSamplingConfig
from neuronx_distributed_inference.models.llama4.modeling_llama4 import NeuronLlama4ForCausalLM, Llama4InferenceConfig, Llama4NeuronConfig
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-path', type=str, required=True)
    parser.add_argument('--traced-model-path', type=str, required=True)
    parser.add_argument('--neuron-config-path', type=str, default=None)
    return parser.parse_args()


def build_config(neuron_config_path, model_path):
    with open(neuron_config_path, 'r') as f:
        config_json = json.load(f)
    text_neuron_config = Llama4NeuronConfig(**config_json['text_config'])
    vision_neuron_config = Llama4NeuronConfig(**config_json['vision_config'])
    return Llama4InferenceConfig(
        text_neuron_config=text_neuron_config,
        vision_neuron_config=vision_neuron_config,
        load_config=load_pretrained_config(model_path)
    )

def compile(model_path, traced_model_path, config):
    model = NeuronLlama4ForCausalLM(model_path, config)
    model.compile(traced_model_path)



args = parse_args()
config = build_config(args.neuron_config_path, args.model_path)
compile("meta-llama/Llama-4-Scout-17B-16E-Instruct",
    "/home/ubuntu/llama4/traced_models/Llama-4-Scout-17B-16E-Instruct",
    scount_neuron_config)

Step 3: Deploy with vLLM Inference#

We provide two examples to run Llama4 with vLLM:

  • Offline inference: you can provide prompts in a python script and execute it.

  • Online inference: you will serve the model in an online server and send requests.

Offline Example#

Prior to launching the vLLM server, you must trace the Llama4 model. Provide the trace model by setting the environment variable NEURON_COMPILED_ARTIFACTS.

[ ]:
import os
from vllm import LLM, SamplingParams

# Hugging Face authentication (replace with your token)
# from huggingface_hub import login
# login(token="your_hf_token_here")

# Configure Neuron environment for inference
os.environ['VLLM_NEURON_FRAMEWORK'] = "neuronx-distributed-inference"
os.environ['NEURON_COMPILED_ARTIFACTS'] = "/home/ubuntu/llama4/traced_models/Llama-4-Scout-17B-16E-Instruct"

IMAGE_URL = "https://httpbin.org/image/png"

# Initialize LLM with Neuron device configuration
llm = LLM(
    model="meta-llama/Llama-4-Scout-17B-16E-Instruct",  # or the file path to the downloaded checkpoint
    max_num_seqs=1,
    max_model_len=16384,
    device="neuron",
    tensor_parallel_size=64,
    use_v2_block_manager=True,
    limit_mm_per_prompt={"image": 5}, # Accepts up to 5 images per prompt
)
# Configure sampling for deterministic output
sampling_params = SamplingParams(top_k=1, max_tokens=100)

# Test 1: Text-only input
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "what is the recipe of mayonnaise in two sentences?"},
        ]
    }
]
for output in llm.chat(conversation, sampling_params):
    print(f"Generated text: {output.outputs[0].text !r}")

# Test 2: Single image with text
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image_url", "image_url": {"url": IMAGE_URL}},
            {"type": "text", "text": "Describe this image"},
        ]
    }
]
for output in llm.chat(conversation, sampling_params):
    print(f"Generated text: {output.outputs[0].text !r}")

# Test 3: Multiple images with text
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image_url", "image_url": {"url": IMAGE_URL}},
            {"type": "image_url", "image_url": {"url": IMAGE_URL}},
            {"type": "text", "text": "Compare these two images, tell me the difference."},
        ]
    }
]
for output in llm.chat(conversation, sampling_params):
    print(f"Generated text: {output.outputs[0].text !r}")

Below is an example output:

Generated text: 'To make mayonnaise, combine 2 egg yolks, 1 tablespoon of lemon juice or vinegar, and a pinch of salt in a bowl, and whisk them together until smooth. Then, slowly pour in 1/2 cup of oil while continuously whisking the mixture until it thickens and emulsifies into a creamy sauce.'
Generated text: "The image depicts a cartoon-style illustration of a pig's face, characterized by its pink color and endearing expression. The pig features two small black eyes with white outlines, a curved smile, and two small nostrils on its snout. Two red circles adorn the cheeks, adding to the pig's rosy appearance.\n\n**Key Features:**\n\n* **Color:** Pink\n* **Facial Expression:** Smiling\n* **Eyes:** Small, black, with white outlines\n* **Sn"
Generated text: "The two images are identical, with no discernible differences. The only variation is a slight difference in the shade of pink used for the pig's face, but this could be due to different rendering or display settings rather than an actual difference in the images themselves.\n\n**Key Features:**\n\n* Both images feature a cartoon-style pig's head with a smiling face.\n* The pig has two small ears, two eyes, and a curved smile.\n* The background of both images is white.\n\n**Conclusion:**\nGiven"

Online Example#

Prior to launching the Vllm server, you must trace the llama4 model, with the traced model path provided through the environment variable NEURON_COMPILED_ARTIFACTS.

Open a terminal and spin up a server of the model. To accommodate multiple image inputs, include the optional argument –limit-mm-per-prompt

[ ]:
%%bash
export VLLM_NEURON_FRAMEWORK="neuronx-distributed-inference"
export NEURON_COMPILED_ARTIFACTS="/home/ubuntu/llama4/traced_models/Llama-4-Scout-17B-16E-Instruct/"
export VLLM_RPC_TIMEOUT=100000
nohup python -m vllm.entrypoints.openai.api_server \
    --model "meta-llama/Llama-4-Scout-17B-16E-Instruct" \
    --max-num-seqs 1 \
    --max-model-len 16384 \
    --tensor-parallel-size 64 \
    --device neuron \
    --port 8000 \
    --use-v2-block-manager \
    --disable-log-requests \
    --override-neuron-config '{}' \
    --limit-mm-per-prompt image=5

...
INFO:     Started server process [25218]
INFO:     Waiting for application startup.
INFO:     Application startup complete.

Open another terminal and execute the following client code with python:

[ ]:
from openai import OpenAI

MODEL = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

client = OpenAI(
    api_key = "EMPTY",
    base_url = "http://localhost:8000/v1"
)

print("== Test text input ==")
completion = client.chat.completions.create(
    model=MODEL,
    messages=[{
        "role": "user",
        "content": [
            {"type": "text", "text": "what is the recipe of mayonnaise in two sentences?"},
        ]
    }]
)
print(completion.choices[0].message.content)


print("== Test image input ==")
completion = client.chat.completions.create(
    model=MODEL,
    messages=[{
        "role": "user",
        "content": [
            {"type": "image_url", "image_url": {"url": "https://httpbin.org/image/png"}},
            {"type": "text", "text": "Describe this image"},
        ]
    }]
)
print(completion.choices[0].message.content)


print("== Test multiple image inputs ==")
completion = client.chat.completions.create(
    model=MODEL,
    messages=[{
        "role": "user",
        "content": [
            {"type": "image_url", "image_url": {"url": "https://httpbin.org/image/png"}},
            {"type": "image_url", "image_url": {"url": "https://httpbin.org/image/png"}},
            {"type": "text", "text": "Compare these two images, tell me the difference."},
        ]
    }]
)
print(completion.choices[0].message.content)

Below is an example output:

== Test text input ==
To make mayonnaise, combine 2 egg yolks, 1 tablespoon of lemon juice or vinegar, and a pinch of salt in a bowl, and whisk them together until smooth. Then, slowly pour in 1/2 cup of oil while continuously whisking the mixture until it thickens and emulsifies into a creamy sauce.
== Test image input ==
The image depicts a cartoon-style illustration of a pig's face, characterized by its pink color and endearing expression. The pig features two small black eyes with white outlines, a curved smile, and two small nostrils on its snout. Two red circles adorn the cheeks, adding to the pig's rosy appearance.

**Key Features:**

* **Ears:** Two triangular ears are positioned at the top of the head.
* **Facial Expression:** The pig's facial expression is cheerful, with a smile and rosy cheeks.
* **Background:** The background of the image is transparent.

Overall, the image presents a cute and friendly cartoon pig face.
== Test multiple image inputs ==
The two images are identical, featuring a cartoon pig's face with a pink color and black outline. The only difference is that the first image has a lighter shade of pink compared to the second image.

**Key Features:**

* Both images depict a cartoon pig's face.
* They have the same facial features, including eyes, nose, mouth, and ears.
* The background of both images is white.

**Color Comparison:**

* The first image has a lighter pink color (RGB: 255, 182, 193).
* The second image has a slightly darker pink color (RGB: 240, 128, 128).

Overall, while the two images appear similar at first glance, they differ slightly in terms of their pink hue.