This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

NCC_ESPP047#

Error message: The compiler found usage of an unsupported 8-bit floating-point data type.

Erroneous code example:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

# Unsupported 8-bit floating-point data type being used here
input_tensor = torch.randn(1, 10).to(torch.float8_e4m3fn)

To fix this error:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

input_tensor = torch.randn(1, 10).to(torch.float8_e4m3fn)
# Convert to a supported type
input_tensor = input_tensor.to(torch.float16)

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3