diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 62eae7568..ec35fc0c4 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -178,7 +178,42 @@ def forward(self, kjt: KeyedJaggedTensor): test_pt2_ir_export=True, ) + def test_kjt_offset_per_key(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + return kjt.offset_per_key() + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + + self._test_kjt_input_module( + M(), + kjt.keys(), + (kjt._values, kjt._lengths), + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + # pyre-ignore + def test_kjt__getitem__(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + out0 = kjt["key0"] + out1 = kjt["key1"] + + return out0, out1 + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + + self._test_kjt_input_module( + M(), + kjt.keys(), + (kjt._values, kjt._lengths), + test_dynamo=False, + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + + # pyre-ignores @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs available", diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index cb1710be2..f7d817c51 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -244,6 +244,10 @@ def _permute_tensor_by_segments( return permuted_tensor, permuted_weights +def is_non_strict_exporting() -> bool: + return not torch.compiler.is_dynamo_compiling() and torch.compiler.is_compiling() + + class JaggedTensorMeta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): pass @@ -822,9 +826,48 @@ def _maybe_compute_offset_per_key( offsets=offsets, values=values, ) - return _length_per_key, _cumsum(_length_per_key) + + if is_non_strict_exporting(): + # only torch.export non-strict case + return ( + _length_per_key, + ( + torch.ops.fbgemm.asynchronous_complete_cumsum( + torch._refs.tensor( + _length_per_key, + dtype=torch.int32, + device=torch.device("cpu"), + pin_memory=False, + requires_grad=False, + ) + ).tolist() + if len(_length_per_key) > 0 + else [] + ), + ) + else: + return _length_per_key, _cumsum(_length_per_key) elif offset_per_key is None: - return length_per_key, _cumsum(length_per_key) + if is_non_strict_exporting(): + # only torch.export non-strict case + return ( + length_per_key, + ( + torch.ops.fbgemm.asynchronous_complete_cumsum( + torch._refs.tensor( + length_per_key, + dtype=torch.int32, + device=torch.device("cpu"), + pin_memory=False, + requires_grad=False, + ) + ).tolist() + if len(length_per_key) > 0 + else [] + ), + ) + else: + return length_per_key, _cumsum(length_per_key) else: return length_per_key, offset_per_key @@ -1825,6 +1868,7 @@ def flatten_lengths(self) -> "KeyedJaggedTensor": def __getitem__(self, key: str) -> JaggedTensor: offset_per_key = self.offset_per_key() + length_per_key = self.length_per_key() index = self._key_indices()[key] start_offset = offset_per_key[index] end_offset = ( @@ -1832,20 +1876,57 @@ def __getitem__(self, key: str) -> JaggedTensor: if index + 1 < len(offset_per_key) else start_offset ) - return JaggedTensor( - values=self._values[start_offset:end_offset], - weights=( - None - if self.weights_or_none() is None - else self.weights()[start_offset:end_offset] - ), - lengths=self.lengths()[ - self.lengths_offset_per_key()[index] : self.lengths_offset_per_key()[ - index + 1 - ] - ], - offsets=None, - ) + + if is_non_strict_exporting(): + _lengths = torch.narrow( + self.lengths(), + 0, + self.lengths_offset_per_key()[index], + self.lengths_offset_per_key()[index + 1] + - self.lengths_offset_per_key()[index], + ) + sz = length_per_key[index] + + torch._check_is_size(start_offset) + torch._check_is_size(sz) + torch._check(start_offset <= self.values().size(0)) + torch._check(sz <= self.values().size(0)) + + return JaggedTensor( + values=torch.narrow( + self.values(), + 0, + start_offset, + sz, + ), + weights=( + None + if self.weights_or_none() is None + else torch.narrow( + self.weights(), + 0, + start_offset, + sz, + ) + ), + lengths=_lengths, + offsets=None, + ) + else: + return JaggedTensor( + values=self._values[start_offset:end_offset], + weights=( + None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset] + ), + lengths=self.lengths()[ + self.lengths_offset_per_key()[ + index + ] : self.lengths_offset_per_key()[index + 1] + ], + offsets=None, + ) def to_dict(self) -> Dict[str, JaggedTensor]: _jt_dict = _maybe_compute_kjt_to_jt_dict(