NxD Inference Features Configuration Guide#

NxD Inference (neuronx-distributed-inference) is an open-source PyTorch-based inference library that simplifies deep learning model deployment on AWS Inferentia and Trainium instances. Neuronx Distributed Inference includes a model hub and modules that users can reference to implement their own models on Neuron.

Checkpoint compatibility with HuggingFace Transformers#

Models included in the NxD Inference model hub are checkpoint-compatible with HuggingFace Transformers. Supporting other checkpoint formats in NxD Inference is possible through converting the obtained checkpoint to the standard HuggingFace Transformers checkpoint format.

Checkpoint support#

NxD Inference supports older PyTorch binary checkpoints and newer safetensors checkpoints. For improved load speed and reduced host memory consumption, we recommend to always use safetensors by default. Both regular and sharded variants of checkpoints are supported.

NxD Inference supports weights stored in the model path in the following formats:

Format

Sharded

File name

Safetensors

No

model.safetensors

Safetensors

Yes

model.safetensors.index.json

Pickle

No

pytorch_model.bin

Pickle

Yes

pytorch_model.bin.index.json

If your weights are in another format, you must convert them to one of these formats before you can compile and load the model to Neuron. See the following references for more information about these formats:

Compiling models#

To run a model on Neuron with NxD Inference, you compile Python code into a NEFF file (Neuron Executable File Format), which you can load to Neuron devices using the Neuron Runtime.

When you call compile(), NxD Inference does the following:

  1. Trace the Python code to produce an HLO file.

  2. Use the Neuron Compiler to compile the HLO file into a NEFF.

During the trace process, the model code is traced based on a given sample tensor for each input. As a result, model code should avoid dynamic logic that depends on the input values in a tensor, because NxD Inference compiles only the code path that is traced for the sample input tensor.

# Configure, initialize, and compile a model.
model = NeuronLlamaForCausalLM(model_path, config)
model.compile(compiled_model_path)

Serialization support#

When you compile a model with NxD Inference, the library serializes the model to a given folder. After you have a serialized model, you can load it directly to a Neuron device without needing to compile again.

The compile function serializes sharded weights by default, and you can disable this functionality with the save_sharded_checkpoint flag in NeuronConfig.

Logical NeuronCore support#

On Trn2 instances, Neuron supports Logical NeuronCore (LNC) configuration, which combines multiple physical NeuronCores into a single logical NeuronCore. We recommend using LNC=2 on Trn2 instances.

neuron_config = NeuronConfig(logical_neuron_cores=2)

For more information about logical NeuronCore support, see Logical NeuronCore configuration.

Tensor-parallelism support#

For transformer decoders used in large language models, tensor-parallelism is necessary as it provides a way to shard the models’ large weight matrices onto multiple NeuronCores, and having NeuronCores working on the same matrix multiply operation collaboratively. neuronx-distributed-inference’s tensor-parallelism support makes heavy use of collective operations such as all-reduce, which is supported natively by the Neuron runtime.

There are some principles for setting tensor-parallelism degree (number of NeuronCores participating in sharded matrix multiply operations) for Neuron-optimized transformer decoder models.

  1. The number of attention heads needs to be divisible by the tensor-parallelism degree.

  2. The total data size of model weights and key-value caches needs to be smaller than the tensor-parallelism degree multiplied by the amount of memory per Neuron core.

    1. On Trn2, each Neuron core has 24GB of memory (with logical_neuron_cores set to 2).

    2. On Inf2/Trn1, each Neuron core has 16GB of memory.

  3. The Neuron runtime supports the following tensor-parallelism degrees:

    1. Trn2: 1, 2, 4, 8, 16, 32, and 64 (with logical_neuron_cores set to 2)

    2. Inf2: 1, 2, 4, 8, and 24

    3. Trn1: 1, 2, 8, 16, and 32

Examples#

  1. meta-llama/Meta-Llama-3.1-8B has 32 attention heads, and when running at batch size 1 and bfloat16 precision, the model requires about 16GB memory. Therefore, a trn1.2xlarge with 32GB device memory is sufficient.

  2. meta-llama/Meta-Llama-3.1-70B has 64 attention heads, and when running at batch size 1 and bfloat16 precision, the model requires about 148GB memory. Therefore, it can run on 16 NeuronCores on one trn1.32xlarge using 256GB device memory.

