diff --git a/flashinfer-aot/csrc_aot/batch_decode.cu b/flashinfer-aot/csrc_aot/batch_decode.cu index fbf6381e..95945a5c 100644 --- a/flashinfer-aot/csrc_aot/batch_decode.cu +++ b/flashinfer-aot/csrc_aot/batch_decode.cu @@ -43,7 +43,7 @@ std::vector BatchDecodeWithPagedKVCachePlan( int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - indptr = indptr.to(torch::kCPU); + TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU"); DecodePlanInfo plan_info; @@ -150,8 +150,7 @@ std::vector BatchDecodeWithPagedKVCacheRun( paged_kv_t paged_kv( num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, static_cast(paged_k_cache.data_ptr()), - static_cast(paged_v_cache.data_ptr()), - kv_cache_strides, + static_cast(paged_v_cache.data_ptr()), kv_cache_strides, static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); diff --git a/flashinfer-aot/csrc_aot/batch_prefill.cu b/flashinfer-aot/csrc_aot/batch_prefill.cu index 5a5b630d..448f4a9f 100644 --- a/flashinfer-aot/csrc_aot/batch_prefill.cu +++ b/flashinfer-aot/csrc_aot/batch_prefill.cu @@ -51,8 +51,8 @@ std::vector BatchPrefillWithKVCachePlan( auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - qo_indptr = qo_indptr.to(torch::kCPU); - kv_indptr = kv_indptr.to(torch::kCPU); + TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU"); + TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU"); PrefillPlanInfo plan_info; diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 4eac068b..1584846b 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -53,7 +53,8 @@ def compile_single_decode_module( ): uri, path = gen_single_decode_cu(*args) return load_cuda_ops( - uri, [path], + uri, + [path], verbose=verbose, ) @@ -64,7 +65,8 @@ def compile_batch_decode_module( ): uri, path = gen_batch_decode_cu(*args) return load_cuda_ops( - uri, [path], + uri, + [path], verbose=verbose, ) @@ -114,6 +116,7 @@ def get_batch_decode_module(*args): _batch_decode_modules[args] = compile_batch_decode_module(*args) return _batch_decode_modules[args] + def single_decode_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -123,8 +126,10 @@ def single_decode_with_kv_cache_with_jit_module( kv_layout: str = "NHD", window_left: int = -1, ): - tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) - return jit_module.run(q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args) + tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) + return jit_module.run( + q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args + ) def single_decode_with_kv_cache( @@ -444,6 +449,7 @@ def __init__( if use_tensor_cores: if use_cuda_graph: + # NOTE(Zihao): if once created, no need to update it in plan/run self._qo_indptr_buf = torch.arange( self._fixed_batch_size + 1, dtype=torch.int32, @@ -555,8 +561,7 @@ def plan( if logits_soft_cap is None: logits_soft_cap = 0.0 - qo_indptr = _get_range_buf(batch_size + 1, indptr.device) - + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( @@ -569,21 +574,18 @@ def plan( raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf.copy_(indptr) - self._paged_kv_indices_buf[: len(indices)] = indices - self._paged_kv_last_page_len_buf.copy_(last_page_len) - if self.use_tensor_cores: - self._qo_indptr_buf.copy_(qo_indptr) + self._paged_kv_indptr_buf.copy_(indptr, non_blocking=True) + self._paged_kv_indices_buf[: len(indices)].copy_(indices, non_blocking=True) + self._paged_kv_last_page_len_buf.copy_(last_page_len, non_blocking=True) else: - self._paged_kv_indptr_buf = indptr.to(self.device) - self._paged_kv_indices_buf = indices.to(self.device) - self._paged_kv_last_page_len_buf = last_page_len.to(self.device) - if self.use_tensor_cores: - self._qo_indptr_buf = qo_indptr.to(self.device) - - qo_indptr = qo_indptr.to("cpu", non_blocking=True) - indptr = indptr.to("cpu", non_blocking=True) + self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True) + self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True) + self._paged_kv_last_page_len_buf = last_page_len.to( + self.device, non_blocking=True + ) + self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=True) + indptr_host = indptr.to("cpu", non_blocking=True) if data_type is not None: q_data_type = data_type kv_data_type = data_type @@ -612,8 +614,8 @@ def plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, - qo_indptr, - indptr, + qo_indptr_host, + indptr_host, batch_size, num_qo_heads, num_kv_heads, @@ -635,7 +637,7 @@ def plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, - indptr, + indptr_host, batch_size, num_qo_heads, num_kv_heads, diff --git a/python/flashinfer/jit/batch_decode_templ.py b/python/flashinfer/jit/batch_decode_templ.py index 80daa551..28b9582e 100644 --- a/python/flashinfer/jit/batch_decode_templ.py +++ b/python/flashinfer/jit/batch_decode_templ.py @@ -42,7 +42,7 @@ int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - indptr = indptr.to(torch::kCPU); + TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU"); DecodePlanInfo plan_info; diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index 1e07924a..f9a2b628 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -49,8 +49,8 @@ auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - qo_indptr = qo_indptr.to(torch::kCPU); - kv_indptr = kv_indptr.to(torch::kCPU); + TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU"); + TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU"); PrefillPlanInfo plan_info; diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 293073cb..90da5fdf 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -57,7 +57,8 @@ def compile_single_prefill_module( ): uri, path = gen_single_prefill_cu(*args) return load_cuda_ops( - uri, [path], + uri, + [path], verbose=verbose, ) @@ -68,7 +69,8 @@ def compile_batch_prefill_module( ): uri, path = gen_batch_prefill_cu(*args) return load_cuda_ops( - uri, [path], + uri, + [path], verbose=verbose, ) @@ -125,6 +127,7 @@ def get_batch_prefill_module(*args): _batch_prefill_modules[args] = compile_batch_prefill_module(*args) return _batch_prefill_modules[args] + def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -137,7 +140,8 @@ def single_prefill_with_kv_cache_with_jit_module( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) out = jit_module.run( - q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args) + q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args + ) return out if return_lse else out[0] @@ -726,10 +730,14 @@ def plan( "The length of paged_kv_indices exceeds the allocated buffer size." ) - self._qo_indptr_buf.copy_(qo_indptr) - self._paged_kv_indptr_buf.copy_(paged_kv_indptr) - self._paged_kv_indices_buf[: len(paged_kv_indices)] = paged_kv_indices - self._paged_kv_last_page_len_buf.copy_(paged_kv_last_page_len) + self._qo_indptr_buf.copy_(qo_indptr, non_blocking=True) + self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=True) + self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_( + paged_kv_indices, non_blocking=True + ) + self._paged_kv_last_page_len_buf.copy_( + paged_kv_last_page_len, non_blocking=True + ) if packed_custom_mask is not None: if not torch.is_tensor(self._custom_mask_buf): @@ -740,20 +748,31 @@ def plan( raise ValueError( "qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." ) - self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask + self._custom_mask_buf[: len(packed_custom_mask)].copy_( + packed_custom_mask, non_blocking=True + ) # NOTE(Zihao): qk_indptr has the same length as qo_indptr - self._qk_indptr_buf.copy_(qk_indptr) + self._qk_indptr_buf.copy_(qk_indptr, non_blocking=True) else: - self._qo_indptr_buf = qo_indptr.to(self.device) - self._paged_kv_indptr_buf = paged_kv_indptr.to(self.device) - self._paged_kv_indices_buf = paged_kv_indices.to(self.device) - self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(self.device) + self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=True) + self._paged_kv_indptr_buf = paged_kv_indptr.to( + self.device, non_blocking=True + ) + self._paged_kv_indices_buf = paged_kv_indices.to( + self.device, non_blocking=True + ) + self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to( + self.device, non_blocking=True + ) if packed_custom_mask is not None: - self._custom_mask_buf = packed_custom_mask.to(self.device) - self._qk_indptr_buf = qk_indptr.to(self.device) + self._custom_mask_buf = packed_custom_mask.to( + self.device, non_blocking=True + ) + self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True) - qo_indptr = qo_indptr.to("cpu", non_blocking=True) - paged_kv_indptr = paged_kv_indptr.to("cpu", non_blocking=True) + # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors + qo_indptr_host = qo_indptr.to("cpu", non_blocking=True) + paged_kv_indptr_host = paged_kv_indptr.to("cpu", non_blocking=True) if packed_custom_mask is not None: mask_mode = MaskMode.CUSTOM.value @@ -781,8 +800,8 @@ def plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, - qo_indptr, - paged_kv_indptr, + qo_indptr_host, + paged_kv_indptr_host, batch_size, num_qo_heads, num_kv_heads, diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index bcb28f3e..8f3c0134 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -257,7 +257,7 @@ def plan( num_blocks_row = len(indptr) - 1 qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32) qo_indptr_host[-1] = M - qo_indptr = qo_indptr_host.to(indptr.device) + qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=True) if indices.max().item() * C > N: raise ValueError("indices out of bound") last_block_len = torch.full( @@ -279,13 +279,13 @@ def plan( mask.contiguous().view(-1), qk_indptr, bitorder="little" ) - self._qo_indptr = qo_indptr.to(self.device) - self._paged_kv_indptr_buf = indptr.to(self.device) - self._paged_kv_indices_buf = indices.to(self.device) - self._paged_kv_last_page_len = last_block_len.to(self.device) + self._qo_indptr = qo_indptr.to(self.device, non_blocking=True) + self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True) + self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True) + self._paged_kv_last_page_len = last_block_len.to(self.device, non_blocking=True) if packed_mask is not None: - self._packed_mask_buf = packed_mask.to(self.device) - self._qk_indptr_buf = qk_indptr.to(self.device) + self._packed_mask_buf = packed_mask.to(self.device, non_blocking=True) + self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True) mask_mode = MaskMode.CUSTOM.value else: self._packed_mask_buf = None