Skip to content

Commit 51575db

Browse files
authored
scan and scan_layers (#7901)
1 parent 31d348e commit 51575db

File tree

13 files changed

+1610
-150
lines changed

13 files changed

+1610
-150
lines changed

examples/decoder_only_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
from torch import nn
88

99

10-
# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
10+
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
1111
@dataclass
1212
class DecoderOnlyConfig:
1313
hidden_size: int = 1024
1414
num_hidden_layers: int = 2
1515
num_attention_heads: int = 8
1616
num_key_value_heads: int = 4
17-
intermediate_size = 32 * 1024
18-
vocab_size = 3200
19-
use_flash_attention = False
17+
intermediate_size: int = 32 * 1024
18+
vocab_size: int = 3200
19+
use_flash_attention: bool = False
2020

2121

2222
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

test/run_tests.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ function run_xla_op_tests1 {
208208
function run_xla_op_tests2 {
209209
run_test "$CDIR/pjrt/test_dtypes.py"
210210
run_test "$CDIR/test_while_loop.py"
211-
run_test "$CDIR/test_scan.py"
211+
run_test "$CDIR/scan/test_scan.py"
212+
run_test "$CDIR/scan/test_scan_layers.py"
212213
run_test "$CDIR/test_autocast.py"
213214
run_test "$CDIR/eager/test_eager.py"
214215
run_test "$CDIR/eager/test_eager_with_xla_compile.py"

0 commit comments

Comments
 (0)