From a9e254bac84cdfd4e3ab00e26855646a6b23b8b2 Mon Sep 17 00:00:00 2001 From: Vladimir Date: Wed, 24 Jan 2024 09:51:56 -0800 Subject: [PATCH] use a correct device when creating OptionalCUDAGuard --- csrc/cache_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 9f173534070a..b7523cb4c3b5 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -34,7 +34,7 @@ void swap_blocks( char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard(src_device); + const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. for (const auto& pair : block_mapping) {