This document is relevant for: Inf1

Developer Guide - PyTorch Neuron (torch-neuron) LSTM Support#

The torch-neuron package can support LSTM operations and yield high performance on both fixed-length and variable-length sequences. Most network configurations can be supported, with the exception of those that require PackedSequence usage outside of LSTM or pad_packed_sequence() operations. Neuron must guarantee that the shapes can remain fixed throughout the network.

The following sections describe which scenarios can and cannot be supported.

Supported Usage#

Fixed-Length Sequences#

In normal usage of an LSTM, the inputs and outputs are expected to be a fixed size sequence length. This is the most basic usage of an LSTM but may not be applicable to applications where the input sequence length may vary.

import torch
import torch_neuron

class Network(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)

    def forward(self, inputs):
        output, (ht, ct) = self.lstm(inputs)
        return output, (ht, ct)

# Example Inputs
seq_len, batch_size, input_size = 5, 2, 3
inputs = torch.rand(seq_len, batch_size, input_size)

# Trace
torch_neuron.trace(Network(), (inputs,))

Packed Input, Padded Output, Pre-Sorted Inputs#

A common usage of an LSTM is when the input sequence sizes vary according to an input sequence lengths (such as tokens).

For example, the following sentences could result in two different sequence lengths after tokenization:

# Input
text = [
   'Hello, sailor',
   'Example',
]

# ... Tokenization ...

# Result
tokens = [
    [101, 7592, 1010, 11803, 102],
    [101, 2742,  102,     0,   0],
]
lengths = [5, 3]

Because the lengths are different, the final LSTM state will be dependent upon the lengths of each sequence in the batch. Torch provides a way to deal with these types of sequences by densely packing batches into a PackedSequence. The most common way this is constructed is by using the pack_padded_sequence() utility function prior to feeding inputs into the LSTM.

Packing the above sequences would result in the following data and batch size tensors.

data = [101, 101, 7592, 2742, 1010, 102, 11803, 102]
batch_sizes = [2, 2, 2, 1, 1]

In addition to correctly computing final LSTM state, using a packed sequence instead of a padded sequence also improves model performance on CPU. On Neuron, where computation is fixed to the maximum length ahead of time, this is does not improve performance.

When an LSTM is processing a PackedSequence, it must do so in a descending sorted length order. To ensure that sequences are sorted, pack_padded_sequence() provides an enforce_sorted flag. When enforce_sorted is True, the input is already expected to contain sequences sorted by length in a decreasing order along the batch dimension. Note that this must be enforced in the application-level code but is only relevant when batch size > 1.

The following network can compile successfully because the input and output to the network are guaranteed to be a fixed shape. The input shape is expected to be a padded tensor and the output tensor is expected to be padded to the maximum sequence length using the pad_packed_sequence() function call:

import torch
import torch_neuron

class Network(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)

    def forward(self, inputs, lengths):
        packed_input = torch.nn.utils.rnn.pack_padded_sequence(
            inputs,
            lengths=lengths,
            enforce_sorted=True,
        )
        packed_result, (ht, ct) = self.lstm(packed_input)
        padded_result, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_result)
        return padded_result, ht, ct

# Example Inputs
seq_len, batch_size, input_size = 5, 2, 3
inputs = torch.rand(seq_len, batch_size, input_size)
lengths = torch.tensor([seq_len] * batch_size)

# Trace
torch_neuron.trace(Network(), (inputs, lengths))

Packed Input, Padded Output, Unsorted Inputs#

When enforce_sorted is False, the input will be sorted unconditionally. This causes some CPU overhead on Neuron because unsupported operators will be inserted into the graph such as aten::sort and aten::scatter_. The aten::lstm operation can still be supported, but it will be less efficient than when enforce_sorted is True.

The following code is able to be traced, but results in the sorting operations running on CPU. This is not problematic in this case because the aten::sort and aten::scatter_ are executed on CPU at the very beginning of the graph just prior to Neuron execution.

