Skip to content

Commit 97573e7

Browse files
Justin Yangmeta-codesync[bot]
authored andcommitted
Back out "remove guard_size_oblivious from torchrec jagged tensors." (#3455)
Summary: Pull Request resolved: #3455 Original commit changeset: dfcabef88e59 Original Phabricator Diff: D83885644 Reviewed By: peking2, miasantreble, ianbelcher Differential Revision: D84543362 fbshipit-source-id: 0a863f5354d44a752a88f82c47b8a22bbb129e23
1 parent ed318b5 commit 97573e7

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

torchrec/pt2/checks.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313

14-
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
14+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1515

1616
USE_TORCHDYNAMO_COMPILING_PATH: bool = False
1717

@@ -91,15 +91,8 @@ def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor:
9191
return x
9292

9393

94-
def pt2_guard_or_false(x: bool) -> bool:
94+
def pt2_guard_size_oblivious(x: bool) -> bool:
9595
if torch.jit.is_scripting() or not is_pt2_compiling():
9696
return x
9797

98-
return guard_or_false(x)
99-
100-
101-
def pt2_guard_or_true(x: bool) -> bool:
102-
if torch.jit.is_scripting() or not is_pt2_compiling():
103-
return x
104-
105-
return guard_or_true(x)
98+
return guard_size_oblivious(x)

torchrec/sparse/jagged_tensor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
pt2_check_size_nonzero,
2626
pt2_checks_all_is_size,
2727
pt2_checks_tensor_slice,
28-
pt2_guard_or_false,
29-
pt2_guard_or_true,
28+
pt2_guard_size_oblivious,
3029
)
3130
from torchrec.streamable import Pipelineable
3231

@@ -1072,7 +1071,7 @@ def _assert_tensor_has_no_elements_or_has_integers(
10721071
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
10731072
return
10741073

1075-
assert pt2_guard_or_false(tensor.numel() == 0) or tensor.dtype in [
1074+
assert pt2_guard_size_oblivious(tensor.numel() == 0) or tensor.dtype in [
10761075
torch.long,
10771076
torch.int,
10781077
torch.short,
@@ -1207,7 +1206,7 @@ def _maybe_compute_length_per_key(
12071206
torch.sum(
12081207
pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1
12091208
).tolist()
1210-
if pt2_guard_or_true(lengths.numel() != 0)
1209+
if pt2_guard_size_oblivious(lengths.numel() != 0)
12111210
else [0] * len(keys)
12121211
)
12131212
)
@@ -1426,7 +1425,7 @@ def _maybe_compute_kjt_to_jt_dict(
14261425
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
14271426
for lengths in split_lengths
14281427
]
1429-
elif pt2_guard_or_true(lengths.numel() > 0):
1428+
elif pt2_guard_size_oblivious(lengths.numel() > 0):
14301429
strided_lengths = lengths.view(len(keys), stride)
14311430
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
14321431
torch._check(strided_lengths.size(0) > 0)

0 commit comments

Comments
 (0)