Skip to content

Commit 3f920f3

Browse files
IvanKobzarevpytorchmergebot
authored andcommitted
[aotd] Support mutations of the same input in fw and bw (pytorch#155354)
Original issue: pytorch#154820 The issue happens when there is a mutation for the same input in forward AND in backward. AOTD emited copy_ after joint_function tracing. This made this fx-node to correspond to the side effects of both mutations (in forward and in backward). After that partitioner can put it either in forward or in backward. The fix: 1/ Introduce joint_function.handle that allows to set "post_forward" callback, to be able to check inputs state after forward We do not want to apply the mutation after joint, if we already applied it in forward. For that we need "mutation_counter" and memorize the version of mutation that we applied for forward mutation. 2/ Exposing mutation_counter to python We want to keep invariant that copy_ exist only in the end of joint graph. 3/ We memorize mutation_counter and state of the inputs after forward, using the handle post_forward. Emit post_forward mutations after joint graph fully traced. add for post_forward mutations "must_be_in_forward" tag (similar to existing "must_be_in_backward") to keep them in forward. 4/ Ban recompute of the source of mutation. Recompute can apply the same op (e.g. add) in forward and backward. For this set MUST_SAVE for the source of mutation in forward. proxy_tensor changes: By default proxy tensor updates tensor_tracker. In this case applied mutations will be chained. But we want that this copy_ will be independent and applied just to primals. For this introducing a contextmanager to be able to disable update of tensor_tracker for adding forward mutations. Pull Request resolved: pytorch#155354 Approved by: https://github.com/bdhirsh
1 parent c82a174 commit 3f920f3

File tree

10 files changed

+396
-106
lines changed

10 files changed

+396
-106
lines changed

aten/src/ATen/FunctionalStorageImpl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
122122

123123
~FunctionalStorageImpl() override = default;
124124

125+
uint64_t mutation_counter() {
126+
return mutation_counter_;
127+
}
125128
void mark_mutation() {
126129
mutation_counter_++;
127130
}

aten/src/ATen/FunctionalTensorWrapper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
7474
bool has_metadata_mutation() const {
7575
return has_metadata_mutation_;
7676
}
77-
77+
uint64_t mutation_counter() const {
78+
return functional_storage_impl()->mutation_counter();
79+
}
7880
void mark_mutation() {
7981
functional_storage_impl()->mark_mutation();
8082
}

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39110000000,0.025
1414

1515

1616

17-
add_loop_inductor_gpu,compile_time_instruction_count,26180000000,0.015
17+
add_loop_inductor_gpu,compile_time_instruction_count,25780000000,0.015
1818

1919

2020

@@ -62,7 +62,7 @@ aotdispatcher_partitioner_cpu,compile_time_instruction_count,8844000000,0.015
6262

6363

6464

65-
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1963000000,0.015
65+
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015
6666

6767

6868

test/functorch/test_aotdispatch.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7842,6 +7842,53 @@ def _inps():
78427842
self.assertEqual(ref_inps_after_fw, inps_after_fw)
78437843
self.assertEqual(ref_inps_after_bw, inps_after_bw)
78447844

7845+
def test_mutation_of_input_in_fw_and_bw(self):
7846+
class AF(torch.autograd.Function):
7847+
@staticmethod
7848+
def forward(ctx, dummy, inplace_tensor):
7849+
inplace_tensor.add_(1)
7850+
7851+
ctx.inplace_tensor = inplace_tensor
7852+
return dummy.clone()
7853+
7854+
@staticmethod
7855+
def backward(ctx, grad_output):
7856+
inplace_tensor = ctx.inplace_tensor
7857+
inplace_tensor.add_(1)
7858+
return grad_output, None, None
7859+
7860+
def fn(dummy, inplace_tensor):
7861+
return AF.apply(dummy, inplace_tensor)
7862+
7863+
def inps():
7864+
dummy = torch.randn((2,), requires_grad=True)
7865+
inplace_tensor = torch.zeros((2,), requires_grad=False)
7866+
return dummy, inplace_tensor
7867+
7868+
def sc_inps():
7869+
dummy = TwoTensor(
7870+
torch.randn((2,), requires_grad=True),
7871+
torch.randn((2,), requires_grad=True),
7872+
)
7873+
inplace_tensor = TwoTensor(
7874+
torch.zeros((2,), requires_grad=False),
7875+
torch.zeros((2,), requires_grad=False),
7876+
)
7877+
return dummy, inplace_tensor
7878+
7879+
for _inps in [inps, sc_inps]:
7880+
dummy, inplace = _inps()
7881+
y = fn(dummy, inplace)
7882+
ref0 = inplace.clone().detach()
7883+
y.sum().backward()
7884+
ref = inplace.clone().detach()
7885+
7886+
dummy, inplace = _inps()
7887+
y = torch.compile(fn, backend="aot_eager", fullgraph=True)(dummy, inplace)
7888+
self.assertEqual(ref0, inplace)
7889+
y.sum().backward()
7890+
self.assertEqual(ref, inplace)
7891+
78457892

78467893
class MockFXGraphCache:
78477894
"""

tools/pyi/gen_pyi.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,13 @@ def gen_pyi(
912912
"None",
913913
)
914914
],
915+
"_functionalize_mutation_counter": [
916+
defs(
917+
"_functionalize_mutation_counter",
918+
["t: Tensor"],
919+
"_int",
920+
)
921+
],
915922
"_functionalize_are_all_mutations_hidden_from_autograd": [
916923
defs(
917924
"_functionalize_are_all_mutations_hidden_from_autograd",

torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,15 @@ def aot_dispatch_autograd_graph(
265265
fw_metadata,
266266
)
267267
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
268+
joint_fn_handle = joint_fn_to_trace.handle
268269

269270
joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn(
270271
joint_fn_to_trace,
271272
joint_inputs,
272273
meta=fw_metadata,
273274
aot_config=aot_config,
274275
trace_joint=True,
276+
joint_fn_handle=joint_fn_handle,
275277
)
276278

277279
# TODO: replace with AOTDispatchSubclassWrapper once we refactor

0 commit comments

Comments
 (0)