Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[KVCACHE] Improved schedule for prefill attention" #17466

Merged
merged 1 commit into from
Oct 15, 2024

Conversation

MasterJH5574
Copy link
Contributor

This PR reverts #17432 as we observe a correctness issue
when num_attention_heads is 28.

The correctness issue leads to incorrect end-to-end results in LLM inference.

@MasterJH5574
Copy link
Contributor Author

MasterJH5574 commented Oct 14, 2024

@krishnaraj36 Hi! Thank you for the great contribution on the prefill attention improvement. Unfortunately we just ran into a correctness issue caused by this PR and thus decide to temporarily revert it first. Particularly, the prefill kernel produces incorrect results when num_qo_heads is 28 (, num_kv_heads is 4, and the number of GQA groups is thus 7). The current unit test uses 32 as num_qo_heads, where the improved kernel works perfectly well and doesn't reveal the correctness issue.

Here is how you can reproduce the issue:

  1. replace this line with num_qo_heads=28
  2. run this test via
    python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
    

Then it should be able to show the error like

~/W/tvm workspace ⇡1 *3 !2 ?2 ❯ python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
Traceback (most recent call last):
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 964, in <module>
    test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 558, in test_paged_attention_kv_cache_prefill_and_decode
    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 468, in apply_attention
    tvm.testing.assert_allclose(
  File "/home/ruihang/Workspace/tvm/python/tvm/testing/utils.py", line 120, in assert_allclose
    np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 4852 / 10752 (45.1%)
Max absolute difference: 0.997
Max relative difference: 86.7
 x: array([[[[0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],...
 y: array([[[[0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],...

I think we are good to go with the improved kernels once the correctness issue is fixed. Would you mind taking a look at this issue? Thanks a lot in advance.

@MasterJH5574
Copy link
Contributor Author

BTW another information on the error: the kernel produces undetermined results, that is being said, if I run the test multiple times, each time the kernel produces a different results.

@krishnaraj36
Copy link
Contributor

@krishnaraj36 Hi! Thank you for the great contribution on the prefill attention improvement. Unfortunately we just ran into a correctness issue caused by this PR and thus decide to temporarily revert it first. Particularly, the prefill kernel produces incorrect results when num_qo_heads is 28 (, num_kv_heads is 4, and the number of GQA groups is thus 7). The current unit test uses 32 as num_qo_heads, where the improved kernel works perfectly well and doesn't reveal the correctness issue.

Here is how you can reproduce the issue:

  1. replace this line with num_qo_heads=28
  2. run this test via
    python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
    

Then it should be able to show the error like

~/W/tvm workspace ⇡1 *3 !2 ?2 ❯ python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
Traceback (most recent call last):
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 964, in <module>
    test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 558, in test_paged_attention_kv_cache_prefill_and_decode
    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 468, in apply_attention
    tvm.testing.assert_allclose(
  File "/home/ruihang/Workspace/tvm/python/tvm/testing/utils.py", line 120, in assert_allclose
    np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 4852 / 10752 (45.1%)
Max absolute difference: 0.997
Max relative difference: 86.7
 x: array([[[[0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],...
 y: array([[[[0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],...

I think we are good to go with the improved kernels once the correctness issue is fixed. Would you mind taking a look at this issue? Thanks a lot in advance.

@MasterJH5574 : Thanks for reporting this issue, Sure we will look into this issue.

@tqchen tqchen merged commit 0c67cd8 into main Oct 15, 2024
32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants