Skip to content

Commit 67cff9b

Browse files
committed
scan and apply_layers
Add the lowering of scan to HLO While op. Introduce apply_layers which can sequentially apply a bunch of layers using scan underneath. Beef up unit tests including linear layers and decoders. add regression test for parameter_id_tensor_mapping add test_apply_layers.py to test shell scripts correctly import decoder model from examples
1 parent 989ac69 commit 67cff9b

File tree

12 files changed

+716
-39
lines changed

12 files changed

+716
-39
lines changed

examples/decoder_only_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
from torch import nn
88

99

10-
# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
10+
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
1111
@dataclass
1212
class DecoderOnlyConfig:
1313
hidden_size: int = 1024
1414
num_hidden_layers: int = 2
1515
num_attention_heads: int = 8
1616
num_key_value_heads: int = 4
17-
intermediate_size = 32 * 1024
18-
vocab_size = 3200
19-
use_flash_attention = False
17+
intermediate_size: int = 32 * 1024
18+
vocab_size: int = 3200
19+
use_flash_attention: bool = False
2020

2121

2222
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ function run_xla_op_tests2 {
197197
run_test "$CDIR/pjrt/test_dtypes.py"
198198
run_test "$CDIR/test_while_loop.py"
199199
run_test "$CDIR/test_scan.py"
200+
run_test "$CDIR/test_apply_layers.py"
200201
run_test "$CDIR/test_autocast.py"
201202
run_test "$CDIR/eager/test_eager.py"
202203
run_test "$CDIR/eager/test_eager_with_xla_compile.py"

test/test_apply_layers.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import sys
2+
import os
3+
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(
4+
sys.argv[0]))) + "/examples"
5+
sys.path.append(example_folder)
6+
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore
7+
8+
import sys
9+
import unittest
10+
from typing import Iterable
11+
12+
import torch
13+
14+
import torch_xla
15+
from torch_xla.experimental.apply_layers import apply_layers
16+
17+
from test_utils import XlaTestCase # type:ignore
18+
19+
20+
class ApplyLayersTest(XlaTestCase):
21+
22+
def setUp(self):
23+
super().setUp()
24+
25+
self.device = torch_xla.device()
26+
27+
def test_empty_layers(self):
28+
layers = []
29+
input_data = torch.randn(64).to(self.device)
30+
torch_xla.sync()
31+
output = apply_layers(layers, input_data.clone())
32+
super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.01)
33+
34+
def test_linear_layers(self):
35+
# We want to apply these layers sequentially
36+
import torch.nn as nn
37+
layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)]
38+
input_data = torch.randn(64).to(self.device)
39+
40+
from copy import deepcopy
41+
scan_layers = deepcopy(layers)
42+
loop_layers = deepcopy(layers)
43+
torch_xla.sync()
44+
45+
output = apply_layers(scan_layers, input_data.clone())
46+
output.sum().backward()
47+
48+
# Test that the result is the same as for loop.
49+
loop_output = input_data.clone()
50+
from copy import deepcopy
51+
for layer in loop_layers:
52+
loop_output = layer(loop_output)
53+
torch_xla.sync()
54+
55+
super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.01)
56+
57+
loop_output.sum().backward()
58+
torch_xla.sync()
59+
60+
# Test that the gradients are the same too.
61+
for layer_scan, layer_loop in zip(scan_layers, loop_layers):
62+
super().compareResults(
63+
layer_scan.weight.grad,
64+
layer_loop.weight.grad,
65+
abs_err=0.0001,
66+
rel_err=0.01)
67+
super().compareResults(
68+
layer_scan.bias.grad,
69+
layer_loop.bias.grad,
70+
abs_err=0.0001,
71+
rel_err=0.01)
72+
73+
def test_decoder_model(self):
74+
# Define a decoder model that composes the decoder model in the example,
75+
# but adds the ability to run the layers with the `scan` operator.
76+
class DecoderOnlyModelWithScan(torch.nn.Module):
77+
78+
def __init__(self, **kwargs):
79+
super(DecoderOnlyModelWithScan, self).__init__()
80+
self.decoder = DecoderOnlyModel(**kwargs)
81+
82+
@property
83+
def layers(self) -> Iterable[torch.nn.Module]:
84+
return self.decoder.layers
85+
86+
def forward(
87+
self,
88+
input_ids: torch.Tensor,
89+
) -> torch.Tensor:
90+
return self.decoder.forward(input_ids)
91+
92+
def forward_scan(
93+
self,
94+
input_ids: torch.Tensor,
95+
) -> torch.Tensor:
96+
inputs_embeds = self.decoder.embed_tokens(input_ids)
97+
# embed positions
98+
assert isinstance(inputs_embeds, torch.Tensor)
99+
# decoder layers
100+
hidden_states = apply_layers(self.decoder.layers, inputs_embeds)
101+
hidden_states = self.decoder.norm(hidden_states)
102+
# [B, S, H] -> [B, S, V]
103+
return self.decoder.output(hidden_states)
104+
105+
# Make it smaller for fast model run and comparisons.
106+
config = DecoderOnlyConfig(
107+
hidden_size=128, intermediate_size=8 * 128, vocab_size=256)
108+
model = DecoderOnlyModelWithScan(config=config).to(self.device)
109+
batch_size = 2
110+
sequence_length = 8
111+
112+
# Generate random input_ids within the range of the vocabulary size
113+
input_ids = torch.randint(0, config.vocab_size,
114+
(batch_size, sequence_length)).to(self.device)
115+
116+
from copy import deepcopy
117+
loop_model = deepcopy(model)
118+
scan_model = deepcopy(model)
119+
torch_xla.sync()
120+
121+
# Run the loop-based model.
122+
loop_output = loop_model(input_ids.clone())
123+
loop_output.sum().backward()
124+
torch_xla.sync()
125+
126+
# Run again, this time using `scan`
127+
scan_output = scan_model.forward_scan(input_ids.clone())
128+
scan_output.sum().backward()
129+
torch_xla.sync()
130+
131+
# Compare results
132+
super().compareResults(scan_output, loop_output, abs_err=0.05, rel_err=0.01)
133+
134+
# Check gradients
135+
for layer_scan, layer_loop in zip(scan_model.layers, loop_model.layers):
136+
for (name,
137+
param_scan), (name2,
138+
param_loop) in zip(layer_scan.named_parameters(),
139+
layer_loop.named_parameters()):
140+
assert name == name2
141+
if param_scan.grad is not None or param_loop.grad is not None:
142+
super().compareResults(
143+
param_scan.grad, param_loop.grad, abs_err=0.1, rel_err=0.05)
144+
print(f"Pass: {name} {param_scan.shape}")
145+
146+
147+
if __name__ == '__main__':
148+
test = unittest.main()
149+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_operations.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import itertools
2222
import math
2323
from numbers import Number
24+
from functools import reduce
2425
import numpy
2526
import random
2627
import re
@@ -2597,6 +2598,29 @@ def test_api(self):
25972598
mapping = ctx.parameter_id_tensor_mapping()
25982599
self.assertEqual(len(mapping), 2)
25992600