Sequence Parallelism#

Sequence parallelism splits tensors across the sequence dimension to improve performance. You can enable sequence parallelism by setting sequence_parallel_enabled=True in NeuronConfig.

neuron_config = NeuronConfig(sequence_parallel_enabled=True)

Compile-time Configurations#

NxD Inference models support a variety of compile-time configurations you can use to tune model performance. For more information, see the NxD Inference API Reference.

Hugging Face generate() API support#

NxD Inference models support the HuggingFace generate() API via the HuggingFaceGenerationAdapter class. This adapter wraps a Neuron model to provide the HuggingFace generation interface.

NxD Inference’s supports the following HuggingFace generation modes:

  • Greedy decoding — num_beams=1 and do_sample=False.

  • Multinomial sampling — num_beams=1 and do_sample=True.

  • Assisted (speculative) decoding — assistant_model or prompt_lookup_num_tokens are specified.

NxD Inference doesn’t currently support other HuggingFace generation modes such beam-search sampling.

Note: When you call generate, the number of prompts must match the batch_size for the model, which is an attribute of NeuronConfig.

neuron_config = NeuronConfig(batch_size=2)

Example#

The following example demonstrates how to wrap a model with HuggingFaceGenerationAdapter to call generate().

from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter

# Init Neuron model, HuggingFace tokenizer, HuggingFace and generation config.


# Run generation with HuggingFaceGenerationAdapter.
generation_model = HuggingFaceGenerationAdapter(model)
inputs = tokenizer(prompts, padding=True, return_tensors="pt")
outputs = generation_model.generate(
    inputs.input_ids,
    generation_config=generation_config,
    attention_mask=inputs.attention_mask,
    max_length=model.neuron_config.max_length,
    **kwargs,
)

output_tokens = tokenizer.batch_decode(
    outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

print("Generated outputs:")
for i, output_token in enumerate(output_tokens):
    print(f"Output {i}: {output_token}")

On-device Sampling Support#

On-device sampling performs sampling logic on the Neuron device (rather than on the CPU) to achieve better performance. To enable on device sampling, provide an OnDeviceSamplingConfig for the on_device_sampling_config attribute in NeuronConfig.

on_device_sampling_config = OnDeviceSamplingConfig(global_topk=256)
neuron_config = NeuronConfig(on_device_sampling_config=on_device_sampling_config)

Dynamic Sampling#

With dynamic sampling, you can pass different top_k, top_p, and temperature values to the forward call to configure sampling for each input in a batch. To enable dynamic sampling, provide an OnDeviceSamplingConfig with dynamic=True.

on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True)
neuron_config = NeuronConfig(on_device_sampling_config=on_device_sampling_config)

To use dynamic sampling, pass a sampling_params tensor to the forward function of the model. The sampling_params tensor has shape [batch_size, 3], where the three values per batch are top_k, top_p, and temperature.

The following example demonstrates how to create sampling_params for a batch with two inputs. In the first input, top_k=50, top_p=0.5, and temperature=0.75. In the second input, top_k=5, top_p=1.0, and temperature=1.0.

sampling_params = torch.tensor([[50, 0.5, 0.75], [5, 1.0, 1.0]])

Greedy Sampling#

By default, on-device sampling uses greedy sampling, where the model picks the highest scoring token.

Multinomial (Top-K) Sampling#

With multinomial (top-k) sampling, the model picks one of the top k-highest scoring tokens. To use on-device multinomial sampling, you must enable dynamic sampling. You can configure the default top_k attribute in the OnDeviceSamplingConfig, or you can specify the top_k value in each call to the model’s forward function.

on_device_sampling_config = OnDeviceSamplingConfig(top_k=5)

Top-P Support in On-Device Sampling#

To use top-p in on-device sampling, enable dynamic sampling, and specify top_p values in the sampling_params.

Temperature Support in On-Device Sampling#

To adjust temperature in on-device sampling, enable dynamic sampling, and specify temperature values in the sampling_params.

QKV Weight Fusion#

QKV weight fusion concatenates a model’s query, key and value weight matrices to achieve better performance, because larger matrices allow for more efficient data movement and compute. You can enable QKV weight fusion by setting fused_qkv=True in the NeuronConfig.

