Skip to content

Commit 7aa0e60

Browse files
authored
Aot Autograd - Dont fallback with new op (#1493)
* Aot Autograd - Dont fallback with new op * Bring back moco - Moco fails with einsum float16 error * CI
1 parent 9b69076 commit 7aa0e60

File tree

2 files changed

+1
-24
lines changed

2 files changed

+1
-24
lines changed

test/test_aot_autograd.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,6 @@ def fn(param, y):
6060
aot_fn(x, y)
6161
self.assertTrue(not is_safe[0])
6262

63-
def test_new(self):
64-
# https://github.com/pytorch/torchdynamo/issues/1448
65-
def fn(argsort: torch.Tensor):
66-
new = argsort.new(2, 12, 4096)
67-
x = torch.add(new, 2)
68-
return (
69-
new,
70-
x,
71-
)
72-
73-
x = torch.randn((2, 12, 4096))
74-
is_safe = [True]
75-
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
76-
aot_fn = torchdynamo.optimize(compiler_fn)(fn)
77-
aot_fn(x)
78-
self.assertTrue(not is_safe[0])
79-
8063
def test_negative_testing(self):
8164
def fn(x, y):
8265
return torch.sin(x).add_(y)

torchdynamo/optimizations/training.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def is_aot_autograd_safe_to_run(gm, example_inputs):
3333
Issues
3434
1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147
3535
2) LSTM - https://github.com/pytorch/functorch/issues/586
36-
3) New op - https://github.com/pytorch/torchdynamo/issues/1448
37-
4) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301
36+
3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301
3837
"""
3938

4039
def raise_or_warn(reason):
@@ -52,11 +51,6 @@ def raise_or_warn(reason):
5251
if submod.__class__.__name__ == "LSTM":
5352
return raise_or_warn("LSTM")
5453

55-
# 2) new does not work with fake tensor and aot autograd
56-
for node in gm.graph.nodes:
57-
if node.op == "call_method" and node.target == "new":
58-
return raise_or_warn("new operator")
59-
6054
# 2) Mutation in the graph
6155
mutated = False
6256
try:

0 commit comments

Comments
 (0)