2601+
def test_get_parameters_scalar(self):
2602+
"""Scalar tensors parameters may be shared in the HLO graph if their
2603+
numerical values are equal. `parameter_id_tensor_mapping` needs to handle
2604+
that appropriately.
2605+
"""
2606+
2607+
device = xm.xla_device()
2608+
tensors = []
2609+
for i in range(10):
2610+
# Add three copies of the same value.
2611+
tensors.append(torch.tensor(i, device=device))
2612+
tensors.append(torch.tensor(i, device=device))
2613+
tensors.append(torch.tensor(i, device=device))
2614+
result = reduce(lambda a, b: a + b, tensors)
2615+
ctx = torch_xla._XLAC.lowering.LoweringContext()
2616+
ctx.build([result])
2617+
mapping = ctx.parameter_id_tensor_mapping()
2618+
2619+
import json
2620+
hlo_json = json.loads(ctx.hlo_json())
2621+
num_parameters = len(hlo_json["hostProgramShape"]["parameters"])
2622+
self.assertEqual(len(mapping), num_parameters)
2623+
26002624

26012625
class TestGeneric(test_utils.XlaTestCase):
26022626

test/test_scan.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import sys
22
import unittest
3-
import torch_xla
3+
from functools import reduce
4+
45
import torch
6+
from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree
7+
8+
import torch_xla
59
from torch_xla.experimental.scan import scan
6-
from torch.utils._pytree import tree_map, tree_flatten, tree_iter
710

8-
from test_utils import XlaTestCase
11+
from test_utils import XlaTestCase # type:ignore
912

1013

1114
def _loopy_scan(fn, init, xs):
@@ -24,6 +27,8 @@ def _loopy_scan(fn, init, xs):
2427
class ScanTest(XlaTestCase):
2528

2629
def setUp(self):
30+
super().setUp()
31+
2732
self.device = torch_xla.device()
2833

2934
def compare_pytree(self, expected_pytree, actual_pytree):
@@ -32,31 +37,54 @@ def compare_pytree(self, expected_pytree, actual_pytree):
3237
assert expected_spec == actual_spec
3338
super().compareResults(flat_expected_pytree, flat_actual_pytree)
3439

35-
def run_test(self, step_fn, init, xs):
40+
def run_test(self, fn, init: PyTree, xs: PyTree):
41+
"""Compares the result of scanning with `fn` with our optimized HLO implementation
42+
against a for loop implementation. Checks both output values and gradients.
43+
"""
3644
# Actual output
37-
final_carry, ys = scan(step_fn, init, xs)
45+
init_scan = tree_map(lambda v: v.detach().requires_grad_(), init)
46+
xs_scan = tree_map(lambda v: v.detach().requires_grad_(), xs)
47+
final_carry, ys = scan(fn, init_scan, xs_scan)
48+
# Add up all leaves in `ys` and `backward()` once.
49+
reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(ys)),
50+
torch.tensor(0.0)).backward()
3851
torch_xla.sync()
3952

4053
# Expected output
41-
expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs)
54+
init_loop = tree_map(lambda v: v.detach().requires_grad_(), init)
55+
xs_loop = tree_map(lambda v: v.detach().requires_grad_(), xs)
56+
expected_final_carry, expected_ys = _loopy_scan(fn, init_loop, xs_loop)
57+
# Add up all leaves in `ys` and `backward()` once.
58+
reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(expected_ys)),
59+
torch.tensor(0.0)).backward()
4260
torch_xla.sync()
4361

44-
# Compare
62+
# Compare values
4563
self.compare_pytree(expected_final_carry, final_carry)
4664
self.compare_pytree(expected_ys, ys)
4765

66+
# Compare gradients
67+
self.compare_pytree(
68+
tree_map(lambda v: v.grad, init_scan),
69+
tree_map(lambda v: v.grad, init_loop))
70+
self.compare_pytree(
71+
tree_map(lambda v: v.grad, xs_scan), tree_map(lambda v: v.grad,
72+
xs_loop))
73+
4874
return final_carry, ys
4975

50-
def test_scan_forward_simple(self):
76+
def test_scan_simple(self):
5177
"""This test uses `scan` to implement `torch.cumsum`."""
5278

5379
def step_fn(carry, x):
5480
new_carry = carry + x
5581
y = new_carry
5682
return new_carry, y
5783

