This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
NCC_EVRF010#
Error message: The compiler encountered simultaneous use of input and kernel dilation, which is not supported.
Erroneous code example:
x = jnp.ones((1, 4, 4, 1), dtype=jnp.float32)
kernel = jnp.ones((3, 3, 1, 1), dtype=jnp.float32)
result = lax.conv_general_dilated(
x,
kernel,
window_strides=(1, 1),
padding=((2, 2), (2, 2)),
lhs_dilation=(2, 2), # input dilation
rhs_dilation=(2, 2), # kernel dilation
dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)
If possible, use only only input or kernel dilation:
x = jnp.ones((1, 4, 4, 1), dtype=jnp.float32)
kernel = jnp.ones((3, 3, 1, 1), dtype=jnp.float32)
result = lax.conv_general_dilated(
x,
kernel,
window_strides=(1, 1),
padding=((2, 2), (2, 2)),
lhs_dilation=(1, 1), # no input dilation
rhs_dilation=(2, 2),
dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)
Or apply dilation manually and apply convolution to the remainder.
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3