diff --git a/examples/decoder_only_model.py b/examples/decoder_only_model.py index 712423d79ad7..79040e5d24d2 100644 --- a/examples/decoder_only_model.py +++ b/examples/decoder_only_model.py @@ -7,16 +7,16 @@ from torch import nn -# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core. +# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core. @dataclass class DecoderOnlyConfig: hidden_size: int = 1024 num_hidden_layers: int = 2 num_attention_heads: int = 8 num_key_value_heads: int = 4 - intermediate_size = 32 * 1024 - vocab_size = 3200 - use_flash_attention = False + intermediate_size: int = 32 * 1024 + vocab_size: int = 3200 + use_flash_attention: bool = False def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: diff --git a/test/run_tests.sh b/test/run_tests.sh index 0912d53ded5a..543bc5f84032 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -208,7 +208,8 @@ function run_xla_op_tests1 { function run_xla_op_tests2 { run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_while_loop.py" - run_test "$CDIR/test_scan.py" + run_test "$CDIR/scan/test_scan.py" + run_test "$CDIR/scan/test_scan_layers.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" run_test "$CDIR/eager/test_eager_with_xla_compile.py" diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py new file mode 100644 index 000000000000..f344386c8525 --- /dev/null +++ b/test/scan/test_scan.py @@ -0,0 +1,477 @@ +import sys +import os +import re +import unittest +from functools import reduce + +import torch +from functorch.compile import default_partition, min_cut_rematerialization_partition # type: ignore +from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree + +import torch_xla +from torch_xla.experimental.scan import scan, value_and_grad_partitioned, tree_flatten_none + +parent_folder = os.path.dirname(os.path.dirname(__file__)) +sys.path.append(parent_folder) +from test_utils import XlaTestCase # type:ignore + + +def _loopy_scan(fn, init, xs): + """A simple scan implemented with for loops serving as reference + implementation.""" + carry = init + ys = [] + xs_len = len(next(iter(tree_iter(xs)))) + for i in range(xs_len): + carry, y = fn(carry, tree_map(lambda x: x[i], xs)) + ys.append(y) + + def none_stack(*ys): + if len(ys) == 0: + return None + if ys[0] is None: + assert all(y is None for y in ys) + return None + return torch.stack(ys) + + ys = tree_map(none_stack, *ys) + return carry, ys + + +class TestBase(XlaTestCase): + + def setUp(self): + super().setUp() + self.device = torch_xla.device() + + def compare_pytree(self, expected_pytree, actual_pytree): + flat_expected_pytree, expected_spec = tree_flatten(expected_pytree) + flat_actual_pytree, actual_spec = tree_flatten(actual_pytree) + assert expected_spec == actual_spec, f"{expected_spec} != {actual_spec}" + # If there are `None`, they must happen in the same location. + for expected, actual in zip(flat_expected_pytree, flat_actual_pytree): + assert (expected is None) == (actual is None), \ + f"Mismatched None. expected: {expected}, actual: {actual}" + # Get rid of `None` before passing to compareResults. + flat_expected_pytree = [x for x in flat_expected_pytree if x is not None] + flat_actual_pytree = [x for x in flat_actual_pytree if x is not None] + super().compareResults(flat_expected_pytree, flat_actual_pytree) + + +class ScanTest(TestBase): + + def run_test(self, + fn, + init: PyTree, + xs: PyTree, + partition_fn=default_partition): + """Compares the result of scanning with `fn` with our optimized HLO implementation + against a for loop implementation. Checks both output values and gradients. + """ + squish = lambda t: reduce( + lambda a, b: a + b, + map(lambda v: v.sum() + if v is not None else 0, tree_leaves(t)), torch.tensor(0.0)) + dupe = lambda v: v.detach().clone().requires_grad_(v.requires_grad) + + # Actual output + init_scan = tree_map(dupe, init) + xs_scan = tree_map(dupe, xs) + final_carry, ys = scan(fn, init_scan, xs_scan, partition_fn=partition_fn) + # Add up all leaves and `backward()` once. + (squish(final_carry) + squish(ys)).backward() + torch_xla.sync() + + # Expected output + init_loop = tree_map(dupe, init) + xs_loop = tree_map(dupe, xs) + expected_final_carry, expected_ys = _loopy_scan(fn, init_loop, xs_loop) + # Add up all leaves and `backward()` once. + (squish(expected_final_carry) + squish(expected_ys)).backward() + torch_xla.sync() + + # Compare values + self.compare_pytree(expected_final_carry, final_carry) + self.compare_pytree(expected_ys, ys) + + # Compare gradients + self.compare_pytree( + tree_map(lambda v: v.grad, init_loop), + tree_map(lambda v: v.grad, init_scan)) + self.compare_pytree( + tree_map(lambda v: v.grad, xs_loop), tree_map(lambda v: v.grad, + xs_scan)) + + return final_carry, ys + + def test_scan_simple(self): + """This test uses `scan` to implement `torch.cumsum`.""" + + def step_fn(carry, x): + new_carry = carry + x + y = new_carry + return new_carry, y + + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + final_carry, ys = self.run_test(step_fn, init, xs) + + # Also ensure that our loop-based scan is correct, with manual checks + # that replicate the step_fn. + expected_final_carry = torch.sum(xs, dim=0) + init + expected_ys = torch.cumsum(xs, dim=0) + self.compare_pytree(expected_final_carry, final_carry) + self.compare_pytree(expected_ys, ys) + + def test_scan_fn_not_callable(self): + init = torch.tensor([1.0, 1.0], device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) + with self.assertRaises(ValueError): + scan(1000, init, xs) # type: ignore + + def test_scan_incompatible_length(self): + init = torch.tensor([1.0, 1.0], device=self.device) + xs_1 = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + device=self.device) + xs_2 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device) + with self.assertRaises(ValueError): + scan(lambda a, b: (a, b), init, (xs_1, xs_2)) + + def test_scan_tuples(self): + """Test scanning over the leading axis of a tuple of tensors simultaneously, + which is a simple PyTree.""" + + def fn(carry, x): + carry1, carry2 = carry + x1, x2 = x + new_carry1 = carry1 + x1.sum() + new_carry2 = carry2 + x2.sum() + y1 = x1 * 2 + torch.sum(new_carry1) + y2 = x2 * 2 + torch.sum(new_carry2) + return (new_carry1, new_carry2), (y1, y2) + + init = (torch.tensor([0.0], requires_grad=True, device=self.device), + torch.tensor([1.0, 2.0], requires_grad=True, device=self.device)) + + xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], + requires_grad=True, + device=self.device), + torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], + requires_grad=True, + device=self.device)) + + self.run_test(fn, init, xs) + + def test_scan_create_tensors(self): + """Test scanning over a function that internally creates tensors.""" + + def fn(carry, x): + a = torch.tensor([1.0, 2.0], device=self.device) + b = torch.tensor([3.0, 4.0], device=self.device) + return carry + a, x + b + + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + self.run_test(fn, init, xs) + + def test_scan_internal_in_place_mutation(self): + """ + Test internal in-place mutations inside the `fn` to be scanned over. + """ + + def fn(carry, x): + carry = carry.clone() + carry.add_(x) + y = x.clone() + y.add_(42) + return carry, y + + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + self.run_test(fn, init, xs) + + def test_scan_external_in_place_mutation(self): + """ + Test that external in-place mutations raise an exception instead of silently + giving wrong results. + """ + # TODO(yifeit): Modify this test when external in-place mutation is eventually supported. + weird_global = torch.tensor([0.0, 0.0], device=torch_xla.device()) + + def step_fn(carry, x): + new_carry = carry + x + weird_global.add_(1.0) + y = new_carry + weird_global + return new_carry, y + + init = torch.tensor([0.0, 0.0], device=torch_xla.device()) + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + device=torch_xla.device()) + + with self.assertRaisesRegex(AssertionError, "FakeTensor"): + scan(step_fn, init, xs) + + def test_scan_gradness(self): + """ + Test the gradient output of `scan` when various inputs require or doesn't + require gradients. + """ + + def test_case(init_requires_grad: bool, xs_requires_grad: bool): + + def fn(carry, x): + new_carry = carry * x + y = new_carry + x + return new_carry, y + + init = torch.tensor([1.0, 1.0], + requires_grad=init_requires_grad, + device=self.device) + xs = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], + requires_grad=xs_requires_grad, + device=self.device) + self.run_test(fn, init, xs) + + test_case(True, True) + test_case(True, False) + test_case(False, True) + + def test_scan_output_none(self): + """ + Test scan when `fn` returns `None` as output. This case is exercised by + `scan_layers`, which only needs the carry. + """ + + def fn(carry, x): + return torch.cos(carry) + x, None + + init = torch.tensor([1.0, 1.0], requires_grad=True, device=self.device) + xs = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], + requires_grad=True, + device=self.device) + _final_carry, ys = self.run_test(fn, init, xs) + self.assertIsNone(ys) + + def test_scan_output_unit(self): + """ + Test scan when `fn` returns `()` as output. + """ + + def fn(carry, x): + return torch.cos(carry) + x, () + + init = torch.tensor([1.0, 1.0], requires_grad=True, device=self.device) + xs = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], + requires_grad=True, + device=self.device) + _final_carry, ys = self.run_test(fn, init, xs) + self.assertEqual(ys, ()) + + def test_scan_rand_in_fn(self): + """ + Test that the RNG state in each iteration of `fn` is not the same. + """ + + def step_fn(carry, x): + new_carry = carry + x + y = new_carry + torch.rand(2, device=torch_xla.device()) + return new_carry, y + + init = torch.tensor([0.0, 0.0], device=torch_xla.device()) + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + device=torch_xla.device()) + _, ys = scan(step_fn, init, xs) + # ys should be a 2D tensor with this shape. + self.assertEqual(ys.shape, (3, 2)) + # Values across the first dimension should not be the same. + self.assertNotEqual(ys[0][0], ys[1][0]) + self.assertNotEqual(ys[0][1], ys[1][1]) + + def test_scan_with_rematerialization(self): + """ + Test scanning `fn` but also the backward pass recomputes the forward. + """ + + def fn(carry, x): + for _ in range(10): + carry = torch.sin(carry) + for _ in range(10): + x = torch.sin(x) + return carry, x + + carry = torch.randn(4, 4, requires_grad=True, device=self.device) + xs = torch.randn(20, 4, 4, requires_grad=True, device=self.device) + + # Check the gradients and also cross-check with results from a run + # where we don't have activation checkpointing. + final_carry_remat, ys_remat = self.run_test( + fn, carry, xs, partition_fn=min_cut_rematerialization_partition) + final_carry, ys = self.run_test(fn, carry, xs) + super().compareResults(final_carry, final_carry_remat) + super().compareResults(ys, ys_remat) + torch_xla.sync() + + SINE_OP = re.compile(r" sine\(f32\b") + + def count_number_of_sines(partition_fn): + """ + Uses `partition_fn` to partition `fn` into forward and backward passes + while building the scan operation, then counts the number of `sine` HLO + operators in the joint graph. + + The intention is that if `partition_fn` recomputes some forward ops + during the backward, we'll see a larger number of `sine` operations since + `fn` consists of only `torch.sin` in this test. + """ + own_carry = carry.clone().detach().requires_grad_() + own_xs = xs.clone().detach().requires_grad_() + final_carry, ys = scan(fn, own_carry, own_xs, partition_fn=partition_fn) + torch_xla.sync() + (torch.sum(final_carry) + torch.sum(ys)).backward() + assert own_carry.grad is not None + assert own_xs.grad is not None + text: str = torch_xla._XLAC._get_xla_tensors_hlo( + [own_carry.grad, own_xs.grad]) + return len(SINE_OP.findall(text)) + + # Check the HLO to verify that `sine(...)` recomputation happens in the backward + # in the version using `min_cut_rematerialization_partition`, and never happens + # in the default partition. + self.assertGreater( + count_number_of_sines(min_cut_rematerialization_partition), 10) + self.assertEqual(count_number_of_sines(default_partition), 0) + + +class PyTreeTest(TestBase): + + def test_tree_flatten_none(self): + pytree = ((1, 2), (None, 3), None) + flat, unflatten = tree_flatten_none(pytree) + assert tuple(flat) == (1, 2, 3) + assert unflatten(flat) == ((1, 2), (None, 3), None) + + +class ValueAndGradPartitionedTest(TestBase): + + def test_transform_linear_layer(self): + + def fn(carry, x): + new_carry = carry @ x + y = new_carry + return new_carry, y + + init = torch.tensor([[1.0, 2.0], [3.0, 4.0]], + requires_grad=True, + device=self.device) + xs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], + requires_grad=True, + device=self.device) + forward, backward = value_and_grad_partitioned(fn, init, xs) + + # Forward should return `(new_carry, (y, (carry, x)))`, + # because `(carry, x)` are the two intermediate activations (primals), + # and they will be packed alongside the original output `y`. + out = forward(init, xs[0]) + torch_xla.sync() + carry = init + x = xs[0] + new_carry = init @ x + y = new_carry + self.compare_pytree(out, (new_carry, (y, (carry, x)))) + + # Backward should take in `(grad_new_carry, (grad_y, (carry, x)))`, and + # return `(grad_carry, grad_x)`. `(carry, x)` are the two intermediate + # activations (primals), and they are packed alongside the gradient with + # respect to y, `grad_y`. + grad_new_carry = torch.ones_like(new_carry) + grad_y = torch.ones_like(y) + out = backward(grad_new_carry, (grad_y, (carry, x))) + torch_xla.sync() + grad_carry = (grad_new_carry + grad_y) @ x.T + grad_x = carry.T @ (grad_new_carry + grad_y) + self.compare_pytree(out, (grad_carry, grad_x)) + + def test_transform_non_trivial_pytree(self): + """ + `fn` simulates two linear layers operating on two values a and b. + Test that we can trace `fn` when it uses non-trivial pytree, and + compare gradients against those from torch.autograd. + """ + + def fn(carry, x): + weights = x['weights'] + biases = x['biases'] + carry_a = carry['a'] + carry_b = carry['b'] + new_carry_a = torch.sin((carry_a @ weights) + biases) + new_carry_b = torch.cos((carry_b @ weights) + biases) + y = torch.sigmoid(new_carry_a + new_carry_b) + return {'a': new_carry_a, 'b': new_carry_b}, y + + init = { + 'a': torch.randn(2, 3, requires_grad=True, device=self.device), + 'b': torch.randn(2, 3, requires_grad=True, device=self.device) + } + x = { + 'weights': torch.randn(3, 3, requires_grad=True, device=self.device), + 'biases': torch.randn(2, 3, requires_grad=True, device=self.device) + } + + # Get the forward and backward functions using value_and_grad_partitioned + forward, backward = value_and_grad_partitioned( + fn, init, tree_map(lambda v: v.unsqueeze(0), x)) + + # Run the forward function + carry_out, (y_out, activations) = forward(init, x) + torch_xla.sync() + + # Compute expected outputs and gradients using PyTorch autograd + def compute_outputs_and_gradients(carry, x): + # Clone inputs to ensure they're independent + carry = tree_map(lambda v: v.clone().detach().requires_grad_(True), carry) + x = tree_map(lambda v: v.clone().detach().requires_grad_(True), x) + + # Forward pass + new_carry, y = fn(carry, x) + + # Run backward to compute gradients. + out, _ = tree_flatten((new_carry, y)) + torch.autograd.backward(out, tree_map(lambda v: torch.ones_like(v), out)) + + # Collect gradients + grads = { + 'init': tree_map(lambda v: v.grad, carry), + 'x': tree_map(lambda v: v.grad, x), + } + outputs = {'carry': new_carry, 'y': y} + return outputs, grads + + # Compute expected outputs and gradients + expected_outputs, expected_grads = compute_outputs_and_gradients(init, x) + + # Compare the outputs from the forward function with the expected outputs + self.compare_pytree(carry_out, expected_outputs['carry']) + self.compare_pytree(y_out, expected_outputs['y']) + + # Prepare gradients for the backward function + grad_carry = tree_map(lambda v: torch.ones_like(v), carry_out) + grad_y = torch.ones_like(y_out) + + # Run the backward function + grad_init, grad_x = backward(grad_carry, (grad_y, activations)) + torch_xla.sync() + + # Compare the gradients from the backward function with the expected gradients + self.compare_pytree(grad_init, expected_grads['init']) + self.compare_pytree(grad_x, expected_grads['x']) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/scan/test_scan_layers.py b/test/scan/test_scan_layers.py new file mode 100644 index 000000000000..a1eb68bd7d27 --- /dev/null +++ b/test/scan/test_scan_layers.py @@ -0,0 +1,280 @@ +import sys +import os +example_folder = os.path.dirname(os.path.dirname( + os.path.dirname(__file__))) + "/examples" +sys.path.append(example_folder) +from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore + +import unittest +from copy import deepcopy +from typing import Iterable + +import torch +import torch.nn as nn + +import torch_xla +from torch_xla.experimental.scan_layers import scan_layers + +parent_folder = os.path.dirname(os.path.dirname(__file__)) +sys.path.append(parent_folder) +from test_utils import XlaTestCase # type:ignore + + +class ScanLayersTest(XlaTestCase): + + def setUp(self): + super().setUp() + + self.device = torch_xla.device() + + def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor): + assert a is not b, f"Expected {a} and {b} to be different tensors" + assert a.data is not b.data, f"Expected {a} and {b} to have different storage" + + def assert_while_found_in_hlo(self, tensor: torch.Tensor): + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) + self.assertIn("while(", hlo_text) + self.assertIn("condition=", hlo_text) + self.assertIn("body=", hlo_text) + + def test_empty_layers(self): + layers = [] + input_data = torch.randn(64).to(self.device) + output = scan_layers(layers, input_data.clone()) + super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.001) + + def test_linear_layers(self): + # Fix the random seed to avoid flakes. + with torch.random.fork_rng(): + with torch_xla.xm.fork_rng(): + torch.random.manual_seed(42) + torch_xla.xm.set_rng_state(42) + # We want to apply these layers sequentially + layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)] + input_data = torch.randn(64).to(self.device) + torch_xla.sync(wait=True) + + layers_for_scan = deepcopy(layers) + layers_for_loop = deepcopy(layers) + torch_xla.sync() + + output = scan_layers(layers_for_scan, input_data.clone()) + self.assert_while_found_in_hlo(output) + output.sum().backward() + torch_xla.sync() + + # Test that the result is the same as for loop. + loop_output = input_data.clone() + for layer in layers_for_loop: + loop_output = layer(loop_output) + torch_xla.sync() + + super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.001) + self.assert_different_tensor(loop_output, output) + + loop_output.sum().backward() + torch_xla.sync() + + # Test that the gradients are the same too. + for layer_scan, layer_loop in zip(layers_for_scan, layers_for_loop): + assert layer_scan.weight.grad is not None + assert layer_loop.weight.grad is not None + assert layer_scan.bias.grad is not None + assert layer_loop.bias.grad is not None + super().compareResults( + layer_scan.weight.grad, + layer_loop.weight.grad, + abs_err=0.0001, + rel_err=0.001) + super().compareResults( + layer_scan.bias.grad, + layer_loop.bias.grad, + abs_err=0.0001, + rel_err=0.001) + self.assert_different_tensor(layer_scan.weight.grad, + layer_loop.weight.grad) + self.assert_different_tensor(layer_scan.bias.grad, layer_loop.bias.grad) + + def test_tuple_layers(self): + """Test applying layers that consume and return tuples. Construct a module + that transforms each element in the tuple. + """ + + class TupleModule(torch.nn.Module): + + def __init__(self): + super(TupleModule, self).__init__() + self.linear = nn.Linear(64, 64) + self.w = nn.Parameter(torch.randn(64, 64, requires_grad=True)) + + def forward(self, x, y, z): + return self.linear(x).sin(), self.linear( + y).cos(), self.linear(z) @ self.w + + layers = [TupleModule().to(self.device) for _ in range(10)] + torch_xla.sync() + + layers_for_scan = deepcopy(layers) + layers_for_loop = deepcopy(layers) + torch_xla.sync() + + # Also make input data some non-trivial graph instead of just device data. + input_data = (torch.randn(64).to(self.device) * 100, + torch.randn(64).to(self.device) * 200, + torch.randn(64).to(self.device) * 300) + a = torch.randn(64).to(self.device) + input_data = tuple(t + a for t in input_data) + output = scan_layers(layers_for_scan, input_data) + self.assert_while_found_in_hlo(output[0]) + self.assert_while_found_in_hlo(output[1]) + output[0].sum().backward() + torch_xla.sync() + + # Test that the result is the same as for loop. + loop_output = input_data + for layer in layers_for_loop: + loop_output = layer(*loop_output) + torch_xla.sync() + + super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.001) + self.assert_different_tensor(loop_output[0], output[0]) + + loop_output[0].sum().backward() + torch_xla.sync() + + # Test that the gradients are the same too. + for layer_scan, layer_loop in zip(layers_for_scan, layers_for_loop): + assert layer_scan.linear.weight.grad is not None + assert layer_loop.linear.weight.grad is not None + assert layer_scan.linear.bias.grad is not None + assert layer_loop.linear.bias.grad is not None + super().compareResults( + layer_scan.linear.weight.grad, + layer_loop.linear.weight.grad, + abs_err=0.0001, + rel_err=0.001) + super().compareResults( + layer_scan.linear.bias.grad, + layer_loop.linear.bias.grad, + abs_err=0.0001, + rel_err=0.001) + self.assert_different_tensor(layer_scan.linear.weight.grad, + layer_loop.linear.weight.grad) + self.assert_different_tensor(layer_scan.linear.bias.grad, + layer_loop.linear.bias.grad) + + def test_decoder_model(self): + # Define a decoder model that composes the decoder model in the example, + # but adds the ability to run the layers with the `scan` operator. + class DecoderOnlyModelWithScan(torch.nn.Module): + + def __init__(self, **kwargs): + super(DecoderOnlyModelWithScan, self).__init__() + self.decoder = DecoderOnlyModel(**kwargs) + + @property + def layers(self) -> Iterable[torch.nn.Module]: + return self.decoder.layers + + def forward( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.decoder.forward(input_ids) + + def forward_scan( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + inputs_embeds = self.decoder.embed_tokens(input_ids) + # embed positions + assert isinstance(inputs_embeds, torch.Tensor) + # decoder layers + hidden_states = scan_layers(self.decoder.layers, inputs_embeds) + hidden_states = self.decoder.norm(hidden_states) + # [B, S, H] -> [B, S, V] + return self.decoder.output(hidden_states) + + # Fix the random seed to avoid flakes. + with torch.random.fork_rng(): + with torch_xla.xm.fork_rng(): + torch.random.manual_seed(42) + torch_xla.xm.set_rng_state(42) + + # Make it smaller for fast model run and comparisons. + config = DecoderOnlyConfig( + hidden_size=128, intermediate_size=8 * 128, vocab_size=256) + model = DecoderOnlyModelWithScan(config=config).to(self.device) + batch_size = 2 + sequence_length = 8 + + # Generate random input_ids within the range of the vocabulary size + input_ids = torch.randint(0, config.vocab_size, + (batch_size, sequence_length)).to(self.device) + + loop_model = deepcopy(model) + scan_model = deepcopy(model) + torch_xla.sync(wait=True) + + # Run the loop-based model. + loop_output = loop_model(input_ids.clone()) + loop_output.sum().backward() + torch_xla.sync() + + # Run again, this time using `scan` + scan_output = scan_model.forward_scan(input_ids.clone()) + scan_output.sum().backward() + + # Before materializing the tensors, check that tensor HLO has `While` in it. + self.assert_while_found_in_hlo(scan_output) + for layer_scan in scan_model.layers: + for (name, param_scan) in layer_scan.named_parameters(): + if param_scan.grad is not None: + self.assert_while_found_in_hlo(param_scan.grad) + + torch_xla.sync() + + # Compare results + super().compareResults( + scan_output, loop_output, abs_err=0.0001, rel_err=0.0001) + + # Check gradients + checks = 0 + for layer_scan, layer_loop in zip(scan_model.layers, loop_model.layers): + for (name, + param_scan), (name2, + param_loop) in zip(layer_scan.named_parameters(), + layer_loop.named_parameters()): + assert name == name2 + # Either the parameter should have gradient in both, or it should not + # have gradient in both. + assert (param_scan.grad is not None) == (param_loop.grad is not None) + # Check gradients + if param_scan.grad is not None and param_loop.grad is not None: + # Check that they are not the same tensor + assert id(param_scan.grad) != id(param_loop.grad) + assert id(param_scan.grad.untyped_storage()) != id( + param_loop.grad.untyped_storage()) + super().compareResults( + param_scan.grad, param_loop.grad, abs_err=0.0001, rel_err=0.0001) + checks = checks + 1 + assert checks > 0 + + def test_heterogenous_layers(self): + layer1 = nn.Linear(128, 128).to(torch_xla.device()) + layer2 = nn.Sequential(nn.Linear(128, 128).to(torch_xla.device())) + with self.assertRaisesRegex(ValueError, "mismatched keys"): + scan_layers([layer1, layer2], + torch.zeros((128,), device=torch_xla.device())) + + def test_mismatched_shapes(self): + layer1 = nn.Linear(128, 128).to(torch_xla.device()) + layer2 = nn.Linear(128, 129).to(torch_xla.device()) + with self.assertRaisesRegex(ValueError, "Shape mismatch"): + scan_layers([layer1, layer2], + torch.zeros((128,), device=torch_xla.device())) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_operations.py b/test/test_operations.py index 8d772a140f51..cc3a73c45804 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -21,6 +21,7 @@ import itertools import math from numbers import Number +from functools import reduce import numpy import random import re @@ -2639,15 +2640,40 @@ def test_api(self): result = a + b - ctx = torch_xla._XLAC.lowering.LoweringContext() + ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") ctx.build([result]) hlo = ctx.hlo() hlo_text = ctx.hlo_text() - self.assertTrue('opcode: "parameter"' in hlo_text) - self.assertTrue('opcode: "add"' in hlo_text) + self.assertIn('MyCustomName', hlo_text) + self.assertIn('opcode: "parameter"', hlo_text) + self.assertIn('opcode: "parameter"', hlo_text) + self.assertIn('opcode: "add"', hlo_text) mapping = ctx.parameter_id_tensor_mapping() self.assertEqual(len(mapping), 2) + def test_get_parameters_scalar(self): + """Scalar tensors parameters may be shared in the HLO graph if their + numerical values are equal. `parameter_id_tensor_mapping` needs to handle + that appropriately. + """ + + device = torch_xla.device() + tensors = [] + for i in range(10): + # Add three copies of the same value. + tensors.append(torch.tensor(i, device=device)) + tensors.append(torch.tensor(i, device=device)) + tensors.append(torch.tensor(i, device=device)) + result = reduce(lambda a, b: a + b, tensors) + ctx = torch_xla._XLAC.lowering.LoweringContext() + ctx.build([result]) + mapping = ctx.parameter_id_tensor_mapping() + + import json + hlo_json = json.loads(ctx.hlo_json()) + num_parameters = len(hlo_json["hostProgramShape"]["parameters"]) + self.assertEqual(len(mapping), num_parameters) + class TestGeneric(test_utils.XlaTestCase): diff --git a/test/test_scan.py b/test/test_scan.py deleted file mode 100644 index 6926c01fb013..000000000000 --- a/test/test_scan.py +++ /dev/null @@ -1,107 +0,0 @@ -import sys -import unittest -import torch_xla -import torch -from torch_xla.experimental.scan import scan -from torch.utils._pytree import tree_map, tree_flatten, tree_iter - -from test_utils import XlaTestCase - - -def _loopy_scan(fn, init, xs): - """A simple scan implemented with for loops serving as reference - implementation.""" - carry = init - ys = [] - xs_len = len(next(iter(tree_iter(xs)))) - for i in range(xs_len): - carry, y = fn(carry, tree_map(lambda x: x[i], xs)) - ys.append(y) - ys = tree_map(lambda *x: torch.stack(x), *ys) - return carry, ys - - -class ScanTest(XlaTestCase): - - def setUp(self): - self.device = torch_xla.device() - - def compare_pytree(self, expected_pytree, actual_pytree): - flat_expected_pytree, expected_spec = tree_flatten(expected_pytree) - flat_actual_pytree, actual_spec = tree_flatten(actual_pytree) - assert expected_spec == actual_spec - super().compareResults(flat_expected_pytree, flat_actual_pytree) - - def run_test(self, step_fn, init, xs): - # Actual output - final_carry, ys = scan(step_fn, init, xs) - torch_xla.sync() - - # Expected output - expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs) - torch_xla.sync() - - # Compare - self.compare_pytree(expected_final_carry, final_carry) - self.compare_pytree(expected_ys, ys) - - return final_carry, ys - - def test_scan_forward_simple(self): - """This test uses `scan` to implement `torch.cumsum`.""" - - def step_fn(carry, x): - new_carry = carry + x - y = new_carry - return new_carry, y - - init = torch.tensor([0.0, 0.0], device=self.device) - xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) - final_carry, ys = self.run_test(step_fn, init, xs) - - # Also ensure that our loop-based scan is correct, with manual checks - # that replicate the step_fn. - expected_final_carry = torch.sum(xs, dim=0) + init - expected_ys = torch.cumsum(xs, dim=0) - self.compare_pytree(expected_final_carry, final_carry) - self.compare_pytree(expected_ys, ys) - - def test_scan_fn_not_callable(self): - init = torch.tensor([1.0, 1.0], device=self.device) - xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) - with self.assertRaises(ValueError): - scan(1000, init, xs) # type: ignore - - def test_scan_incompatible_length(self): - init = torch.tensor([1.0, 1.0], device=self.device) - xs_1 = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], - device=self.device) - xs_2 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device) - with self.assertRaises(ValueError): - scan(lambda a, b: (a, b), init, (xs_1, xs_2)) - - def test_scan_forward_tuples(self): - """Test scanning over the leading axis of a tuple of tensors simultaneously, - which is a simple PyTree.""" - - def step_fn(carry, x): - carry1, carry2 = carry - x1, x2 = x - new_carry1 = carry1 + x1.sum() - new_carry2 = carry2 + x2.sum() - y1 = x1 * 2 - y2 = x2 * 2 - return (new_carry1, new_carry2), (y1, y2) - - init = (torch.tensor([0.0], device=self.device), - torch.tensor([1.0, 2.0], device=self.device)) - - xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device), - torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device)) - - self.run_test(step_fn, init, xs) - - -if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index c29e6f42be56..8d5e74bde03a 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -25,7 +25,8 @@ XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrappi python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_while_loop.py -python3 test/test_scan.py +python3 test/scan/test_scan.py +python3 test/scan/test_scan_layers.py python3 test/test_pallas.py -v python3 test/test_pallas_spmd.py python3 test/test_tpu_paged_attention_kernel.py diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 7e36a20c0e5b..84a9d066cbff 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -978,10 +978,14 @@ void BuildProfilerSubmodule(py::module* m) { class PyLoweringContext { public: - PyLoweringContext() : PyLoweringContext(bridge::GetCurrentDevice()) {} + PyLoweringContext() + : PyLoweringContext("PyLoweringContext", bridge::GetCurrentDevice()) {} - PyLoweringContext(torch::lazy::BackendDevice device) - : lowering_ctx("PyLoweringContext", device) {} + PyLoweringContext(const std::string& name) + : PyLoweringContext(name, bridge::GetCurrentDevice()) {} + + PyLoweringContext(const std::string& name, torch::lazy::BackendDevice device) + : lowering_ctx(name, device) {} // Builds a HLO graph given a set of output tensors. void Build(std::vector tensors) { @@ -1069,7 +1073,6 @@ class PyLoweringContext { // etc.) std::unordered_map GetParameterIdTensorMapping() { // Find parameters in the lowering - const std::vector& param_ids = lowering_ctx.GetParameterSequence(); const std::vector& device_data = lowering_ctx.GetParametersData(); @@ -1086,7 +1089,9 @@ class PyLoweringContext { at::ScalarType dtype = MaybeUpcastToHostTorchType(literal.shape().element_type()); at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); - results[param_ids[i]] = input; + std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); + XLA_CHECK(param_id.has_value()); + results[param_id.value()] = input; } return results; } @@ -1109,12 +1114,13 @@ class PyLoweringContext { torch::lazy::BackendData::Handle handle = data->GetHandle(); // Linearly search parameters and compare opaque handles - const std::vector& param_ids = lowering_ctx.GetParameterSequence(); const std::vector& device_data = lowering_ctx.GetParametersData(); for (int i = 0; i < device_data.size(); ++i) { if (device_data[i]->GetHandle() == handle) { - return param_ids[i]; + std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); + XLA_CHECK(param_id.has_value()); + return param_id.value(); } } return -1; @@ -1186,7 +1192,8 @@ void BuildLoweringContextSubmodule(py::module* m) { py::class_> lowering_context_class(lowering, "LoweringContext", py::module_local()); - lowering_context_class.def(py::init<>()) + lowering_context_class.def(py::init()) + .def(py::init()) .def("build", &PyLoweringContext::Build) .def("buildforiloop", &PyLoweringContext::BuildForiLoop) .def("hlo", &PyLoweringContext::GetHlo) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index c104be7c4386..c2db9b363095 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -136,6 +137,16 @@ xla::XlaOp LoweringContext::GetParameter( return it->second.param; } +std::optional LoweringContext::GetParameterId( + const std::shared_ptr& data) const { + torch::lazy::BackendData::Handle handle = data->GetHandle(); + auto it = parameters_map_.find(handle); + if (it == parameters_map_.end()) { + return std::nullopt; + } + return it->second.index; +} + const std::vector& LoweringContext::GetParametersData() const { return parameters_; @@ -195,13 +206,14 @@ void LoweringContext::AssignOutputOp(const torch::lazy::Output& output, xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) { auto it = emitted_outputs_.find(output); + if (it == emitted_outputs_.end()) { auto post_order = torch::lazy::Util::ComputePostOrder(output.node, &emit_status_); for (auto node : post_order) { LowerNode(node); } - // At this point the outpout better be present, otherwise there is an issue + // At this point the output better be present, otherwise there is an issue // with the lowering code. it = emitted_outputs_.find(output); XLA_CHECK(it != emitted_outputs_.end()) @@ -216,6 +228,7 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { HloMetadataSetter meta_setter(this, node); const XlaNode* casted = dynamic_cast(node); + result_ops = casted->Lower(this); if (!casted->dynamic_dims().empty()) { xla::internal::XlaBuilderFriend builder_friend; diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index e645f959af01..3a36695e1c05 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -52,6 +53,11 @@ class LoweringContext : public torch::lazy::LoweringContext { const std::shared_ptr& data, const std::unordered_set& dynamic_dims = {}); + // If a parameter associated with data has already been declared, returns its + // ID. Otherwise, returns `std::nullopt`. + std::optional GetParameterId( + const std::shared_ptr& data) const; + // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. const std::vector& GetParametersData() const; diff --git a/torch_xla/experimental/pytreeify.py b/torch_xla/experimental/pytreeify.py new file mode 100644 index 000000000000..9fb0d282526c --- /dev/null +++ b/torch_xla/experimental/pytreeify.py @@ -0,0 +1,50 @@ +import torch.utils._pytree as pytree +from torch.autograd import Function + + +# Taken from https://github.com/pytorch/pytorch/issues/96337 +# +# The main purpose is to support autograd in the `scan` operator, which takes in +# PyTrees and outputs PyTrees. Builtin PyTorch autograd ignores tensors in +# non-trivial PyTrees such as dictionaries of tensors. This decorator adds +# arbitrary PyTree support by flattening the PyTree before handing to PyTorch and +# unflattening on the way back. +def pytreeify(cls): + assert issubclass(cls, Function) + + orig_fw = cls.forward + orig_bw = cls.backward + orig_apply = cls.apply + + def new_apply(*inp): + flat_inp, struct = pytree.tree_flatten(inp) + out_struct_holder = [] + flat_out = orig_apply(struct, out_struct_holder, *flat_inp) + assert flat_out is not None + assert len(out_struct_holder) == 1 + return pytree.tree_unflatten(flat_out, out_struct_holder[0]) + + def new_forward(ctx, struct, out_struct_holder, *flat_inp): + inp = pytree.tree_unflatten(flat_inp, struct) + out = orig_fw(ctx, *inp) + flat_out, out_struct = pytree.tree_flatten(out) + ctx._inp_struct = struct + ctx._out_struct = out_struct + out_struct_holder.append(out_struct) + return tuple(flat_out) + + def new_backward(ctx, *flat_grad_outputs): + grad_outputs = pytree.tree_unflatten(flat_grad_outputs, ctx._out_struct) + if not isinstance(grad_outputs, tuple): + grad_outputs = (grad_outputs,) + grad_inputs = orig_bw(ctx, *grad_outputs) + flat_grad_inputs, grad_inputs_struct = pytree.tree_flatten(grad_inputs) + if grad_inputs_struct != ctx._inp_struct: + raise RuntimeError("The backward generated an arg structure that doesn't " + "match the forward's input.") + return (None, None) + tuple(flat_grad_inputs) + + cls.apply = new_apply + cls.forward = new_forward + cls.backward = new_backward + return cls diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 9008e03dbd91..7e872edc42d2 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -2,12 +2,50 @@ Reference: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html +# High level design + +The implementation is factored into two layers: core and autograd. The core +layer focuses on the numerical scan operation without any gradient tracking, and +the autograd layer adds forward and backward support using the scan primitive in +core. + +## Core + +The `_scan_impl_flat` function implements the core logic of scan on flattened +tensors. It uses XLA's `While` op to iterate over the leading dimension of the +input tensors. The body of the `While` loop calls `fn` and updates the carry and +output tensors. + +The `_scan_impl_pytree` function adds PyTree support on top. It flattens the +input PyTrees, calls `_scan_impl_flat` to perform the scan on the flattened +tensors, and then unflattens the results. Because gradients are sometimes +`None`, it also hides any `None`s in PyTrees from `_scan_impl_flat`, +simplifying the latter's implementation. + +## Autograd + +The `value_and_grad_partitioned` function symbolically traces the user-provided +function `fn` to obtain the forward and backward computation graphs. It then +creates two functions, `forward` and `backward`, that can be used in the +`Scan.forward` and `Scan.backward` methods. + +The `scan` operator is implemented as a PyTorch autograd Function, `Scan`. +The `Scan.forward` method scans the forward graph over the inputs. +The `Scan.backward` method scans the backward graph over the gradients and +activations. """ -from typing import Callable, TypeVar +import itertools +from typing import Callable, Dict, Sequence, TypeVar, Tuple, List, Optional, overload import torch -from torch.utils._pytree import tree_map, tree_iter +import torch.autograd +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten, tree_iter, PyTree +from functorch.compile import aot_function, make_boxed_func, default_partition # type: ignore + +import torch_xla +import torch_xla.core.xla_builder as xb +from torch_xla.experimental.pytreeify import pytreeify Carry = TypeVar('Carry') X = TypeVar('X') @@ -18,11 +56,13 @@ def scan( fn: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, xs: X, + partition_fn=default_partition, + # TODO: consider exposing knobs to control the RNG seed used in each `fn` iteration. ) -> tuple[Carry, Y]: """Apply a function over leading dimension of tensors while carrying along state. - + This is similar to the JAX `jax.lax.scan` function found in [1]. - + You may use it to loop over the leading dimension of tensors efficiently. If `xs` is a single tensor, this function is roughly equal to the following Python code: @@ -33,33 +73,65 @@ def scan(fn, init, xs): carry, y = fn(carry, xs[i]) ys.append(y) return carry, torch.stack(ys, dim=0) - + In the general case, `Carry`, `X`, and `Y` can be arbitrary PyTrees. This function will iterate through the leading dimension of every leaf element of `xs` simultaneously, and pass a slice of those elements to `fn` as another PyTree. This means you may scan over multiple tensors and produce multiple output tensors at once. - - Args: - fn: a Python callable that accepts two PyTrees of tensors: the carry object and the - slices of `xs` along its leading dimension. It should return two PyTrees: the carry - object and the slices of the output. The returned carry object will be passed to - the next invocation of `fn`. + Notes: - init: the initial carry object passed to the first invocation of `fn`. + `fn` must be AOTAutograd traceable. That requires PyTorch to understand the operations + within. For example if you invoke a custom kernel inside `fn`, you need to register the + custom kernel. See https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html. + + Args: + fn: A Python callable that accepts two PyTrees of tensors: the carry object and the + slices of `xs` along its leading dimension. It should return two PyTrees: the carry + object and the slices of the output. The returned carry object will be passed to + the next invocation of `fn`. + + init: The initial carry object passed to the first invocation of `fn`. - xs: the input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along - the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of - tensor), `fn` will get PyTrees of slices. In that case the leading dimension size - of the leaves in the PyTree must be the same. + xs: The input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along + the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of + tensor), `fn` will get PyTrees of slices. In that case the leading dimension size + of the leaves in the PyTree must be the same. + + partition_fn: (Optional[Callable]) Since `scan` uses AOTAutograd to trace `fn`, you may + override what computation happen in the forward and backward passes by specifying + different partition functions. `default_partition` implies no activation checkpointing. + You may specify `functorch.compile.min_cut_rematerialization_partition` to use min-cut + based activation checkpointing. You may also write your own partitioner to insert any + custom logic such as host offloading of activations. Returns: - (carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and `ys` is a PyTree with the same structure as `xs`, but where the leaves are formed by stacking the leaf outputs of `fn` respectively. This means if your `fn` returns `(carry, (y1, y2))` then this function will return `(carry, (torch.stack(all_y1), torch.stack(all_y2)))`. + + Example: + + >>> # Example of using `scan` to implement `torch.cumsum`. + >>> import torch_xla.runtime + >>> import torch + >>> from torch_xla.experimental.scan import scan + >>> + >>> def fn(carry, x): + >>> new_carry = carry + x + >>> y = new_carry + >>> return new_carry, y + >>> + >>> with torch_xla.runtime.xla_device(): + >>> init = torch.tensor([0.0, 0.0], requires_grad=True) + >>> xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + >>> requires_grad=True) + >>> final_carry, ys = scan(fn, init, xs) + >>> torch_xla.sync() + >>> print(final_carry) # Should be [9.0, 12.0] + >>> print(ys) # Should be [[1.0, 2.0], [4.0, 6.0], [9.0, 12.0]] [1]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html """ @@ -82,14 +154,506 @@ def scan(fn, init, xs): if xs_length is None: raise ValueError(f"`xs` {xs} is an empty PyTree.") - carry = init - ys = [] + forward, backward = value_and_grad_partitioned( + fn, init, xs, partition_fn=partition_fn) + carry, ys = Scan.apply(forward, backward, init, xs) # type: ignore + return carry, ys + + +def value_and_grad_partitioned( + fn: Callable[[Carry, X], tuple[Carry, Y]], + init: Carry, + xs: X, + partition_fn=default_partition) -> tuple[Callable, Callable]: + """ + Given a user `fn` to be scanned over the leading dimension of the input `xs` + PyTree and an initial carry object `init`, symbolically traces `fn` and + returns two functions, `forward` and `backward`, which wrap the forward and + backward graphs of `fn` and plumbs through intermediate activations. + Specifically, given + + `fn(carry, x) -> (new_carry, y)` + + this function will build and return + + `forward(carry, x) -> (new_carry, (y, activations))` + + `backward(grad_new_carry, (grad_y, activations)) -> (grad_carry, grad_x)` + + where `grad_y` is the gradient w.r.t `y`, and `grad_new_carry` is the gradient + w.r.t. `new_carry`. + + `activations` will always be a flat list of tensors. + + This is similar to the `value_and_grad` transform found in JAX, but additionally + partitions and returns separate forward/backward passes, so that we may later + use them in the `autograd.Function` implementation of `Scan`. + + Args: + fn: (Callable[[Carry, X], tuple[Carry, Y]]) A callable with signature + `fn(carry, x_t) -> (new_carry, y_t)`, representing the function to be scanned. + + init: (Carry) The initial carry object. + + xs: (X) A PyTree of inputs to be scanned over. + + partition_fn: An optional partitioning function used to partition fn into + forward and backward graphs. + + Returns: + A tuple of `(forward, backward)`, detailed in the docstring of this function. + """ + + # Make some fake tensors to trace the user function and obtain the + # forward and backward graphs. Note that the init/carry fake tensor + # always requires grad. That's because even if the user passed in some + # `init` that does not require grad, we still want gradients to flow + # through the `carry` from one iteration of the user function to the + # next. In summary, the `carry` argument used to trace a user function + # to get a correct backward pass always requires grad. + def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor: + return torch.empty_like( + v, dtype=v.dtype, device=v.device, requires_grad=requires_grad) + + fake_carry_pytree = tree_map(make_fake_tensor, init) + fake_x_pytree = tree_map( + lambda v: make_fake_tensor(v[0], requires_grad=v.requires_grad), xs) + + with torch.enable_grad(): + fw_compiler, get_fwd = _make_get_graph_compiler() + bw_compiler, get_bwd = _make_get_graph_compiler() + fn_compiled = aot_function( + fn, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn) + _, unflatten_bwd_out = tree_flatten_none((fake_carry_pytree, fake_x_pytree)) + out = fn_compiled(fake_carry_pytree, fake_x_pytree) + # How many outputs out of the fwd_graph is actually outputs of `fn`, and not + # intermediate activations. + num_out = len(list(tree_iter(out))) + # Capture the backward. + out, unflatten_fwd_out = tree_flatten_none(out) + torch.autograd.backward(out, tree_map(lambda v: torch.ones_like(v), out)) + + fwd_graph = get_fwd() + bwd_graph = get_bwd() + + def forward(carry, x): + flat_carry, _ = tree_flatten(carry) + flat_x, _ = tree_flatten(x) + out = fwd_graph(*flat_carry, *flat_x) + actual_out, activations = split(out, num_out) + carry, y = unflatten_fwd_out(actual_out) + y = (y, activations) + return carry, y + + def backward(carry, x): + grad_new_carry, _ = tree_flatten(carry) + (grad_y, activations) = x + grad_y, _ = tree_flatten_none(grad_y) + out = bwd_graph(*activations, *grad_new_carry, *grad_y) + grad_carry, grad_x = unflatten_bwd_out(out) + return grad_carry, grad_x + + return forward, backward + + +def _make_get_graph_compiler(): + """ + Creates a compiler that records the graph, and a getter + function to retrieve them. + """ + graph: List[Optional[torch.fx.GraphModule]] = [None] + + def forward_comp(fx_module: torch.fx.GraphModule, _): + assert graph[0] is None + graph[0] = fx_module + return make_boxed_func(fx_module) + + def get_graph(): + g = graph[0] + assert g is not None + return g + + return forward_comp, get_graph + + +@pytreeify +class Scan(torch.autograd.Function): + + @staticmethod + def forward(ctx, forward, backward, init, xs): + # Forward pass, save activations for backward + ctx._backward = backward + with torch.no_grad(): + carry, ys = _scan_impl_pytree(forward, init, xs) + ys, activations = ys + ctx.save_for_backward(*activations) + return carry, ys + + @staticmethod + def backward(ctx, grad_carry, grad_ys): + activations = ctx.saved_tensors + backward = ctx._backward + with torch.no_grad(): + # Reverse loop to propagate gradients from last iteration to first. + grad_init, grad_xs = _scan_impl_pytree( + backward, grad_carry, (grad_ys, activations), reverse=True) + return None, None, grad_init, grad_xs + + +def _scan_impl_pytree(fn, init, xs, reverse: bool = False): + """Forward logic of scan without gradient tracking. `fn` operates on + PyTrees. `init` and `xs` are also PyTrees. + + See the `Scan` class which implements an autograd `Function` and builds + autograd support on top of `_scan_impl`. + """ + flat_init, unflatten_carry = tree_flatten_none(init) + flat_xs, unflatten_xs = tree_flatten_none(xs) + unflatten_y: Callable[..., PyTree] = lambda _: () # Set by `flat_fn`. + + def flat_fn( + carry: Sequence[torch.Tensor], x: Sequence[torch.Tensor] + ) -> Tuple[Sequence[torch.Tensor], Sequence[torch.Tensor]]: + nonlocal unflatten_y + carry_pytree = unflatten_carry(carry) + x_pytree = unflatten_xs(x) + carry_pytree, y_pytree = fn(carry_pytree, x_pytree) + flat_carry, _ = tree_flatten_none(carry_pytree) + flat_y, unflatten_y = tree_flatten_none(y_pytree) + return flat_carry, flat_y + + flat_carry, flat_y = _scan_impl_flat( + flat_fn, flat_init, flat_xs, reverse=reverse) + return unflatten_carry(flat_carry), unflatten_y(flat_y) + - for i in range(xs_length): - carry, y = fn(carry, tree_map(lambda x: x[i], xs)) - ys.append(y) +def tree_flatten_none(pytree: PyTree): + """ + Flattens input `pytree`, and filters out any `None` leaf PyTree nodes. + Returns the flattened list, and an unflatten function and also adds back + the removed `None`s in their correct location. + """ + flat, spec = tree_flatten(pytree) + flat, add_none = _remove_none(flat) + + def unflatten(flat): + flat = add_none(flat) + return tree_unflatten(flat, spec) + + return flat, unflatten + + +def _remove_none(s: Sequence[Optional[torch.Tensor]]): + """ + Filters out `None` values from `s`. Returns the filtered sequence, + and another function that will add back the `None` values when given a + sequence of the same structure. + """ + filtered = [v for v in s if v is not None] + none_mask = [v is None for v in s] + + def add_back_nones(s_filtered): + res = [] + idx_filtered = 0 + for is_none in none_mask: + if is_none: + res.append(None) + else: + res.append(s_filtered[idx_filtered]) + idx_filtered += 1 + return res + + return filtered, add_back_nones + + +def dynamic_update_slice(ys: xb.Op, y: xb.Op, idx: xb.Op) -> xb.Op: + # See https://openxla.org/xla/operation_semantics#dynamicupdateslice. + y = y.broadcast([1]) + indices = [idx] + for _ in range(ys.shape().rank - 1): + indices.append(idx.zeros_like()) + return ys.dynamic_update_slice(y, indices) + + +def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op: + indices = [idx] + for _ in range(xs.shape().rank - 1): + indices.append(idx.zeros_like()) + slice_shape = list(xs.shape().sizes) + slice_shape[0] = 1 + sliced = xs.dynamic_slice(indices, slice_shape) + shape = list(xs.shape().sizes) + shape = shape[1:] + return sliced.reshape(shape) + + +class Builder: + + def __init__(self, name: str): + self._builder = xb.create_builder(name) + self._params = [] + self._param_tensors = [] + + def add_param(self, val: torch.Tensor): + idx = len(self._params) + param = xb.mkparam(self._builder, idx, xb.tensor_shape(val)) + self._params.append(param) + self._param_tensors.append(val) + return idx + + def params(self) -> Tuple[xb.Op, ...]: + return tuple(self._params) + + def param_tensors(self) -> Tuple[torch.Tensor, ...]: + return tuple(self._param_tensors) + + def num_params(self) -> int: + return len(self._params) + + +def _scan_impl_flat(fn, + init: Sequence[torch.Tensor], + xs: Sequence[torch.Tensor], + reverse: bool = False): + """Forward logic of scan without gradient tracking. `fn` operates on + two flat list of tensors. `init` and `xs` are also flat lists of tensors. None + of the tensors will be `None`. + + See the `Scan` class which implements an autograd `Function` and builds + autograd support on top of `_scan_impl`. + + ## Handling of random numbers + + When `fn` generates random numbers (e.g. it uses a dropout layer), we need to + ensure that each iteration of `fn` within the scan yields different random + numbers, despite running the same HLO operations. JAX requires the user to + explicitly fork the RNG state and pass it to `fn`. In PyTorch, the RNG state + is an implicit global variable. Therefore, we take a slightly different + approach: + + - Identify usage of RNG state via `_get_tensors_xla_device_data_node`. + - Create N different copies of the RNG state contained in a tensor. + - While building the `While` op body, index into the RNG state tensor at the + current iteration and provide that seed value to `fn`. + + ## Handling of HLO parameters + + Let's say the user writes a `fn` like this: + + def fn(carry, x): + foo = torch.zeros(8) + return carry, x + foo + + `fn` will lower into an HLO computation like this: + + HloModule Fn, entry_computation_layout={ + (f32[8], f32[8], f32[8]) -> (f32[8], f32[8]) + } + + The HLO computation takes three parameters while `fn` takes two arguments. + That's because IR lowering does not distinguish if a leaf data tensor comes from + a function argument or from within the function. All data tensors are lowered + into HLO parameters. We'll call them "hoisted variables" or `hoisted_vars`, since + instead of baking the value of those tensors as literals in the HLO graph, + they are turned into additional parameters of the computation. + """ + carry_len = len(init) + xs_len = len(xs) + + # Abstractly trace and lower `fn`. + # Later we will include `fn_computation` within the while loop body. + def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: + return torch.empty( + v.size(), dtype=v.dtype).to(device).requires_grad_(v.requires_grad) - # Combine the list of PyTrees into one PyTree, where the leaves are - # stacked into a new major axis. - ys = tree_map(lambda *x: torch.stack(x), *ys) + device = torch_xla.device() + fake_carry = tree_map(make_fake_tensor, init) + fake_x = tree_map(lambda v: make_fake_tensor(v[0]), xs) + fake_output_carry, fake_output_y = fn(fake_carry, fake_x) + + y_len = len(fake_output_y) + fn_outputs = fake_output_carry + fake_output_y + + fn_ctx = torch_xla._XLAC.lowering.LoweringContext("FnComputation") + fn_ctx.set_name_string("fn_ctx") + fn_ctx.build(list(fn_outputs)) + fn_hlo = fn_ctx.hlo() + fn_computation = xb.computation_from_module_proto("fn_computation", fn_hlo) + + # Figure out the shape of `ys` from the abstract tracing. + fn_carry_out, fn_y_out = split(fn_outputs, carry_len) + assert carry_len + y_len == len(fn_outputs) + fn_carry_shapes = [v.shape for v in fn_carry_out] + fn_y_shapes = [v.shape for v in fn_y_out] + for fn_carry_shape, init_leaf in zip(fn_carry_shapes, init): + assert fn_carry_shape == init_leaf.shape, f"`fn` must keep the `carry` shape unchanged. \ + Got {fn_carry_shape} but expected {init_leaf.shape}" + + builder = Builder('scan') + num_iters = next(iter(tree_iter(xs))).size(0) + ys = [ + torch.zeros((num_iters, *fn_y_shape), device=device) + for fn_y_shape in fn_y_shapes + ] + # Start the `curr_iter` loop variable at zero. + zero = torch.tensor(0, device=device) + builder.add_param(zero) + + # We are building a bigger XLA computation (the while loop) that calls + # a smaller computation (`fn_computation`). This is a mapping from + # `fn_computation` param ID to While computation param ID. + fn_param_id_to_while_param_id: Dict[int, int] = {} + + # Add carry and x. + for real, fake in ((init, fake_carry), (xs, fake_x)): + for val, fake_val in zip(real, fake): + idx = builder.add_param(val) + param_id = fn_ctx.tensor_parameter_id(fake_val) + if param_id != -1: + fn_param_id_to_while_param_id[param_id] = idx + + # Add the output as a param since our While computation consumes it, updates + # one slice, and returns the updated ys in each iteration. + for val in ys: + builder.add_param(val) + + # Detect hoisted variables. + hoisted_vars: Dict[int, torch.Tensor] = fn_ctx.parameter_id_tensor_mapping() + for v in itertools.chain(fake_carry, fake_x): + param_id = fn_ctx.tensor_parameter_id(v) + if param_id != -1: + del hoisted_vars[param_id] + + # Detect RNG seed usage within the scanned function within hoisted variables. + ids, i_values = torch_xla._XLAC._get_tensors_xla_device_data_node(fn_outputs) + seed_info_id = torch_xla._XLAC._get_seed_info_id() + seed_parameter_id = None + if seed_info_id in ids: + seed_idx = ids.index(seed_info_id) + seed_parameter_id = fn_ctx.tensor_parameter_id(i_values[seed_idx]) + assert seed_parameter_id != -1, "`fn` uses random seed, but random seed is not \ + a parameter to the traced HLO graph" + + # Replace the single seed value with a tensor of seeds, one per iteration. + seed_tensor = hoisted_vars[seed_parameter_id] + assert seed_tensor.dtype == torch.int64 + hoisted_vars[seed_parameter_id] = torch.randint( + 0, 2**62, (num_iters,), dtype=torch.int64, device=torch_xla.device()) + + # Add hoisted variables as While computation params as well, + # including the potentially updated seed tensor. + for param_id, tensor in hoisted_vars.items(): + idx = builder.add_param(tensor.to(torch_xla.device())) + fn_param_id_to_while_param_id[param_id] = idx + + # Since we are threading five objects through the body_fn: + # + # - curr_iter: the current loop iteration + # - carry: the scan state + # - xs: the flattened input pytree + # - ys: the flattened output of fn + # - hoisted_vars: tensors not provided as arguments to fn but still used by fn. + # + # We need to concatenate all into one big list prior to entering `body_fn` and + # `cond_fn`, and split them back which is easier to work with after that. This + # pair of `pack`, `unpack` functions is for that purpose. + T = TypeVar('T') + + def pack(curr_iter: T, carry: Sequence[T], xs: Sequence[T], ys: Sequence[T], + hoisted_vars: Sequence[T]) -> Tuple[T, ...]: + return tuple(itertools.chain((curr_iter,), carry, xs, ys, hoisted_vars)) + + def unpack(seq: Sequence[T]) -> Tuple[T, List[T], List[T], List[T], List[T]]: + curr_iter, carry, xs, ys, hoisted_vars = split( + list(seq), 1, carry_len, xs_len, y_len) + curr_iter = curr_iter[0] + return curr_iter, carry, xs, ys, hoisted_vars + + def replace_rng_seed(curr_iter: xb.Op, *while_params: xb.Op): + """Slices the pre-generated seed tensor for the current iteration.""" + if seed_parameter_id is None: + return while_params + idx = fn_param_id_to_while_param_id[seed_parameter_id] + replaced = list(while_params) + replaced[idx] = dynamic_slice(replaced[idx], curr_iter) + return replaced + + def call_fn_computation(*while_params: xb.Op) -> xb.Op: + # We need to order the tensors in increasing parameter ID order when + # passing them to `xb.Op.call`. + fn_inputs = [ + while_params[fn_param_id_to_while_param_id[i]] + for i in range(len(fn_param_id_to_while_param_id)) + ] + return xb.Op.call(fn_computation, fn_inputs) + + def cond_fn(curr_iter: xb.Op, *rest): + return curr_iter < xb.Op.scalar( + curr_iter.builder(), num_iters, dtype=xb.Type.S64) + + def body_fn(*while_params: xb.Op): + curr_iter, carry, xs, ys, hoisted_vars = unpack(while_params) + if reverse: + max_iter = xb.Op.scalar( + curr_iter.builder(), num_iters - 1, dtype=xb.Type.S64) + idx = max_iter - curr_iter + else: + idx = curr_iter + x = [dynamic_slice(v, idx) for v in xs] + result = call_fn_computation( + *replace_rng_seed(idx, curr_iter, *carry, *x, *ys, *hoisted_vars)) + for i in range(carry_len): + carry[i] = result.get_tuple_element(i) + for i in range(y_len): + y = result.get_tuple_element(i + carry_len) + ys[i] = dynamic_update_slice(ys[i], y, idx) + one = xb.Op.scalar(curr_iter.builder(), 1, dtype=xb.Type.S64) + return pack(curr_iter + one, carry, xs, ys, hoisted_vars) + + res = xb.Op.mkwhile(builder.params(), cond_fn, body_fn) + computation = res.build('scan') + outputs = torch_xla._XLAC._xla_user_computation('xla::scan', + builder.param_tensors(), + computation) + _curr_iter, carry, xs, ys, _hoisted_vars = unpack(outputs) return carry, ys + + +U = TypeVar('U') + + +@overload +def split(seq: List[U], *part_lengths: int) -> Tuple[List[U], ...]: + ... + + +@overload +def split(seq: Tuple[U, ...], *part_lengths: int) -> Tuple[Tuple[U, ...], ...]: + ... + + +def split(seq: Sequence[U], *part_lengths: int) -> Tuple[Sequence[U], ...]: + """Splits a sequence into subsequences with given lengths. + + Args: + seq: The sequence (list or tuple) to split. + *part_lengths: The lengths of the subsequences, except the last subsequence. + + Example: + + a, b, c = split((1, 2, 3, 4, 5), 2, 2) + # a == (1, 2), b == (3, 4), c == (5, ) + + Returns: + A tuple of subsequences (lists or tuples). + """ + parts = [] + start = 0 + for length in part_lengths: + parts.append(seq[start:start + length]) + start += length + parts.append(seq[start:]) + return tuple(parts) diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py new file mode 100644 index 000000000000..fa4c6121fa1a --- /dev/null +++ b/torch_xla/experimental/scan_layers.py @@ -0,0 +1,142 @@ +from typing import Iterable, Mapping, Sequence, Dict, Tuple + +import torch +import torch.nn as nn +from torch.utils._pytree import tree_map +from functorch.compile import default_partition + +from torch_xla.experimental.scan import scan + + +def scan_layers(layers: Iterable[torch.nn.Module], + input_data, + partition_fn=default_partition): + """Runs each layer in `layers` sequentially, starting with `input_data`. + + `input_data` is provided as input to the first layer in `layers`. The output of one + layer is provided as input to next layer. + + All modules in `layers` must have the same structure, and they must perform the same + calculations given the same model parameters and inputs. In practice, this means you + cannot use different dropout probabilities, parameter shapes, activation functions etc., + across the `layers`. + + Under these conditions, this function is equivalent to + + sequential = torch.nn.Sequential(*layers) + sequential(input_data) + + This function can be faster to compile since it reuses the XLA computation of the + first layer to perform the computation of all other layers. + + Args: + layers: (Iterable[torch.nn.Module]) A list of layers to run. + + input_data: The input to be given to the first layer from `layers`. + + partition_fn: (Optional[Callable]) The graph parition function passed to AOTAutograd. + Since this function uses AOTAutograd to trace `fn`, you may override what computation + happen in the forward and backward passes by specifying different partition functions. + `default_partition` implies no activation checkpointing. You may specify + `functorch.compile.min_cut_rematerialization_partition` to use min-cut based + activation checkpointing. You may also write your own partitioner to insert any custom + logic such as host offloading of activations. + + Returns: + The output of the last layer from `layers`. + + Example: + + >>> import torch_xla.runtime + >>> import torch + >>> import torch.nn as nn + >>> from torch_xla.experimental.scan_layers import scan_layers + >>> with torch_xla.runtime.xla_device(): + >>> layers = [nn.Linear(16, 16) for i in range(10)] + >>> input = torch.randn(16) + >>> output = scan_layers(layers, input) + >>> assert output.shape == (16,) # Output is the 10-th layer output + >>> print(output) # Some random numbers + """ + # Handle empty layers case. + try: + first_layer = next(iter(layers)) + except StopIteration: + return input_data + + # Extract and stack the parameters and buffers into pytrees. + params_and_buffers = [_extract_weights_and_buffers(layer) for layer in layers] + params_list = [p for p, _ in params_and_buffers] + buffers_list = [b for _, b in params_and_buffers] + + _ensure_same_structure(params_list) + _ensure_same_structure(buffers_list) + + stacked_params = tree_map(lambda *tensors: torch.stack(tensors, dim=0), + *params_list) + stacked_buffers = tree_map(lambda *tensors: torch.stack(tensors, dim=0), + *buffers_list) + + # Use the first layer as the example/template layer. + from copy import deepcopy + example_layer = deepcopy(first_layer) + + # Define the function to apply at each step + def one_layer(carry, params_buffers): + # Apply the current layer's weights and biases to the example layer, + # then run the resulting layer. + output = torch.func.functional_call( # type: ignore + example_layer, params_buffers, carry, strict=True) + return output, None + + stacked_params_buffers = (stacked_params, stacked_buffers) + final_carry, _ = scan( + one_layer, input_data, stacked_params_buffers, partition_fn=partition_fn) + + return final_carry + + +def _extract_weights_and_buffers( + module: nn.Module +) -> Tuple[Dict[str, torch.nn.Parameter], Dict[str, torch.Tensor]]: + """ + Extracts the parameters and buffers from a PyTorch module and + stores them in separate dictionaries. + """ + weights_dict = {name: param for name, param in module.named_parameters()} + buffers_dict = {name: buffer for name, buffer in module.named_buffers()} + return weights_dict, buffers_dict + + +def _ensure_same_structure(dicts: Sequence[Mapping[str, torch.Tensor]]): + """ + Verifies that all dictionaries in `dicts` have the same structure: + they have the same keys and all the values have the same shape. + """ + if not dicts: + return + + reference_keys = set(dicts[0].keys()) + reference_shapes = {key: dicts[0][key].shape for key in reference_keys} + + for idx, current_dict in enumerate(dicts[1:], start=1): + current_keys = set(current_dict.keys()) + + # Check if keys match + if current_keys != reference_keys: + missing_keys = reference_keys - current_keys + extra_keys = current_keys - reference_keys + error_message = f"Layer {idx} has mismatched keys." + if missing_keys: + error_message += f" Missing keys: {missing_keys}." + if extra_keys: + error_message += f" Extra keys: {extra_keys}." + raise ValueError(error_message) + + # Check if shapes match for each key + for key in reference_keys: + ref_shape = reference_shapes[key] + current_shape = current_dict[key].shape + if ref_shape != current_shape: + raise ValueError(f"Shape mismatch for '{key}' in layer {idx}: " + f"expected {ref_shape}, got {current_shape}.")