diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index d2bf2594b..7d08babde 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -309,7 +309,9 @@ def construct_output_kt( ) -class VariableBatchEmbeddingBagCollectionAwaitable(LazyAwaitable[KeyedTensor]): +class VariableBatchEmbeddingBagCollectionAwaitable( + LazyGetItemMixin[str, torch.Tensor], LazyAwaitable[KeyedTensor] +): def __init__( self, awaitables: List[Awaitable[torch.Tensor]],