diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 0467754d7..721cd5486 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -818,9 +818,54 @@ def _maybe_compute_offset_per_key( offsets=offsets, values=values, ) - return _length_per_key, _cumsum(_length_per_key) + + if ( + not torch.jit.is_scripting() + and not is_torchdynamo_compiling() + and torch.compiler.is_compiling() + ): + return ( + _length_per_key, + ( + torch.ops.fbgemm.asynchronous_complete_cumsum( + torch._refs.tensor( + _length_per_key, + dtype=torch.int32, + device=torch.device("cpu"), + pin_memory=False, + requires_grad=False, + ) + ).tolist() + if len(_length_per_key) > 0 + else [] + ), + ) + else: + return _length_per_key, _cumsum(_length_per_key) elif offset_per_key is None: - return length_per_key, _cumsum(length_per_key) + if ( + not torch.jit.is_scripting() + and not is_torchdynamo_compiling() + and torch.compiler.is_compiling() + ): + return ( + length_per_key, + ( + torch.ops.fbgemm.asynchronous_complete_cumsum( + torch._refs.tensor( + length_per_key, + dtype=torch.int32, + device=torch.device("cpu"), + pin_memory=False, + requires_grad=False, + ) + ).tolist() + if len(length_per_key) > 0 + else [] + ), + ) + else: + return length_per_key, _cumsum(length_per_key) else: return length_per_key, offset_per_key