neuron_config = NeuronConfig(fused_qkv=True)

Bucketing#

LLM inference is a generation process that can produce variable length sequences. This poses a problem since the Neuron compiler produces executables which expect statically shaped inputs and outputs. To make LLMs work with different shapes, NxD Inference supports buckets and applies padding wherever it is required. When you run inference, NxD Inference automatically chooses the smallest bucket that fits the input for optimal performance. For more information about bucketing, see Autobucketing for Inference (|torch-neuronx|).

Automatic Bucketing#

When automatic bucketing is enabled, NxD Inference automatically chooses buckets for each model according to the following logic:

  • Context encoding: Powers of two between 128 and the max context length.

    • Note: Max context length is equivalent to sequence length by default.

  • Token generation: Powers of two between 128 and the maximum sequence length.

To enable automatic bucketing, set enable_bucketing=True in NeuronConfig.

neuron_config = NeuronConfig(enable_bucketing=True)

Configuring Specific Buckets#

You can configure specific buckets to further optimize inference based on the input and output length distribution that you expect to process with your model. In NeuronConfig, set enable_bucketing=True, and provide a list of bucket sizes in context_encoding_buckets and/or token_generation_buckets.

neuron_config = NeuronConfig(
    enable_bucketing=True,
    context_encoding_buckets=[1024, 2048, 4096],
    token_generation_buckets=[8192]
)

Quantization#

NxD Inference supports quantization, where model weights and data are converted to a smaller data type to reduce memory bandwidth usage, which improves model performance.

Note: Quantization slightly reduces accuracy due to using data types with lower precision and/or lower range.

Model Weight Quantization#

NxD Inference supports quantizing model weights to the following data types:

  • INT8 (int8) - 8 bit int.

  • FP8 - 8 bit float.

    • f8e4m3 - 8-bit float with greater precision and less range.

      • Important: To use f8e4m3 for quantization, you must set the XLA_HANDLE_SPECIAL_SCALAR environment variable to 1.

    • f8e5m2 - 8-bit float with greater range and less precision.

NxD Inference supports the following quantization methods, which you specify with quantization_type in NeuronConfig:

  • per_tensor_symmetric

  • per_channel_symmetric

Example#

The following example demonstrates how to quantize a model to INT8. To quantize a model to a different data type, change the quantization_dtype config attribute in NeuronConfig.

from neuronx_distributed_inference.models.config import NeuronConfig
from neuronx_distributed_inference.models.llama.modeling_llama import (
    LlamaInferenceConfig,
    NeuronLlamaForCausalLM
)
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config

model_path = "/home/ubuntu/models/Llama-3.1-8B"
quantized_model_path = "/home/ubuntu/models/Llama-3.1-8B-quantized"

neuron_config = NeuronConfig(
    quantized=True,
    quantized_checkpoints_path=quantized_model_path,
    quantization_dtype="int8",
    quantization_type="per_tensor_symmetric"
)

config = LlamaInferenceConfig(
    neuron_config,
    load_config=load_pretrained_config(model_path)
)

# Quantize the model and save it to `quantized_checkpoints_path`.
NeuronLlamaForCausalLM.save_quantized_state_dict(model_path, config)

# Compile, load, and use the model.
model = NeuronLlamaForCausalLM(model_path, config)

KV Cache Quantization#

NxD Inference supports KV cache quantization, where the model’s KV cache is quantized to a smaller data type. When enabled, the model quantizes the KV cache to the torch.float8_e4m3fn data type. Before using the KV cache, the model dequantizes the KV cache to the data type specified by torch_dtype in NeuronConfig.

To enable KV cache quantization, set kv_cache_quant=True in NeuronConfig.

neuron_config = NeuronConfig(kv_cache_quant=True)
  • Important: To use KV cache quantization, you must set the XLA_HANDLE_SPECIAL_SCALAR environment variable to 1.

Speculative Decoding#

Speculative decoding is a performance optimization technique where a smaller draft LLM model predicts the next tokens, and the larger target LLM model verifies those predictions. NxD Inference supports the following speculative decoding implementations:

  1. Vanilla speculative decoding, where a separate draft model predicts the next n tokens for the target model. Each model is compiled independently.

  2. Medusa speculative decoding, where several small model heads predict next tokens, and the target model verifies all predictions at the same time.

  3. EAGLE speculative decoding, where the draft model uses additional context from the target model to improve generation efficiency. NxD Inference supports EAGLE v1 with a flat draft structure.

