diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fa37d0c7539f6..023696f3cea9c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -211,3 +211,6 @@ steps: - pytest -v -s distributed/test_custom_all_reduce.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index 8b68e0e939669..3ebfc16547e44 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -19,4 +19,4 @@ sentence-transformers # required for embedding aiohttp # quantization -bitsandbytes==0.42.0 +bitsandbytes==0.42.0 \ No newline at end of file diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 805b8883b9d94..6f44030feebb0 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -2,7 +2,6 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ -import os import weakref import pytest @@ -13,7 +12,6 @@ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] -VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" def test_vllm_gc_ed(): @@ -39,10 +37,6 @@ def test_models( max_tokens: int, enforce_eager: bool, ) -> None: - backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - if backend_by_env_var == "FLASHINFER" and enforce_eager is False: - pytest.skip("Skipping non-eager test for FlashInferBackend.") - with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index eb423aef230cb..b8ae5b4c44f8d 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -21,7 +21,6 @@ os.environ["TEST_DIST_MODEL"], ] DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" -VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -39,16 +38,12 @@ def test_models( ) -> None: distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) - backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - enforce_eager = backend_by_env_var == "FLASHINFER" - with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, dtype=dtype, tensor_parallel_size=2, - enforce_eager=enforce_eager, distributed_executor_backend=distributed_executor_backend ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 535d30b55bc9d..4ecac7379c7f6 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,10 +1,16 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type -import flashinfer +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + from vllm_flash_attn import flash_attn_varlen_func +except ImportError: + flash_attn_varlen_func = None + BatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + import torch -from flashinfer import BatchDecodeWithPagedKVCacheWrapper -from vllm_flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata): # requests only. max_prefill_seq_len: int - use_cuda_graph: bool = False + use_cuda_graph: bool = True + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None - # Metadata for the prefill stage since we still - # use flash attention for prefill. + # Metadata for the prefill stage seq_start_loc: Optional[torch.Tensor] = None + query_start_loc: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None - # Metadata for the decode stage - # Workspace buffer required by the kernel, the buffer should not - # be allocated/deacollated by the FalshInfermetadata object. - workspace_buffer: Optional[torch.Tensor] = None # An example for paged_kv_indices, paged_kv_indptr: # request 1, page indices [0, 5, 8] # request 2, page indices [1, 6, 7] @@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata): page_size: Optional[int] = None # The data type of the paged kv cache data_type: torch.dtype = None + device: torch.device = torch.device("cuda") def __post_init__(self): # Refer to @@ -109,13 +113,35 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") - # When using flashinfer, we are also creating the FlashInferMetadata, - # which will also call post_init by default, here we want to skip the - # post_init if it's the prefill phase. - if self.num_prefills == 0: - assert self.num_decode_tokens > 0 - self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD") + def begin_forward(self): + if self.num_prefill_tokens > 0: + if self.paged_kv_indices is None: + return + + assert self.prefill_wrapper is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + self.prefill_wrapper.begin_forward( + self.query_start_loc, self.paged_kv_indptr, + self.paged_kv_indices, self.paged_kv_last_page_len, + self.num_qo_heads, self.num_kv_heads, self.head_dim, + self.page_size) + else: + if not self.use_cuda_graph: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + + assert self.decode_wrapper is not None self.decode_wrapper.begin_forward( self.paged_kv_indptr, self.paged_kv_indices, @@ -133,8 +159,9 @@ def asdict_zerocopy(self, ) -> Dict[str, Any]: if skip_fields is None: skip_fields = set() - # We need to skip the decode_wrapper field since it cannot be + # We need to skip the prefill/decode_wrapper field since it cannot be # broadcasted with nccl when TP is enabled. + skip_fields.add('prefill_wrapper') skip_fields.add('decode_wrapper') return super().asdict_zerocopy(skip_fields) @@ -168,6 +195,7 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -217,10 +245,14 @@ def forward( self.kv_cache_dtype, ) + query = query.contiguous( + ) # Flashinfer requires query to be contiguous if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - assert prefill_meta.block_tables is not None - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + if kv_cache is None: output = flash_attn_varlen_func( q=query, k=key, @@ -235,13 +267,14 @@ def forward( alibi_slopes=self.alibi_slopes, ) else: - raise NotImplementedError( - "Prefix caching is not supported with flashinfer yet.") + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + output = prefill_meta.prefill_wrapper.forward(query, + kv_cache, + causal=True) else: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None - query = query.contiguous( - ) # Flashinfer requires query to be contiguous output = attn_metadata.decode_metadata.decode_wrapper.forward( query, kv_cache, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 96f88bbf4cf59..851bf52a505ee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -77,8 +77,9 @@ def get_attn_backend( return IpexAttnBackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is required for the Flashinfer backend. " - "Please make sure --enforce-eager is set.") + logger.warning(("Flashinfer will be stuck on llma-2-7b," + " please avoid using Flashinfer as the" + "backend when running on llma-2-7b.")) from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend elif backend == _Backend.PALLAS: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 08216603023d7..942063677a427 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,6 +10,17 @@ import torch import torch.nn as nn +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +except ImportError: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, @@ -198,11 +209,14 @@ def __init__( # Lazy initialization self.model: nn.Module # Set after load_model - # Set if the backend is flashinfer. - self.flashinfer_workspace_buffer: torch.Tensor # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.flashinfer_decode_workspace_buffer = None + self.flashinfer_decode_wrapper = None + self.flashinfer_prefill_workspace_buffer = None + self.flashinfer_prefill_wrapper = None + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -450,15 +464,6 @@ def _prepare_model_input_tensors( if curr_sliding_window_blocks is not None: block_table = block_table[ -curr_sliding_window_blocks:] - if self.attn_backend.get_name() == "flashinfer": - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + - len(block_table)) - last_page_len = seq_data.get_len( - ) % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) else: # Only happens when memory profiling runs. block_table = [] @@ -505,7 +510,9 @@ def _prepare_model_input_tensors( for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) - if _is_block_tables_empty(seq_group_metadata.block_tables): + is_profile_run = _is_block_tables_empty( + seq_group_metadata.block_tables) + if is_profile_run: # During memory profiling, the block tables are not # initialized yet. In this case, we just use a dummy # slot mapping. @@ -544,6 +551,27 @@ def _prepare_model_input_tensors( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + # Prepare input tensors for flashinfer + if self.attn_backend.get_name() == "flashinfer": + seq_len = seq_data.get_len() + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + + paged_kv_indices.extend(block_table[:block_table_bound]) + paged_kv_indptr.append(paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + batch_size = len(input_tokens) max_query_len = max(query_lens) max_prefill_seq_len = max(prefill_seq_lens, default=0) @@ -566,6 +594,12 @@ def _prepare_model_input_tensors( seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) + + if self.attn_backend.get_name() == "flashinfer": + last_paged_kv_indptr = paged_kv_indptr[-1] + paged_kv_indptr.append(last_paged_kv_indptr) + paged_kv_last_page_len.append(0) + batch_size = graph_batch_size num_decode_tokens = batch_size @@ -589,9 +623,19 @@ def _prepare_model_input_tensors( ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) @@ -600,6 +644,10 @@ def _prepare_model_input_tensors( dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, @@ -612,30 +660,30 @@ def _prepare_model_input_tensors( device=self.device) if self.attn_backend.get_name() == "flashinfer": - if not hasattr(self, "flashinfer_workspace_buffer"): - # Allocate 16MB workspace buffer - # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html - self.flashinfer_workspace_buffer = torch.empty( - 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, - dtype=torch.int, - device=self.device) - paged_kv_indices_tensor = torch.tensor(paged_kv_indices, - dtype=torch.int, - device=self.device) - paged_kv_last_page_len_tensor = torch.tensor( - paged_kv_last_page_len, dtype=torch.int, device=self.device) + if len(paged_kv_indptr) > 0: + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + device='cpu', + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, + device='cpu', + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, device='cpu', dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) + attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - use_cuda_graph=False, max_prefill_seq_len=max_prefill_seq_len, block_tables=block_tables, - workspace_buffer=self.flashinfer_workspace_buffer, paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, @@ -644,25 +692,14 @@ def _prepare_model_input_tensors( num_kv_heads=self.model_config.get_num_kv_heads( self.parallel_config), head_dim=self.model_config.get_head_size(), - page_size=16, + page_size=self.block_size, seq_start_loc=seq_start_loc, - data_type=kv_cache_dtype) - else: - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + query_start_loc=query_start_loc, + device=self.device, + data_type=kv_cache_dtype, + use_cuda_graph=use_captured_graph) + else: attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=slot_mapping_tensor, @@ -854,27 +891,97 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] + if self.attn_backend.get_name() == "flashinfer": + # For flashinfer, different batch sizes will share the + # same workspace buffer. + decode_workspace_buffer = \ + torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + indices_buffer = torch.empty(max_batch_size * + self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) + indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.device) + last_page_len_buffer = torch.empty(max_batch_size, + dtype=torch.int32, + device=self.device) + with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): - # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables[:batch_size], - use_cuda_graph=True, - ) + if self.attn_backend.get_name() == "flashinfer": + indptr_buffer = indptr_buffer[:batch_size + 1] + last_page_len_buffer = last_page_len_buffer[:batch_size] + + num_qo_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + num_kv_heads = self.model_config.get_num_kv_heads( + self.parallel_config) + if num_qo_heads // num_kv_heads >= 4: + use_tensor_cores = True + else: + use_tensor_cores = False + decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + decode_workspace_buffer, indptr_buffer, indices_buffer, + last_page_len_buffer, "NHD", use_tensor_cores) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.kv_cache_dtype, self.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange( + 0, batch_size + 1, dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange( + 0, batch_size, dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full( + (batch_size, ), self.block_size, dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=slot_mapping[:batch_size], + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len= + paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.device, + data_type=kv_cache_dtype, + use_cuda_graph=True, + decode_wrapper=decode_wrapper, + prefill_wrapper=None) + attn_metadata.begin_forward() + else: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) if self.lora_config: lora_mapping = LoRAMapping( @@ -883,8 +990,20 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ) self.set_active_loras(set(), lora_mapping) - graph_runner = CUDAGraphRunner(self.model) - hidden_states = graph_runner.capture( + graph_runner = CUDAGraphRunner(self.model, + self.attn_backend.get_name()) + + if self.attn_backend.get_name() == "flashinfer": + graph_runner.flashinfer_indptr_buffer = indptr_buffer + graph_runner.flashinfer_indices_buffer = indices_buffer + graph_runner.flashinfer_last_page_len_buffer = \ + last_page_len_buffer + graph_runner.flashinfer_decode_workspace_buffer = \ + decode_workspace_buffer + graph_runner.flashinfer_decode_wrapper = \ + decode_wrapper + + graph_runner.capture( input_tokens[:batch_size], input_positions[:batch_size], hidden_states[:batch_size] @@ -918,11 +1037,12 @@ def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], ) -> ModelInputForGPUWithSamplingMetadata: - return ( + model_input = \ ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, - )) + ) + return model_input def prepare_model_input( self, @@ -970,6 +1090,36 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.attn_backend.get_name() == "flashinfer": + assert model_input.attn_metadata is not None + assert model_input.input_tokens is not None + if self.flashinfer_decode_workspace_buffer is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + + model_input.attn_metadata.prefill_wrapper = \ + self.flashinfer_prefill_wrapper + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + model_input.attn_metadata.decode_wrapper = self.graph_runners[ + batch_size].flashinfer_decode_wrapper + else: + model_input.attn_metadata.decode_wrapper = \ + self.flashinfer_decode_wrapper + model_input.attn_metadata.begin_forward() + # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata @@ -1020,13 +1170,22 @@ def execute_model( class CUDAGraphRunner: - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, backend_name: str): self.model = model + self.backend_name = backend_name + self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} self._graph: Optional[torch.cuda.CUDAGraph] = None + self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None + self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None + self.flashinfer_indices_buffer: Optional[torch.Tensor] = None + self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None + self.flashinfer_decode_wrapper: Optional[ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None + @property def graph(self): assert self._graph is not None @@ -1079,14 +1238,23 @@ def capture( torch.cuda.synchronize() # Save the input and output buffers. - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - } + if self.backend_name == "flashinfer": + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": attn_metadata.slot_mapping, + } + else: + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": + attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } self.output_buffers = {"hidden_states": hidden_states} return hidden_states @@ -1106,10 +1274,12 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + if self.backend_name != "flashinfer": + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, + non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay()