This document is relevant for: Inf2
, Trn1
, Trn2
How to Use Convolution Kernels in UNet Training Models#
Task overview#
This topic discusses how to modify UNet training models to use convolution kernels with the AWS Neuron SDK. This implementation helps avoid out-of-memory errors seen when performing training on the convolution-heavy UNet model.
Prerequisites#
AWS Neuron SDK 2.26 or later: Required for kernel implementation support
trn1.32xlarge instance: Needed for model training
Existing UNet implementation: Base model to be modified
PyTorch-Neuron environment: Required for neural network operations
Instructions#
1: Import required dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
from neuronxcc.nki._private_kernels.conv import conv2d_dw_fb01_io01_01bf_rep_nhwc_Pcinh
2: Create the convolution wrapper function
@nki.jit
def conv_wrap(img_ref, filter_ref, out_shape):
out_arr = nl.ndarray(shape=out_shape, dtype=img_ref.dtype, buffer=nl.hbm)
conv2d_dw_fb01_io01_01bf_rep_nhwc_Pcinh(img_ref, filter_ref, out_arr, **{
'input': img_ref.shape,
'filter': filter_ref.shape,
'output': out_shape,
'in_perm': [0, 1, 2, 3],
'kern_perm': [0, 1, 2, 3],
'out_perm': [0, 1, 2, 3],
'stride': (1, 1),
'padding': ((1, 1), (1, 1))})
return out_arr
3: Implement the custom Conv2d module
class BwdConv2dWithKernel(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding, bias):
super().__init__()
assert padding == 1
assert bias == False
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
nn.init.kaiming_uniform_(self.weight, a=0.0, mode='fan_in', nonlinearity='leaky_relu')
4: Replace standard convolutions in the UNet model
class DoubleConvWithKernel(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
BwdConv2dWithKernel(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
BwdConv2dWithKernel(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
5: Update the UNet model initialization
def __init__(self, n_channels, n_classes, bilinear=False):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConvWithKernel(n_channels, 64))
# ... rest of initialization
Confirm your work#
To confirm successful implementation, verify the following:
Expected training output
Training Device=xla:0 Epoch=1 Step=20 Loss=0.30803
Training Device=xla:0 Epoch=2 Step=560 Loss=0.01826
Check for:
No out-of-memory errors during execution
Decreasing loss values across epochs
Common issues#
Memory Errors
Solution: Verify all standard convolutions are replaced with BwdConv2dWithKernel implementations
Compilation Errors
Solution: Confirm Neuron SDK version is 2.26 or later
Kernel Errors
Solution: Use the kernel for supported configurations. The kernel will error out in unsupported scenarios.