Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit ddd5206

Browse files
committed
Aot Autograd - Dont fallback with new op
1 parent 4bb7125 commit ddd5206

File tree

3 files changed

+1
-25
lines changed

3 files changed

+1
-25
lines changed

benchmarks/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def set_model_name(name):
6868
"pytorch_struct",
6969
"speech_transformer",
7070
"vision_maskrcnn",
71-
"moco",
7271
# Huggingface
7372
"AlbertForMaskedLM", # OOM
7473
"AlbertForQuestionAnswering", # OOM

test/test_aot_autograd.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,6 @@ def fn(param, y):
6060
aot_fn(x, y)
6161
self.assertTrue(not is_safe[0])
6262

63-
def test_new(self):
64-
# https://github.com/pytorch/torchdynamo/issues/1448
65-
def fn(argsort: torch.Tensor):
66-
new = argsort.new(2, 12, 4096)
67-
x = torch.add(new, 2)
68-
return (
69-
new,
70-
x,
71-
)
72-
73-
x = torch.randn((2, 12, 4096))
74-
is_safe = [True]
75-
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
76-
aot_fn = torchdynamo.optimize(compiler_fn)(fn)
77-
aot_fn(x)
78-
self.assertTrue(not is_safe[0])
79-
8063
def test_negative_testing(self):
8164
def fn(x, y):
8265
return torch.sin(x).add_(y)

torchdynamo/optimizations/training.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def is_aot_autograd_safe_to_run(gm, example_inputs):
3434
Issues
3535
1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147
3636
2) LSTM - https://github.com/pytorch/functorch/issues/586
37-
3) New op - https://github.com/pytorch/torchdynamo/issues/1448
38-
4) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301
37+
3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301
3938
"""
4039

4140
def raise_or_warn(reason):
@@ -53,11 +52,6 @@ def raise_or_warn(reason):
5352
if submod.__class__.__name__ == "LSTM":
5453
return raise_or_warn("LSTM")
5554

56-
# 2) new does not work with fake tensor and aot autograd
57-
for node in gm.graph.nodes:
58-
if node.op == "call_method" and node.target == "new":
59-
return raise_or_warn("new operator")
60-
6155
# 2) Mutation in the graph
6256
mutated = False
6357
try:

0 commit comments

Comments
 (0)