diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index e0a19547c..e22c829b0 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -776,7 +776,11 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: if weights is not None and not torch.is_floating_point(weights): weights = None if features.variable_stride_per_key() and isinstance( - self.emb_module, SplitTableBatchedEmbeddingBagsCodegen + self.emb_module, + ( + SplitTableBatchedEmbeddingBagsCodegen, + DenseTableBatchedEmbeddingBagsCodegen, + ), ): return self.emb_module( indices=features.values().long(), diff --git a/torchrec/distributed/sharding/dp_sharding.py b/torchrec/distributed/sharding/dp_sharding.py index 020beb012..6ffb52e4c 100644 --- a/torchrec/distributed/sharding/dp_sharding.py +++ b/torchrec/distributed/sharding/dp_sharding.py @@ -153,10 +153,6 @@ def forward( Awaitable[Awaitable[SparseFeatures]]: awaitable of awaitable of SparseFeatures. """ - if sparse_features.variable_stride_per_key(): - raise ValueError( - "Dense TBE kernel does not support variable batch per feature" - ) return NoWait(cast(Awaitable[KeyedJaggedTensor], NoWait(sparse_features))) diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index de4923207..c6d81b8c2 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -245,15 +245,11 @@ def test_sharding_rw( SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE.value, - ], - ), - apply_optimizer_in_backward_config=st.sampled_from([None]), + kernel_type=st.just(EmbeddingComputeKernel.DENSE.value), + apply_optimizer_in_backward_config=st.just(None), # TODO - need to enable optimizer overlapped behavior for data_parallel tables ) - @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) def test_sharding_dp( self, sharder_type: str, @@ -591,12 +587,13 @@ def test_sharding_twrw( ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, ShardingType.ROW_WISE.value, + ShardingType.DATA_PARALLEL.value, ] ), global_constant_batch=st.booleans(), pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), ) - @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) def test_sharding_variable_batch( self, sharding_type: str, @@ -608,13 +605,18 @@ def test_sharding_variable_batch( self.skipTest( "bounds_check_indices on CPU does not support variable length (batch size)" ) + kernel = ( + EmbeddingComputeKernel.DENSE.value + if sharding_type == ShardingType.DATA_PARALLEL.value + else EmbeddingComputeKernel.FUSED.value + ) self._test_sharding( # pyre-ignore[6] sharders=[ create_test_sharder( sharder_type=SharderType.EMBEDDING_BAG_COLLECTION.value, sharding_type=sharding_type, - kernel_type=EmbeddingComputeKernel.FUSED.value, + kernel_type=kernel, device=self.device, ), ],