58-
init = torch.tensor([0.0, 0.0], device=self.device)
59-
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device)
84+
init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
85+
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
86+
requires_grad=True,
87+
device=self.device)
6088
final_carry, ys = self.run_test(step_fn, init, xs)
6189

6290
# Also ensure that our loop-based scan is correct, with manual checks
@@ -80,26 +108,30 @@ def test_scan_incompatible_length(self):
80108
with self.assertRaises(ValueError):
81109
scan(lambda a, b: (a, b), init, (xs_1, xs_2))
82110

83-
def test_scan_forward_tuples(self):
111+
def test_scan_tuples(self):
84112
"""Test scanning over the leading axis of a tuple of tensors simultaneously,
85113
which is a simple PyTree."""
86114

87-
def step_fn(carry, x):
115+
def fn(carry, x):
88116
carry1, carry2 = carry
89117
x1, x2 = x
90118
new_carry1 = carry1 + x1.sum()
91119
new_carry2 = carry2 + x2.sum()
92-
y1 = x1 * 2
93-
y2 = x2 * 2
120+
y1 = x1 * 2 + torch.sum(new_carry1)
121+
y2 = x2 * 2 + torch.sum(new_carry2)
94122
return (new_carry1, new_carry2), (y1, y2)
95123

96-
init = (torch.tensor([0.0], device=self.device),
97-
torch.tensor([1.0, 2.0], device=self.device))
124+
init = (torch.tensor([0.0], requires_grad=True, device=self.device),
125+
torch.tensor([1.0, 2.0], requires_grad=True, device=self.device))
98126

99-
xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device),
100-
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device))
127+
xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]],
128+
requires_grad=True,
129+
device=self.device),
130+
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]],
131+
requires_grad=True,
132+
device=self.device))
101133

102-
self.run_test(step_fn, init, xs)
134+
self.run_test(fn, init, xs)
103135

104136

105137
if __name__ == '__main__':

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ python3 test/pjrt/test_dtypes.py
2525
python3 test/pjrt/test_dynamic_plugin_tpu.py
2626
python3 test/test_while_loop.py
2727
python3 test/test_scan.py
28+
python3 test/test_apply_layers.py
2829
python3 test/test_pallas.py
2930
python3 test/test_pallas_spmd.py
3031
python3 test/test_input_output_aliases.py

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,6 @@ class PyLoweringContext {
10641064
// etc.)
10651065
std::unordered_map<int64_t, at::Tensor> GetParameterIdTensorMapping() {
10661066
// Find parameters in the lowering
1067-
const std::vector<size_t>& param_ids = lowering_ctx.GetParameterSequence();
10681067
const std::vector<torch::lazy::BackendDataPtr>& device_data =
10691068
lowering_ctx.GetParametersData();
10701069

@@ -1081,7 +1080,9 @@ class PyLoweringContext {
10811080
at::ScalarType dtype =
10821081
MaybeUpcastToHostTorchType(literal.shape().element_type());
10831082
at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype);
1084-
results[param_ids[i]] = input;
1083+
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]);
1084+
XLA_CHECK(param_id.has_value());
1085+
results[param_id.value()] = input;
10851086
}
10861087
return results;
10871088
}
@@ -1104,12 +1105,13 @@ class PyLoweringContext {
11041105
torch::lazy::BackendData::Handle handle = data->GetHandle();
11051106

11061107
// Linearly search parameters and compare opaque handles
1107-
const std::vector<size_t>& param_ids = lowering_ctx.GetParameterSequence();
11081108
const std::vector<torch::lazy::BackendDataPtr>& device_data =
11091109
lowering_ctx.GetParametersData();
11101110
for (int i = 0; i < device_data.size(); ++i) {
11111111
if (device_data[i]->GetHandle() == handle) {
1112-
return param_ids[i];
1112+
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]);
1113+
XLA_CHECK(param_id.has_value());
1114+
return param_id.value();
11131115
}
11141116
}
11151117
return -1;

0 commit comments

Comments
 (0)