Llama 3.3 70B#
Learn how to get started with the Llama 3.3 70B model with Neuron, using recommended online and offline serving configurations.
About Llama 3.3 70B#
Llama 3.3 70B is Meta’s multilingual large language model with 70B parameters and a transformer architecture featuring Grouped Query Attention (GQA).
For detailed model specifications, capabilities, and checkpoints, see the official meta-llama/Llama-3.3-70B-Instruct model card on Hugging Face.
Quickstart#
The following examples show how to use Llama 3.3 70B with NeuronX Distributed Inference (NxDI) framework and vLLM for both online and offline use cases on Neuron devices.
Before you start…
Before running the sample code below, review how to set up your environment by following the NxDI Setup Guide. Additionally, download the model checkpoint to a local directory of your choice (such as ~/models/Llama-3.3-70B-Instruct/
).
Select the instance type and make sure to update the highlighted code below to match your chosen path before you execute it.
1import torch
2from transformers import AutoTokenizer, GenerationConfig
3
4from neuronx_distributed_inference.models.config import NeuronConfig
5from neuronx_distributed_inference.models.llama.modeling_llama import LlamaInferenceConfig, NeuronLlamaForCausalLM
6from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params
7from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config
8
9MODEL_PATH = "~/models/Llama-3.3-70B-Instruct/"
10TRACED_MODEL_PATH = "~/traced_models/Llama-3.3-70B-Instruct/"
11SEED = 0
12NEURON_CONFIG = NeuronConfig(
13 batch_size=1,
14 tp_degree=64,
15 enable_bucketing=True,
16 is_continuous_batching=True,
17 logical_nc_config=2,
18 seq_len=16384,
19)
20
21# Set random seed for reproducibility
22torch.manual_seed(SEED)
23
24# Initialize configs and tokenizer.
25generation_config = GenerationConfig.from_pretrained(MODEL_PATH)
26eos = generation_config.eos_token_id
27generation_config_kwargs = {
28 "do_sample": True,
29 "top_k": 1,
30 "pad_token_id": eos[0] if isinstance(eos, list) else eos,
31}
32generation_config.update(**generation_config_kwargs)
33config = LlamaInferenceConfig(NEURON_CONFIG, load_config=load_pretrained_config(MODEL_PATH))
34
35tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right")
36tokenizer.pad_token = tokenizer.eos_token
37
38# Compile and save model.
39print("Compiling and saving model...")
40model = NeuronLlamaForCausalLM(MODEL_PATH, config)
41model.compile(TRACED_MODEL_PATH)
42tokenizer.save_pretrained(TRACED_MODEL_PATH)
43
44# Load from compiled checkpoint.
45print("Loading model from compiled checkpoint...")
46model = NeuronLlamaForCausalLM(TRACED_MODEL_PATH)
47model.load(TRACED_MODEL_PATH)
48
49# Generate outputs.
50print("Generating outputs...")
51prompts = ["I believe the meaning of life is", "The color of the sky is"]
52sampling_params = prepare_sampling_params(
53 batch_size=NEURON_CONFIG.batch_size,
54 top_k=[10, 5],
55 top_p=[0.5, 0.9],
56 temperature=[0.9, 0.5],
57)
58print(f"Prompts: {prompts}")
59
60inputs = tokenizer(prompts, padding=True, return_tensors="pt")
61generation_model = HuggingFaceGenerationAdapter(model)
62outputs = generation_model.generate(
63 inputs.input_ids,
64 generation_config=generation_config,
65 attention_mask=inputs.attention_mask,
66 max_length=model.config.neuron_config.max_length,
67 sampling_params=sampling_params,
68)
69
70output_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
71print("Generated outputs:")
72for i, output_token in enumerate(output_tokens):
73 print(f"Output {i}: {output_token}")
1import torch
2from transformers import AutoTokenizer, GenerationConfig
3
4from neuronx_distributed_inference.models.config import NeuronConfig
5from neuronx_distributed_inference.models.llama.modeling_llama import LlamaInferenceConfig, NeuronLlamaForCausalLM
6from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params
7from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config
8
9MODEL_PATH = "~/models/Llama-3.3-70B-Instruct/"
10TRACED_MODEL_PATH = "~/traced_models/Llama-3.3-70B-Instruct/"
11SEED = 0
12NEURON_CONFIG = NeuronConfig(
13 batch_size=1,
14 tp_degree=32,
15 enable_bucketing=True,
16 is_continuous_batching=True,
17 logical_nc_config=1,
18 seq_len=16384,
19)
20
21# Set random seed for reproducibility
22torch.manual_seed(SEED)
23
24# Initialize configs and tokenizer.
25generation_config = GenerationConfig.from_pretrained(MODEL_PATH)
26eos = generation_config.eos_token_id
27generation_config_kwargs = {
28 "do_sample": True,
29 "top_k": 1,
30 "pad_token_id": eos[0] if isinstance(eos, list) else eos,
31}
32generation_config.update(**generation_config_kwargs)
33config = LlamaInferenceConfig(NEURON_CONFIG, load_config=load_pretrained_config(MODEL_PATH))
34
35tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right")
36tokenizer.pad_token = tokenizer.eos_token
37
38# Compile and save model.
39print("Compiling and saving model...")
40model = NeuronLlamaForCausalLM(MODEL_PATH, config)
41model.compile(TRACED_MODEL_PATH)
42tokenizer.save_pretrained(TRACED_MODEL_PATH)
43
44# Load from compiled checkpoint.
45print("Loading model from compiled checkpoint...")
46model = NeuronLlamaForCausalLM(TRACED_MODEL_PATH)
47model.load(TRACED_MODEL_PATH)
48
49# Generate outputs.
50print("Generating outputs...")
51prompts = ["I believe the meaning of life is", "The color of the sky is"]
52sampling_params = prepare_sampling_params(
53 batch_size=NEURON_CONFIG.batch_size,
54 top_k=[10, 5],
55 top_p=[0.5, 0.9],
56 temperature=[0.9, 0.5],
57)
58print(f"Prompts: {prompts}")
59
60inputs = tokenizer(prompts, padding=True, return_tensors="pt")
61generation_model = HuggingFaceGenerationAdapter(model)
62outputs = generation_model.generate(
63 inputs.input_ids,
64 generation_config=generation_config,
65 attention_mask=inputs.attention_mask,
66 max_length=model.config.neuron_config.max_length,
67 sampling_params=sampling_params,
68)
69
70output_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
71print("Generated outputs:")
72for i, output_token in enumerate(output_tokens):
73 print(f"Output {i}: {output_token}")
Select the instance type and make sure to update the highlighted code below to match your chosen path before you execute it.
1import os
2
3os.environ["VLLM_NEURON_FRAMEWORK"] = "neuronx-distributed-inference"
4
5from vllm import LLM, SamplingParams
6
7# Create an LLM.
8llm = LLM(
9 model="~/models/Llama-3.3-70B-Instruct/",
10 tensor_parallel_size=64,
11 max_num_seqs=1,
12 max_model_len=16384,
13 device="neuron",
14 use_v2_block_manager=True,
15 override_neuron_config={},
16)
17
18# Sample prompts.
19prompts = [
20 "The president of the United States is",
21 "The capital of France is",
22 "The future of AI is",
23]
24outputs = llm.generate(prompts, SamplingParams(top_k=1))
25
26for output in outputs:
27 prompt = output.prompt
28 generated_text = output.outputs[0].text
29 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
1import os
2
3os.environ["VLLM_NEURON_FRAMEWORK"] = "neuronx-distributed-inference"
4
5from vllm import LLM, SamplingParams
6
7# Create an LLM.
8llm = LLM(
9 model="~/models/Llama-3.3-70B-Instruct/",
10 tensor_parallel_size=32,
11 max_num_seqs=1,
12 max_model_len=16384,
13 device="neuron",
14 use_v2_block_manager=True,
15 override_neuron_config={},
16)
17
18# Sample prompts.
19prompts = [
20 "The president of the United States is",
21 "The capital of France is",
22 "The future of AI is",
23]
24outputs = llm.generate(prompts, SamplingParams(top_k=1))
25
26for output in outputs:
27 prompt = output.prompt
28 generated_text = output.outputs[0].text
29 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Select the instance type and make sure to update the highlighted code below to match your chosen path before you execute it.
1VLLM_NEURON_FRAMEWORK='neuronx-distributed-inference' python -m vllm.entrypoints.openai.api_server \
2 --model="~/models/Llama-3.3-70B-Instruct/" \
3 --tensor-parallel-size=64 \
4 --max-num-seqs=1 \
5 --max-model-len=16384 \
6 --device="neuron" \
7 --use-v2-block-manager \
8 --override-neuron-config='{}' \
9 --port=8080
Once the vLLM server is online, submit requests using the example below:
1from openai import OpenAI
2
3
4client = OpenAI(api_key="EMPTY", base_url="http://0.0.0.0:8080/v1")
5models = client.models.list()
6model_name = models.data[0].id
7
8prompt = "Hello, my name is Llama "
9
10response = client.chat.completions.create(
11 model=model_name,
12 messages=[{"role": "user", "content": prompt}],
13 max_tokens=1024,
14 temperature=1.0,
15 top_p=1.0,
16 stream=False,
17 extra_body={"top_k": 50},
18)
19
20generated_text = response.choices[0].message.content
21print(generated_text)
1VLLM_NEURON_FRAMEWORK='neuronx-distributed-inference' python -m vllm.entrypoints.openai.api_server \
2 --model="~/models/Llama-3.3-70B-Instruct/" \
3 --tensor-parallel-size=32 \
4 --max-num-seqs=1 \
5 --max-model-len=16384 \
6 --device="neuron" \
7 --use-v2-block-manager \
8 --override-neuron-config='{}' \
9 --port=8080
Once the vLLM server is online, submit requests using the example below:
1from openai import OpenAI
2
3
4client = OpenAI(api_key="EMPTY", base_url="http://0.0.0.0:8080/v1")
5models = client.models.list()
6model_name = models.data[0].id
7
8prompt = "Hello, my name is Llama "
9
10response = client.chat.completions.create(
11 model=model_name,
12 messages=[{"role": "user", "content": prompt}],
13 max_tokens=1024,
14 temperature=1.0,
15 top_p=1.0,
16 stream=False,
17 extra_body={"top_k": 50},
18)
19
20generated_text = response.choices[0].message.content
21print(generated_text)
Recommended configuration#
Select a use case to view the recommended Neuron configuration. For the definitions of the flags listed below, see the NxDI API reference guide.
For most use cases, the configuration below can be used to optimize throughput on Neuron devices. You can also increase the batch_size
or use quantization to improve throughput even further.
For this specific configuration, we recommend using Data Parallelism (DP) of 2. For more details on how to implement data parallelism, refer to the Data Parallelism on Trn2 tutorial.
trn2.48xlarge
1NeuronConfig(
2 async_mode=True,
3 batch_size=8,
4 ctx_batch_size=1,
5 tp_degree=32,
6 attn_block_tkg_nki_kernel_cache_update=True,
7 attn_block_tkg_nki_kernel_enabled=True,
8 attn_kernel_enabled=True,
9 cc_pipeline_tiling_factor=1,
10 enable_bucketing=True,
11 fused_qkv=True,
12 is_continuous_batching=True,
13 k_cache_transposed=True,
14 kv_cache_tiling=False,
15 logical_nc_config=2,
16 mlp_kernel_enabled=True,
17 qkv_kernel_enabled=True,
18 seq_len=16384,
19 sequence_parallel_enabled=True,
20 token_generation_buckets=[256, 512, 1024, 2048, 4096, 8192, 10240, 12288, 16384],
21 context_encoding_buckets=[256, 512, 1024, 2048, 4096, 8192, 10240, 12288, 16384],
22 on_device_sampling_config={'do_sample': True, 'dynamic': True},
23)
For most use cases, the configuration below can be used to optimize latency on Neuron devices.
trn2.48xlarge
1NeuronConfig(
2 async_mode=True,
3 batch_size=1,
4 tp_degree=64,
5 attn_block_tkg_nki_kernel_cache_update=True,
6 attn_block_tkg_nki_kernel_enabled=True,
7 attn_kernel_enabled=True,
8 cc_pipeline_tiling_factor=1,
9 enable_bucketing=True,
10 fused_qkv=True,
11 is_continuous_batching=True,
12 k_cache_transposed=True,
13 kv_cache_tiling=False,
14 logical_nc_config=2,
15 mlp_kernel_enabled=True,
16 qkv_kernel_enabled=True,
17 seq_len=16384,
18 sequence_parallel_enabled=True,
19 token_generation_buckets=[256, 512, 1024, 2048, 4096, 8192, 10240, 12288, 16384],
20 context_encoding_buckets=[256, 512, 1024, 2048, 4096, 8192, 10240, 12288, 16384],
21 on_device_sampling_config={'do_sample': True, 'dynamic': True},
22)