|  | 
| 1 | 1 | # SPDX-License-Identifier: Apache-2.0 | 
| 2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | 
| 3 | 3 | 
 | 
| 4 |  | -import random | 
| 5 |  | - | 
| 6 | 4 | import numpy as np | 
| 7 | 5 | import pytest | 
| 8 | 6 | import torch | 
| @@ -409,29 +407,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): | 
| 409 | 407 |         model_runner.model_config.get_head_size() | 
| 410 | 408 |     ] | 
| 411 | 409 |     # TODO mla test | 
| 412 |  | -    default_stride = list(range(5)) | 
|  | 410 | +    default_stride = tuple(range(5)) | 
| 413 | 411 |     # Permutation that gets you back to expected kv shape | 
| 414 |  | -    rnd_stride = tuple(random.sample(default_stride, len(default_stride))) | 
| 415 |  | - | 
| 416 |  | -    def rnd_stride_order(): | 
| 417 |  | -        return rnd_stride | 
| 418 |  | - | 
| 419 |  | -    # Patch the attention backend class and re-trigger the KV cache creation. | 
| 420 |  | -    for attn_group in model_runner._attn_group_iterator(): | 
| 421 |  | -        attn_backend = attn_group.backend | 
| 422 |  | -        monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", | 
| 423 |  | -                            rnd_stride_order) | 
| 424 |  | - | 
| 425 |  | -    model_runner.attn_groups = [] | 
| 426 |  | -    model_runner.initialize_kv_cache(model_runner.kv_cache_config) | 
| 427 |  | - | 
| 428 |  | -    # Shape is unchanged, but layout may differ | 
| 429 |  | -    kv_cache_shape = model_runner.kv_caches[0].shape | 
| 430 |  | -    assert list(kv_cache_shape) == expected_kv_cache_shape | 
| 431 |  | -    if default_stride == rnd_stride: | 
| 432 |  | -        assert all(kv.is_contiguous() for kv in model_runner.kv_caches) | 
| 433 |  | -    else: | 
| 434 |  | -        assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) | 
|  | 412 | +    for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)): | 
|  | 413 | + | 
|  | 414 | +        def rnd_stride_order(test_stride=test_stride): | 
|  | 415 | +            return test_stride | 
|  | 416 | + | 
|  | 417 | +        # Patch the attention backend class and re-trigger the KV cache creation | 
|  | 418 | +        for attn_group in model_runner._attn_group_iterator(): | 
|  | 419 | +            attn_backend = attn_group.backend | 
|  | 420 | +            monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", | 
|  | 421 | +                                rnd_stride_order) | 
|  | 422 | + | 
|  | 423 | +        model_runner.attn_groups = [] | 
|  | 424 | +        model_runner.kv_caches = [] | 
|  | 425 | +        model_runner.initialize_kv_cache(model_runner.kv_cache_config) | 
|  | 426 | + | 
|  | 427 | +        # Shape is unchanged, but layout may differ | 
|  | 428 | +        kv_cache_shape = model_runner.kv_caches[0].shape | 
|  | 429 | +        assert list(kv_cache_shape) == expected_kv_cache_shape | 
|  | 430 | +        if default_stride == test_stride: | 
|  | 431 | +            assert all(kv.is_contiguous() for kv in model_runner.kv_caches) | 
|  | 432 | +        else: | 
|  | 433 | +            assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) | 
| 435 | 434 | 
 | 
| 436 | 435 | 
 | 
| 437 | 436 | def test_update_config(model_runner): | 
|  | 
0 commit comments