Tutorial: Deploy Pixtral Large on Trn2 instances#

This tutorial provides a step-by-step guide to deploy mistralai/Pixtral-Large-Instruct-2411 using NeuronX Distributed (NxD) Inference on a single trn2.48xlarge instance.

Prerequisites#

Set up and connect to a trn2.48xlarge instance#

As a prerequisite, this tutorial requires that you have a Trn2 instance with a Deep Learning AMI that has the Neuron SDK pre-installed. To set up a Trn2 instance using Deep Learning AMI with pre-installed Neuron SDK, see the NxDI setup guide.

To use Jupyter Notebook on the Neuron instance, you can follow this guide. After you are connected, activate the Python virtual environment that includes the Neuron SDK.

pip list | grep neuron

You should see Neuron packages including neuronx-distributed-inference and neuronx-cc.

Install packages#

NxD Inference supports running models with vLLM. This functionality is available in a fork of the vLLM GitHub repository:

To run NxD Inference with vLLM, you need to download and install vLLM from this fork. Refer the Neuron vllm installation guide to install vllm.

Ensure that the Neuron virtual environment is activated if using a new terminal instead of the one from connect step above. Then, install the Neuron vLLM fork into the virtual environment.

Step 1 Download the model and convert the checkpoint#

To deploy mistralai/Pixtral-Large-Instruct-2411 on Neuron, you need to first download the checkpoint from HuggingFace to a local path on the Trn2 instance (for more information on downloading models from HuggingFace, refer the guide on Downloading models).

Once you have downloaded the model, convert the original Pixtral checkpoint by running the following script. After the conversion, you should see a config.json file in the output folder along with weights in model-xxxx-of-xxxx.safetensors format.

Note: There is a known issue in the Huggingface conversion script that sets the image_token_index to 32000 in config.json. You need to manually set image_token_index to 10 before proceeding with the subsequent steps.

Step 2: Compile and deploy Pixtral Large#

While compiling the model, certain configurations are used to optimize the performance of the model. These configurations are described below and can be modified as per one’s use-case.

  • Pixtral consists of a text model and a vision encoder. You need to specify configurations explicitly through text_neuron_config and vision_neuron_config.

  • tp_degree : This is the tensor parallel degree for sharding the model across the neuron cores. Here, it is set to 64 for the text model and 16 for the vision encoder.

  • batch_size : This is set to the batch size for compiling the models. Currently prefill is always done with batch_size = 1; hence the batch_size in vision_neuron_config is set to 1 and the batch_size in text_neuron_config is set to the desired value for handling concurrent requests (same as max-num-seqs for the vllm argument).

  • seq_len : Set this to the maximum sequence length that needs to be supported.

  • text_neuron_config

    • enable_bucketing : Bucketing allows one to optimize performance for specific sequence lengths and in this case we configure specific buckets.

    • context_encoding_buckets : This refers to the prefill phase (size of the input prompt) and should be set to handle different sequence lengths for inputs. It’s set to [2048, 4096, 10240].

    • token_generation_buckets : Token generation buckets are set to the output token lengths. In this case - [2048, 4096, 10240].

    • flash_decoding_enabled : Setting this to True enables partitioning the KV cache and improves the performance for long sequences. Refer the app note on Flash Decoding for more details.

    • sequence_parallel_enabled : Sequence Parallelism splits tensors across the sequence dimension to improve performance.

    • fused_qkv : QKV weight fusion concatenates a model’s query, key and value weight matrices to achieve better performance.

    • qkv_kernel_enabled : Enable the use of the fused QKV kernel.

    • mlp_kernel_enabled : Enable the use of the MLP kernel.

    • cc_pipeline_tiling_factor :

  • vision_neuron_config

    • buckets : In the context of the vision encoder, buckets account for two dimensions - image sizes and number of images. The Pixtral HF processor processes each image in 16x16 patches. For example, a 512x512 image is processed as a 32x32 grid, which is 32x32=1024 image tokens. To handle 6 images, it’ll be 6144 tokens. In this case, buckets are set to [2048, 4096, 6144, 8192, 10240] to handle different number of images and image sizes.

    • seq_len : Set this to the maximum sequence length for the use case.

    • tp_degree : The vision encoder uses a tensor parallel degree of 16.

Compile and deploy using vllm#

In this step, you can directly use the vllm command to deploy the model. The neuronx-distributed-inference model loader in vllm performs JIT compilation before deploying it with the model server. Replace <path to converted pixtral checkpoint> with your specific path before running the below command.