Vanilla Speculative Decoding#

To use vanilla speculative decoding, you configure, compile, and load a draft model in addition to the main target model. To enable vanilla speculative decoding, set speculation_length and trace_tokengen_model=False in the target model’s NeuronConfig. The draft model’s NeuronConfig should use the same configuration but with these additional attributes reset to their defaults.

Vanilla speculative decoding currently supports only batch sizes of 1.

Example#

The following example demonstrates using Llama-3.2 3B as a draft model for Llama-3.1 70B. The speculation length is set to 5 tokens.

import copy

from transformers import AutoTokenizer, GenerationConfig

from neuronx_distributed_inference.models.config import NeuronConfig
from neuronx_distributed_inference.models.llama.modeling_llama import (
    LlamaInferenceConfig,
    NeuronLlamaForCausalLM
)
from neuronx_distributed_inference.utils.accuracy import get_generate_outputs
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config

prompts = ["I believe the meaning of life is"]

model_path = "/home/ubuntu/models/Llama-3.1-70B"
draft_model_path = "/home/ubuntu/models/Llama-3.2-3B"
compiled_model_path = "/home/ubuntu/neuron_models/Llama-3.1-70B"
compiled_draft_model_path = "/home/ubuntu/neuron_models/Llama-3.2-3B"

# Initialize target model.
neuron_config = NeuronConfig(
    speculation_length=5,
    trace_tokengen_model=False
)
config = LlamaInferenceConfig(
    neuron_config,
    load_config=load_pretrained_config(model_path)
)
model = NeuronLlamaForCausalLM(model_path, config)

# Initialize draft model.
draft_neuron_config = copy.deepcopy(neuron_config)
draft_neuron_config.speculation_length **=** 0
draft_neuron_config.trace_tokengen_model **=** True
draft_config = LlamaInferenceConfig(
    draft_neuron_config,
    load_config=load_pretrained_config(draft_model_path)
)
draft_model = NeuronLlamaForCausalLM(draft_model_path, draft_config)

# Compile and save models.
model.compile(compiled_model_path)
draft_model.compile(compiled_draft_model_path)

# Load models to the Neuron device.
model.load(compiled_model_path)
draft_model.load(compiled_draft_model_path)

# Load tokenizer and generation config.
tokenizer **=** AutoTokenizer.from_pretrained(model_path, padding_side**=**neuron_config.padding_side)
generation_config = GenerationConfig.from_pretrained(model_path)

# Run generation.
_, output_tokens = get_generate_outputs(
    model,
    prompts,
    tokenizer,
    is_hf=False,
    draft_model=draft_model,
    generation_config=generation_config
)

print("Generated outputs:")
for i, output_token in enumerate(output_tokens):
    print(f"Output {i}: {output_token}")

Medusa Speculative Decoding#

To use Medusa speculative decoding, you must use a model that is specifically fine-tuned for Medusa speculation, such as text-generation-inference/Mistral-7B-Instruct-v0.2-medusa. You must also provide a Medusa tree. For an example Medusa tree, see medusa_mc_sim_7b_63.json in the examples folder in NeuronX Distributed Inference.

To enable Medusa, set is_medusa=True, set the medusa_speculation_length, set the num_medusa_heads, and specify the medusa_tree.

def load_json_file(json_path):
    with open(json_path, "r") as f:
        return json.load(f)

medusa_tree = load_json_file("medusa_mc_sim_7b_63.json")

neuron_config = NeuronConfig(
    is_medusa=True,
    medusa_speculation_length=64,
    num_medusa_heads=4,
    medusa_tree=medusa_tree
)

To run generation with a Medusa model and the HuggingFace generate() API, set the assistant_model to the target model.

For more information about Medusa speculative decoding, see the official implementation on GitHub: FasterDecoding/Medusa.

Medusa speculative decoding currently supports only batch sizes of 1.

EAGLE Speculative Decoding#

NxD Inference supports EAGLE v1 speculative decoding with a flat draft structure.

EAGLE Checkpoint Compatibility#

To use EAGLE speculative decoding, you must use a draft model that is specifically fine-tuned for EAGLE speculation. Additionally, to use EAGLE with NxD Inference, the draft model must include the LM head weights from the target model. These weights are shared between the draft and target model.

Because NxD Inference uses a flat draft structure, it predicts only one token per draft iteration. Although NxD Inference doesn’t support EAGLE with a tree structure, you can train an EAGLE checkpoint in the same way. Note that depending on your use case and dataset, you might see lower acceptance rate with the flat draft structure compared with using a tree structure.

NxD Inference supports EAGLE models with or without input normalization. By default, NxD Inference expects that the EAGLE model doesn’t use input normalization. To use an EAGLE model with input normalization, set enable_eagle_draft_input_norm to True in NeuronConfig.

You can find links to pretrained EAGLE draft model checkpoints for various popular models in the official EAGLE repository on GitHub: SafeAILab/EAGLE. However, these pretrained EAGLE model checkpoints don’t include the LM head weights from the target model. To use these pretrained checkpoints with NxD Inference, you must first copy the LM head weights from the target to the draft model.

The following code demonstrates how to perform this operation for a Llama-3.1-70B-Instruct target model and the corresponding EAGLE draft:

import json
import os

import torch
from safetensors import safe_open
from safetensors.torch import save_file

target_model_path = "Meta-Llama-3.1-70B-Instruct"
draft_model_path = "Llama-3.1-70B-Instruct-EAGLE-Draft"

DRAFT_MODEL_SAFETENSORS_NAME = "model.safetensors"
LM_HEAD_WEIGHT_TENSOR_NAME = "lm_head.weight"
TARGET_MODEL_SAFETENSORS_INDEX_NAME = "model.safetensors.index.json"

def find_lm_head_safetensors_location(model_dir):
    model_index_location_path = os.path.join(model_dir, TARGET_MODEL_SAFETENSORS_INDEX_NAME)

    with open(model_index_location_path, 'r') as f:
        model_index_locations = json.load(f)

    lm_head_safetensors_name = model_index_locations["weight_map"][LM_HEAD_WEIGHT_TENSOR_NAME]

    return lm_head_safetensors_name

# Find the target model `lm_head.weight` location in safetensors
target_lm_head_safetensors_name = find_lm_head_safetensors_location(target_model_path)
target_lm_head_safetensors_path = os.path.join(target_model_path, target_lm_head_safetensors_name)

# Open the target model.safetensor containing `lm_head.weight`
with safe_open(target_lm_head_safetensors_path, framework="pt") as f:
    target_lm_head = f.get_tensor(LM_HEAD_WEIGHT_TENSOR_NAME)

# Collect all tensors in the draft model
draft_model_safetensors_path = os.path.join(draft_model_path, DRAFT_MODEL_SAFETENSORS_NAME)
tensors = {}
with safe_open(draft_model_safetensors_path, framework="pt") as f:
    for key in f.keys():
        tensors[key] = f.get_tensor(key)

# Add the LM head weights and save out the new draft model.safetensors file
tensors[LM_HEAD_WEIGHT_TENSOR_NAME] = target_lm_head.type(torch.float16)
save_file(tensors, draft_model_safetensors_path)

Fused Speculation#

EAGLE speculation uses a feature called fused speculation, where the draft model and target model are fused into a single compiled model to improve performance. Fused speculation uses a different config called FusedSpecNeuronConfig, which specifies the model class. draft config, and draft model path to fuse with the target model.

Example#

import copy

from neuronx_distributed_inference.models.config import (
    FusedSpecNeuronConfig,
    NeuronConfig,
    OnDeviceSamplingConfig
)
from neuronx_distributed_inference.models.llama.modeling_llama import (
    NeuronLlamaForCausalLM,
    NeuronLlamaModel
)
from neuronx_distributed_inference.utils.accuracy import get_generate_outputs
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config
from transformers import AutoTokenizer, GenerationConfig

prompt = "The future of AI is"

model_path = "/home/ubuntu/models/Llama-3.1-70B-Instruct"
draft_model_path = "/home/ubuntu/models/Llama-3.1-70B-Instruct-EAGLE-Draft"
compiled_model_path = "/home/ubuntu/neuron_models/Llama-3.1-70B-Instruct-EAGLE"
max_sequence_length = 1024

# Initialize on-device sampling configuration.
on_device_sampling_config = OnDeviceSamplingConfig(
    temperature=0.7,
    top_k=50,
    top_p=1.0,
)

# Initialize model configuration.
neuron_config = NeuronConfig(
    # Neuron supports EAGLE batch sizes greater than 1.
    # We set batch size to 1 in this tutorial due to a
    # limitation in the transformers library for
    # generation with speculative decoding.
    # For more information, see: https://github.com/huggingface/transformers/issues/32165
    batch_size = 1,
    enable_eagle_speculation=True,
    enable_fused_speculation=True,
    max_context_length=max_sequence_length,
    max_length=max_sequence_length,
    on_device_sampling_config=on_device_sampling_config,
    seq_len=max_sequence_length,
    speculation_length=5,
    # For best performance, set to the maximum tensor
    # parallelism of your Neuron instance type.
    tp_degree=32,
    trace_tokengen_model=False
)

config = NeuronLlamaForCausalLM.get_config_cls()(
    neuron_config, load_config=load_pretrained_config(model_path)
)

# Initialize draft model configuration and set EAGLE-specific values.
draft_neuron_config = copy.deepcopy(neuron_config)
draft_neuron_config.trace_tokengen_model = True
draft_neuron_config.enable_fused_speculation = False
draft_neuron_config.is_eagle_draft = True
draft_neuron_config.sequence_parallel_enabled = False

draft_config = NeuronLlamaForCausalLM.get_config_cls()(
    draft_neuron_config, load_config=load_pretrained_config(draft_model_path))

# Initialize fused speculation configuration.
fused_spec_config = FusedSpecNeuronConfig(
    NeuronLlamaForCausalLM._model_cls,
    draft_config=draft_config,
    draft_model_path=draft_model_path,
)
config.fused_spec_config = fused_spec_config

# Initialize model from configuration.
model = NeuronLlamaForCausalLM(model_path, config)

# Compile and save model.
model.compile(compiled_model_path)

# Load model to the Neuron device.
model.load(compiled_model_path)

# Load tokenizer and generation config.
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=neuron_config.padding_side)
generation_config = GenerationConfig.from_pretrained(model_path)
generation_config.max_length = 1024
# pad_token_id is required for Hugging Face assisted sampling.
generation_config.pad_token_id = tokenizer.eos_token_id

# Run generation and print outputs.
_, output_tokens = get_generate_outputs(
    model,
    [prompt],
    tokenizer,
    is_hf=False,
    # draft_model is not set here due to fused speculation.
    draft_model=None,
    generation_config=generation_config
)

print("Generated output:")
for _, output in enumerate(output_tokens):
    print(output)

MoE model architecture support#

NxD Inference supports mixture-of-experts (MoE) models. The library includes ready-to-use modeling code for Mixtral and DBRX. These models are built using reusable MoE modules from NeuronX Distributed Core: RouterTopK, ExpertMLPs, and MoE. You can use these modules to onboard additional MoE models.

NxD Inference also provides a helper function, initialize_moe_module, which you can use to initialize an MoE model’s MLP module from these MoE modules. For examples of how to use this helper function, see the decoder layer module implementation in the Mixtral and DBRX modeling code.

Grouped-query attention (GQA) support#

NxD Inference provides a reusable attention module, NeuronAttentionBase, which you can use when onboarding models. This module is also used in NxD Inference modeling code like Llama and Mixtral.

NxD Inference supports the following sharding strategies for the KV cache used in the attention module:

  • CONVERT_TO_MHA — Transforms a GQA attention mechanism into a traditional MHA mechanism by replicating the K/V heads to evenly match the corresponding Q heads. This consumes more memory than would otherwise be used with other sharding mechanisms but works in all cases.

  • REPLICATE_TO_TP_DEGREE — Transforms a GQA attention mechanism such that there is exactlyone K/V head per tp_degree through replication e.g. 8 K/V heads with tp_degree=32 results in 32 K/V heads. This is more memory efficient but does not work for all configurations. Q heads are padded interleaved to retain correct alignment between Q and K/V heads.

The NeuronAttentionBase module uses REPLICATE_TO_TP_DEGREE by default. If the TP degree isn’t divisible by the number of KV heads, NeuronAttentionBase uses CONVERT_TO_MHA.