diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 121ca9ec45205..5dec11e2eede7 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -37,11 +37,10 @@ def swap_blocks( ) -> None: src_k_cache, src_v_cache = src_kv_cache dst_k_cache, dst_v_cache = dst_kv_cache + src_indices, dst_indices = src_to_dst + device = dst_k_cache.device torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True) - - device = dst_k_cache.device - src_indices, dst_indices = src_to_dst dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device) dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index c85bf6892fb28..28f460c31aa9b 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -156,14 +156,18 @@ def initialize_cache( self.tpu_cache = [] tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( num_gpu_blocks, self.block_size, num_kv_heads, head_size) + cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_cpu_blocks, self.block_size, num_kv_heads, head_size) for _ in range(num_layers): tpu_k_cache = torch.zeros(tpu_cache_shape, dtype=dtype, device=self.device) tpu_v_cache = torch.zeros_like(tpu_k_cache) self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) - cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu") - cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu") + cpu_k_cache = torch.zeros(cpu_cache_shape, + dtype=dtype, + device="cpu") + cpu_v_cache = torch.zeros_like(cpu_k_cache) self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) self._warmup_model()