Skip to content

Commit ae7bc73

Browse files
committed
Update test to perform assertion on CPU.
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
1 parent a9f1987 commit ae7bc73

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tests/v1/tpu/test_topk_topp_sampler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def test_topk_equivalence_to_native_impl():
3333
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
3434

3535
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
36-
assert torch.allclose(result_native, result_tpu)
36+
37+
xm.mark_step()
38+
39+
# Perform assertion on CPU.
40+
assert torch.allclose(result_native.cpu(), result_tpu.cpu())
3741

3842

3943
def test_topp_result_sums_past_p():

0 commit comments

Comments
 (0)