diff --git a/torchrec/pt2/checks.py b/torchrec/pt2/checks.py index e997d39d2..76626a9f8 100644 --- a/torchrec/pt2/checks.py +++ b/torchrec/pt2/checks.py @@ -11,7 +11,7 @@ import torch -from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious USE_TORCHDYNAMO_COMPILING_PATH: bool = False @@ -91,15 +91,8 @@ def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor: return x -def pt2_guard_or_false(x: bool) -> bool: +def pt2_guard_size_oblivious(x: bool) -> bool: if torch.jit.is_scripting() or not is_pt2_compiling(): return x - return guard_or_false(x) - - -def pt2_guard_or_true(x: bool) -> bool: - if torch.jit.is_scripting() or not is_pt2_compiling(): - return x - - return guard_or_true(x) + return guard_size_oblivious(x) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 0c1905387..db1a26aba 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -25,8 +25,7 @@ pt2_check_size_nonzero, pt2_checks_all_is_size, pt2_checks_tensor_slice, - pt2_guard_or_false, - pt2_guard_or_true, + pt2_guard_size_oblivious, ) from torchrec.streamable import Pipelineable @@ -1072,7 +1071,7 @@ def _assert_tensor_has_no_elements_or_has_integers( # TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable. return - assert pt2_guard_or_false(tensor.numel() == 0) or tensor.dtype in [ + assert pt2_guard_size_oblivious(tensor.numel() == 0) or tensor.dtype in [ torch.long, torch.int, torch.short, @@ -1207,7 +1206,7 @@ def _maybe_compute_length_per_key( torch.sum( pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1 ).tolist() - if pt2_guard_or_true(lengths.numel() != 0) + if pt2_guard_size_oblivious(lengths.numel() != 0) else [0] * len(keys) ) ) @@ -1426,7 +1425,7 @@ def _maybe_compute_kjt_to_jt_dict( torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) for lengths in split_lengths ] - elif pt2_guard_or_true(lengths.numel() > 0): + elif pt2_guard_size_oblivious(lengths.numel() > 0): strided_lengths = lengths.view(len(keys), stride) if not torch.jit.is_scripting() and is_torchdynamo_compiling(): torch._check(strided_lengths.size(0) > 0)