Tutorial: Deploy Qwen3-MoE 235B on Trn2 instances#

This tutorial provides a step-by-step guide to deploy Qwen/Qwen3-235B-A22B on a single trn2.48xlarge instance using vLLM V1 with the vLLM-Neuron Plugin.

Examples#

Step 1: Set up your development environment#

As a prerequisite, this tutorial requires that you have a Trn2 instance created from a Deep Learning AMI that has the Neuron SDK pre-installed.

To set up a Trn2 instance using Deep Learning AMI with pre-installed Neuron SDK, see the NxDI setup guide. To use Jupyter Notebook on the Neuron instance, you can follow this guide.

After setting up an instance, use SSH to connect to the Trn2 instance using the key pair that you chose when you launched the instance.

After you are connected, activate the Python virtual environment that includes the Neuron SDK.

pip list | grep neuron

You should see Neuron packages including neuronx-distributed-inference and neuronx-cc.

Step 2: Install the vLLM version that supports NxD Inference#

NxD Inference supports running models with vLLM. This functionality is available in the vLLM-Neuron GitHub repository. Install the latest release branch of vLLM-Neuron plugin following instructions in the vLLM User Guide for NxD Inference.

Ensure that the Neuron virtual environment is activated if using a new terminal instead of the one from connect step above. Then, install the Neuron vLLM into the virtual environment.

Step 3 Download the model from HuggingFace#

To deploy Qwen/Qwen3-235B-A22B on Neuron, you need to first download the checkpoint from HuggingFace to a local path on the Trn2 instance (for more information on downloading models from HuggingFace, refer the guide on Downloading models).

After the download, you should see a config.json file in the output folder along with weights in model-xxxx-of-xxxx.safetensors format.

Step 4: Compile and deploy Qwen3 Inference#

In this step, you can directly use the vllm command to deploy the model. The neuronx-distributed-inference model loader in vllm performs JIT compilation before deploying it with the model server. Replace the default model path ~/models/Qwen3-235B-A22B/ with your specific path before running the below command.

We provide two examples to run Qwen3 with vLLM V1:

  • Offline inference: you can provide prompts in a python script and execute it.

  • Online inference: you will serve the model in an online server and send requests.

Model Compilation and Configuration#

There are a few fields you can configure to improve performance:

  • tp_degree: degree of tensor parallelism.

  • attention_dp_degree: degree of data parallelism at the attention layer for decoding.

  • cp_degree: degree of context parallelism at the attention layer for prefill.

  • moe_tp_degree: degree of tensor parallelism at the moe layer, moe_tp_degree*moe_ep_degree should equal to tp_degree.

  • moe_ep_degree: degree of expert parallelism at the moe layer, moe_tp_degree*moe_ep_degree should equal to tp_degree.

  • blockwise_matmul_config: the configuration of the blockwise MoE kernel for prefill, here we recommend to shard on the intermediate dimension.

  • use_index_calc_kernel: whether to use specialized kernel for index calculations.

  • moe_mask_padded_token: whether to mask padded tokens at the moe layer.

  • qkv_kernel_enabled and qkv_nki_kernel_enabled: whether to use the fused QKV kernel.

  • qkv_cte_nki_kernel_fuse_rope: whether to use the fused QKV and RoPE kernel.

  • strided_context_parallel_kernel_enabled: whether to use the strided context parallel flash attention kernel.

[ ]:
qwen3_moe_neuron_config = {
    "tp_degree": 64,
    "attention_dp_degree": 8,
    "cp_degree": 16,
    "moe_tp_degree": 2,
    "moe_ep_degree": 32,
    "use_index_calc_kernel": True,
    "moe_mask_padded_tokens": True,
    "batch_size": 16,
    "ctx_batch_size": 1,
    "max_context_length": 16384,
    "seq_len": 16384,
    "is_continuous_batching": True,
    "fused_qkv": True,
    "blockwise_matmul_config":{"use_shard_on_intermediate_dynamic_while": True, "skip_dma_token": True},
    "on_device_sampling_config": {
        "do_sample": True,
        "temperature": 0.6,
        "top_k": 20,
        "top_p": 0.95
    },
    "enable_bucketing": True,
    "context_encoding_buckets": [10240, 16384],
    "token_generation_buckets": [10240, 16384],
    "flash_decoding_enabled": False,
    "logical_nc_config": 2,
    "sequence_parallel_enabled": True,
    "qkv_kernel_enabled": True,
    "qkv_nki_kernel_enabled": True,
    "qkv_cte_nki_kernel_fuse_rope": True,
    "attn_kernel_enabled": True,
    "strided_context_parallel_kernel_enabled": True,
    "async_mode": True
}

