Skip to content

Commit 406386d

Browse files
LiuXiaoxuanPKUjimpang
authored andcommitted
[Minor] More fix of test_cache.py CI test failure (vllm-project#2750)
1 parent a47908b commit 406386d

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

tests/kernels/test_cache.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,15 @@ def test_swap_blocks(
181181
num_blocks: int,
182182
dtype: torch.dtype,
183183
seed: int,
184-
device: int,
184+
device: str,
185185
) -> None:
186186
random.seed(seed)
187187
torch.random.manual_seed(seed)
188188
if torch.cuda.is_available():
189189
torch.cuda.manual_seed(seed)
190-
src_device = f"{direction[0]}:{device}" if direction[
191-
0] == "cuda" else direction[0]
192-
dst_device = f"{direction[1]}:{device}" if direction[
193-
1] == "cuda" else direction[1]
190+
191+
src_device = device if direction[0] == "cuda" else 'cpu'
192+
dst_device = device if direction[1] == "cuda" else 'cpu'
194193

195194
src_blocks = random.sample(range(num_blocks), num_mappings)
196195
# For the same device, mapping must not overlap

vllm/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,13 @@ def create_kv_caches_with_random(
258258
key_cache = torch.empty(size=key_cache_shape,
259259
dtype=torch_dtype,
260260
device=device)
261-
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
262-
key_cache.uniform_(-scale, scale)
263-
elif cache_dtype == 'fp8_e5m2':
261+
if cache_dtype == 'fp8_e5m2':
264262
_generate_random_fp8_e5m2(key_cache, -scale, scale)
263+
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
264+
key_cache.uniform_(-scale, scale)
265+
else:
266+
raise ValueError(
267+
f"Does not support key cache of type {cache_dtype}")
265268
key_caches.append(key_cache)
266269

267270
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
@@ -270,9 +273,12 @@ def create_kv_caches_with_random(
270273
value_cache = torch.empty(size=value_cache_shape,
271274
dtype=torch_dtype,
272275
device=device)
273-
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
274-
value_cache.uniform_(-scale, scale)
275-
elif cache_dtype == 'fp8_e5m2':
276+
if cache_dtype == 'fp8_e5m2':
276277
_generate_random_fp8_e5m2(value_cache, -scale, scale)
278+
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
279+
value_cache.uniform_(-scale, scale)
280+
else:
281+
raise ValueError(
282+
f"Does not support value cache of type {cache_dtype}")
277283
value_caches.append(value_cache)
278284
return key_caches, value_caches

0 commit comments

Comments
 (0)