Inpainting Images with Black Forest Labs Flux.1-Fill-Dev on Trn1/Trn2#
This tutorial provides a step-by-step guide for inpainting/outpainting images using the Flux.1-Fill-dev model from Black Forest Labs with NeuronX Distributed (NxD) Inference on a single trn2.48xl instance.
Background, Concepts, and Optimizations#
Tensor and Context Parallelism#
For the latent transformer model, use a combination of Tensor Parallelism and Context Parallelism. Due to the compute-bound nature of diffusion inference, add additional parallelism by using sharding on the sequence dimension. Sharding is governed by the world_size relative to the backbone_tp_degree.
Step 1: Setup the environment#
Set up and connect to a trn2.48xlarge instance#
As a prerequisite, this tutorial requires that you have a Trn2 instance created from a Deep Learning AMI that has the Neuron SDK pre-installed. To set up a Trn2 instance using Deep Learning AMI with pre-installed Neuron SDK, see the NxDI setup guide.
After setting up an instance, use SSH to connect to the Trn2 instance using the key pair that you chose when you launched the instance.
To use a Jupyter Notebook (.ipynb) on the Neuron instance, follow the Jupyter Notebook QuickStart guide.
After you are connected, activate the Python virtual environment that includes the Neuron SDK.
source ~/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
Run pip list to verify that the Neuron SDK is installed.
pip list | grep neuron
You should see Neuron packages including neuronx-distributed-inference and neuronx-cc.
Download the model#
To use this sample, you must first download the model checkpoint from HuggingFace to a local path on the Trn2 instance. For more information, see Download models in the HuggingFace documentation. You can download and use black-forest-labs/FLUX.1-dev for this tutorial.
[ ]:
!pip install matplotlib
[ ]:
import os
import torch
from matplotlib import pyplot as plt
from neuronx_distributed_inference.models.diffusers.flux.application import NeuronFluxApplication
from neuronx_distributed_inference.models.config import NeuronConfig
from neuronx_distributed_inference.models.diffusers.flux.clip.modeling_clip import CLIPInferenceConfig
from neuronx_distributed_inference.models.diffusers.flux.t5.modeling_t5 import T5InferenceConfig
from neuronx_distributed_inference.models.diffusers.flux.modeling_flux import FluxBackboneInferenceConfig
from neuronx_distributed_inference.models.diffusers.flux.vae.modeling_vae import VAEDecoderInferenceConfig
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config
from neuronx_distributed_inference.utils.diffusers_adapter import load_diffusers_config
Step 2: Setup Inference Parameters and Model Config#
Start by initializing your inference parameters, which include model parallelism configuration, image sizes and model configuration. Ensure that that CKPT_DIR matches the local directory where you downloaded the model in Step 1
[ ]:
world_size = 8
backbone_tp_degree = 4
dtype = torch.bfloat16
height, width = [1024, 1024]
guidance_scale = 3.5
num_inference_steps = 25
prompt = "Milky way galaxy in space"
# The Ckpt directory root under huggingface
CKPT_DIR = "/shared/models/FLUX.1-Fill-dev/"
# Existing Compiled working directory for the compiler
BASE_COMPILE_WORK_DIR = "/tmp/flux/compiler_workdir/"
Step 3: Setup Model and Neuron Configuration#
Here, you initialize the various component model configuration objects for the models within the Flux Pipeline. The Flux pipeline contains CLIP, T5, the backbone transformer and the VAE. For each component model, you can use the following parallelism configuration:
For CLIP,
tp_degreeof 1For T5,
tp_degreeis the same as theworld_size. In the case of this example, this will be 8.For the backbone transformer, if using Context Parallelism,
tp_degreeis half the world size. In the case of this example, this will be 4, which allows for 2 CP ranks.Finally, for the VAE,
tp_degreeof 1.
[ ]:
text_encoder_path = os.path.join(CKPT_DIR, "text_encoder")
text_encoder_2_path = os.path.join(CKPT_DIR, "text_encoder_2")
backbone_path = os.path.join(CKPT_DIR, "transformer")
vae_decoder_path = os.path.join(CKPT_DIR, "vae")
clip_neuron_config = NeuronConfig(
tp_degree=1,
world_size=world_size,
torch_dtype=dtype,
)
clip_config = CLIPInferenceConfig(
neuron_config=clip_neuron_config,
load_config=load_pretrained_config(text_encoder_path),
)
t5_neuron_config = NeuronConfig(
tp_degree=world_size,
world_size=world_size,
torch_dtype=dtype,
)
t5_config = T5InferenceConfig(
neuron_config=t5_neuron_config,
load_config=load_pretrained_config(text_encoder_2_path),
)
backbone_neuron_config = NeuronConfig(
tp_degree=backbone_tp_degree,
world_size=world_size,
torch_type=dtype,
)
backbone_config = FluxBackboneInferenceConfig(
neuron_config=backbone_neuron_config,
load_config=load_diffusers_config(backbone_path),
height=height,
width=width,
)
decoder_neuron_config = NeuronConfig(
tp_degree=1,
world_size=world_size,
torch_type=dtype,
)
decoder_config = VAEDecoderInferenceConfig(
neuron_config=decoder_neuron_config,
load_config=load_diffusers_config(vae_decoder_path),
height=height,
width=width,
)
setattr(
backbone_config,
"vae_scale_factor",
decoder_config.vae_scale_factor,
)
Step 4: Initialize the Flux Application and Compile#
Now you instantiate the NeuronFluxApplication which contains the pipeline orchestration logic, as well as the various component models. You then compile the application, which then compiles each component model individually.
[ ]:
flux_app = NeuronFluxApplication(
model_path=CKPT_DIR,
text_encoder_config=clip_config,
text_encoder2_config=t5_config,
backbone_config=backbone_config,
decoder_config=decoder_config,
height=height,
width=width,
)
flux_app.compile(BASE_COMPILE_WORK_DIR)
Step 5: Load Model#
This step loads the compiled model (NEFF), along with the model weights into device memory. Specifically, calling load on the flux_app loads all the individual component models.
[ ]:
flux_app.load(BASE_COMPILE_WORK_DIR)
Step 6: Load the Image and Mask#
Load the image and mask which denotes the area that has to be filled in adherence to the prompt. The cat.png and mask.png are taken from COCO dataset (https://cocodataset.org/#explore?id=261706). Ensure that the images are in the same directory as the notebook.
[ ]:
from diffusers.utils import load_image
from PIL import Image
def load_and_resize_image(image_path: str, height: int, width: int) -> Image.Image:
"""Load an image from a file path and resize it to the specified dimensions."""
image = load_image(image_path)
return image.resize((width, height), Image.Resampling.LANCZOS)
image = load_and_resize_image('./cat.png', height, width)
mask_image = load_and_resize_image('./mask.png', height, width)
Step 7: Generate Fill Image using the model#
Finally, you will fill the masked-region of the image using the prompt and render it:
[ ]:
image = flux_app(
prompt=prompt,
image=image,
mask_image=mask_image,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
).images[0]
plt.imshow(image)
plt.show()
Notes#
Running Flux Inference on trn1#
This sample can also be deployed to a trn1.32xlarge with a few modifications. If you are using Context Parallelism specifically, then apply the following parallelism configuration
world_size = 16
backbone_tp_degree = 8
Otherwise use the following:
world_size = 8
backbone_tp_degree = 8