This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
NCC_EVRF013#
Error message: TopK does not support int32 or int64 input tensors.
Erroneous code example:
def forward(self, x):
# assume x is an integer tensor
# error: cannot call TopK on integer dtypes
k = 5
values, indices = torch.topk(x, k=k, dim=-1)
return values, indices
To fix this error, you can cast your tensor to a supported floating point dtype.
def forward(self, x):
x = x.float()
k = 5
values, indices = torch.topk(x, k=k, dim=-1)
return values, indices
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3