vLLM User Guide for NxD Inference#
vLLM is a popular library for LLM inference and serving utilizing advanced inference features such as continuous batching.
This guide describes how to utilize AWS Inferentia and AWS Trainium AI accelerators in vLLM by using NxD Inference (neuronx-distributed-inference).
Overview#
NxD Inference integrates with vLLM by using vLLM’s Plugin System to extend the model execution components responsible for loading and invoking models within vLLM’s LLMEngine (see vLLM architecture for more details). This means input processing, scheduling and output processing follow the default vLLM behavior.
Supported Models#
Following models are supported on vLLM with NxD Inference:
Llama 2/3.1/3.3
Llama 4 Scout, Maverick
Qwen 2.5
Qwen 3
If you are adding your own model to NxD Inference, see Integrating Onboarded Model with vLLM.
Setup#
Prerequisite: Launch an instance and install drivers and tools#
Before installing vLLM with the instructions below, you must launch an Inferentia or Trainium instance and install the necessary Neuron SDK dependency libraries. Refer to these setup instructions for different ways to prepare your environment, including using Neuron DLAMIs and Neuron DLCs for quick setups.
Prerequisites:
Latest AWS Neuron SDK (Neuron SDK 2.26.1)
Python 3.8+ (compatible with vLLM requirements)
Supported AWS instances: Inf2, Trn1/Trn1n, Trn2
Installing the AWS Neuron fork of vLLM#
Neuron maintains a vLLM-Neuron Plugin that supports the latest features for NxD Inference. Follow the instructions below to obtain and configure it.
Quickstart using Docker#
You can use a preconfigured Deep Learning Container (DLC) with the AWS vLLM-Neuron plugin pre-installed. Refer to the vllm-inference-neuronx container on aws-neuron/deep-learning-containers to get started.
For a complete step-by-step tutorial on deploying the vLLM Neuron DLC, see Quickstart: Configure and deploy a vLLM server using Neuron Deep Learning Container (DLC).
Manually install from source#
Important
This is beta preview of the vLLM Neuron plugin. For a more stable experience, consider using the AWS Neuron vllm fork.
Install the plugin from GitHub sources using the following commands. The plugin will automatically install the correct version of vLLM along with other required dependencies.
git clone --branch 0.2.0 https://github.com/vllm-project/vllm-neuron.git
cd vllm-neuron
pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e .
Usage#
Neuron Environment Setup#
Quickstart#
Here is a quick and minimal example to get running.
import os
from vllm import LLM, SamplingParams
# Initialize the model
llm = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_num_seqs=4,
max_model_len=128,
tensor_parallel_size=32
)
# Generate text
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
sampling_params = SamplingParams(temperature=0.0)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"Prompt: {output.prompt}")
print(f"Generated: {output.outputs[0].text}")
Feature Support#
Feature |
Status |
Notes |
|---|---|---|
Continuous batching |
🟢 |
|
Prefix Caching |
🟢 |
|
Speculative Decoding |
🟢 |
Only Eagle V1 is supported |
Quantization |
🟢 |
INT8/FP8 quantization support |
Dynamic sampling |
🟢 |
|
Tool calling |
🟢 |
|
CPU Sampling |
🟢 |
|
Chunked Prefill |
🚧 |
|
Multimodal |
🚧 |
Only Llama 4 is supported |
🟢 Functional: Fully operational, with ongoing optimizations.
🚧 WIP: Under active development.
Feature Configuration#
NxD Inference models provide many configuration options. When using NxD Inference through vLLM, you configure the model with a default configuration that sets the required fields from vLLM settings.
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=scheduler_config.max_num_seqs,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
enable_bucketing=True,
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
padding_side="right"
)
Use the additional_config field to provide an override_neuron_config dictionary that specifies your desired NxD Inference configuration settings. You provide the settings you want to override as a dictionary (or JSON object when starting vLLM from the CLI) containing basic types. For example, to disable auto bucketing, run code similar to this:
additional_config=dict(
override_neuron_config=dict(
is_prefix_caching=True,
is_block_kv_layout=True,
pa_num_blocks=4096,
pa_block_size=32,
)
)
or when launching vLLM from the CLI
--additional-config '{
"override-neuron-config": {
"is_prefix_caching":True, \
"is_block_kv_layout":True, \
"pa_num_blocks": 4096, \
"pa_block_size": 32}
}'
For more information on NxD Inference features, see NxD Inference Features Configuration Guide and NxD Inference API Reference.
Scheduling and K/V Cache#
NxD Inference uses a contiguous memory layout for the K/V cache instead of PagedAttention support. It integrates into vLLM’s block manager by setting the block size to the maximum length supported by the model and allocating one block per maximum number of sequences configured. However, the vLLM scheduler currently does not introspect the blocks associated to each sequence when (re-)scheduling running sequences. It requires an additional free block regardless of space available in the current block resulting in preemption. This would lead to a large increase in latency for the preempted sequence because it would be rescheduled in the context encoding phase. Since NxD Inference’s implementation ensures each block is big enough to fit the maximum model length, preemption is never needed in our current integration. Therefore, Neuron disabled the preemption checks done by the scheduler in our fork. This significantly improves E2E performance of the Neuron integration.
Decoding#
On-device samples is enabled by default, which performs sampling logic on the Neuron devices rather than passing the generated logits back to CPU and sample through vLLM. This allows you to use Neuron hardware to accelerate sampling and reduce the amount of data transferred between devices leading to improved latency.
However, on-device sampling comes with some limitations. Currently, we only support the following
sampling parameters: temperature, top_k and top_p parameters.
Other sampling parameters are currently
not supported through on-device sampling.
When on-device sampling is enabled, we handle the following special cases:
When
top_kis set to -1, we limittop_kto 256 instead.When
temperatureis set to 0, we use greedy decoding to remain compatible with existing conventions. This is the same as settingtop_kto 1.
By default, on-device sampling utilizes a greedy decoding strategy to select tokens with the highest probabilities.
You can enable a different on-device sampling strategy by passing a on_device_sampling_config
using the override neuron config feature (see Model Configuration). It is strongly recommended to make use
of the global_top_k configuration limiting the maximum value of top_k a user can request for improved performance.
Quantization#
NxD Inference supports quantization but has not yet been integrated with vLLMs configuration for quantization.
If you want to use quantization, do not set vLLM’s --quantization setting to neuron_quant.
Keep it unset and use the Neuron configuration of the model to configure quantization of the NxD Inference model directly.
For more information on how to configure and use quantization with NxD Inference incl. requirements on checkpoints,
refer to Quantization in the NxD Inference Feature Guide.
Loading pre-compiled models / Serialization Support#
Tracing and compiling the model can take a non-trivial amount of time depending on model size e.g. a small-ish model of 15GB might take around 15min to compile. Exact times depend on multiple factors. Doing this on each server start would lead to unacceptable application startup times. Therefore, we support storing and loading the traced and compiled models.
Both are controlled through the NEURON_COMPILED_ARTIFACTS variable. When pointed to a path that contains a pre-compiled model,
we load the pre-compiled model directly, and any differing model configurations passed in to the vllm API will not trigger re-compilation.
If loading from the NEURON_COMPILED_ARTIFACTS path fails, then we will recompile the model with the provided configurations and store
the results in the provided location. If NEURON_COMPILED_ARTIFACTS is not set, we will compile the model and store it under a neuron-compiled-artifacts
subdirectory in the directory of your model checkpoint.
Prefix Caching#
Starting in Neuron SDK 2.24, prefix caching is supported on the AWS Neuron fork of vLLM. Prefix caching allows developers to improve TTFT by re-using the KV Cache of the common shared prompts across inference requests. See Prefix Caching for more information on how to enable prefix caching with vLLM.
Examples#
For more in depth NxD Inference tutorials that include vLLM deployment steps, refer to Tutorials.
The following examples use meta-llama/Llama-3.3-70B-Instruct on a Trn2.48xlarge instance.
If you have access to the model checkpoint locally, replace meta-llama/Llama-3.3-70B-Instruct with the path to your local copy.
Otherwise, you need to request access through HuggingFace and login via huggingface-cli login using
a HuggingFace user access token before running the examples.
If you use a different instance type, you need to adjust the tp_degree according to the number of Neuron Cores
available on your instance type. (For more information see: Tensor-parallelism support.)
Offline Inference Example#
Here is an example for running offline inference. Bucketing is only disabled to demonstrate how to override Neuron configuration values. Keeping it enabled generally delivers better performance.
import os
from vllm import LLM, SamplingParams
# Initialize the model
llm = LLM(
model="meta-llama/Llama-3.3-70B-Instruct",
max_num_seqs=4,
max_model_len=4096,
tensor_parallel_size=64,
additional_config=dict(
override_neuron_config=dict(
enable_bucketing=False,
)
),
)
# Generate text
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
sampling_params = SamplingParams(temperature=0.0)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"Prompt: {output.prompt}")
print(f"Generated: {output.outputs[0].text}")
Online Inference Example#
You can start an OpenAI API compatible server with the same settings as the offline example by running the following command:
python3 -m vllm.entrypoints.openai.api_server \
--model "meta-llama/Llama-3.3-70B-Instruct" \
--tensor-parallel-size 64 \
--max-model-len 4096 \
--max-num-seqs 4 \
--no-enable-prefix-caching \
--additional-config '{
"override_neuron_config": {
"enable_bucketing": False,
}
}' \
--port 8000
In addition to the sampling parameters supported by OpenAI, we also support top_k.
You can change the sampling parameters and enable or disable streaming.
from openai import OpenAI
# Client Setup
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model_name = models.data[0].id
# Sampling Parameters
max_tokens = 1024
temperature = 1.0
top_p = 1.0
top_k = 50
stream = False
# Chat Completion Request
prompt = "Hello, my name is Llama "
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=int(max_tokens),
temperature=float(temperature),
top_p=float(top_p),
stream=stream,
extra_body={'top_k': top_k}
)
# Parse the response
generated_text = ""
if stream:
for chunk in response:
if chunk.choices[0].delta.content is not None:
generated_text += chunk.choices[0].delta.content
else:
generated_text = response.choices[0].message.content
print(generated_text)
Specifying context and token buckets (online inference)#
You can tune bucketing for prefill (context encoding) and decode (token generation) by
passing override_neuron_config to the OpenAI-compatible server.
The example below targets a 1K-token workload on meta-llama/Llama-3.3-70B-Instruct with single sequence (BS=1) execution.
python -m vllm.entrypoints.openai.api_server \
--model "meta-llama/Llama-3.3-70B-Instruct" \
--tensor-parallel-size 64 \
--max-num-seqs 1 \
--max-model-len 1024 \
--port 8080 \
--additional-config '{
"override-neuron-config":
{"enable_bucketing": true,
"context_encoding_buckets": [256, 512, 1024],
"token_generation_buckets": [32, 64, 128, 256, 512, 768],
"max_context_length": 1024,
"seq_len": 1024,
"batch_size": 1,
"ctx_batch_size": 1,
"tkg_batch_size": 1,
"is_continuous_batching": true}
}'
Known Issues#
Chunked prefill is disabled by default on Neuron for optimal performance. To enable chunked prefill, set the environment variable
DISABLE_NEURON_CUSTOM_SCHEDULER="1".Users are required to provide a
num_gpu_blocks_overridearg calculated asceil(max_model_len // block_size) * max_num_seqswhen invoking vLLM to avoid a potential OOB error.Prefix caching with
batch_size=1generates incorrect outputs. Recommend to usebatch_size>1when prefix caching is enabled.When using HuggingFace model IDs with both shard on load and models that have
tie_word_embeddingsset totruein their config (such as Qwen3-8B), you may encounter the errorNotImplementedError: Cannot copy out of meta tensor; no data!. To resolve this, download the model checkpoint locally from Hugging Face and serve it from the local path instead of using the HuggingFace model ID.
Support#
Documentation: AWS Neuron Documentation
Issues: GitHub Issues
Community: AWS Neuron Forum