Skip to content

Commit 3bee673

Browse files
committed
add a warmup aclgraph run for ci
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
1 parent 7c9a5ed commit 3bee673

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/e2e/singlecard/ops/test_rotary_embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# Only Neox style true scenario is supported for now
1818
IS_NEOX_STYLE = [True]
1919
DTYPES = [torch.half]
20-
HEAD_SIZES = [64, 96, 128, 256]
20+
HEAD_SIZES = [64, 64, 96, 128, 256]
2121
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
2222
NUM_HEADS = [17] # Arbitrary values for testing
2323
BATCH_SIZES = [5] # Arbitrary values for testing
@@ -248,7 +248,9 @@ def forward(
248248
o = self.o_proj(query)
249249
return o
250250

251-
251+
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
252+
# add a warmup graph replay for workaround
253+
ACL_GRPAH_FIRST_RUN = True
252254
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
253255
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
254256
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -327,10 +329,11 @@ def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
327329
static_positions.copy_(random_filled_positions)
328330
static_hidden_states.copy_(random_filled_hidden_states)
329331

330-
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
331-
# add a warmup graph replay for workaround
332332
aclgraph.replay()
333333
aclgraph.replay()
334+
if ACL_GRPAH_FIRST_RUN:
335+
ACL_GRPAH_FIRST_RUN = False
336+
return
334337
output_reference = model(static_positions, static_hidden_states)
335338
torch.testing.assert_close(static_output,
336339
output_reference,

0 commit comments

Comments
 (0)