You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Nov 15, 2022. It is now read-only.
Introduce NestedTensor as a means of converting and standardizing lists of Tensors of different sizes.
Value to the user
Construct a single data structure that any operator with a need for variably sized data can accept. Leave the construction and constraints of specialized layouts or metadata such as PackedSequence or int32 offsets to an operator-specific overload.
User cases
torch.nn layers
After an audit of nn layers we found the following list of nn layers that either expect or optionally support some kind of representation of batches of variably sized data as input. Each of these representations has their own constraints such as non-empty entries, sorted by length, a choice of batch-first or last, etc.
together with torch.nn.utils.rnn, which is needed for optimized RNN kernels. Usually expects input of shape (sequence, batch, embedding). NB: The batch dimension being second is a performance artifact. Alternatively PaddedSequence can be used to represent a collection of variably sized sequences of shapes (S1, E), (S2, E), ..., (SN, E).
PaddedSequence: specialized layout based data structure for CUDNN RNN functions that stores variable length data in a compressed form. Closest thing we have to NestedTensor, but restricted in layout and specific to only RNNs on GPUs. Awkward to use with DataParallel.
construction of PackedSequence:
pack_padded_sequence: returns PackedSequence given padded tensor plus lengths (must be on CPU). Lengths cannot be 0, must be sorted for ONNX, batch first or last must be chosen for result. example invocation.
pack_sequence: returns PackedSequence based on list of Tensors of shape L x *
decomposition of PackedSequence:
pad_packed_sequence: return padded tensor plus lengths given PackedSequence. In particular useful for multi-GPU case. example invocation.
pad_sequence: returns padded tensor based on list of Tensors of shape L x * (list of Tensors means users can get lengths)
input and offsets: 1d data Tensor of concatenated vectors representing sequences with 1d integer Tensor of offsets to represent the boundaries between them. EmbeddingBag here is a fusion of nn.Embedding and a reduction (specified via enum and could be sum, max or mean) across each sequence, which allows it to accept variable length data and return a regular torch.Tensor. offset maybe match CSR if last offset is included (optional flag). Supports empty sequences, i.e. repeated offsets, which will return zeros.
padding by specifying a target index (often a universally set padding index) value that is then given to inputs to be ignored. Default value is -100. example initialization. The MarginLosses don't have a choice of padding index and are set to -1.
padded input and target data together input_lenghts and target_lengths, similar to what pad_sequence returns. example construction and example usesite. target_lengths needs to be int32, blank set to 0 and entries less than 256 to use efficient CUDNN kernels. Note also that the batch dimension has to be second, which is the same for RNN and Transformers.
Expected padded input data, plus a mask similarly shaped mask where 1 or True means "ignore the corresponding element". example construction of inputs. The batch dimension is second to match the RNN API.
Collate functions in the wild
We often see users write one-off implementations to represent a list of Tensors as a single Tensor plus a Tensor mask or similar representations. This can happen as an explicit, separate factory function, as part of a collate function in context of DataLoader or is built right into the Datasets. TensorCores further complicate this story due to their multiple of X constraints. Further, users might also implement operations such as merging batches across the time dimension, which can be quite tricky to get right.
We currently don't provide a canonical implementation that allows users to standardize on a singular approach to this. It's easy to make mistakes and there can be subtle bugs when dealing with, for example, empty Tensor entries. Instead, operators could accept a specialized data structure.This is also a vehicle to standardize some of these layouts and prevent minor differences such as a different default padding index or a slightly different offset convention.
API
See colab notebook for example implementation, supported layouts, API, API invocations, example resolutions and discussion of design decisions.
Concrete next steps
Add additional Layout support to support operators
Ship Python-only version of NestedTensor that works against PyTorch 1.6+ (based on target library constraints) with the requested features
Send integration PRs to select libraries for prototype feedback
Iterate on feedback
Alternatives
For the nn use case we could standardize some conventions. For example, all loss functions should accept a padding idex and have the same default padding index. target lengths for CTCLoss could be automatically converted into int32 etc. to run optimized CUDNN kernels. The above audit uncovered a small list of quality of life improvements.
For the collate use case we could provide some basic helper functions that do simple things such as creating a padded or packed tensor from a list of tensors. Effectively this skips NestedTensor as an intermediate step and provides conversion function to/from these one-off layouts. At a minimum we should have a tutorial that provides a sanctioned approach to padding and masking instead of letting users come up with it over and over again.
Otherwise, in the longer run, NestedTensor is meant to be a representation for lists of variably sized Tensors. The nestedtensor project is meant to provide a performant and easier-to-use replacement of this user-written dynamic shape support. However, as part of the nestedtensor project, most of the code required for the above API already exists and there's an opportunity for useful upstream while NestedTensor is making its way into beta and maybe eventually into core.
NB
We currently have various operations seemingly relevant in semantics or intent.
All are intended to be used in the context of padding for convolutions, which is an entirely different concept. padding also notably shows up in the context of convolution and pooling operators and in these cases also refers to image boundary semantics and not dynamic shape support.
The text was updated successfully, but these errors were encountered:
Motivation
Introduce NestedTensor as a means of converting and standardizing lists of Tensors of different sizes.
Value to the user
Construct a single data structure that any operator with a need for variably sized data can accept. Leave the construction and constraints of specialized layouts or metadata such as PackedSequence or int32 offsets to an operator-specific overload.
User cases
torch.nn layers
After an audit of nn layers we found the following list of nn layers that either expect or optionally support some kind of representation of batches of variably sized data as input. Each of these representations has their own constraints such as non-empty entries, sorted by length, a choice of batch-first or last, etc.
Recurrent layers
together with torch.nn.utils.rnn, which is needed for optimized RNN kernels. Usually expects input of shape (sequence, batch, embedding). NB: The batch dimension being second is a performance artifact. Alternatively PaddedSequence can be used to represent a collection of variably sized sequences of shapes (S1, E), (S2, E), ..., (SN, E).
construction of PackedSequence:
pack_padded_sequence: returns PackedSequence given padded tensor plus lengths (must be on CPU). Lengths cannot be 0, must be sorted for ONNX, batch first or last must be chosen for result. example invocation.
pack_sequence: returns PackedSequence based on list of Tensors of shape L x *
decomposition of PackedSequence:
pad_packed_sequence: return padded tensor plus lengths given PackedSequence. In particular useful for multi-GPU case. example invocation.
pad_sequence: returns padded tensor based on list of Tensors of shape L x * (list of Tensors means users can get lengths)
EmbeddingBag
CrossEntropyLoss / NLLLoss / MultiLabelSoftMarginLoss / MultiLabelMarginLoss
CTCLoss
Transformer Layers and MultiheadAttention
Collate functions in the wild
We often see users write one-off implementations to represent a list of Tensors as a single Tensor plus a Tensor mask or similar representations. This can happen as an explicit, separate factory function, as part of a collate function in context of DataLoader or is built right into the Datasets. TensorCores further complicate this story due to their multiple of X constraints. Further, users might also implement operations such as merging batches across the time dimension, which can be quite tricky to get right.
Factory function examples
fairseq raw_audio_dataset: padding and masking
detr NestedTensor: padding and masking - integration as of 20210430
torchvision detection: padding and masking
torchvision ImageList: list of Tensors
parlai padded_tensor: padding and masking with special support for multiple of 8 constraint
Collate function example
Conversion examples
Proposal
We currently don't provide a canonical implementation that allows users to standardize on a singular approach to this. It's easy to make mistakes and there can be subtle bugs when dealing with, for example, empty Tensor entries. Instead, operators could accept a specialized data structure.This is also a vehicle to standardize some of these layouts and prevent minor differences such as a different default padding index or a slightly different offset convention.
API
See colab notebook for example implementation, supported layouts, API, API invocations, example resolutions and discussion of design decisions.
Concrete next steps
Alternatives
For the nn use case we could standardize some conventions. For example, all loss functions should accept a padding idex and have the same default padding index. target lengths for CTCLoss could be automatically converted into int32 etc. to run optimized CUDNN kernels. The above audit uncovered a small list of quality of life improvements.
For the collate use case we could provide some basic helper functions that do simple things such as creating a padded or packed tensor from a list of tensors. Effectively this skips NestedTensor as an intermediate step and provides conversion function to/from these one-off layouts. At a minimum we should have a tutorial that provides a sanctioned approach to padding and masking instead of letting users come up with it over and over again.
Otherwise, in the longer run, NestedTensor is meant to be a representation for lists of variably sized Tensors. The nestedtensor project is meant to provide a performant and easier-to-use replacement of this user-written dynamic shape support. However, as part of the nestedtensor project, most of the code required for the above API already exists and there's an opportunity for useful upstream while NestedTensor is making its way into beta and maybe eventually into core.
NB
We currently have various operations seemingly relevant in semantics or intent.
torch.nn.functional.pad
torch.nn.[Constant, Reflection, Replication, Zero]Pad[1d, 2d, 3d]
All are intended to be used in the context of padding for convolutions, which is an entirely different concept. padding also notably shows up in the context of convolution and pooling operators and in these cases also refers to image boundary semantics and not dynamic shape support.
The text was updated successfully, but these errors were encountered: