Skip to content

Commit 0082c1b

Browse files
committed
scan and apply_layers: milestone 1
This commit adds the lowering of scan to HLO While op. It also introduce apply_layers which can sequentially apply a bunch of layers using scan underneath. In this milestone we use AOTAutograd to obtain the backward of the function being scanned. Users can either save the activations in fn or recompute them by passing different graph partitioners to AOTAutograd. ALso give the lowered fn computation a more meaningful name
1 parent 31d348e commit 0082c1b

File tree

13 files changed

+1610
-150
lines changed

13 files changed

+1610
-150
lines changed

examples/decoder_only_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
from torch import nn
88

99

10-
# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
10+
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
1111
@dataclass
1212
class DecoderOnlyConfig:
1313
hidden_size: int = 1024
1414
num_hidden_layers: int = 2
1515
num_attention_heads: int = 8
1616
num_key_value_heads: int = 4
17-
intermediate_size = 32 * 1024
18-
vocab_size = 3200
19-
use_flash_attention = False
17+
intermediate_size: int = 32 * 1024
18+
vocab_size: int = 3200
19+
use_flash_attention: bool = False
2020

2121

2222
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

test/run_tests.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ function run_xla_op_tests1 {
208208
function run_xla_op_tests2 {
209209
run_test "$CDIR/pjrt/test_dtypes.py"
210210
run_test "$CDIR/test_while_loop.py"
211-
run_test "$CDIR/test_scan.py"
211+
run_test "$CDIR/scan/test_scan.py"
212+
run_test "$CDIR/scan/test_scan_layers.py"
212213
run_test "$CDIR/test_autocast.py"
213214
run_test "$CDIR/eager/test_eager.py"
214215
run_test "$CDIR/eager/test_eager_with_xla_compile.py"

0 commit comments

Comments
 (0)