diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index b6b4b71ae..d8e08171b 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1095,7 +1095,6 @@ def _maybe_compute_stride_kjt( lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[torch.IntTensor], - inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> int: if stride is None: if len(keys) == 0: @@ -1103,10 +1102,6 @@ def _maybe_compute_stride_kjt( elif ( stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0 ): - # For VBE KJT, batch size should be based on inverse_indices when set. - if inverse_indices is not None: - return inverse_indices[1].shape[-1] - s = stride_per_key_per_rank.sum(dim=1).max().item() if not torch.jit.is_scripting() and is_non_strict_exporting(): stride = torch.sym_int(s) @@ -2156,7 +2151,6 @@ def stride(self) -> int: self._lengths, self._offsets, self._stride_per_key_per_rank, - self._inverse_indices, ) self._stride = stride return stride diff --git a/torchrec/sparse/tests/test_keyed_jagged_tensor.py b/torchrec/sparse/tests/test_keyed_jagged_tensor.py index bac0a2c52..1636a06bd 100644 --- a/torchrec/sparse/tests/test_keyed_jagged_tensor.py +++ b/torchrec/sparse/tests/test_keyed_jagged_tensor.py @@ -1017,18 +1017,6 @@ def test_meta_device_compatibility(self) -> None: lengths=torch.tensor([], device=torch.device("meta")), ) - def test_vbe_kjt_stride(self) -> None: - inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]]) - kjt = KeyedJaggedTensor( - keys=["f1", "f2", "f3"], - values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), - lengths=torch.tensor([3, 3, 2]), - stride_per_key_per_rank=[[2], [1]], - inverse_indices=(["f1", "f2"], inverse_indices), - ) - - self.assertEqual(kjt.stride(), inverse_indices.shape[-1]) - class TestKeyedJaggedTensorScripting(unittest.TestCase): def test_scriptable_forward(self) -> None: