|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | | -import random |
5 | | - |
6 | 4 | import numpy as np |
7 | 5 | import pytest |
8 | 6 | import torch |
@@ -409,29 +407,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): |
409 | 407 | model_runner.model_config.get_head_size() |
410 | 408 | ] |
411 | 409 | # TODO mla test |
412 | | - default_stride = list(range(5)) |
| 410 | + default_stride = tuple(range(5)) |
413 | 411 | # Permutation that gets you back to expected kv shape |
414 | | - rnd_stride = tuple(random.sample(default_stride, len(default_stride))) |
415 | | - |
416 | | - def rnd_stride_order(): |
417 | | - return rnd_stride |
418 | | - |
419 | | - # Patch the attention backend class and re-trigger the KV cache creation. |
420 | | - for attn_group in model_runner._attn_group_iterator(): |
421 | | - attn_backend = attn_group.backend |
422 | | - monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", |
423 | | - rnd_stride_order) |
424 | | - |
425 | | - model_runner.attn_groups = [] |
426 | | - model_runner.initialize_kv_cache(model_runner.kv_cache_config) |
427 | | - |
428 | | - # Shape is unchanged, but layout may differ |
429 | | - kv_cache_shape = model_runner.kv_caches[0].shape |
430 | | - assert list(kv_cache_shape) == expected_kv_cache_shape |
431 | | - if default_stride == rnd_stride: |
432 | | - assert all(kv.is_contiguous() for kv in model_runner.kv_caches) |
433 | | - else: |
434 | | - assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) |
| 412 | + for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)): |
| 413 | + |
| 414 | + def rnd_stride_order(test_stride=test_stride): |
| 415 | + return test_stride |
| 416 | + |
| 417 | + # Patch the attention backend class and re-trigger the KV cache creation |
| 418 | + for attn_group in model_runner._attn_group_iterator(): |
| 419 | + attn_backend = attn_group.backend |
| 420 | + monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", |
| 421 | + rnd_stride_order) |
| 422 | + |
| 423 | + model_runner.attn_groups = [] |
| 424 | + model_runner.kv_caches = [] |
| 425 | + model_runner.initialize_kv_cache(model_runner.kv_cache_config) |
| 426 | + |
| 427 | + # Shape is unchanged, but layout may differ |
| 428 | + kv_cache_shape = model_runner.kv_caches[0].shape |
| 429 | + assert list(kv_cache_shape) == expected_kv_cache_shape |
| 430 | + if default_stride == test_stride: |
| 431 | + assert all(kv.is_contiguous() for kv in model_runner.kv_caches) |
| 432 | + else: |
| 433 | + assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) |
435 | 434 |
|
436 | 435 |
|
437 | 436 | def test_update_config(model_runner): |
|
0 commit comments