[ ]:
%%writefile start_vllm.sh
#!/bin/bash

echo "Running vLLM server in the background..."
rm -f ./vllm_server.log

export NEURON_RT_INSPECT_ENABLE=0
export NEURON_RT_VIRTUAL_CORE_SIZE=2
export VLLM_NEURON_FRAMEWORK="neuronx-distributed-inference"
VLLM_RPC_TIMEOUT=100000

nohup python3 -m vllm.entrypoints.openai.api_server \
    --model "/home/ubuntu/model_hf/" \
    --limit-mm-per-prompt 'image=6' \
    --tensor-parallel-size 64 \
    --max-model-len 10240 \
    --max-num-seqs 4 \
    --device neuron \
    --override-neuron-config "{\"text_neuron_config\": { \"tp_degree\": 64, \"world_size\": 64, \"batch_size\": 4, \"seq_len\": 10240, \"ctx_batch_size\": 1, \"flash_decoding_enabled\": true, \"enable_bucketing\": true, \"skip_warmup\": true, \"context_encoding_buckets\": [2048, 4096, 10240], \"token_generation_buckets\": [2048, 4096, 10240], \"torch_dtype\": \"float16\", \"sequence_parallel_enabled\": true, \"fused_qkv\": true, \"qkv_kernel_enabled\": true, \"mlp_kernel_enabled\": true, \"cc_pipeline_tiling_factor\": 1 }, \"vision_neuron_config\": { \"batch_size\": 1, \"seq_len\": 10240, \"tp_degree\": 16, \"world_size\": 64, \"torch_dtype\": \"float16\", \"buckets\": [2048, 4096, 6144, 8192, 10240] }}" > ./vllm_server.log 2>&1 &
SERVER_PID=$!

echo "Server started in the background with the following id: $SERVER_PID. Waiting until server is ready to serve..."
until grep -q "Application startup complete" ./vllm_server.log 2>/dev/null || ! kill -0 $SERVER_PID 2>/dev/null; do sleep 0.5; done
grep -q "Application startup complete" ./vllm_server.log 2>/dev/null && echo "vLLM Server is ready!" || (echo "vLLM Server failed, check the ./vllm_server.log file" && exit 1)
[ ]:
!chmod +x ./start_vllm.sh
!./start_vllm.sh

Step 3: Ping the server using a client#

After deploying the model server, you can run inference by sending it requests. The below example sends a text prompt with a single image -

[ ]:
import requests
import json
from huggingface_hub import hf_hub_download
from datetime import datetime, timedelta

url = "http://0.0.0.0:8000/v1/chat/completions"
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}

model = "mistralai/Pixtral-Large-Instruct-2411"
vllm_model = "/home/ubuntu/model_hf/"

def load_system_prompt(repo_id: str, filename: str) -> str:
    file_path = hf_hub_download(repo_id=repo_id, filename=filename)
    with open(file_path, "r") as file:
        system_prompt = file.read()
    today = datetime.today().strftime("%Y-%m-%d")
    yesterday = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
    model_name = repo_id.split("/")[-1]
    return system_prompt.format(name=model_name, today=today, yesterday=yesterday)


SYSTEM_PROMPT = load_system_prompt(model, "SYSTEM_PROMPT.txt")

image_url = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/europe.png"

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": "Which of the depicted countries has the best food? Which the second and third and fourth? Name the country, its color on the map and one its city that is visible on the map, but is not the capital. Make absolutely sure to only name a city that can be seen on the map.",
            },
            {"type": "image_url", "image_url": {"url": image_url}},
        ],
    },
]

data = {"model": vllm_model, "messages": messages}

response = requests.post(url, headers=headers, data=json.dumps(data))
print(response.json()["choices"][0]["message"]["content"])

Sample response from the model#

The ranking of countries based on the best food is subjective and can vary greatly depending on personal preferences. It can be perceived as offensive by some to rank cuisines but I will do it based on commonly held opinions.

1. Italy
Color on the map: Brown
City visible on the map: Napoli (in brown color)

2. France
Color on the map: Dark teal
City visible on the map: Marseille (in dark teal color)

3. Spain
Color on the map: Red pink
City visible on the map: Barcelona (in red pink color)

4. Germany
Color on the map: Orange
City visible on the map: Cologne (in orange color)

Conclusion#

Congratulations ! You now know how to deploy mistralai/Pixtral-Large-Instruct-2411 on a trn2.48xlarge instance. Modify the configurations and deploy the model as per your requirements and use case.