This document is relevant for: Inf2, Trn1, Trn2

How to Use Convolution Kernels in UNet Training Models#

Task overview#

This topic discusses how to modify UNet training models to use convolution kernels with the AWS Neuron SDK. This implementation helps avoid out-of-memory errors seen when performing training on the convolution-heavy UNet model.

Prerequisites#

  • AWS Neuron SDK 2.26 or later: Required for kernel implementation support

  • trn1.32xlarge instance: Needed for model training

  • Existing UNet implementation: Base model to be modified

  • PyTorch-Neuron environment: Required for neural network operations

Instructions#

1: Import required dependencies

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
from neuronxcc.nki._private_kernels.conv import conv2d_dw_fb01_io01_01bf_rep_nhwc_Pcinh

2: Create the convolution wrapper function

@nki.jit
def conv_wrap(img_ref, filter_ref, out_shape):
    out_arr = nl.ndarray(shape=out_shape, dtype=img_ref.dtype, buffer=nl.hbm)
    conv2d_dw_fb01_io01_01bf_rep_nhwc_Pcinh(img_ref, filter_ref, out_arr, **{
        'input': img_ref.shape,
        'filter': filter_ref.shape,
        'output': out_shape,
        'in_perm': [0, 1, 2, 3],
        'kern_perm': [0, 1, 2, 3],
        'out_perm': [0, 1, 2, 3],
        'stride': (1, 1),
        'padding': ((1, 1), (1, 1))})
    return out_arr

3: Implement the custom Conv2d module

class BwdConv2dWithKernel(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, bias):
        super().__init__()
        assert padding == 1
        assert bias == False
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        nn.init.kaiming_uniform_(self.weight, a=0.0, mode='fan_in', nonlinearity='leaky_relu')

4: Replace standard convolutions in the UNet model

class DoubleConvWithKernel(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            BwdConv2dWithKernel(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            BwdConv2dWithKernel(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

5: Update the UNet model initialization

def __init__(self, n_channels, n_classes, bilinear=False):
    super().__init__()
    self.n_channels = n_channels
    self.n_classes = n_classes
    self.bilinear = bilinear
    self.inc = (DoubleConvWithKernel(n_channels, 64))
    # ... rest of initialization

Confirm your work#

To confirm successful implementation, verify the following:

Expected training output
Training Device=xla:0 Epoch=1 Step=20 Loss=0.30803
Training Device=xla:0 Epoch=2 Step=560 Loss=0.01826

Check for:

  • No out-of-memory errors during execution

  • Decreasing loss values across epochs

Common issues#

Memory Errors

  • Solution: Verify all standard convolutions are replaced with BwdConv2dWithKernel implementations

Compilation Errors

  • Solution: Confirm Neuron SDK version is 2.26 or later

Kernel Errors

  • Solution: Use the kernel for supported configurations. The kernel will error out in unsupported scenarios.