Qwen3 235B A22B#
Learn how to get started with the Qwen3 235B A22B model with Neuron, using recommended online and offline serving configurations.
About Qwen3 235B A22B#
Qwen3 235B A22B is a mixture-of-experts (MoE) model with 235B parameters developed by the Qwen Team, activating 22B parameters per forward pass.
For detailed model specifications, capabilities, and checkpoints, see the official Qwen/Qwen3-235B-A22B model card on Hugging Face.
Quickstart#
The following examples show how to use Qwen3 235B A22B with NeuronX Distributed Inference (NxDI) framework and vLLM for both online and offline use cases on Neuron devices.
Before you start…
Before running the sample code below, review how to set up your environment by following the NxDI Setup Guide. Additionally, download the model checkpoint to a local directory of your choice (such as ~/models/Qwen3-235B-A22B/).
Select the instance type and make sure to update the highlighted code below to match your chosen path before you execute it.
1import torch
2from transformers import AutoTokenizer, GenerationConfig
3
4from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig
5from neuronx_distributed_inference.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeInferenceConfig, NeuronQwen3MoeForCausalLM
6from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config
7
8MODEL_PATH = "~/models/Qwen3-235B-A22B/"
9TRACED_MODEL_PATH = "~/traced_models/Qwen3-235B-A22B/"
10SEED = 0
11NEURON_CONFIG = MoENeuronConfig(
12 tp_degree=64,
13 attention_dp_degree=8,
14 cp_degree=16,
15 moe_tp_degree=2,
16 moe_ep_degree=32,
17 use_index_calc_kernel=True,
18 mode_mask_padded_tokens=True,
19 batch_size=16,
20 ctx_batch_size=1,
21 max_context_length=16384,
22 seq_len=16384,
23 scratch_pad_size=1024,
24 torch_dtype="float16",
25 is_continuous_batching=True,
26 fused_qkv=True,
27 blockwise_matmul_config={'use_shard_on_intermediate_dynamic_while': True, 'skip_dma_token': True},
28 on_device_sampling_config={'do_sample': True, 'temperature': 0.6, 'top_k': 20, 'top_p': 0.95},
29 enable_bucketing=True,
30 token_generation_buckets=[10240, 16384],
31 context_encoding_buckets=[10240, 16384],
32 flash_decoding_enabled=False,
33 logical_nc_config=2,
34 cc_pipeline_tiling_factor=2,
35 sequence_parallel_enabled=True,
36 qkv_kernel_enabled=True,
37 qkv_nki_kernel_enabled=True,
38 qkv_cte_nki_kernel_fuse_rope=True,
39 attn_kernel_enabled=True,
40 strided_context_parallel_kernel_enabled=True,
41 async_mode=True,
42)
43
44# Set random seed for reproducibility
45torch.manual_seed(SEED)
46
47# Initialize configs and tokenizer.
48generation_config = GenerationConfig.from_pretrained(MODEL_PATH)
49config = Qwen3MoeInferenceConfig(NEURON_CONFIG, load_config=load_pretrained_config(MODEL_PATH))
50
51tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right")
52tokenizer.pad_token = tokenizer.eos_token
53
54# Compile and save model.
55print("Compiling and saving model...")
56model = NeuronQwen3MoeForCausalLM(MODEL_PATH, config)
57model.compile(TRACED_MODEL_PATH)
58tokenizer.save_pretrained(TRACED_MODEL_PATH)
59
60# Load from compiled checkpoint.
61print("Loading model from compiled checkpoint...")
62model = NeuronQwen3MoeForCausalLM(TRACED_MODEL_PATH)
63model.load(TRACED_MODEL_PATH)
64
65# Generate outputs.
66print("\nGenerating outputs...")
67prompt = "Give me a short introduction to large language models."
68messages = [
69 {"role": "user", "content": prompt}
70]
71text = tokenizer.apply_chat_template(
72 messages,
73 tokenize=False,
74 add_generation_prompt=True,
75 enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
76)
77inputs = tokenizer([text], padding=True, return_tensors="pt")
78generation_model = HuggingFaceGenerationAdapter(model)
79outputs = generation_model.generate(
80 inputs.input_ids,
81 generation_config=generation_config,
82 attention_mask=inputs.attention_mask,
83 max_length=model.config.neuron_config.max_length,
84)
85
86output_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
87print("Generated outputs:")
88for i, output_token in enumerate(output_tokens):
89 print(f"Output {i}: {output_token}")
Select the instance type and make sure to update the highlighted code below to match your chosen path before you execute it.
1import os
2
3os.environ["VLLM_NEURON_FRAMEWORK"] = "neuronx-distributed-inference"
4
5from vllm import LLM, SamplingParams
6
7# Create an LLM.
8llm = LLM(
9 model="~/models/Qwen3-235B-A22B/",
10 tensor_parallel_size=64,
11 max_num_seqs=16,
12 max_model_len=16384,
13 additional_config={'override_neuron_config': {'tp_degree': 64, 'attention_dp_degree': 8, 'cp_degree': 16, 'moe_tp_degree': 2, 'moe_ep_degree': 32, 'use_index_calc_kernel': True, 'mode_mask_padded_tokens': True, 'batch_size': 16, 'ctx_batch_size': 1, 'max_context_length': 16384, 'seq_len': 16384, 'scratch_pad_size': 1024, 'torch_dtype': 'float16', 'is_continuous_batching': True, 'fused_qkv': True, 'blockwise_matmul_config': {'use_shard_on_intermediate_dynamic_while': True, 'skip_dma_token': True}, 'on_device_sampling_config': {'do_sample': True, 'temperature': 0.6, 'top_k': 20, 'top_p': 0.95}, 'enable_bucketing': True, 'token_generation_buckets': [10240, 16384], 'context_encoding_buckets': [10240, 16384], 'flash_decoding_enabled': False, 'logical_nc_config': 2, 'cc_pipeline_tiling_factor': 2, 'sequence_parallel_enabled': True, 'qkv_kernel_enabled': True, 'qkv_nki_kernel_enabled': True, 'qkv_cte_nki_kernel_fuse_rope': True, 'attn_kernel_enabled': True, 'strided_context_parallel_kernel_enabled': True, 'async_mode': True}},
14 enable_prefix_caching=False,
15 enable_chunked_prefill=False,
16)
17
18# Sample prompts.
19prompts = [
20 "The president of the United States is",
21 "The capital of France is",
22 "The future of AI is",
23]
24outputs = llm.generate(prompts, SamplingParams(top_k=1))
25
26for output in outputs:
27 prompt = output.prompt
28 generated_text = output.outputs[0].text
29 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Select the instance type and make sure to update the highlighted code below to match your chosen path before you execute it.
1VLLM_NEURON_FRAMEWORK='neuronx-distributed-inference' python -m vllm.entrypoints.openai.api_server \
2 --model="~/models/Qwen3-235B-A22B/" \
3 --tensor-parallel-size=64 \
4 --max-num-seqs=16 \
5 --max-model-len=16384 \
6 --additional-config='{"override_neuron_config": {"async_mode": true, "attention_dp_degree": 8, "attn_kernel_enabled": true, "batch_size": 16, "blockwise_matmul_config": {"skip_dma_token": true, "use_shard_on_intermediate_dynamic_while": true}, "cc_pipeline_tiling_factor": 2, "context_encoding_buckets": [10240, 16384], "cp_degree": 16, "ctx_batch_size": 1, "enable_bucketing": true, "flash_decoding_enabled": false, "fused_qkv": true, "is_continuous_batching": true, "logical_nc_config": 2, "max_context_length": 16384, "mode_mask_padded_tokens": true, "moe_ep_degree": 32, "moe_tp_degree": 2, "on_device_sampling_config": {"do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95}, "qkv_cte_nki_kernel_fuse_rope": true, "qkv_kernel_enabled": true, "qkv_nki_kernel_enabled": true, "scratch_pad_size": 1024, "seq_len": 16384, "sequence_parallel_enabled": true, "strided_context_parallel_kernel_enabled": true, "token_generation_buckets": [10240, 16384], "torch_dtype": "float16", "tp_degree": 64, "use_index_calc_kernel": true}}' \
7 --no-enable-chunked-prefill \
8 --no-enable-prefix-caching \
9 --port=8080
Once the vLLM server is online, submit requests using the example below:
1from openai import OpenAI
2
3
4client = OpenAI(api_key="EMPTY", base_url="http://0.0.0.0:8080/v1")
5models = client.models.list()
6model_name = models.data[0].id
7
8prompt = "Hello, my name is Llama "
9
10response = client.chat.completions.create(
11 model=model_name,
12 messages=[{"role": "user", "content": prompt}],
13 max_tokens=1024,
14 temperature=1.0,
15 top_p=1.0,
16 stream=False,
17 extra_body={"top_k": 50},
18)
19
20generated_text = response.choices[0].message.content
21print(generated_text)
Recommended configuration#
Select a use case to view the recommended Neuron configuration. For the definitions of the flags listed below, see the NxDI API reference guide.
For most use cases, the configuration below can be used to optimize throughput on Neuron devices. You can also increase the batch_size or use quantization to improve throughput even further.
For this specific configuration, we recommend using Expert Parallelism (EP) of 32. For more details, refer to the Qwen3-MoE Inference on Trn2 tutorial.
trn2.48xlarge
1NeuronConfig(
2 tp_degree=64,
3 attention_dp_degree=8,
4 cp_degree=16,
5 moe_tp_degree=2,
6 moe_ep_degree=32,
7 use_index_calc_kernel=True,
8 mode_mask_padded_tokens=True,
9 batch_size=64,
10 ctx_batch_size=1,
11 max_context_length=16384,
12 seq_len=16384,
13 scratch_pad_size=1024,
14 torch_dtype="float16",
15 is_continuous_batching=True,
16 fused_qkv=True,
17 blockwise_matmul_config={'use_shard_on_intermediate_dynamic_while': True, 'skip_dma_token': True},
18 on_device_sampling_config={'do_sample': True, 'temperature': 0.6, 'top_k': 20, 'top_p': 0.95},
19 enable_bucketing=True,
20 token_generation_buckets=[10240, 16384],
21 context_encoding_buckets=[10240, 16384],
22 flash_decoding_enabled=False,
23 logical_nc_config=2,
24 cc_pipeline_tiling_factor=2,
25 sequence_parallel_enabled=True,
26 qkv_kernel_enabled=True,
27 qkv_nki_kernel_enabled=True,
28 qkv_cte_nki_kernel_fuse_rope=True,
29 attn_kernel_enabled=True,
30 strided_context_parallel_kernel_enabled=True,
31 async_mode=True,
32)
For most use cases, the configuration below can be used to optimize latency on Neuron devices.
For this specific configuration, we recommend using Expert Parallelism (EP) of 32. For more details, refer to the qwen3-moe-tutorial tutorial.
trn2.48xlarge
1NeuronConfig(
2 tp_degree=64,
3 attention_dp_degree=8,
4 cp_degree=16,
5 moe_tp_degree=2,
6 moe_ep_degree=32,
7 use_index_calc_kernel=True,
8 mode_mask_padded_tokens=True,
9 batch_size=16,
10 ctx_batch_size=1,
11 max_context_length=16384,
12 seq_len=16384,
13 scratch_pad_size=1024,
14 torch_dtype="float16",
15 is_continuous_batching=True,
16 fused_qkv=True,
17 blockwise_matmul_config={'use_shard_on_intermediate_dynamic_while': True, 'skip_dma_token': True},
18 on_device_sampling_config={'do_sample': True, 'temperature': 0.6, 'top_k': 20, 'top_p': 0.95},
19 enable_bucketing=True,
20 token_generation_buckets=[10240, 16384],
21 context_encoding_buckets=[10240, 16384],
22 flash_decoding_enabled=False,
23 logical_nc_config=2,
24 cc_pipeline_tiling_factor=2,
25 sequence_parallel_enabled=True,
26 qkv_kernel_enabled=True,
27 qkv_nki_kernel_enabled=True,
28 qkv_cte_nki_kernel_fuse_rope=True,
29 attn_kernel_enabled=True,
30 strided_context_parallel_kernel_enabled=True,
31 async_mode=True,
32)