diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 572563c0bd82..88516b75cde2 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -145,7 +145,7 @@ def test_flash_attn_with_paged_kv( v_descale = None if q_dtype is not None: # QKV are drawn from N(0, 1): no need for a fp8 scaling factor - maybe_quantized_query = query.to(q_dtype) + maybe_quantized_query = q.to(q_dtype) maybe_quantized_key_cache = key_cache.to(q_dtype) maybe_quantized_value_cache = value_cache.to(q_dtype)