Like the previous example, the call to pad_packed_sequence() ensures that the output is a fixed-shape based on the maximum sequence length.

import torch
import torch_neuron

class Network(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)

    def forward(self, inputs, lengths):
        packed_input = torch.nn.utils.rnn.pack_padded_sequence(
            inputs,
            lengths=lengths,
            enforce_sorted=False,
        )
        packed_result, (ht, ct) = self.lstm(packed_input)
        padded_result, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_result)
        return padded_result, ht, ct

# Example Inputs
seq_len, batch_size, input_size = 5, 2, 3
inputs = torch.rand(seq_len, batch_size, input_size)
lengths = torch.tensor([seq_len] * batch_size)

# Trace
trace = torch_neuron.trace(Network(), (inputs, lengths))

Packed Inputs, Final Hidden & Cell State Only#

When only the final LSTM hidden & cell state is used, it does not matter if the inputs are packed or unpacked since these state tensors will not vary in size.

import torch
import torch_neuron

class Network(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)

    def forward(self, inputs, lengths):
        packed_input = torch.nn.utils.rnn.pack_padded_sequence(
            inputs,
            lengths=lengths,
            enforce_sorted=True,
        )
        packed_output, (ht, ct) = self.lstm(packed_input)
        return ht, ct

# Example Inputs
seq_len, batch_size, input_size = 5, 2, 3
inputs = torch.rand(seq_len, batch_size, input_size)
lengths = torch.tensor([seq_len] * batch_size)

# Trace
trace = torch_neuron.trace(Network(), (inputs, lengths))

Note that when the packed_output is unused, it does not need to be passed to the pad_packed_sequence() to enable the LSTM to be compiled.

Unsupported Usage#

Neuron does not support the use of a PackedSequence outside of the LSTM operation and the pad_packed_sequence() operation. This is because the shape of a PackedSequence can vary depending on the input data. This is incompatible with the Neuron restriction that all tensor sizes must be known at compilation time. When a PackedSequence is used only by an LSTM or pad_packed_sequence() operation, Neuron can guarantee the size of the intermediary tensors by padding on behalf of the application.

This means that If the PackedSequence is either used by a different operation or returned from the network this would result in all of the LSTM operations to be executed on CPU or the network compilation will fail.

PackedSequence Returned#

The following is unsupported because the PackedSequence result of the LSTM is returned by the network:

class Network(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)

    def forward(self, inputs, lengths):
        packed_input = torch.nn.utils.rnn.pack_padded_sequence(
            inputs,
            lengths=lengths,
            enforce_sorted=False,
        )
        packed_result, (ht, ct) = self.lstm(packed_input)
        return packed_result.data, ht, ct

Behavior: In this case, compilation fails and the following warning is generated:

Operator "aten::lstm" consuming a PackedSequence input can only be supported when its corresponding PackedSequence output is unused or unpacked using "aten::_pad_packed_input". Found usage by "prim::Return"

Resolution: To avoid this error, the packed_result should be padded prior to being returned from the network by using pad_packed_sequence()

Invalid PackedSequence Usage#

The following is unsupported because the PackedSequence result of the LSTM is used by a non-LSTM operator:

class Network(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)

    def forward(self, inputs, lengths):
        packed_input = torch.nn.utils.rnn.pack_padded_sequence(
            inputs,
            lengths=lengths,
            enforce_sorted=False,
        )
        packed_result, (ht, ct) = self.lstm(packed_input)
        return torch.max(packed_result.data)

Behavior: In this case, compilation fails and the following warning is generated:

Operator "aten::lstm" consuming a PackedSequence input can only be supported when its corresponding PackedSequence output is unused or unpacked using "aten::_pad_packed_input". Found usage by "aten::max"

Resolution: To avoid this error, the packed_result should be padded prior to being used in the max() from the network by using pad_packed_sequence().

This document is relevant for: Inf1