Offline Example#

[ ]:
import os

os.environ["VLLM_NEURON_FRAMEWORK"] = "neuronx-distributed-inference"

from vllm import LLM, SamplingParams

# Create an LLM.
llm = LLM(
   model="~/models/Qwen3-235B-A22B/",
   tensor_parallel_size=64,
   max_num_seqs=16,
   max_model_len=16384,
   additional_config=dict(
    override_neuron_config=qwen3_moe_neuron_config  # Use the configuration defined above
    ),
   enable_prefix_caching=False,
   enable_chunked_prefill=False,
)

# Sample prompts.
prompts = [
   "The president of the United States is",
   "The capital of France is",
   "The future of AI is",
]
outputs = llm.generate(prompts, SamplingParams(top_k=1))

for output in outputs:
   prompt = output.prompt
   generated_text = output.outputs[0].text
   print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Below is an example output:

Prompt: 'The president of the United States is', Generated text: ' the head of state and head of government of the United States, indirectly elected to'
Prompt: 'The capital of France is', Generated text: ' Paris. The capital of Italy is Rome. The capital of Germany is Berlin.'
Prompt: 'The future of AI is', Generated text: " not just about smarter algorithms or faster processors; it's about creating systems that can"

Online Example#

[ ]:
import json

VLLM_NEURON_FRAMEWORK='neuronx-distributed-inference'
additional_neuron_config=json.dumps(dict(override_neuron_config=qwen3_moe_neuron_config))
start_server_cmd=cmd = f'''python3 -m vllm.entrypoints.openai.api_server \
   --model="~/models/Qwen3-235B-A22B/" \
   --tensor-parallel-size=64 \
   --max-num-seqs=16 \
   --max-model-len=16384 \
   --additional-config=\'{additional_neuron_config}\' \
   --no-enable-chunked-prefill \
   --no-enable-prefix-caching \
   --port=8080
'''

import os
os.system(start_server_cmd)

Once the vLLM server is online, submit requests using the example below:

[ ]:
from openai import OpenAI


client = OpenAI(api_key="EMPTY", base_url="http://0.0.0.0:8080/v1")
models = client.models.list()
model_name = models.data[0].id

prompt = "Hello, my name is Llama "

response = client.chat.completions.create(
    model=model_name,
    messages=[{"role": "user", "content": prompt}],
    max_tokens=1024,
    temperature=1.0,
    top_p=1.0,
    stream=False,
    extra_body={"top_k": 50},
)

generated_text = response.choices[0].message.content
print(generated_text)

Below is an example output:

<think>
Okay, so the user is Llama, and they want to know if I can handle that name. Let me think. First, Llama is an animal, but people can have names like that too. I should make sure I use the correct capitalization if that's how they present themselves. The user mentioned they're trying to start a conversation, so I should respond warmly. Maybe they want to check if I can remember their name or if I can be friendly. I need to acknowledge their name properly and invite them to ask questions. Also, considering Llama isn't a common name, I should take care not to misspell it or use lowercase unless instructed. Let me confirm the name and offer assistance. I'll keep it simple and welcoming.

Wait, but maybe the user just wants to confirm they're using the correct name format. Should I include an emoji to keep the tone friendly? The example response uses a Llama face, but since Llama is their name, maybe a different emoji like a star or checkmark? Or maybe none, to stay professional. But the user wants a conversational tone, so perhaps a smiley. Let me structure the response as: "Hello, Llama! 😊 Nice to meet you. How can I assist you today?" That's friendly, uses their name correctly, and opens the door for help without assuming their intent.
</think>

Hello, Llama! 🤗 It's nice to meet you. How can I assist you today? Let me know if you have any questions or need help exploring topics!

Conclusion#

Congratulations ! You now know how to deploy Qwen/Qwen3-235B-A22B on a trn2.48xlarge instance. Modify the configurations and deploy the model as per your requirements and use case.