diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index f8571c0ca030..6332aa42ad64 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -31,6 +31,9 @@ def __init__( fd_interm_tensor=None, device=None, dtype=torch.float16, + enable_streamingllm: bool = False, + start_token_size: int = 4, + generated_token_size: int = 512, ): self.num_heads = num_heads self.head_dim = head_dim @@ -45,12 +48,19 @@ def __init__( self._use_spec_dec = False self._num_tokens_to_verify = None + self.enable_streamingllm = enable_streamingllm + self.start_token_size = start_token_size + self.generated_token_size = generated_token_size + self._current_batch_size = 0 self._sequences_dict = dict() self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32) self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths) - max_blocks_per_seq = (self.max_length + block_size - 1) // block_size + if enable_streamingllm: + max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1 + else: + max_blocks_per_seq = (self.max_length + block_size - 1) // block_size self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32) self._block_tables_helper = torch.full_like(self._block_tables, -1) @@ -109,6 +119,33 @@ def batch_token_ids(self) -> List[List[int]]: out.append(seq.input_token_id + seq.output_token_id) return out + def streamingllm_update_batch(self): + """ + Update sequence_lengths and block_tables when it is necessary to swap out a block. + """ + + updated_block_ids = [] + + if self.current_batch_size > 0: + need_update = False + sequence_lengths_list = self._sequence_lengths.tolist() + block_tables_list = self._block_tables.tolist() + for batch_id in range(self.current_batch_size): + # We assume that the start token occupies the entire first block. + if self.cache_manager.check_block_full(self.block_tables_list[batch_id][-1]): + need_update = True + sequence_lengths_list[batch_id] = sequence_lengths_list[batch_id] - self.block_size + block_id = block_tables_list[batch_id].pop(1) + updated_block_ids.append(block_id) + block_tables_list[batch_id].append(-1) + if need_update: + self._sequence_lengths = torch.tensor( + sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device + ) + self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device) + + return updated_block_ids + def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: """Set batch bucket to use speculatvie decoding. This will notify the adjust the lengths of inputs during modeling, @@ -144,46 +181,13 @@ def _make_compact(self) -> None: self._block_tables_helper.fill_(-1) self._current_batch_size = valid_num - def add_seq( - self, - seq: Sequence, - alloc_block_table: torch.Tensor = None, - alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None, - ) -> Union[torch.Tensor, None]: - """Add a single sequence to the batch. - User could opt to provide either a block table or a function to allocate block tables. - - Args: - seq (Sequence): The sequence to be added to the batch - alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence - alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence, - which is expected to reserve blocks and update status of kv-cache manager. - - Returns: - block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager. - None if the sequence cannot be added. - """ - block_table = None - # TODO might consider sorting by length - if self._current_batch_size < self.max_batch_size: - self._sequences_dict[seq.request_id] = seq - self._sequences_indexes[seq.request_id] = self._current_batch_size - self._sequence_lengths[self._current_batch_size] = seq.sentence_len - # NOTE the added seq still require block table allocation by kvcache manager - block_table = self._block_tables[self._current_batch_size - 1] - if alloc_block_table is not None: - # copy block ids from provided block tables - self._block_tables[self._current_batch_size - 1] = alloc_block_table - elif alloc_block_table_fn: - alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item()) - self._current_batch_size += 1 - return block_table - def add_seqs( self, seqs: List[Sequence], alloc_block_tables: torch.Tensor = None, alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None, + need_reused_block_table: bool = False, + streaningllm_prompt_len: int = 0, ) -> Union[torch.Tensor, None]: """Add a list of sequences to the batch. User could opt to provide either block tables or a function to allocate block tables. @@ -193,7 +197,8 @@ def add_seqs( alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences, which is expected to reserve blocks and update status of kv-cache manager. - + need_reused_block_table (bool): Whether to reuse cached block tables. + streaningllm_prompt_len (int): The length of sentences used for streamingLLM. Returns: block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager. None if the sequences cannot be added. @@ -206,15 +211,24 @@ def add_seqs( num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs)) block_tables = None if num_seqs_to_add > 0: + # NOTE block tables to be updated by kvcache manager + block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] for i, seq in enumerate(seqs[:num_seqs_to_add]): + if need_reused_block_table: + block_tables[i] = seq.block_table self._sequences_dict[seq.request_id] = seq self._sequences_indexes[seq.request_id] = self._current_batch_size + i # TODO external (rename): modify Sequence.sentence_len to seq_len - self._sequence_lengths[ - self._current_batch_size : self._current_batch_size + num_seqs_to_add - ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) - # NOTE block tables to be updated by kvcache manager - block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] + if need_reused_block_table: + self._sequence_lengths[ + self._current_batch_size : self._current_batch_size + num_seqs_to_add + ] = torch.tensor( + [streaningllm_prompt_len[seq_id] for seq_id in range(num_seqs_to_add)], dtype=torch.int32 + ) + else: + self._sequence_lengths[ + self._current_batch_size : self._current_batch_size + num_seqs_to_add + ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) if alloc_block_tables is not None: # copy block ids from provided block tables self._block_tables[ @@ -232,7 +246,10 @@ def add_seqs( return block_tables def pop_seq_update_batch( - self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None + self, + request_id: int, + free_block_table_fn: Callable[[torch.Tensor], None] = None, + update_cached_sequences_fn: Callable[[Sequence], None] = None, ) -> Tuple[Sequence, Union[torch.Tensor, None]]: """Pop a single sequence by id from the batch, and update the batch bucket status. @@ -240,12 +257,21 @@ def pop_seq_update_batch( request_id (int): The uid of the sequence free_block_table_fn (Callable): The function to free the block table of a sequence, if not provided, then we have to release the block table manually after calling this method - + update_cached_sequences_fn (Callable[[Sequence], None]): When enabling streamingllm, the previous inference sequences will be saved to cached_sequences_dict. + This function is used to update cached_sequences_dict. Returns: A tuple of: seq (Sequence): The target sequence and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks, none if the sequence is not found or free_block_table_fn is provided. """ + + # When update_cached_sequences_fn is not None, it indicates that we have enabled streamingLLM. + # At this point, completed sentences will be stored in cached_sequences_dict and will not + # be released within the current function. + assert ( + free_block_table_fn is None or update_cached_sequences_fn is None + ), f"free_block_table_fn={free_block_table_fn} and update_cached_sequences_fn={update_cached_sequences_fn} cannot be passed simultaneously." + seq: Sequence = self._sequences_dict.get(request_id) block_table = None if seq is not None: @@ -268,7 +294,11 @@ def pop_seq_update_batch( if free_block_table_fn: free_block_table_fn(self._block_tables[seq_b_idx]) else: - block_table = self._block_tables[seq_b_idx].detach().clone() + if update_cached_sequences_fn: + # When enabling streamingllm, save previous inference sequences. + update_cached_sequences_fn(seq, self._block_tables[seq_b_idx]) + else: + block_table = self._block_tables[seq_b_idx].detach().clone() # replace block table of the target seq with that of the last seq in the batch self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx] self._block_tables[last_seq_b_idx].fill_(-1) @@ -276,7 +306,11 @@ def pop_seq_update_batch( if free_block_table_fn: free_block_table_fn(self._block_tables[0]) else: - block_table = self._block_tables[0].detach().clone() + if update_cached_sequences_fn: + # When enabling streamingllm, save previous inference sequences. + update_cached_sequences_fn(seq, self._block_tables[seq_b_idx]) + else: + block_table = self._block_tables[0].detach().clone() self._sequence_lengths[0].fill_(0) self._block_tables[0].fill_(-1) self._sequences_indexes.pop(request_id) @@ -339,17 +373,29 @@ def pop_n_seqs( return seqs, block_tables def pop_finished( - self, free_block_table_fn: Callable[[torch.Tensor], None] = None + self, + free_block_table_fn: Callable[[torch.Tensor], None] = None, + update_cached_sequences_fn: Callable[[Sequence], None] = None, ) -> Tuple[List[Sequence], List[torch.Tensor]]: """Pop finished sequences in the batch and a list of block tables of the finished sequences, if free_block_table_fn is not provided. Args: free_block_table_fn (Callable): The function to free the block table of a single sequence + update_cached_sequences_fn (Callable[[Sequence], None]): When enabling streamingllm, the previous inference sequences will be saved to cached_sequences_dict. + This function is used to update cached_sequences_dict. Returns: A tuple of: finished_seqs (List[Sequence]): The finished sequences, and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences. """ + + # When update_cached_sequences_fn is not None, it indicates that we have enabled streamingLLM. + # At this point, completed sentences will be stored in cached_sequences_dict and will not + # be released within the current function. + assert ( + free_block_table_fn is None and update_cached_sequences_fn is None + ), f"free_block_table_fn={free_block_table_fn} and update_cached_sequences_fn={update_cached_sequences_fn} cannot be passed simultaneously." + finished_seqs = [] finished_block_tables = [] for seq in self._sequences_dict.values(): @@ -360,7 +406,7 @@ def pop_finished( # For now, the performance difference is not significant, so we use the frist method to pop seqs. # Precise evaluations to be done. for seq in finished_seqs: - _, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn) + _, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn, update_cached_sequences_fn) if block_table is not None: finished_block_tables.append(block_table) @@ -443,6 +489,14 @@ def merge(self, other: "BatchBucket") -> List[int]: return unmerged_ids + def has_reused_seqs(self): + """ """ + has_reused = False + for seq in self.seqs_li: + if seq.block_table: + has_reused = True + return has_reused + ########## The following methods are expected to be used in modeling ########### # For compatibility. diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 61bc7c8abc9c..7990f65a9b36 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -54,6 +54,7 @@ class InputMetaData(RPC_PARAM): Args: block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None. sequence_lengths (torch.Tensor): A tensor containing sequence lengths. + current_prompt_lengths (torch.Tensor): A tensor containing current prompt lengths. fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None. batch_size (int, optional): The current batch size. Defaults to 64. is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding). @@ -70,6 +71,7 @@ class InputMetaData(RPC_PARAM): block_tables: torch.Tensor = None sequence_lengths: torch.Tensor = None + current_prompt_lengths: torch.Tensor = None fd_inter_tensor: FDIntermTensors = None batch_size: int = 64 # current_batch_size is_prompts: bool = False @@ -89,6 +91,7 @@ def to_rpc_param(self) -> Dict[str, any]: return { "block_tables": self.block_tables.tolist(), "sequence_lengths": self.sequence_lengths.tolist(), + "current_prompt_lengths": self.current_prompt_lengths.tolist(), "batch_size": self.batch_size, "is_prompts": self.is_prompts, "use_cuda_kernel": self.use_cuda_kernel, @@ -117,6 +120,9 @@ def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": sequence_lengths=torch.tensor( rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() ), + current_prompt_lengths=torch.tensor( + rpc_dict["current_prompt_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() + ), batch_size=rpc_dict["batch_size"], is_prompts=rpc_dict["is_prompts"], use_cuda_kernel=rpc_dict["use_cuda_kernel"], @@ -134,6 +140,7 @@ def __repr__(self) -> str: return ( f"InputMetaData(block_tables={self.block_tables}, " f"sequence_lengths={self.sequence_lengths}, " + f"current_prompt_lengths={self.current_prompt_lengths}, " f"fd_inter_tensor={self.fd_inter_tensor}, " f"batch_size={self.batch_size}, " f"is_prompts={self.is_prompts}, " @@ -166,8 +173,9 @@ class InferenceConfig(RPC_PARAM): top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0. - repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. + repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. + ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. block_size (int): The number of blocks in a logical block, defaults to 16. @@ -176,10 +184,12 @@ class InferenceConfig(RPC_PARAM): micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence - high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. - ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. + enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation. + start_token_size(int): The size of the start_token, When using StreamingLLM. + generated_token_size(int): The size of the generated_token, When using StreamingLLM. """ # NOTE: arrange configs according to their importance and frequency of usage @@ -208,6 +218,7 @@ class InferenceConfig(RPC_PARAM): no_repeat_ngram_size: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 forced_eos_token_id: int = None + ignore_eos: bool = False # speculative decoding configs max_n_spec_tokens: int = 5 @@ -221,15 +232,19 @@ class InferenceConfig(RPC_PARAM): pp_size: int = 1 micro_batch_size: int = 1 micro_batch_buffer_size: int = None - high_precision: Optional[bool] = False # cuda kernel option use_cuda_kernel: bool = False + high_precision: Optional[bool] = False # cuda_graph use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference max_context_len_to_capture: int = 512 - ignore_eos: bool = False + + # StreamingLLM + enable_streamingllm: bool = False + start_token_size: int = 4 + generated_token_size: int = 512 def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len @@ -272,6 +287,22 @@ def _verify_config(self) -> None: "{input_text}" in self.prompt_template ), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '" + if self.enable_streamingllm: + assert ( + self.use_cuda_graph == False + ), "We currently do not support using streamingLLM and CUDA graph simultaneously." + assert ( + self.max_input_len <= self.inference_config.generated_token_size + ), f"When enabling streamingLLM, max_input_len={self.max_input_len} must be less or equal than self.inference_config.generated_token_size={self.inference_config.generated_token_size}." + assert ( + self.start_token_size <= self.block_size + ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." + assert ( + self.generated_token_size % self.block_size == 0 + ), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}." + # We assume that start_token_size occupies one block. + self.start_token_size = self.block_size + def to_generation_config(self, model_config) -> GenerationConfig: meta_config = { "max_length": self.max_input_len + self.max_output_len, diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 96c2b15ee16e..fbd95829c102 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -89,6 +89,9 @@ def __init__( self.use_cuda_graph = self.inference_config.use_cuda_graph if self.use_cuda_graph: + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and CUDA graph simultaneously." self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool = None # Set during graph capture. if verbose: @@ -197,6 +200,9 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[P @torch.inference_mode() def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): assert self.use_cuda_graph, "please turn on the cuda graph" + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and CUDA graph simultaneously." if self.verbose: self.logger.info("Colossal AI CUDA Graph Capture begin") @@ -348,6 +354,11 @@ def enable_spec_dec( engine.clear_spec_dec() ``` """ + + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and Speculative Decoding simultaneously." + if drafter_model is None and self.drafter is None: raise ValueError("Drafter not initialized. Please provide a Drafter Model") if n_spec_tokens is not None: @@ -555,6 +566,9 @@ def generate( if self.use_spec_dec: assert self.drafter is not None, "Drafter Model is not initialized." + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and Speculative Decoding simultaneously." while self.request_handler.check_unfinished_seqs(): output_seqs_list += self.steps_spec_dec() else: @@ -596,6 +610,7 @@ def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str] def add_request( self, + user_ids: Union[List[int], int] = None, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, @@ -605,6 +620,7 @@ def add_request( Add requests. Args: + user_id (List[int], optional): The IDs of the input sequences' owner. request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. @@ -617,6 +633,9 @@ def add_request( block_size = self.inference_config.block_size + if user_ids is not None and not isinstance(user_ids, list): + user_ids = [user_ids] + if request_ids is not None and not isinstance(request_ids, list): request_ids = [request_ids] @@ -653,8 +672,11 @@ def add_request( ), f"The request_id type must be int, but got {type(request_ids[0])}" assert len(request_ids) == prompts_num request_id = request_ids[i] + user_id = user_id[i] else: request_id = next(self.counter) + # Default user_id to request_id if not provided. + user_id = request_id if prompts == None: prompt = None else: @@ -667,25 +689,51 @@ def add_request( elif max_length is not None: max_new_tokens = max_length - len(prompts_token_ids[i]) - sequence = Sequence( - request_id, - prompt, - prompts_token_ids[i], - block_size, - None, - self.tokenizer.eos_token_id, - self.tokenizer.pad_token_id, - max_output_len=max_new_tokens, - ignore_eos=self.inference_config.ignore_eos, - ) + assert ( + self.inference_config.max_output_len >= max_new_tokens + ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." + + timestamp = time.time() + cached_sequences_dict = self.request_handler.cached_sequences_dict + + assert len( + prompts_token_ids[i] <= self.inference_config.max_input_len + ), f"The lengths of prompt must be less or equal than max_input_len={self.inference_config.max_input_len}, but got {len(prompts_token_ids[i])}." + + if self.inference_config.enable_streamingllm and user_id in cached_sequences_dict: + sequence = cached_sequences_dict[user_id] + sequence.reused( + request_id, + timestamp, + prompt, + prompts_token_ids[i], + ) + else: + sequence = Sequence( + user_id=user_id, + request_id=request_id, + prompt=prompt, + input_token_id=prompts_token_ids[i], + block_size=block_size, + sample_params=None, + block_table=None, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + max_output_len=max_new_tokens, + ignore_eos=self.inference_config.ignore_eos, + ) + self.request_handler.add_sequence(sequence) def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: input_ids = batch.get_1D_inputs() sequence_lengths = batch.get_sequence_lengths() + current_prompt_lengths = None if batch.is_prompts: - n_tokens = sequence_lengths.sum().item() + if self.inference_config.enable_streamingllm and batch.has_reused_seqs(): + current_prompt_lengths = batch.get_1D_inputs() + n_tokens = current_prompt_lengths.sum().item() else: n_tokens = batch.current_batch_size if batch.use_spec_dec: @@ -712,6 +760,7 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, input_meta_data = InputMetaData( block_tables=batch.get_block_table_tensor(), sequence_lengths=sequence_lengths, + current_prompt_lengths=current_prompt_lengths, fd_inter_tensor=batch.fd_inter_tensor, batch_size=batch.current_batch_size, is_prompts=batch.is_prompts, @@ -746,14 +795,29 @@ def step(self) -> List[str]: input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) if input_meta_data.use_cuda_graph: + assert ( + self.inference_config.enable_streamingllm == False + ), "We currently do not support using streamingLLM and CUDA graph simultaneously." model_executable = self.graph_runners[input_meta_data.batch_size] else: model_executable = self.model # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + logits = model_executable( + input_token_ids, + output_tensor, + input_meta_data, + self.k_cache, + self.v_cache, + self.inference_config.enable_streamingllm and batch.has_reused_seqs(), + ) if self.inference_config.pad_input: logits = logits[:, -1, :] + + if self.inference_config.enable_streamingllm: + updated_block_ids = batch.streamingllm_update_batch() + self.request_handler.streamingllm_free_block_tables(updated_block_ids) + next_tokens = search_tokens( self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids ) diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 5085c55558b4..14435822647c 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -127,6 +127,15 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo ) // inference_config.block_size head_dim = model_config.hidden_size // model_config.num_attention_heads + if self.inference_config.enable_streamingllm: + # user_id -> sequences + self.cached_sequences_dict = {} + # The current default cache size is the max_batch_size. + self.cache_size = inference_config.max_batch_size + else: + self.cached_sequences_dict = None + self.cache_size = 0 + fd_inter_tensor = FDIntermTensors() if fd_inter_tensor._tensors_initialized: @@ -157,6 +166,9 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo fd_interm_tensor=fd_inter_tensor, dtype=self.dtype, device=device, + enable_streamingllm=inference_config.enable_streamingllm, + start_token_size=inference_config.start_token_size, + generated_token_size=inference_config.generated_token_size, ) self.prefill_bb = BatchBucket( num_heads=model_config.num_attention_heads // inference_config.tp_size, @@ -168,6 +180,9 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo fd_interm_tensor=fd_inter_tensor, dtype=self.dtype, device=device, + enable_streamingllm=inference_config.enable_streamingllm, + start_token_size=inference_config.start_token_size, + generated_token_size=inference_config.generated_token_size, ) def _init_cache(self, model_config): @@ -224,10 +239,17 @@ def schedule(self): for seq in self.running_list.prefill[:num_seqs_to_add]: seq.mark_running() # allocate blocks for the prefill batch - self.prefill_bb.add_seqs( - self.running_list.prefill[:num_seqs_to_add], - alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables, - ) + if self.inference_config.enable_streamingllm and self.prefill_bb.has_reused_seqs(): + self.streamingllm_prefill_alloc( + num_seqs_to_add, + self.cache_manager.allocate_context_from_block_tables, + self.cache_manager.allocate_context_from_non_empty_block_tables, + ) + else: + self.prefill_bb.add_seqs( + self.running_list.prefill[:num_seqs_to_add], + alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables, + ) return self.prefill_bb @@ -236,7 +258,7 @@ def schedule(self): self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size ) if seqs_ids_to_recycle: - seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle) + seqs_to_recycle, _ = self.running_bb.pop_seqs(seqs_ids_to_recycle, self.cached_sequences_dict) for seq in seqs_to_recycle: seq.recycle() self.running_list.remove(seq) @@ -343,13 +365,78 @@ def update(self): # since we want to reuse the memory recorded on the block tables self.prefill_bb.clear(free_block_tables_fn=None) - finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table) + finished_seqs, _ = self.running_bb.pop_finished(cached_sequences_dict=self.update_cached_sequences) for seq in finished_seqs: self.running_list.remove(seq) self.done_list.extend(finished_seqs) return finished_seqs + def update_cached_sequences(self, seq: Sequence, block_table: torch.Tensor): + if seq.user_id not in self.cached_sequences_dict: + if self.cache_size == self.max_batch_size: + outdated_seq = min(self.cached_sequences_dict.values(), key=lambda x: x.timestamp) + self.cache_manager.free_block_table(outdated_seq.block_table) + del self.cached_sequences_dict[outdated_seq.user_id] + seq.block_table = block_table.clone() + self.cached_sequences_dict[seq.user_id] = seq + + def streamingllm_free_block_tables(self, updated_block_ids: List[int]): + """ + Free the block that needs to be swapped out. + """ + self.cache_manager.streamingllm_free_block_tables(updated_block_ids) + + def streamingllm_prefill_alloc( + self, + num_seqs_to_add: int = 0, + ): + current_seqs = self.running_list.prefill[:num_seqs_to_add] + + reused_seqs = [seq for seq in current_seqs if seq.block_table is None] + normal_seqs = [seq for seq in current_seqs if seq.block_table is not None] + + if normal_seqs: + self.prefill_bb.add_seqs( + normal_seqs, + alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables, + ) + + if reused_seqs: + block_size = self.inference_config.block_size + start_block_num = self.inference_config.start_token_size // block_size + max_blocks_per_sequence = self.cache_manager.get_max_blocks_per_sequence() + streaningllm_prompt_lens = [] + for seq in reused_seqs: + block_table = seq.block_table + streaningllm_prompt_len = self.cache_manager.get_used_slots(block_table) + vaild_block_num = sum(1 for x in block_table if x > 0) + unused_block_num = max_blocks_per_sequence - vaild_block_num + assert ( + seq.input_len <= self.inference_config.generated_token_size + ), f"When enabling streamingLLM, the length of seq={seq} must be less or equal than self.inference_config.generated_token_size={self.inference_config.generated_token_size}, but got seq.input_len." + need_swap_blocks = (seq.input_len + block_size - 1) // block_size - unused_block_num + + if streaningllm_prompt_len + seq.input_len >= max_blocks_per_sequence: + streaningllm_prompt_len = max_blocks_per_sequence + else: + streaningllm_prompt_len = streaningllm_prompt_len + seq.input_len + + if need_swap_blocks > 0: + block_table[start_block_num : vaild_block_num - need_swap_blocks] = block_table[ + start_block_num + need_swap_blocks : vaild_block_num + ] + block_table[vaild_block_num - need_swap_blocks :] = [-1] * ( + max_blocks_per_sequence - vaild_block_num + need_swap_blocks + ) + + self.prefill_bb.add_seqs( + reused_seqs, + alloc_block_tables_fn=self.cache_manager.allocate_context_from_non_empty_block_tables, + need_reused_block_table=True, + streaningllm_prompt_lens=streaningllm_prompt_lens, + ) + class RPCRequestHandler(RequestHandler): """ diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index a20bd8ee79ea..97bba2296c9f 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -78,11 +78,18 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> N self.max_output_length = config.max_output_len # Cache block settings self.block_size = config.block_size + # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size - self.max_blocks_per_sequence = ( - self.max_input_length + self.max_output_length + self.block_size - 1 - ) // self.block_size - self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width + if config.enable_streamingllm: + self.max_blocks_per_sequence = ( + config.start_token_size + config.generated_token_size + self.block_size - 1 + ) // self.block_size + 1 + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width * 2 + else: + self.max_blocks_per_sequence = ( + self.max_input_length + self.max_output_length + self.block_size - 1 + ) // self.block_size + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Physical cache allocation if config.use_cuda_kernel: @@ -284,6 +291,82 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context block.add_ref() self._allocate_on_block(block, block.block_size) + def allocate_context_from_non_empty_block_tables( + self, block_tables: torch.Tensor, context_lengths: torch.Tensor + ) -> None: + """Allocate logical cache blocks for a batch of sequences during prefill stage from non-empty block_tables. + + Args: + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] + context_lengths (torch.Tensor): [bsz]] + """ + assert block_tables.dim() == 2 + assert block_tables.size(0) == context_lengths.size(0) + + blocks_required = context_lengths // self.block_size + num_blocks_required = torch.sum(blocks_required).item() + assert isinstance(num_blocks_required, int) + if num_blocks_required > self._available_blocks: + self.logger.error( + f"Lacking blocks to allocate. Available blocks {self._available_blocks}; blocks asked {num_blocks_required}." + ) + return + + block_start = torch.sum(block_tables > 0, dim=1) + + for i, block_id in enumerate(block_tables[block_start] - 1): + block: CacheBlock = self._cache_blocks[block_id] + self._allocate_on_block( + block, + block.block_size + if context_lengths[i] % block.block_size == 0 + else context_lengths[i].item() % block.block_size, + ) + + bsz = block_tables.size(0) + # Try contiguous allocation + torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:]) + torch.subtract( + self._block_states_cum[num_blocks_required:], + self._block_states_cum[:-num_blocks_required], + out=self._block_finder[num_blocks_required - 1 :], + ) + end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1) + if end_indexes.numel() > 0: + # contiguous cache exists + end_idx = end_indexes[0].item() + 1 # open interval + start_idx = end_idx - num_blocks_required # closed interval + alloc_block_ids = torch.arange(start_idx, end_idx) + for i in range(bsz): + curr_required = blocks_required[i] + block_tables[i, block_start[i] : curr_required] = torch.arange( + start_idx, start_idx + curr_required, device=block_tables.device + ) + start_idx += curr_required + else: + # non-contiguous cache + available_block_ids = torch.nonzero(self._block_states > 0).view(-1) + alloc_block_ids = available_block_ids[:num_blocks_required] + alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device) + start_idx = 0 + for i in range(bsz): + curr_required = blocks_required[i] + block_tables[i, block_start[i] : curr_required] = alloc_block_ids[start_idx, start_idx + curr_required] + start_idx += curr_required + + # Update cache blocks + self._block_states[alloc_block_ids] = 0 + self._available_blocks -= num_blocks_required + last_block_locs = torch.cumsum(blocks_required, dim=0) - 1 + last_block_locs = last_block_locs.to(device=alloc_block_ids.device) + + for block_id in alloc_block_ids: + if block_id in alloc_block_ids[last_block_locs]: + continue + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + self._allocate_on_block(block, block.block_size) + def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None: """Allocate the logical cache block for a single sequence during decoding stage, and updates the provided block table if a new cache block is needed. @@ -446,6 +529,26 @@ def clear_all(self) -> None: self._available_blocks = self.num_blocks self._block_states[:] = 1 + def streamingllm_free_block_tables(self, updated_block_ids: List[int]): + """ + Free the block that needs to be swapped out. + """ + for global_block_id in updated_block_ids: + if global_block_id < 0: + return + block: CacheBlock = self._cache_blocks[global_block_id] + block.remove_ref() + if not block.has_ref(): + block.allocated_size = 0 + self._available_blocks += 1 + self._block_states[global_block_id] = 1 + + def get_used_slots(self, block_table: torch.Tensor): + num_positive_elements = torch.sum(block_table > 0, dim=0).items() + last_block_id = block_table[num_positive_elements - 1] + block: CacheBlock = self._cache_blocks[last_block_id] + return (num_positive_elements - 1) * self.block_size + block.allocated_size + def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """Get the tensor corresponding to the cache block with the prompted id for a specific layer.""" return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx] @@ -498,6 +601,12 @@ def _init_device_caches( v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device)) return k_cache, v_cache + def check_block_full(self, global_block_id: int) -> bool: + if global_block_id < 0: + return False + block: CacheBlock = self._cache_blocks[global_block_id] + return block.available_space <= 0 + class RPCKVCacheManager(KVCacheManager): def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None: @@ -533,11 +642,18 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb self.max_output_length = config.max_output_len # Cache block settings self.block_size = config.block_size + # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size - self.max_blocks_per_sequence = ( - self.max_input_length + self.max_output_length + self.block_size - 1 - ) // self.block_size - self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width + if config.enable_streamingllm: + self.max_blocks_per_sequence = ( + config.start_token_size + config.generated_token_size + self.block_size - 1 + ) // self.block_size + 1 + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width * 2 + else: + self.max_blocks_per_sequence = ( + self.max_input_length + self.max_output_length + self.block_size - 1 + ) // self.block_size + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Logical cache blocks allocation self._available_blocks = self.num_blocks diff --git a/colossalai/inference/logit_processors.py b/colossalai/inference/logit_processors.py index ea73f833242e..4aaed83d10e8 100644 --- a/colossalai/inference/logit_processors.py +++ b/colossalai/inference/logit_processors.py @@ -157,6 +157,7 @@ def apply_forced_eos_token_id( select_indexes = [] num_sequences = logits.shape[0] + # NOTE (yuehuayingxueluo): When streamingLLM is enabled, the following logic may cause errors. sequence_lengths = sequence_lengths[:num_sequences] max_lengths = max_lengths[:num_sequences] for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)): diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index f6f160eb7e96..4b952a6990e9 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -107,6 +107,9 @@ def llama_model_forward( use_cuda_kernel = False logger.warning("CUDA kernel is disabled for speculative-decoding.") + if inputmetadata.current_prompt_lengths: + current_prompt_lengths = inputmetadata.current_prompt_lengths + hidden_states = self.embed_tokens(input_tokens_ids) cu_seqlens = None @@ -133,9 +136,21 @@ def llama_model_forward( total_length = hidden_states.size(0) cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device) sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device) - inference_ops.get_cos_and_sin( - self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts - ) + if current_prompt_lengths: + inference_ops.prefix_get_cos_and_sin( + self._cos_cached, + self._sin_cached, + cos, + sin, + sequence_lengths, + current_prompt_lengths, + kv_seq_len, + inputmetadata.is_prompts, + ) + else: + inference_ops.get_cos_and_sin( + self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts + ) cos_sin = (cos, sin) else: cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) @@ -157,6 +172,7 @@ def llama_model_forward( is_verifier=inputmetadata.use_spec_dec, tokens_to_verify=tokens_to_verify, sequence_lengths=sequence_lengths, + current_prompt_lengths=current_prompt_lengths, cos_sin=cos_sin, fd_inter_tensor=inputmetadata.fd_inter_tensor, kv_seq_len=kv_seq_len, @@ -186,6 +202,7 @@ def llama_decoder_layer_forward( k_cache: torch.Tensor, v_cache: torch.Tensor, sequence_lengths: torch.Tensor, + current_prompt_lengths: torch.Tensor, cos_sin: Tuple[torch.Tensor], fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, @@ -233,6 +250,7 @@ def llama_decoder_layer_forward( is_verifier=is_verifier, tokens_to_verify=tokens_to_verify, sequence_lengths=sequence_lengths, + current_prompt_lengths=current_prompt_lengths, cos_sin=cos_sin, fd_inter_tensor=fd_inter_tensor, kv_seq_len=kv_seq_len, @@ -486,6 +504,7 @@ def forward( k_cache: torch.Tensor, v_cache: torch.Tensor, sequence_lengths: torch.Tensor, + current_prompt_lengths: torch.Tensor, cos_sin: Tuple[torch.Tensor], fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, @@ -534,26 +553,50 @@ def forward( block_size = k_cache.size(-2) if is_prompts: - if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: - # flash attn 2 currently only supports FP16/BF16. - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) - inference_ops.context_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len - ) + if current_prompt_lengths: + if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: + # flash attn 2 currently only supports FP16/BF16. + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) + inference_ops.context_kv_cache_memcpy( + key_states, + value_states, + k_cache, + v_cache, + sequence_lengths, + cu_seqlens, + block_tables, + kv_seq_len, + ) - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=kv_seq_len, - max_seqlen_k=kv_seq_len, - dropout_p=0.0, - softmax_scale=sm_scale, - causal=True, - ) - attn_output = attn_output.view(token_nums, -1) + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=kv_seq_len, + max_seqlen_k=kv_seq_len, + dropout_p=0.0, + softmax_scale=sm_scale, + causal=True, + ) + attn_output = attn_output.view(token_nums, -1) + else: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + use_new_kcache_layout=use_cuda_kernel, + ) else: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) attn_output = context_attention_unpadded( @@ -563,6 +606,7 @@ def forward( k_cache=k_cache, v_cache=v_cache, context_lengths=sequence_lengths, + current_prompt_lengths=current_prompt_lengths, block_tables=block_tables, block_size=block_size, output=output_tensor, diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 1a3094a27e2d..0666da4590c1 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from typing import Any, List +import torch + from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -51,7 +53,9 @@ class Sequence: """Store information of input sequence. Args: + user_id (int): The ID of the input sequence's owner. request_id (int): The ID of input sequence. + timestamp: (float): It represents the last time seq was called. prompt (str): The prompt of input sequence. input_token_id (List[int]): The tokens ID of input sequence. block_size (int): The block size of input sequence. @@ -62,13 +66,17 @@ class Sequence: max_output_len (int): Maximum output length. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. output(str): The output of sequence + """ + user_id: int request_id: int + timestamp: float prompt: str input_token_id: List[int] block_size: int sample_params: Any # SampleParams needs to be imported later. + block_table: torch.Tensor eos_token_id: int pad_token_id: int max_output_len: int = 256 @@ -78,6 +86,7 @@ class Sequence: def __post_init__(self): self.output_token_id = [] + self.streaningllm_prompt_len = 0 self.status = RequestStatus.WAITING @property @@ -94,6 +103,13 @@ def input_len(self) -> int: """ return len(self.input_token_id) + @property + def streaningllm_prompt_len(self) -> int: + """ + Get the length of sentences used for streamingLLM. + """ + return self.streaningllm_prompt_len + @property def output_len(self) -> int: """ @@ -162,6 +178,14 @@ def recycle(self) -> None: is already done but it still in running list" self.status = RequestStatus.RECYCLED + def reused(self, request_id: int, timestamp: float, prompt: str, input_token_id: List[int]) -> None: + self.request_id = request_id + self.timestamp = timestamp + self.prompt = prompt + self.input_token_id = input_token_id + self.output_token_id = [] + self.status = RequestStatus.WAITING + def __repr__(self) -> str: return ( f"(request_id={self.request_id}, " diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index 9c69c4125d62..3f0258d06a17 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -363,6 +363,213 @@ def _fwd_context_paged_attention_kernel_v2( return +# Triton 2.1.0 +# TODO(yuanheng-zhao): This is a temporary dispatch to use the new layout for kcache +# merge `_fwd_context_paged_attention_kernel_v2` with `_fwd_context_paged_attention_kernel` later +# as the kcache layout has been supported in the whole triton flow. +@triton.jit +def _prefix_fwd_context_paged_attention_kernel_v1( + Q, + K, + V, + O, + KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim] + BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence] + batch_size, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_kcb, # k cache stride(0) - num_blocks + stride_kch, # k cache stride(1) - num_kv_heads + stride_kcsplit_x, # k cache stride(2) - head_dim // x + stride_kcs, # k cache stride(3) - block_szie + stride_kcd, # k cache stride(4) - x + stride_vcb, # v cache stride(0) - num_blocks + stride_vch, # v cache stride(1) - num_kv_heads + stride_vcbs, # v cache stride(2) - block_size + stride_vcd, # v cache stride(3) - head_dim + stride_bts, + stride_btb, + context_lengths, + current_prompt_lengths, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + KCACHE_X: tl.constexpr, # k stride on the second last dimension + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return + cur_head_idx = tl.program_id(1) + block_start_m = tl.program_id(2) # Br, max_input_len // Block_M + cur_kv_head_idx = cur_head_idx // KV_GROUPS + + # NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same + tl.static_assert(BLOCK_M == BLOCK_N) + tl.static_assert(BLOCK_N == BLOCK_SIZE) + + # get the context sequence length from provided context lengths tensor + context_seq_len = tl.load(context_lengths + cur_seq_idx) + # get the current prompt length from provided context lengths tensor + current_prompt_len = tl.load(current_prompt_lengths + cur_seq_idx) + + if block_start_m * BLOCK_M >= context_seq_len: + return + + # NOTE when talking to fused QKV and a nopadding context attention, + # we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum` + # could be considered as the start index of the context sequence. + # FIXME might want to explore better way to get the summation of prev seq lengths. + # `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton. + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(current_prompt_lengths + i) + + offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh + offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh + + computed_blocks = (context_seq_len - current_prompt_len) // BLOCK_SIZE + computed_slots = (context_seq_len - current_prompt_len) % BLOCK_SIZE + + # block table for the context sequence + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + # block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq) + # Consider `block_start_m` as the logical block idx in the context block table, + # as we have BLOCK_M the same size as the block size. + cur_block_idx = block_start_m + cur_block_id = tl.load(block_table_ptr + cur_block_idx * stride_btb) + offsets_dmodel = tl.arange(0, HEAD_DIM) + block_range = tl.arange(0, BLOCK_SIZE) + + if block_start_m >= computed_blocks: + offset_kvcache = cur_block_id * stride_vcb + cur_kv_head_idx * stride_vch + + offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offsets_n = tl.arange(0, BLOCK_N) + + if cur_head_idx % KV_GROUPS == 0: + # Copy k to corresponding cache block + if block_start_m == computed_blocks: + block_range = tl.arange(computed_slots, BLOCK_SIZE) + + X_range = tl.arange(0, KCACHE_X) + # unroll the loop aggressively + for split_x_group_id in tl.static_range(HEAD_DIM // KCACHE_X): + offsets_dmodel_x_partition = tl.arange(split_x_group_id * KCACHE_X, (split_x_group_id + 1) * KCACHE_X) + offsets_k = ( + K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt + ) + k = tl.load(offsets_k, mask=offsets_m[:, None] < current_prompt_len, other=0.0) + # HACK: KCache must be contiguous in order to apply the following offsets calculation + offsets_kcache = ( + KCache + + offset_kvcache + + split_x_group_id * BLOCK_SIZE * KCACHE_X + + block_range[:, None] * KCACHE_X + + X_range[None, :] + ) + tl.store(offsets_kcache, k, mask=block_range[:, None] < current_prompt_len - block_start_m * BLOCK_SIZE) + # Copy v to corresponding cache block + offsets_dmodel = tl.arange(0, HEAD_DIM) # offsets_dmodel + offsets_vt = block_start_m * BLOCK_N + offsets_n + offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_dmodel[:, None] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[None, :] < current_prompt_len, other=0.0) + offsets_vcache = ( + VCache + offset_kvcache + block_range[None, :] * stride_vcbs + offsets_dmodel[:, None] * stride_vcd + ) + tl.store(offsets_vcache, v, mask=block_range[None, :] < current_prompt_len - block_start_m * BLOCK_SIZE) + + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(context_lengths + i) + + offset_kv_cache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch + + Q_block_ptr = tl.make_block_ptr( + base=Q + offset_q, + shape=(current_prompt_len, HEAD_DIM), + strides=(stride_qt, stride_qd), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + K_block_ptr = tl.make_block_ptr( + base=KCache + offset_kv_cache, + shape=(HEAD_DIM, current_prompt_len), + strides=(stride_kd, stride_kt), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=VCache + offset_kv_cache, + shape=(current_prompt_len, HEAD_DIM), + strides=(stride_vt, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + offset_q, + shape=(current_prompt_len, HEAD_DIM), + strides=(stride_ot, stride_od), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0)) + + for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N): + block_start_n = tl.multiple_of(block_start_n, BLOCK_N) + + k = tl.load(K_block_ptr, boundary_check=(0, 1)) + S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + S_ij += tl.dot(Q_i, k) + S_ij *= sm_scale + S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf")) + + m_ij = tl.max(S_ij, 1) # rowmax(Sij) + m_ij = tl.maximum(m_i, m_ij) # m_ij + S_ij -= m_ij[:, None] + p_ij_hat = tl.exp(S_ij) + scale = tl.exp(m_i - m_ij) + l_ij = scale * l_i + tl.sum(p_ij_hat, 1) + acc = acc * scale[:, None] + + v = tl.load(V_block_ptr, boundary_check=(1, 0)) + p_ij_hat = p_ij_hat.to(v.type.element_ty) + + acc += tl.dot(p_ij_hat, v) + l_i = l_ij + m_i = m_ij + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0)) + + return + + # Triton 2.1.0 @triton.jit def _alibi_fwd_context_paged_attention_kernel( @@ -553,11 +760,12 @@ def context_attention_unpadded( q: torch.Tensor, # [num_tokens, num_heads, head_dim] k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] - k_cache: torch.Tensor, # [num_blocks, num_kv_heads, block_size, head_dim] + k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] v_cache: torch.Tensor, # [num_blocks, num_kv_heads, block_size, head_dim] context_lengths: torch.Tensor, # [num_seqs] - block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], - block_size: int, + current_prompt_lengths: torch.Tensor = None, # [num_seqs] + block_tables: torch.Tensor = None, # [num_seqs, max_blocks_per_sequence], + block_size: int = 16, output: torch.Tensor = None, # [num_tokens, num_heads, head_dim] alibi_slopes: torch.Tensor = None, # [num_heads] max_seq_len: int = None, @@ -612,42 +820,86 @@ def context_attention_unpadded( ), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready" x = k_cache_shape[4] # Intuition: 16 // dtype_size - _fwd_context_paged_attention_kernel_v2[grid]( - q, - k, - v, - output, - k_cache, - v_cache, - block_tables, - num_seqs, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - output.stride(0), - head_dim, - 1, - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride(3), - block_tables.stride(0), - block_tables.stride(1), - context_lengths, - sm_scale, - KV_GROUPS=num_kv_group, - BLOCK_SIZE=block_size, - HEAD_DIM=Lk, - KCACHE_X=x, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) + if current_prompt_lengths: + _prefix_fwd_context_paged_attention_kernel_v1( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride(4), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + current_prompt_lengths, + sm_scale, + KV_GROUPS=num_kv_group, + BLOCK_SIZE=block_size, + HEAD_DIM=Lk, + KCACHE_X=x, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + _fwd_context_paged_attention_kernel_v2[grid]( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + sm_scale, + KV_GROUPS=num_kv_group, + BLOCK_SIZE=block_size, + HEAD_DIM=Lk, + KCACHE_X=x, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) return output if alibi_slopes is not None: diff --git a/examples/inference/llama/benchmark_llama.py b/examples/inference/llama/benchmark_llama.py index 2d24d87adfd1..8b6d471ade5f 100644 --- a/examples/inference/llama/benchmark_llama.py +++ b/examples/inference/llama/benchmark_llama.py @@ -142,6 +142,8 @@ def benchmark_inference(args): block_size=32, tp_size=args.tp_size, use_cuda_kernel=True, + enable_streamingllm=True, + generated_token_size=32, ) engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) elif args.mode == "vllm": diff --git a/examples/inference/llama/llama_generation.py b/examples/inference/llama/llama_generation.py index c0a1a585a1b9..73db7dff623e 100644 --- a/examples/inference/llama/llama_generation.py +++ b/examples/inference/llama/llama_generation.py @@ -41,6 +41,9 @@ def infer(args): block_size=16, tp_size=args.tp_size, use_cuda_kernel=args.use_cuda_kernel, + enable_streamingllm=args.enable_streamingllm, + start_token_size=args.start_token_size, + generated_token_size=args.generated_token_size, ) coordinator.print_on_master(f"Initializing Inference Engine...") engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True) @@ -56,6 +59,8 @@ def infer(args): temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, + no_repeat_ngram_size=args.no_repeat_ngram_size, + repetition_penalty=args.repetition_penalty, ) coordinator.print_on_master(f"Generating...") out = engine.generate(prompts=[args.prompt], generation_config=generation_config) @@ -100,6 +105,25 @@ def infer(args): parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation") parser.add_argument("--top_k", type=int, default=50, help="Top k for generation") parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation") + parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM") + parser.add_argument( + "--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM," + ) + parser.add_argument( + "--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM" + ) + parser.add_argument( + "--no_repeat_ngram_size", + type=int, + default=0, + help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.", + ) + parser.add_argument( + "--repetition_penalty", + type=float, + default=1.0, + help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.", + ) args = parser.parse_args() infer(args) diff --git a/examples/inference/llama/run_benchmark.sh b/examples/inference/llama/run_benchmark.sh index 1927159765ba..1f66c6b57f97 100755 --- a/examples/inference/llama/run_benchmark.sh +++ b/examples/inference/llama/run_benchmark.sh @@ -24,8 +24,8 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 # benchmark llama2-7b one single GPU -for input_len in 128 512 1024; do - for output_len in 128 256; do +for input_len in 128; do + for output_len in 256; do for bsz in 16 32 64; do python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${bsz}_${input_len}_${output_len}_${mode}_${GPU}.txt done