File tree Expand file tree Collapse file tree 2 files changed +7
-15
lines changed Expand file tree Collapse file tree 2 files changed +7
-15
lines changed Original file line number Diff line number Diff line change 1111
1212import torch
1313
14- from torch .fx .experimental .symbolic_shapes import guard_or_false , guard_or_true
14+ from torch .fx .experimental .symbolic_shapes import guard_size_oblivious
1515
1616USE_TORCHDYNAMO_COMPILING_PATH : bool = False
1717
@@ -91,15 +91,8 @@ def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor:
9191 return x
9292
9393
94- def pt2_guard_or_false (x : bool ) -> bool :
94+ def pt2_guard_size_oblivious (x : bool ) -> bool :
9595 if torch .jit .is_scripting () or not is_pt2_compiling ():
9696 return x
9797
98- return guard_or_false (x )
99-
100-
101- def pt2_guard_or_true (x : bool ) -> bool :
102- if torch .jit .is_scripting () or not is_pt2_compiling ():
103- return x
104-
105- return guard_or_true (x )
98+ return guard_size_oblivious (x )
Original file line number Diff line number Diff line change 2525 pt2_check_size_nonzero ,
2626 pt2_checks_all_is_size ,
2727 pt2_checks_tensor_slice ,
28- pt2_guard_or_false ,
29- pt2_guard_or_true ,
28+ pt2_guard_size_oblivious ,
3029)
3130from torchrec .streamable import Pipelineable
3231
@@ -1072,7 +1071,7 @@ def _assert_tensor_has_no_elements_or_has_integers(
10721071 # TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
10731072 return
10741073
1075- assert pt2_guard_or_false (tensor .numel () == 0 ) or tensor .dtype in [
1074+ assert pt2_guard_size_oblivious (tensor .numel () == 0 ) or tensor .dtype in [
10761075 torch .long ,
10771076 torch .int ,
10781077 torch .short ,
@@ -1207,7 +1206,7 @@ def _maybe_compute_length_per_key(
12071206 torch .sum (
12081207 pt2_check_size_nonzero (lengths .view (len (keys ), stride )), dim = 1
12091208 ).tolist ()
1210- if pt2_guard_or_true (lengths .numel () != 0 )
1209+ if pt2_guard_size_oblivious (lengths .numel () != 0 )
12111210 else [0 ] * len (keys )
12121211 )
12131212 )
@@ -1426,7 +1425,7 @@ def _maybe_compute_kjt_to_jt_dict(
14261425 torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
14271426 for lengths in split_lengths
14281427 ]
1429- elif pt2_guard_or_true (lengths .numel () > 0 ):
1428+ elif pt2_guard_size_oblivious (lengths .numel () > 0 ):
14301429 strided_lengths = lengths .view (len (keys ), stride )
14311430 if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
14321431 torch ._check (strided_lengths .size (0 ) > 0 )
You can’t perform that action at this time.
0 commit comments