Skip to content

Commit

Permalink
[scan] Handle fn output aliasing an input (#8529)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Jan 4, 2025
1 parent 0ca733b commit a2d3d4d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
47 changes: 47 additions & 0 deletions test/scan/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,53 @@ def fn(carry, x):
device=self.device)
self.run_test(fn, init, xs)

def test_scan_input_output_aliases_carry(self):
"""
Test scan still works when a fn output aliases its carry input.
"""

def fn(carry, x):
return carry, x + 1

init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
self.run_test(fn, init, xs)

def test_scan_input_output_aliases_x(self):
"""
Test scan still works when a fn output aliases its x input.
"""

def fn(carry, x):
return carry + 1, x

init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
self.run_test(fn, init, xs)

def test_scan_input_in_place_mutation(self):
"""
Test that fn cannot mutate its input. The semantics of that in a `scan`
is unclear and should be disallowed.
"""

def fn(carry, x):
carry.add_(x)
y = x.clone()
y.add_(42)
return carry, y

init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
with self.assertRaisesRegex(RuntimeError, 'in-place operation'):
self.run_test(fn, init, xs)

def test_scan_external_in_place_mutation(self):
"""
Test that external in-place mutations raise an exception instead of silently
Expand Down
18 changes: 16 additions & 2 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,25 @@ def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor:
fake_x_pytree = tree_map(
lambda v: make_fake_tensor(v[0], requires_grad=v.requires_grad), xs)

# If an output of `fn` aliases the input, `aot_function` will handle that
# pair of variables with an epilogue inside its generated autograd.Function
# that we can't access. In other words, the captured graph won't contain
# those variables. See description in
# https://github.com/pytorch/pytorch/issues/85036.
#
# Because we're abusing AOTAutograd to capture the graph, we need AOTAutograd
# to handle all variables in the graph as opposed to in the opaque epilogue.
# This wrapper clones an output if it aliases an input. This hack can go away
# if scan is instead implemented as a Dynamo compiler pass.
def fn_no_output_aliasing(*args):
inputs = set(tree_iter(args))
return tree_map(lambda v: v.clone() if v in inputs else v, fn(*args))

with torch.enable_grad():
fw_compiler, get_fwd = _make_get_graph_compiler()
bw_compiler, get_bwd = _make_get_graph_compiler()
fn_compiled = aot_function(
fn,
fn_no_output_aliasing,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn)
Expand Down Expand Up @@ -293,7 +307,7 @@ def forward(ctx, forward, backward, init, xs):
return carry, ys

@staticmethod
def backward(ctx, grad_carry, grad_ys):
def backward(ctx, grad_carry, grad_ys): # type: ignore
activations = ctx.saved_tensors
backward = ctx._backward
with torch.no_grad():
Expand Down

0 comments on commit a2d3d4d

Please sign in to comment.