From 7169cbae6d86c754872a22ef27dd35d6e2abf209 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 20 Sep 2024 14:25:51 +0200 Subject: [PATCH 01/29] wip --- .../models/flash_causal_lm.py | 121 ++++++++++++++---- 1 file changed, 97 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 33fe30a87cd..8933f13e952 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -149,11 +149,26 @@ class FlashCausalLMBatch(Batch): max_seqlen: int - # Prefill metadata tensors to efficiently compute logprobs + # Prefill metadata tensors prefill_head_indices: Optional[torch.Tensor] - prefill_next_token_indices: Optional[torch.tensor] + prefill_next_token_indices: Optional[torch.Tensor] prefill_cu_outlens: Optional[List[int]] + # Whether at least one request is prefilling/chunking + # == any(prefilling_mask) + prefilling: bool + # For each request, whether they are still prefilling/chunking + prefilling_mask: List[bool] + # For each request, whether the model output should be used or discarded + # If we are chunking, we don't care about the output as it might be different + # from the token in the prompt + use_output_token: List[bool] + + # If the request is decoding, `next_chunk_length = 1` + # `None if not batch.prefilling` + next_chunk_lengths: Optional[List[int]] + next_chunk_lengths_tensor: Optional[torch.Tensor] + # Prefixes prefix_ids: List[List[int]] @@ -232,11 +247,14 @@ def from_tokenized( prefix_ids = [] requests_idx_mapping = {} + chunking = False all_prefill_logprobs = True no_prefill_logprobs = True prefill_head_indices = [] prefill_next_token_indices = [] prefill_cu_outlens = [0] + next_chunk_lengths = [] + use_output_token = [] next_token_chooser_parameters = [] stopping_criterias = [] @@ -276,6 +294,7 @@ def from_tokenized( assert prefix_len > 0 prefix_len -= 1 + # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_len]) @@ -284,9 +303,18 @@ def from_tokenized( input_length = len(tokenized_input) input_lengths.append(input_length) + if True: + # This request only requires one prefill and no chunking + use_output_token.append(True) + next_chunk_lengths.append(1) + else: + chunking = True + raise NotImplementedError + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) + # FIXME: use all input tokens not just postfix ones all_input_ids.append(tokenized_input) # Position ids @@ -357,6 +385,7 @@ def from_tokenized( # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: + raise NotImplementedError request_prefill_cache_indices = torch.arange( cumulative_length + max(0, input_length - sliding_window), cumulative_length + input_length, @@ -368,6 +397,7 @@ def from_tokenized( no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs if r.prefill_logprobs: + raise NotImplementedError prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 @@ -445,6 +475,12 @@ def from_tokenized( adapter_segments, dtype=torch.int32, device=device ) + if chunking: + next_chunk_lengths_tensor = torch.tensor(next_chunk_lengths, dtype=torch.int64, device=device) + else: + next_chunk_lengths = None + next_chunk_lengths_tensor = None + if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 @@ -491,6 +527,11 @@ def from_tokenized( prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, + prefilling=True, + prefilling_mask=[True] * pb.requests.len(), + use_output_token=use_output_token, + next_chunk_lengths=next_chunk_lengths, + next_chunk_lengths_tensor=next_chunk_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, @@ -1426,7 +1467,7 @@ def forward( max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - if cu_seqlen_prefill is None and self.max_past() is not None: + if not batch.prefilling and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. @@ -1440,7 +1481,7 @@ def forward( else: cuda_graph = None - if cu_seqlen_prefill is not None or cuda_graph is None: + if batch.prefilling or cuda_graph is None: if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -1475,6 +1516,7 @@ def forward( adapter_data=adapter_data, ) if batch.prefill_cache_indices is not None: + raise NotImplementedError batch.prefill_cache_indices = None return logits, speculative_logits @@ -1528,7 +1570,6 @@ def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: start = time.time_ns() - prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) @@ -1554,13 +1595,13 @@ def generate_token( adapter_data = AdapterBatchData.from_meta( adapter_meta, self.layer_to_adapter_weights, - prefill, + batch.prefilling, batch.prefill_head_indices, ) out, speculative_logits = self.forward(batch, adapter_data) - if prefill: + if batch.prefilling: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) @@ -1597,22 +1638,31 @@ def generate_token( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if prefill: + if batch.prefilling: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None + if batch.next_chunk_lengths is None: + # We are done prefilling after this forward + next_position_ids = batch.position_ids.new_empty(len(batch)) + # [BATCH_SIZE] + # Last slot for each request, will be incremented later + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + else: + # We still have prefill chunks to go through + next_forward_size = sum(batch.next_chunk_lengths) + next_position_ids = batch.position_ids.new_empty(next_forward_size) + batch.slot_indices = batch.slot_indices.new_empty(next_forward_size) + batch.cu_seqlen_prefill[1:] = torch.cumsum(batch.next_chunk_lengths_tensor, dim=0) else: prefill_logprobs = None next_position_ids = batch.position_ids # Cumulative length cumulative_length = 0 + cumulative_chunk_lengths = 0 # Results generations: List[Generation] = [] @@ -1625,21 +1675,32 @@ def generate_token( # one, we need to first do a GPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time - # For each member of the batch index = 0 + # For each member of the batch for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length - if prefill: + if batch.prefilling: + if batch.next_chunk_lengths is not None: + next_chunk_length = batch.next_chunk_lengths[i] + else: + next_chunk_length = 1 + # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] out_length = out_end_index - out_start_index + position_start_index = + # Initialize position_ids # In decode, we do not need this as we can just increment position ids + + + next_position_ids + next_position_ids[i] = batch.position_ids[end_index - 1] # Initialize adapter indices @@ -1651,6 +1712,7 @@ def generate_token( # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: + raise NotImplementedError if len(batch) > 1: prefill_tokens_indices[out_start_index : out_end_index - 1] = ( batch.input_ids[start_index + 1 : start_index + out_length] @@ -1668,14 +1730,23 @@ def generate_token( cumulative_length += input_length # Update values - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - - if prefill: + if batch.next_prefilling_chunk_lengths is None: + # We are done prefilling + batch.prefilling = False + batch.next_prefilling_chunk_lengths = None + batch.next_prefilling_chunk_lengths_tensor = None + # We do not need cu_seqlen_prefill anymore + batch.cu_seqlen_prefill = None + + if not batch.prefilling: + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.input_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if batch.prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( @@ -1684,7 +1755,8 @@ def generate_token( device=batch.adapter_meta.adapter_segments.device, ) - if prefill and prefill_logprobs: + if batch.prefilling and prefill_logprobs: + raise NotImplementedError # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( @@ -1795,7 +1867,8 @@ def generate_token( generated_text = None # Prefill - if prefill and request.prefill_logprobs: + if batch.prefilling and request.prefill_logprobs: + raise NotImplementedError out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] From de043b53c4ff18e88292cea948f35896e59c5ad1 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:57:18 +0200 Subject: [PATCH 02/29] rollback --- backends/v3/src/backend.rs | 4 +- server/tests/conftest.py | 4 - .../models/flash_causal_lm.py | 121 ++++-------------- .../text_generation_server/models/globals.py | 4 +- 4 files changed, 28 insertions(+), 105 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index f8a10ca2606..77fdb0419bc 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -36,9 +36,9 @@ impl BackendV3 { speculate: u32, ) -> Self { let prefix_caching = - std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); + std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); - let attention: String = std::env::var("ATTENTION").expect("attention env var"); + let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string()); let attention: Attention = attention .parse() diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d99771f8ad0..1efeba5864e 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -2,10 +2,6 @@ import os from text_generation_server.pb import generate_pb2 -os.environ["USE_PREFIX_CACHING"] = "1" -os.environ["ATTENTION"] = "flashinfer" - - @pytest.fixture def default_pb_parameters(): return generate_pb2.NextTokenChooserParameters( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8933f13e952..33fe30a87cd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -149,26 +149,11 @@ class FlashCausalLMBatch(Batch): max_seqlen: int - # Prefill metadata tensors + # Prefill metadata tensors to efficiently compute logprobs prefill_head_indices: Optional[torch.Tensor] - prefill_next_token_indices: Optional[torch.Tensor] + prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] - # Whether at least one request is prefilling/chunking - # == any(prefilling_mask) - prefilling: bool - # For each request, whether they are still prefilling/chunking - prefilling_mask: List[bool] - # For each request, whether the model output should be used or discarded - # If we are chunking, we don't care about the output as it might be different - # from the token in the prompt - use_output_token: List[bool] - - # If the request is decoding, `next_chunk_length = 1` - # `None if not batch.prefilling` - next_chunk_lengths: Optional[List[int]] - next_chunk_lengths_tensor: Optional[torch.Tensor] - # Prefixes prefix_ids: List[List[int]] @@ -247,14 +232,11 @@ def from_tokenized( prefix_ids = [] requests_idx_mapping = {} - chunking = False all_prefill_logprobs = True no_prefill_logprobs = True prefill_head_indices = [] prefill_next_token_indices = [] prefill_cu_outlens = [0] - next_chunk_lengths = [] - use_output_token = [] next_token_chooser_parameters = [] stopping_criterias = [] @@ -294,7 +276,6 @@ def from_tokenized( assert prefix_len > 0 prefix_len -= 1 - # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_len]) @@ -303,18 +284,9 @@ def from_tokenized( input_length = len(tokenized_input) input_lengths.append(input_length) - if True: - # This request only requires one prefill and no chunking - use_output_token.append(True) - next_chunk_lengths.append(1) - else: - chunking = True - raise NotImplementedError - prefix_offsets.append(input_length - 5) read_offsets.append(input_length) - # FIXME: use all input tokens not just postfix ones all_input_ids.append(tokenized_input) # Position ids @@ -385,7 +357,6 @@ def from_tokenized( # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: - raise NotImplementedError request_prefill_cache_indices = torch.arange( cumulative_length + max(0, input_length - sliding_window), cumulative_length + input_length, @@ -397,7 +368,6 @@ def from_tokenized( no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs if r.prefill_logprobs: - raise NotImplementedError prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 @@ -475,12 +445,6 @@ def from_tokenized( adapter_segments, dtype=torch.int32, device=device ) - if chunking: - next_chunk_lengths_tensor = torch.tensor(next_chunk_lengths, dtype=torch.int64, device=device) - else: - next_chunk_lengths = None - next_chunk_lengths_tensor = None - if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 @@ -527,11 +491,6 @@ def from_tokenized( prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, - prefilling=True, - prefilling_mask=[True] * pb.requests.len(), - use_output_token=use_output_token, - next_chunk_lengths=next_chunk_lengths, - next_chunk_lengths_tensor=next_chunk_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, @@ -1467,7 +1426,7 @@ def forward( max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - if not batch.prefilling and self.max_past() is not None: + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. @@ -1481,7 +1440,7 @@ def forward( else: cuda_graph = None - if batch.prefilling or cuda_graph is None: + if cu_seqlen_prefill is not None or cuda_graph is None: if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -1516,7 +1475,6 @@ def forward( adapter_data=adapter_data, ) if batch.prefill_cache_indices is not None: - raise NotImplementedError batch.prefill_cache_indices = None return logits, speculative_logits @@ -1570,6 +1528,7 @@ def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: start = time.time_ns() + prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) @@ -1595,13 +1554,13 @@ def generate_token( adapter_data = AdapterBatchData.from_meta( adapter_meta, self.layer_to_adapter_weights, - batch.prefilling, + prefill, batch.prefill_head_indices, ) out, speculative_logits = self.forward(batch, adapter_data) - if batch.prefilling: + if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) @@ -1638,31 +1597,22 @@ def generate_token( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if batch.prefilling: + if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - if batch.next_chunk_lengths is None: - # We are done prefilling after this forward - next_position_ids = batch.position_ids.new_empty(len(batch)) - # [BATCH_SIZE] - # Last slot for each request, will be incremented later - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - else: - # We still have prefill chunks to go through - next_forward_size = sum(batch.next_chunk_lengths) - next_position_ids = batch.position_ids.new_empty(next_forward_size) - batch.slot_indices = batch.slot_indices.new_empty(next_forward_size) - batch.cu_seqlen_prefill[1:] = torch.cumsum(batch.next_chunk_lengths_tensor, dim=0) + next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + # We do not need cu_seqlen_prefill anymore + batch.cu_seqlen_prefill = None else: prefill_logprobs = None next_position_ids = batch.position_ids # Cumulative length cumulative_length = 0 - cumulative_chunk_lengths = 0 # Results generations: List[Generation] = [] @@ -1675,32 +1625,21 @@ def generate_token( # one, we need to first do a GPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time - index = 0 # For each member of the batch + index = 0 for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length - if batch.prefilling: - if batch.next_chunk_lengths is not None: - next_chunk_length = batch.next_chunk_lengths[i] - else: - next_chunk_length = 1 - + if prefill: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] out_length = out_end_index - out_start_index - position_start_index = - # Initialize position_ids # In decode, we do not need this as we can just increment position ids - - - next_position_ids - next_position_ids[i] = batch.position_ids[end_index - 1] # Initialize adapter indices @@ -1712,7 +1651,6 @@ def generate_token( # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: - raise NotImplementedError if len(batch) > 1: prefill_tokens_indices[out_start_index : out_end_index - 1] = ( batch.input_ids[start_index + 1 : start_index + out_length] @@ -1730,23 +1668,14 @@ def generate_token( cumulative_length += input_length # Update values - if batch.next_prefilling_chunk_lengths is None: - # We are done prefilling - batch.prefilling = False - batch.next_prefilling_chunk_lengths = None - batch.next_prefilling_chunk_lengths_tensor = None - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None - - if not batch.prefilling: - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - - if batch.prefilling: + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.input_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if prefill: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( @@ -1755,8 +1684,7 @@ def generate_token( device=batch.adapter_meta.adapter_segments.device, ) - if batch.prefilling and prefill_logprobs: - raise NotImplementedError + if prefill and prefill_logprobs: # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( @@ -1867,8 +1795,7 @@ def generate_token( generated_text = None # Prefill - if batch.prefilling and request.prefill_logprobs: - raise NotImplementedError + if prefill and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6c518c2caa5..1830dc42f67 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,9 @@ from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -ATTENTION = os.getenv("ATTENTION") +ATTENTION = os.getenv("ATTENTION", "flashinfer") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected From 838756eb188950ee109096e31d0743a2a5e77711 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:40:47 +0200 Subject: [PATCH 03/29] refactor to use prefix/postfix namming + fix all_input_ids_tensor --- .../layers/attention/common.py | 14 +- .../models/flash_causal_lm.py | 295 ++++++++++-------- 2 files changed, 173 insertions(+), 136 deletions(-) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index d6e512c0172..648b010a75a 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -9,7 +9,7 @@ @dataclass class Seqlen: - input_lengths: torch.Tensor + postfix_lengths: torch.Tensor prefix_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor] @@ -18,16 +18,16 @@ class Seqlen: def __init__( self, - input_lengths, + postfix_lengths, prefix_lengths, cu_seqlen_q=None, max_q=None, max_k=None, ): - self.input_lengths = input_lengths + self.postfix_lengths = postfix_lengths self.prefix_lengths = prefix_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape + device = self.postfix_lengths.device + shape = self.postfix_lengths.shape if cu_seqlen_q is None: cu_seqlen_q = torch.arange( shape[0] + 1, @@ -43,7 +43,7 @@ def __init__( # cuda graphs don't like this and this is necessary to clamp within mistral # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.prefix_lengths + total = self.postfix_lengths + self.prefix_lengths torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q @@ -59,7 +59,7 @@ def clamp(self, max): @dataclass class Seqlen: - input_lengths: torch.Tensor + postfix_lengths: torch.Tensor prefix_lengths: torch.Tensor cu_seqlen_q: torch.Tensor max_q: int diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 33fe30a87cd..bb35886c7e1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -143,9 +143,6 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor - # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lens: List[int] - prefix_lens_tensor: torch.Tensor max_seqlen: int @@ -162,8 +159,14 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor: torch.Tensor # Lengths of all generations present in the batch - input_lengths: List[int] - input_lengths_tensor: torch.Tensor + postfix_lengths: List[int] + postfix_lengths_tensor: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lengths: List[int] + prefix_lengths_tensor: torch.Tensor + prompt_lengths: List[int] + prompt_lengths_tensor: torch.Tensor + prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -225,10 +228,13 @@ def from_tokenized( slot_indices = [] prefill_cache_indices = [] - input_lengths = [] + prefix_lengths = [] + postfix_lengths = [] + prompt_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] + all_postfix_ids = [] prefix_ids = [] requests_idx_mapping = {} @@ -257,7 +263,6 @@ def from_tokenized( block_tables = [] slots = [] - prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -266,37 +271,39 @@ def from_tokenized( # request id -> idx in list mapping requests_idx_mapping[r.id] = i - orig_input_length = len(tokenized_input) + prompt_length = len(tokenized_input) + prompt_lengths.append(prompt_length) - prefix_len = r.prefix_len + prefix_length = r.prefix_len assert ( - prefix_len <= orig_input_length - ), f"Prefix {prefix_len} vs input {orig_input_length}" - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 + prefix_length <= prompt_length + ), f"Prefix {prefix_length} vs input {prompt_length}" + if prefix_length == prompt_length: + assert prefix_length > 0 + prefix_length -= 1 # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") - prefix_ids.append(tokenized_input[:prefix_len]) - tokenized_input = tokenized_input[prefix_len:] + prefix_ids.append(tokenized_input[:prefix_length]) + postfix_ids = tokenized_input[prefix_length:] - input_length = len(tokenized_input) - input_lengths.append(input_length) + postfix_length = len(postfix_ids) + postfix_lengths.append(postfix_length) - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) + prefix_offsets.append(postfix_length - 5) + read_offsets.append(postfix_length) + all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) # Position ids request_position_ids = torch.arange( - prefix_len, orig_input_length, dtype=torch.int32 + prefix_length, prompt_length, dtype=torch.int32 ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) + cu_seqlen_prefill.append(cumulative_length + postfix_length) next_token_chooser_parameters.append(r.parameters) @@ -309,7 +316,7 @@ def from_tokenized( ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((input_length,), adapter_index)) + adapter_indices_list.append(torch.full((postfix_length,), adapter_index)) adapter_set.add(adapter_index) # Paged attention @@ -318,11 +325,11 @@ def from_tokenized( speculative_length = 0 if speculative_length is None else speculative_length # Tokens that need to be mapped to blocks. - block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + block_tokens = prompt_length + max_new_tokens - 1 + speculative_length # Tokens that need to be mapped to slots. We don't need slots for the # cached prefix (if present). - slot_tokens = input_length + max_new_tokens - 1 + speculative_length + slot_tokens = postfix_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: @@ -338,19 +345,19 @@ def from_tokenized( else: request_blocks = r.blocks request_slots = r.slots[ - prefix_len: #: orig_input_length + max_new_tokens + speculative_length + prefix_length: #: orig_input_length + max_new_tokens + speculative_length ] block_tables.append(request_blocks) slots.extend(request_slots) - prefix_lens.append(prefix_len) + prefix_lengths.append(prefix_length) num_blocks += len(request_blocks) start_slots.append(cumulative_slot_tokens) request_slot_indices = torch.arange( cumulative_slot_tokens, - cumulative_slot_tokens + input_length, + cumulative_slot_tokens + postfix_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) @@ -358,8 +365,8 @@ def from_tokenized( # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, + cumulative_length + max(0, postfix_length - sliding_window), + cumulative_length + postfix_length, dtype=torch.int64, ) prefill_cache_indices.append(request_prefill_cache_indices) @@ -370,14 +377,16 @@ def from_tokenized( if r.prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 + prefill_out_cumulative_length + postfix_length - 1 ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length + prefill_cu_outlens.append( + prefill_out_cumulative_length + postfix_length + ) + prefill_out_cumulative_length += postfix_length else: prefill_head_indices.append( torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 + [cumulative_length + postfix_length - 1], dtype=torch.int32 ) ) prefill_next_token_indices.append(prefill_out_cumulative_length) @@ -385,12 +394,13 @@ def from_tokenized( prefill_out_cumulative_length += 1 # Update - cumulative_length += input_length + cumulative_length += postfix_length cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, input_length) + max_seqlen = max(max_seqlen, postfix_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( - max_length, input_length + max_new_tokens + speculative_length + max_length, + prefix_length + postfix_length + max_new_tokens + speculative_length, ) adapter_indices = torch.cat(adapter_indices_list).to( @@ -415,13 +425,13 @@ def from_tokenized( ) if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) + input_ids = np.concatenate(all_postfix_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) if sliding_window is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: - input_ids = all_input_ids[0] + input_ids = all_postfix_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] if sliding_window is not None: @@ -436,8 +446,11 @@ def from_tokenized( prefill_cache_indices.to(device) if sliding_window is not None else None ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device + postfix_lengths_tensor = torch.tensor( + postfix_lengths, dtype=torch.int32, device=device + ) + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device ) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) @@ -470,7 +483,9 @@ def from_tokenized( for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) - prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + prefix_lengths_tensor = torch.tensor( + prefix_lengths, dtype=torch.int32, device=device + ) return cls( batch_id=pb.id, @@ -485,14 +500,16 @@ def from_tokenized( block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + prefix_lengths=prefix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + postfix_lengths=postfix_lengths, + postfix_lengths_tensor=postfix_lengths_tensor, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -556,8 +573,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": all_input_ids = [] prefix_ids = [] - input_lengths = [] - prefix_lens = [] + postfix_lengths = [] + prefix_lengths = [] prefix_offsets = [] read_offsets = [] @@ -578,15 +595,15 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": requests.append(self.requests[idx]) # Get length - request_input_length = self.input_lengths[idx] - prefix_len = self.prefix_lens[idx] + request_input_length = self.postfix_lengths[idx] + prefix_length = self.prefix_lengths[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) - input_lengths.append(request_input_length) - prefix_lens.append(prefix_len) + postfix_lengths.append(request_input_length) + prefix_lengths.append(prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -629,9 +646,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - input_lengths_tensor = self.input_lengths_tensor[indices] + postfix_lengths_tensor = self.postfix_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] - prefix_lens_tensor = self.prefix_lens_tensor[indices] + prefix_lengths_tensor = self.prefix_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( @@ -666,10 +683,10 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + postfix_lengths=postfix_lengths, + postfix_lengths_tensor=postfix_lengths_tensor, + prefix_lengths=prefix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -720,7 +737,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch + speculative_length - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( - b.input_lengths, b.stopping_criterias + b.postfix_lengths, b.stopping_criterias ) ), ) @@ -729,13 +746,15 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( total_batch_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) + prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( + total_batch_size + ) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) @@ -753,11 +772,11 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_slots = [] block_tables = [] - prefix_lens = [] + prefix_lengths = [] all_input_ids = [] prefix_ids = [] - input_lengths = [] + postfix_lengths = [] prefix_offsets = [] read_offsets = [] @@ -790,7 +809,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots @@ -817,16 +836,16 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) - prefix_lens.extend(batch.prefix_lens) + prefix_lengths.extend(batch.prefix_lengths) all_input_ids.extend(batch.all_input_ids) prefix_ids.extend(batch.prefix_ids) - input_lengths.extend(batch.input_lengths) + postfix_lengths.extend(batch.postfix_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) @@ -872,15 +891,15 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + prefix_lengths=prefix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + postfix_lengths=postfix_lengths, + postfix_lengths_tensor=postfix_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -1100,9 +1119,9 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = [max_s] * bs + postfix_lengths = [max_s] * bs prefix_lengths = [0] * bs - input_lengths_tensor = ( + postfix_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) @@ -1114,8 +1133,8 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=input_lengths, - prefix_lens=prefix_lengths, + postfix_lengths=postfix_lengths, + prefix_lengths=prefix_lengths, ) from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, @@ -1143,7 +1162,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "input_lengths": input_lengths_tensor, + "postfix_lengths": postfix_lengths_tensor, "prefix_lengths": prefix_lengths_tensor, "state": state, "graph": graph, @@ -1154,12 +1173,12 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=None, - input_lengths_tensor=input_lengths_tensor, + postfix_lengths_tensor=postfix_lengths_tensor, state=state, - prefix_lens_tensor=prefix_lengths_tensor, + prefix_lengths_tensor=prefix_lengths_tensor, ): seqlen = Seqlen( - input_lengths=input_lengths_tensor, + postfix_lengths=postfix_lengths_tensor, prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=None, max_q=1, @@ -1183,7 +1202,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): with torch.cuda.graph(graph, pool=MEM_POOL): seqlen = Seqlen( - input_lengths=input_lengths_tensor, + postfix_lengths=postfix_lengths_tensor, prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=None, max_q=1, @@ -1340,15 +1359,17 @@ def tunableop_warmup(self, seqlen: int): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + postfix_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + prefix_lengths_tensor = torch.zeros( + seqlen, dtype=torch.int32, device=self.device + ) cu_seqlen_prefill = torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ) max_s = seqlen seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + postfix_lengths=postfix_lengths, + prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=1, max_k=seqlen, @@ -1379,7 +1400,7 @@ def forward( kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor + postfix_lengths = batch.postfix_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1396,11 +1417,11 @@ def forward( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + postfix_lengths = ( + postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + prefix_lengths_tensor = ( + batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -1421,8 +1442,8 @@ def forward( kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor + postfix_lengths = batch.postfix_lengths_tensor + prefix_lengths_tensor = batch.prefix_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1444,19 +1465,19 @@ def forward( if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + postfix_lengths=batch.postfix_lengths, + prefix_lengths=batch.prefix_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + postfix_lengths_tensor=postfix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() + max_k = (postfix_lengths + prefix_lengths_tensor).max().item() seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + postfix_lengths=postfix_lengths, + prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, @@ -1485,8 +1506,8 @@ def forward( if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + postfix_lengths=batch.postfix_lengths, + prefix_lengths=batch.prefix_lengths, ) # assert block_tables.shape[0] >= slots.shape[0] cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables @@ -1499,16 +1520,18 @@ def forward( # so it doesn't matter if we override it with bogus values. cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["postfix_lengths"].zero_() + cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor + cuda_graph["prefix_lengths"][ + : prefix_lengths_tensor.shape[0] + ] = prefix_lengths_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - input_lengths_tensor=cuda_graph["input_lengths"], - prefix_lens_tensor=cuda_graph["prefix_lengths"], + postfix_lengths_tensor=cuda_graph["postfix_lengths"], + prefix_lengths_tensor=cuda_graph["prefix_lengths"], state=cuda_graph["state"], ): # Replay the graph @@ -1586,7 +1609,7 @@ def generate_token( accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], + batch.all_input_ids_tensor[:, : max(batch.postfix_lengths)], next_token_logits, speculate, batch.speculative_ids, @@ -1619,7 +1642,12 @@ def generate_token( stopped = True # Zipped iterator - iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) + iterator = zip( + batch.prefix_lengths, + batch.postfix_lengths, + batch.all_input_ids, + accepted_ids, + ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second # one, we need to first do a GPU <-> CPU sync @@ -1627,10 +1655,15 @@ def generate_token( # For each member of the batch index = 0 - for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): + for i, ( + prefix_length, + postfix_length, + all_input_ids, + n_accepted_ids, + ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length - end_index = cumulative_length + input_length + end_index = cumulative_length + postfix_length if prefill: # Indexing metadata @@ -1662,16 +1695,18 @@ def generate_token( ] for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] + batch.all_input_ids_tensor[i, prefix_length + postfix_length + j] = ( + next_input_ids[index] + ) index += 1 - cumulative_length += input_length + cumulative_length += postfix_length # Update values batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids + batch.postfix_lengths_tensor += accepted_ids batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices @@ -1702,7 +1737,7 @@ def generate_token( # Zipped iterator iterator = zip( batch.requests, - batch.input_lengths, + batch.postfix_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, @@ -1867,9 +1902,9 @@ def generate_token( ) # Update values - batch.input_lengths[i] = input_length + n_accepted_ids - if batch.input_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.input_lengths[i] + batch.postfix_lengths[i] = input_length + n_accepted_ids + if batch.postfix_lengths[i] > batch.max_seqlen: + batch.max_seqlen = batch.postfix_lengths[i] batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1893,8 +1928,8 @@ def _forward_context( *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, + postfix_lengths_tensor: torch.Tensor, + prefix_lengths_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": @@ -1905,7 +1940,7 @@ def _forward_context( use_prefill_with_paged_kv_state, ) - # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) + # has_prefix_lengths = any(prefix_length > 0 for prefix_length in prefix_lengths) if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( @@ -1914,12 +1949,12 @@ def _forward_context( ), # block_tables=block_tables_to_ragged( # block_tables=block_tables, - # input_lengths=input_lengths, - # prefix_lens=prefix_lens, + # postfix_lengths=postfix_lengths, + # prefix_lengths=prefix_lengths, # ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + prefix_lens_tensor, + input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -1928,10 +1963,10 @@ def _forward_context( window_left=self.sliding_window, ) else: - assert input_lengths_tensor is not None + assert postfix_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + prefix_lens_tensor, + input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, @@ -1943,19 +1978,21 @@ def _forward_context( def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] + *, block_tables: torch.Tensor, postfix_lengths: List[int], prefix_lengths: List[int] ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(prefix_lens) + assert len(postfix_lengths) == len(prefix_lengths) - total_len = sum(input_lengths) + sum(prefix_lens) + total_len = sum(postfix_lengths) + sum(prefix_lengths) block_tables_ragged = torch.empty( total_len, dtype=torch.int32, device=block_tables.device ) offset = 0 - for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): - seq_len = prefix_len + input_length + for i, (input_length, prefix_length) in enumerate( + zip(postfix_lengths, prefix_lengths) + ): + seq_len = prefix_length + input_length block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] offset += seq_len From e4f9110e147d50101cda8c1f0432692663f4bcc8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:54:59 +0200 Subject: [PATCH 04/29] maybe patching vlms? --- .../models/vlm_causal_lm.py | 57 +++++++++++-------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7f7d2e4d9f4..937811d7d44 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -294,7 +294,7 @@ def forward( kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor + postfix_lengths = batch.postfix_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -311,11 +311,11 @@ def forward( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + postfix_lengths = ( + postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + prefix_lengths_tensor = ( + batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -336,8 +336,8 @@ def forward( kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor + postfix_lengths = batch.postfix_lengths_tensor + prefix_lengths_tensor = batch.prefix_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -357,23 +357,23 @@ def forward( else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = input_lengths + prefix_lens_tensor + input_lengths = postfix_lengths + prefix_lengths_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + postfix_lengths=batch.postfix_lengths, + prefix_lengths=batch.prefix_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + postfix_lengths_tensor=postfix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() + max_k = (postfix_lengths + prefix_lengths_tensor).max().item() seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + postfix_lengths=postfix_lengths, + prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, @@ -410,8 +410,8 @@ def forward( if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + postfix_lengths=batch.postfix_lengths, + prefix_lengths=batch.prefix_lengths, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: @@ -420,13 +420,22 @@ def forward( ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) - - # Replay the graph - cuda_graph["graph"].replay() + cuda_graph["postfix_lengths"].zero_() + cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths + cuda_graph["prefix_lengths"].zero_() + cuda_graph["prefix_lengths"][ + : prefix_lengths_tensor.shape[0] + ] = prefix_lengths_tensor + + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + postfix_lengths_tensor=cuda_graph["postfix_lengths"], + prefix_lengths_tensor=cuda_graph["prefix_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() # Slice output to the correct shape speculative_logits = ( From a85f5ebecd5a35141c7fda741f32420696351188 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:34:08 +0200 Subject: [PATCH 05/29] fix filter and concat --- .../models/flash_causal_lm.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bb35886c7e1..4cc285bfb96 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -573,11 +573,13 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": all_input_ids = [] prefix_ids = [] + prompt_lengths = [] postfix_lengths = [] prefix_lengths = [] prefix_offsets = [] read_offsets = [] + stopping_criterias = [] top_n_tokens = [] adapter_set = set() @@ -595,14 +597,15 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": requests.append(self.requests[idx]) # Get length - request_input_length = self.postfix_lengths[idx] + request_postfix_length = self.postfix_lengths[idx] prefix_length = self.prefix_lengths[idx] - max_seqlen = max(max_seqlen, request_input_length) + max_seqlen = max(max_seqlen, request_postfix_length) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) - postfix_lengths.append(request_input_length) + prompt_lengths.append(self.prompt_lengths[idx]) + postfix_lengths.append(request_postfix_length) prefix_lengths.append(prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -626,17 +629,17 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": start_slots.append(cumulative_max_length) # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + slot_indices[i] = cumulative_max_length + request_postfix_length - 1 # Set slice slot_filtering_indices[ self.start_slots[idx] : self.start_slots[idx] - + request_input_length + + request_postfix_length + remaining_tokens - 1 ] = True - cumulative_max_length += request_input_length + remaining_tokens - 1 + cumulative_max_length += request_postfix_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) @@ -647,6 +650,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices] + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) @@ -683,6 +687,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, postfix_lengths_tensor=postfix_lengths_tensor, prefix_lengths=prefix_lengths, @@ -732,12 +738,13 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max_length = max( max_length, max( - input_length + prefix_length + + postfix_length + stopping_criteria.max_new_tokens + speculative_length - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - b.postfix_lengths, b.stopping_criterias + for prefix_length, postfix_length, stopping_criteria in zip( + b.prefix_lengths, b.postfix_lengths, b.stopping_criterias ) ), ) @@ -746,6 +753,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( + total_batch_size + ) postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( total_batch_size ) @@ -776,6 +786,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch all_input_ids = [] prefix_ids = [] + prompt_lengths = [] postfix_lengths = [] prefix_offsets = [] read_offsets = [] @@ -809,6 +820,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots @@ -845,6 +857,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch all_input_ids.extend(batch.all_input_ids) prefix_ids.extend(batch.prefix_ids) + prompt_lengths.extend(batch.prompt_lengths) postfix_lengths.extend(batch.postfix_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) @@ -898,6 +911,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, postfix_lengths_tensor=postfix_lengths_tensor, prefix_offsets=prefix_offsets, From 962ccfd5b765f8a28a29db5ed8d48434fcbdd56d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:10:00 +0200 Subject: [PATCH 06/29] wip, no filter, no concat --- .../models/flash_causal_lm.py | 542 ++++++++++++++---- server/text_generation_server/models/types.py | 8 + 2 files changed, 452 insertions(+), 98 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4cc285bfb96..8a9512c9cd1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -64,6 +64,8 @@ # Will be set in init SLIDING_WINDOW: Optional[int] = None +TOKEN_BUDGET = 8 + def set_sliding_window(sliding_window: int): global SLIDING_WINDOW @@ -144,12 +146,14 @@ class FlashCausalLMBatch(Batch): # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor - max_seqlen: int + max_postfix_length: int + max_current_length: int # Prefill metadata tensors to efficiently compute logprobs prefill_head_indices: Optional[torch.Tensor] prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] + prefill_tokens: List[Optional[Tokens]] # Prefixes prefix_ids: List[List[int]] @@ -257,7 +261,8 @@ def from_tokenized( prefill_out_cumulative_length = 0 num_blocks = 0 - max_seqlen = 0 + max_postfix_length = 0 + max_current_length = 0 max_length = 0 max_blocks = 0 @@ -285,20 +290,21 @@ def from_tokenized( # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_length]) - postfix_ids = tokenized_input[prefix_length:] + postfix_ids = tokenized_input[prefix_length : prefix_length + 10] + # postfix_ids = tokenized_input[prefix_length:] postfix_length = len(postfix_ids) postfix_lengths.append(postfix_length) - prefix_offsets.append(postfix_length - 5) - read_offsets.append(postfix_length) + prefix_offsets.append(prompt_length - 5) + read_offsets.append(prompt_length) all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) # Position ids request_position_ids = torch.arange( - prefix_length, prompt_length, dtype=torch.int32 + prefix_length, prefix_length + postfix_length, dtype=torch.int32 ) position_ids.append(request_position_ids) @@ -396,11 +402,12 @@ def from_tokenized( # Update cumulative_length += postfix_length cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, postfix_length) max_blocks = max(max_blocks, len(request_blocks)) + max_postfix_length = max(max_postfix_length, postfix_length) + max_current_length = max(max_current_length, prefix_length + postfix_length) max_length = max( max_length, - prefix_length + postfix_length + max_new_tokens + speculative_length, + prompt_length + max_new_tokens + speculative_length, ) adapter_indices = torch.cat(adapter_indices_list).to( @@ -502,10 +509,12 @@ def from_tokenized( slots=slots, prefix_lengths=prefix_lengths, prefix_lengths_tensor=prefix_lengths_tensor, - max_seqlen=max_seqlen, + max_postfix_length=max_postfix_length, + max_current_length=max_current_length, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, + prefill_tokens=[None] * len(pb.requests), postfix_lengths=postfix_lengths, postfix_lengths_tensor=postfix_lengths_tensor, prompt_lengths=prompt_lengths, @@ -565,7 +574,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_seqlen = 0 + max_postfix_length = 0 + max_current_length = 0 requests = [] start_slots = [] @@ -579,6 +589,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefix_offsets = [] read_offsets = [] + prefill_tokens = [] stopping_criterias = [] top_n_tokens = [] @@ -598,15 +609,18 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Get length request_postfix_length = self.postfix_lengths[idx] - prefix_length = self.prefix_lengths[idx] - max_seqlen = max(max_seqlen, request_postfix_length) + request_prefix_length = self.prefix_lengths[idx] + max_postfix_length = max(max_postfix_length, request_postfix_length) + max_current_length = max( + max_current_length, request_prefix_length + request_postfix_length + ) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) prompt_lengths.append(self.prompt_lengths[idx]) postfix_lengths.append(request_postfix_length) - prefix_lengths.append(prefix_length) + prefix_lengths.append(request_prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -614,6 +628,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": stopping_criterias.append(stopping_criteria) top_n_tokens.append(self.top_n_tokens[idx]) + prefill_tokens.append(self.prefill_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) @@ -683,10 +698,12 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_seqlen=max_seqlen, + max_postfix_length=max_postfix_length, + max_current_length=max_current_length, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_tokens=prefill_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, @@ -725,7 +742,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_slots = 0 max_blocks = 0 max_length = 0 - max_seqlen = 0 + max_postfix_length = 0 + max_current_length = 0 for b in batches: total_batch_size += len(b) total_slots += len(b.slots) @@ -734,7 +752,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) max_blocks = max(max_blocks, b.max_blocks) - max_seqlen = max(max_seqlen, b.max_seqlen) + max_postfix_length = max(max_postfix_length, b.max_postfix_length) + max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( @@ -791,6 +810,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefix_offsets = [] read_offsets = [] + prefill_tokens = [] + next_token_chooser_parameters = [] fsm_grammar_states = [] stopping_criterias = [] @@ -862,6 +883,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) + prefill_tokens.extend(batch.prefill_tokens) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) @@ -907,10 +930,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefix_lengths=prefix_lengths, prefix_lengths_tensor=prefix_lengths_tensor, slots=slots, - max_seqlen=max_seqlen, + max_postfix_length=max_postfix_length, + max_current_length=max_current_length, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_tokens=prefill_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, @@ -1416,7 +1441,7 @@ def forward( block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] postfix_lengths = batch.postfix_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -1459,7 +1484,7 @@ def forward( slots = batch.slots[batch.slot_indices] postfix_lengths = batch.postfix_lengths_tensor prefix_lengths_tensor = batch.prefix_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -1608,15 +1633,47 @@ def generate_token( if prefill_logprobs else speculative_logits ) - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( - len(batch) - ) - + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: + prefill_logprobs = None next_token_logits = out next_adapter_indices = batch.adapter_meta.adapter_indices - speculate = get_speculate() + finished_prefilling = True + next_chunk_lengths = [] + if prefill: + # Budget in tokens for the next batch + # We remove next input ids to always have enough space for at least a single decode + # for the remaining requests + batch_budget = TOKEN_BUDGET - len(batch) + for prefix_length, postfix_length, prompt_length in zip( + batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths + ): + remaining_prefill_tokens = max( + prompt_length - prefix_length - postfix_length, 0 + ) + if remaining_prefill_tokens > 0: + next_chunk_length = max( + min(remaining_prefill_tokens, batch_budget), 1 + ) + batch_budget -= next_chunk_length + finished_prefilling = False + else: + # Since speculation will be turned off, this is always true + next_chunk_length = 1 + next_chunk_lengths.append(next_chunk_length) + + # Turn off speculative if some requests are still prefilling + # It makes the logic easier to follow + if prefill and not finished_prefilling: + speculate = 0 + speculative_logits = None + else: + speculate = get_speculate() + ( next_input_ids, next_token_logprobs, @@ -1624,7 +1681,7 @@ def generate_token( accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : max(batch.postfix_lengths)], + batch.all_input_ids_tensor[:, : batch.max_current_length], next_token_logits, speculate, batch.speculative_ids, @@ -1635,18 +1692,15 @@ def generate_token( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if prefill: - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: next_position_ids = batch.position_ids.new_empty(len(batch)) batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None - else: - prefill_logprobs = None + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( + len(batch) + ) + elif not prefill: next_position_ids = batch.position_ids # Cumulative length @@ -1658,6 +1712,7 @@ def generate_token( # Zipped iterator iterator = zip( + batch.prompt_lengths, batch.prefix_lengths, batch.postfix_lengths, batch.all_input_ids, @@ -1671,6 +1726,7 @@ def generate_token( # For each member of the batch index = 0 for i, ( + prompt_length, prefix_length, postfix_length, all_input_ids, @@ -1686,15 +1742,16 @@ def generate_token( out_end_index = batch.prefill_cu_outlens[i + 1] out_length = out_end_index - out_start_index - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] + if finished_prefilling: + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ - end_index - 1 - ] + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices @@ -1709,30 +1766,29 @@ def generate_token( start_index + 1 : start_index + out_length ] - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, prefix_length + postfix_length + j] = ( - next_input_ids[index] - ) - index += 1 + # Represent whether this request is still prefilling + # If it is, the tokens we decoded should be ignored + accept_tokens = prefix_length + postfix_length >= prompt_length - cumulative_length += postfix_length + if accept_tokens: + # Only save tokens if we are done prefilling for this request + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[ + i, prefix_length + postfix_length + j + ] = next_input_ids[index] + index += 1 - # Update values - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.postfix_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices + cumulative_length += postfix_length - if prefill: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) + # Update values + # These values can be updated without a GPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.postfix_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: # Get prefill logprobs @@ -1743,15 +1799,265 @@ def generate_token( # GPU <-> CPU sync prefill_logprobs = prefill_logprobs.view(-1).tolist() + # Does a GPU <-> CPU sync internally + if prefill and finished_prefilling: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + skip_tokens = {} + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + sliding_window = get_sliding_windows() + position_ids = [] + cu_seqlen_prefill = [0] + start_slots = [] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + slots = [] + adapter_indices_list = [] + + for i, ( + r, + next_token_id, + all_input_ids, + prefix_length, + postfix_length, + prompt_length, + next_chunk_length, + ) in enumerate( + zip( + batch.requests, + next_token_ids, + batch.all_input_ids, + batch.prefix_lengths, + batch.postfix_lengths, + batch.prompt_lengths, + next_chunk_lengths, + ) + ): + continue_prefilling = prefix_length + postfix_length < prompt_length + skip_tokens[r.id] = True + if continue_prefilling: + # Update prefix length + prefix_length = prefix_length + postfix_length + batch.prefix_lengths[i] = prefix_length + + # Update postfix length + postfix_length = next_chunk_length + batch.max_postfix_length = max( + batch.max_postfix_length, postfix_length + ) + batch.postfix_lengths[i] = postfix_length + + # Potentially update max_current_length + current_length = prefix_length + postfix_length + batch.max_current_length = max( + batch.max_current_length, current_length + ) + + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[ + prefix_length : prefix_length + postfix_length + ] + + # Position ids + request_position_ids = torch.arange( + prefix_length, prefix_length + postfix_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + postfix_length) + + request_slots = r.slots[prefix_length:] + request_slot_indices = torch.arange( + cumulative_slot_tokens, + cumulative_slot_tokens + postfix_length, + dtype=torch.int64, + ) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, postfix_length - sliding_window), + cumulative_length + postfix_length, + dtype=torch.int64, + ) + + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs + + if r.prefill_logprobs: + prefill_head_indices.append( + request_position_ids + cumulative_length + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + postfix_length - 1 + ) + prefill_cu_outlens.append( + prefill_out_cumulative_length + postfix_length + ) + prefill_out_cumulative_length += postfix_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + postfix_length - 1], + dtype=torch.int32, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + # Position_ids + position_ids.append( + torch.tensor( + (prefix_length + postfix_length,), dtype=torch.int32 + ) + ) + + # Add this request token + cu_seqlen_prefill.append(cumulative_length + 1) + + request_slots = r.slots[prefix_length:] + request_slot_indices = torch.tensor( + (cumulative_slot_tokens + postfix_length,), dtype=torch.int64 + ) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.tensor( + [cumulative_length], dtype=torch.int64 + ) + + prefill_head_indices.append( + torch.tensor([cumulative_length], dtype=torch.int32) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + all_postfix_ids.extend(postfix_ids) + start_slots.append(cumulative_slot_tokens) + slots.extend(request_slots) + slot_indices.append(request_slot_indices) + + if sliding_window is not None: + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append( + torch.full((postfix_length,), adapter_index) + ) + + # Update + cumulative_length += postfix_length + cumulative_slot_tokens += len(request_slots) + + device = batch.input_ids.device + batch.start_slots = torch.tensor(start_slots, dtype=torch.int64) + + if len(batch) > 1: + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + position_ids = position_ids[0] + slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] + + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + batch.cu_seqlen_prefill = cu_seqlen_prefill + batch.position_ids = position_ids.to(device) + batch.slot_indices = slot_indices.to(device) + batch.prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) + batch.input_ids = torch.tensor( + all_postfix_ids, dtype=torch.int64, device=device + ) + batch.postfix_lengths_tensor = torch.tensor( + batch.postfix_lengths, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + batch.prefill_head_indices = prefill_head_indices + batch.prefill_next_token_indices = prefill_next_token_indices + batch.slots = torch.tensor(slots, dtype=torch.int64, device=device) + batch.prefix_lengths_tensor = torch.tensor( + batch.prefix_lengths, dtype=torch.int32, device=device + ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + batch.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=batch.adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) + start_decode = time.time_ns() # Zipped iterator iterator = zip( batch.requests, + batch.prompt_lengths, + batch.prefix_lengths, batch.postfix_lengths, batch.prefix_offsets, batch.read_offsets, @@ -1770,7 +2076,9 @@ def generate_token( index = 0 for i, ( request, - input_length, + prompt_length, + prefix_length, + postfix_length, prefix_offset, read_offset, stopping_criteria, @@ -1783,6 +2091,61 @@ def generate_token( top_token_ids, top_token_logprobs, ) in enumerate(iterator): + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable + if request.id % self.world_size == self.rank: + # Prefill + if prefill and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + request_prefill_tokens = batch.prefill_tokens[i] + + request_prefill_logprobs = prefill_logprobs[ + out_start_index : out_end_index - 1 + ] + prefill_token_ids = all_input_ids[:-1] + + if request_prefill_tokens is None: + # Remove generated token to only have prefill and add nan for first prompt token + request_prefill_logprobs = [float("nan")] * ( + len(prefix_ids) + 1 + ) + request_prefill_logprobs + prefill_token_ids = prefix_ids + prefill_token_ids + + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + + prefill_tokens = Tokens( + prefix_ids + prefill_token_ids, + request_prefill_logprobs, + prefill_texts, + is_special=[], + ) + if request_prefill_tokens is not None: + prefill_tokens = request_prefill_tokens + prefill_tokens + + batch.prefill_tokens[i] = prefill_tokens + else: + batch.prefill_tokens[i] = None + + # Represent whether this request is still prefilling + # If it is, the tokens we decoded should be ignored + skip_token = skip_tokens.get(request.id, False) + + if skip_token: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + # Skip the rest of the decoding + # Values were updated before this for loop + continue + # Append next token to all tokens next_token_texts = [] left = 0 @@ -1823,7 +2186,7 @@ def generate_token( # Shard generations # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: + if request.id % self.world_size == self.rank: if stop: # Decode generated tokens output_text, _, _ = self.decode_token( @@ -1844,31 +2207,6 @@ def generate_token( else: generated_text = None - # Prefill - if prefill and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = ( - [float("nan")] * (len(prefix_ids) + 1) - ) + prefill_logprobs[out_start_index : out_end_index - 1] - prefill_token_ids = all_input_ids[:-1] - prefill_texts = self.tokenizer.batch_decode( - prefix_ids + prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - - prefill_tokens = Tokens( - prefix_ids + prefill_token_ids, - request_prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - if top_n_tokens > 0: all_top_tokens = [] for top_token_ids, top_token_logprobs in zip( @@ -1896,7 +2234,7 @@ def generate_token( generation = Generation( request.id, - prefill_tokens, + batch.prefill_tokens[i], Tokens( _next_token_ids, _next_token_logprobs, @@ -1917,9 +2255,13 @@ def generate_token( ) # Update values - batch.postfix_lengths[i] = input_length + n_accepted_ids - if batch.postfix_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.postfix_lengths[i] + current_postfix_length = postfix_length + n_accepted_ids + batch.max_postfix_length = max( + batch.max_postfix_length, current_postfix_length + ) + batch.postfix_lengths[i] = current_postfix_length + current_length = prefix_length + current_postfix_length + batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1930,9 +2272,13 @@ def generate_token( decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index d4e7cca7504..ed9ae98959c 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -74,6 +74,14 @@ def to_pb(self) -> generate_pb2.Tokens: def __len__(self): return len(self.token_ids) + def __add__(self, other: "Tokens") -> "Tokens": + return Tokens( + self.token_ids + other.token_ids, + self.logprobs + other.logprobs, + self.texts + other.texts, + self.is_special + other.is_special, + ) + @dataclass class Generation: From 0e31619893a44ba0a429c01c49ee628ab5c89569 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:03:13 +0200 Subject: [PATCH 07/29] current --- .../models/flash_causal_lm.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8a9512c9cd1..1e81e673cfa 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -226,6 +226,7 @@ def from_tokenized( device: torch.device, ) -> "FlashCausalLMBatch": sliding_window = get_sliding_windows() + speculate = get_speculate() position_ids = [] cu_seqlen_prefill = [0] start_slots = [] @@ -280,17 +281,21 @@ def from_tokenized( prompt_lengths.append(prompt_length) prefix_length = r.prefix_len + postfix_length = prefix_length + 10 assert ( prefix_length <= prompt_length ), f"Prefix {prefix_length} vs input {prompt_length}" if prefix_length == prompt_length: assert prefix_length > 0 prefix_length -= 1 + if prefix_length + postfix_length < prompt_length: + # FIXME: speculate is not supported for context chunking at the moment + assert speculate == 0 # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_length]) - postfix_ids = tokenized_input[prefix_length : prefix_length + 10] + postfix_ids = tokenized_input[prefix_length : postfix_length] # postfix_ids = tokenized_input[prefix_length:] postfix_length = len(postfix_ids) @@ -1864,8 +1869,8 @@ def generate_token( ) ): continue_prefilling = prefix_length + postfix_length < prompt_length - skip_tokens[r.id] = True if continue_prefilling: + skip_tokens[r.id] = True # Update prefix length prefix_length = prefix_length + postfix_length batch.prefix_lengths[i] = prefix_length @@ -1980,11 +1985,11 @@ def generate_token( ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) adapter_indices_list.append( - torch.full((postfix_length,), adapter_index) + torch.full((next_chunk_length,), adapter_index) ) # Update - cumulative_length += postfix_length + cumulative_length += next_chunk_length cumulative_slot_tokens += len(request_slots) device = batch.input_ids.device From 173bc99ab36b9b8a76aa6f973d7431c2eb697931 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 30 Sep 2024 17:58:14 +0200 Subject: [PATCH 08/29] add prepare_for_prefill --- .../models/flash_causal_lm.py | 882 ++++++++---------- .../text_generation_server/utils/segments.py | 1 + 2 files changed, 385 insertions(+), 498 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1e81e673cfa..2836dcc8c2e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -120,39 +120,47 @@ class FlashCausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor - position_ids: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + position_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor] - # Flash Attention values - - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill - cu_seqlen_prefill: Optional[torch.Tensor] - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] - - # Paged Attention values - # Set when creating the batch # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + start_slots: Optional[torch.Tensor] # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode - slot_indices: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slot_indices: Optional[torch.Tensor] # list of length b of list of length s_i // block_size block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slots: Optional[torch.Tensor] max_postfix_length: int max_current_length: int + # Whether this batch contains at least one request that is prefilling + prefilling: bool + # Whether each request is prefilling + prefilling_mask: List[bool] + # Prefill metadata tensors to efficiently compute logprobs + # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_head_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_next_token_indices: Optional[torch.tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] + # Will be set by `generate_token` and reset after each prefill forward prefill_tokens: List[Optional[Tokens]] # Prefixes @@ -164,12 +172,13 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch postfix_lengths: List[int] - postfix_lengths_tensor: torch.Tensor # size [b], containing the number of blocks that can be retrieved from the cache prefix_lengths: List[int] - prefix_lengths_tensor: torch.Tensor prompt_lengths: List[int] - prompt_lengths_tensor: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + postfix_lengths_tensor: Optional[torch.Tensor] + prefix_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: Optional[torch.Tensor] prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -181,7 +190,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor # Adapter metadata for each request - adapter_meta: AdapterBatchMetadata + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + adapter_meta: Optional[AdapterBatchMetadata] # Number of blocks in this batch num_blocks: int @@ -225,13 +235,7 @@ def from_tokenized( dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - sliding_window = get_sliding_windows() speculate = get_speculate() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] prefix_lengths = [] postfix_lengths = [] @@ -243,24 +247,10 @@ def from_tokenized( prefix_ids = [] requests_idx_mapping = {} - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - next_token_chooser_parameters = [] stopping_criterias = [] top_n_tokens = [] - adapter_indices_list = [] - adapter_set = set() - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - num_blocks = 0 max_postfix_length = 0 max_current_length = 0 @@ -268,7 +258,6 @@ def from_tokenized( max_blocks = 0 block_tables = [] - slots = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -292,8 +281,6 @@ def from_tokenized( # FIXME: speculate is not supported for context chunking at the moment assert speculate == 0 - # Commented as it's costly. - # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_length]) postfix_ids = tokenized_input[prefix_length : postfix_length] # postfix_ids = tokenized_input[prefix_length:] @@ -307,15 +294,6 @@ def from_tokenized( all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) - # Position ids - request_position_ids = torch.arange( - prefix_length, prefix_length + postfix_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + postfix_length) - next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( @@ -325,11 +303,6 @@ def from_tokenized( stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((postfix_length,), adapter_index)) - adapter_set.add(adapter_index) - # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() @@ -338,75 +311,21 @@ def from_tokenized( # Tokens that need to be mapped to blocks. block_tokens = prompt_length + max_new_tokens - 1 + speculative_length - # Tokens that need to be mapped to slots. We don't need slots for the - # cached prefix (if present). - slot_tokens = postfix_length + max_new_tokens - 1 + speculative_length - # blocks and slots can be empty (for example in warmup) if not r.blocks: needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] - request_slots = [ - s - for b in request_blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] else: request_blocks = r.blocks - request_slots = r.slots[ - prefix_length: #: orig_input_length + max_new_tokens + speculative_length - ] block_tables.append(request_blocks) - slots.extend(request_slots) prefix_lengths.append(prefix_length) num_blocks += len(request_blocks) - start_slots.append(cumulative_slot_tokens) - - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + postfix_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, postfix_length - sliding_window), - cumulative_length + postfix_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + postfix_length - 1 - ) - prefill_cu_outlens.append( - prefill_out_cumulative_length + postfix_length - ) - prefill_out_cumulative_length += postfix_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + postfix_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 # Update - cumulative_length += postfix_length - cumulative_slot_tokens += slot_tokens max_blocks = max(max_blocks, len(request_blocks)) max_postfix_length = max(max_postfix_length, postfix_length) max_current_length = max(max_current_length, prefix_length + postfix_length) @@ -415,14 +334,9 @@ def from_tokenized( prompt_length + max_new_tokens + speculative_length, ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -438,92 +352,37 @@ def from_tokenized( if len(pb.requests) > 1: input_ids = np.concatenate(all_postfix_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) else: input_ids = all_postfix_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - postfix_lengths_tensor = torch.tensor( - postfix_lengths, dtype=torch.int32, device=device - ) - prompt_lengths_tensor = torch.tensor( - prompt_lengths, dtype=torch.int32, device=device - ) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) - block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) - prefix_lengths_tensor = torch.tensor( - prefix_lengths, dtype=torch.int32, device=device - ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, - slot_indices=slot_indices, + block_tables=block_tables, block_tables_tensor=block_tables_tensor, - slots=slots, prefix_lengths=prefix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, max_postfix_length=max_postfix_length, max_current_length=max_current_length, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, + prefilling=True, + prefilling_mask=[True] * len(pb.requests), prefill_tokens=[None] * len(pb.requests), postfix_lengths=postfix_lengths, - postfix_lengths_tensor=postfix_lengths_tensor, prompt_lengths=prompt_lengths, - prompt_lengths_tensor=prompt_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -535,13 +394,22 @@ def from_tokenized( top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), speculative_ids=None, + + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + start_slots=None, + slot_indices=None, + slots=None, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + prefix_lengths_tensor=None, + postfix_lengths_tensor=None, + prompt_lengths_tensor=None, + adapter_meta=None, ) @classmethod @@ -594,6 +462,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefix_offsets = [] read_offsets = [] + prefilling_mask = [] prefill_tokens = [] stopping_criterias = [] @@ -604,6 +473,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": max_blocks = 0 # Cumulative length cumulative_max_length = 0 + prefilling=False for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -612,6 +482,11 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": requests.append(self.requests[idx]) + # Prefilling + request_prefilling = self.prefilling_mask[idx] + prefilling = prefilling or request_prefilling + prefilling_mask.append(request_prefilling) + # Get length request_postfix_length = self.postfix_lengths[idx] request_prefix_length = self.prefix_lengths[idx] @@ -705,6 +580,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": slots=slots, max_postfix_length=max_postfix_length, max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -742,6 +619,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch requests = [] requests_idx_mapping = {} + prefilling = False num_blocks = 0 total_batch_size = 0 total_slots = 0 @@ -772,6 +650,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) ), ) + prefilling = prefilling or b.prefilling input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) @@ -821,6 +700,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] + prefilling_mask = [] # Cumulative length cumulative_batch_size = 0 @@ -878,6 +758,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_slots.append(batch.start_slots + cumulative_slots) + prefilling_mask = prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) prefix_lengths.extend(batch.prefix_lengths) all_input_ids.extend(batch.all_input_ids) @@ -937,6 +818,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slots=slots, max_postfix_length=max_postfix_length, max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -965,6 +848,193 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ), ) + def prepare_for_prefill(self): + # Prepare values if we need to continue prefilling + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert self.speculative_ids is None + + sliding_window = get_sliding_windows() + position_ids = [] + cu_seqlen_prefill = [0] + start_slots = [] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + slots = [] + adapter_indices_list = [] + adapter_set = set() + + for i, ( + r, + prefix_length, + postfix_length, + prompt_length, + request_prefilling, + blocks + ) in enumerate( + zip( + self.requests, + self.prefix_lengths, + self.postfix_lengths, + self.prompt_lengths, + self.prefilling_mask, + self.block_tables + ) + ): + next_chunk_length = postfix_length + # Position ids + request_position_ids = torch.arange( + prefix_length, prefix_length + postfix_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + postfix_length) + + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots + + request_slots = request_slots[prefix_length:] + request_slot_indices = torch.arange( + cumulative_slot_tokens, + cumulative_slot_tokens + postfix_length, + dtype=torch.int64, + ) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, postfix_length - sliding_window), + cumulative_length + postfix_length, + dtype=torch.int64, + ) + + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs + + if prefill_logprobs: + prefill_head_indices.append( + request_position_ids + cumulative_length + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + postfix_length - 1 + ) + prefill_cu_outlens.append( + prefill_out_cumulative_length + postfix_length + ) + prefill_out_cumulative_length += postfix_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + postfix_length - 1], + dtype=torch.int32, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + + start_slots.append(cumulative_slot_tokens) + slots.extend(request_slots) + slot_indices.append(request_slot_indices) + + if sliding_window is not None: + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append( + torch.full((next_chunk_length,), adapter_index) + ) + adapter_set.add(adapter_index) + + # Update + cumulative_length += next_chunk_length + cumulative_slot_tokens += len(request_slots) + + device = self.input_ids.device + self.start_slots = torch.tensor(start_slots, dtype=torch.int64) + + if len(self) > 1: + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + position_ids = position_ids[0] + slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] + + self.prefill_cu_outlens = prefill_cu_outlens + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + self.cu_seqlen_prefill = cu_seqlen_prefill + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + self.prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) + self.postfix_lengths_tensor = torch.tensor( + self.postfix_lengths, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + self.prefill_head_indices = prefill_head_indices + self.prefill_next_token_indices = prefill_next_token_indices + self.slots = torch.tensor(slots, dtype=torch.int64, device=device) + self.prefix_lengths_tensor = torch.tensor( + self.prefix_lengths, dtype=torch.int32, device=device + ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) + def __len__(self): return len(self.requests) @@ -1596,7 +1666,10 @@ def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: start = time.time_ns() - prefill = batch.cu_seqlen_prefill is not None + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) @@ -1650,6 +1723,7 @@ def generate_token( finished_prefilling = True next_chunk_lengths = [] if prefill: + next_prefilling_mask = [] # Budget in tokens for the next batch # We remove next input ids to always have enough space for at least a single decode # for the remaining requests @@ -1666,11 +1740,16 @@ def generate_token( ) batch_budget -= next_chunk_length finished_prefilling = False + next_prefilling_mask.append(True) else: # Since speculation will be turned off, this is always true next_chunk_length = 1 + next_prefilling_mask.append(False) next_chunk_lengths.append(next_chunk_length) + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask + # Turn off speculative if some requests are still prefilling # It makes the logic easier to follow if prefill and not finished_prefilling: @@ -1708,13 +1787,6 @@ def generate_token( elif not prefill: next_position_ids = batch.position_ids - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - stopped = True - # Zipped iterator iterator = zip( batch.prompt_lengths, @@ -1730,6 +1802,8 @@ def generate_token( # For each member of the batch index = 0 + # Cumulative length + cumulative_length = 0 for i, ( prompt_length, prefix_length, @@ -1822,242 +1896,51 @@ def generate_token( # Update values if we need to continue prefilling # This represents the `else` case of the `Update values` if above # but since this require the `next_token_ids` to be on CPU, it is better to do it here - skip_tokens = {} if prefill and not finished_prefilling: # Speculation must be ignored while we prefill even with chunking # it simplifies everything assert batch.speculative_ids is None all_postfix_ids = [] - sliding_window = get_sliding_windows() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - - slots = [] - adapter_indices_list = [] - for i, ( - r, + request_prefilling, next_token_id, all_input_ids, prefix_length, postfix_length, - prompt_length, next_chunk_length, ) in enumerate( zip( - batch.requests, + batch.prefilling_mask, next_token_ids, batch.all_input_ids, batch.prefix_lengths, batch.postfix_lengths, - batch.prompt_lengths, next_chunk_lengths, ) ): - continue_prefilling = prefix_length + postfix_length < prompt_length - if continue_prefilling: - skip_tokens[r.id] = True - # Update prefix length - prefix_length = prefix_length + postfix_length - batch.prefix_lengths[i] = prefix_length - - # Update postfix length - postfix_length = next_chunk_length - batch.max_postfix_length = max( - batch.max_postfix_length, postfix_length - ) - batch.postfix_lengths[i] = postfix_length - - # Potentially update max_current_length - current_length = prefix_length + postfix_length - batch.max_current_length = max( - batch.max_current_length, current_length - ) - + if request_prefilling: + next_prefix_length = prefix_length + postfix_length # Get new prompt IDs to prefill postfix_ids = all_input_ids[ - prefix_length : prefix_length + postfix_length + next_prefix_length : next_prefix_length + next_chunk_length ] - - # Position ids - request_position_ids = torch.arange( - prefix_length, prefix_length + postfix_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + postfix_length) - - request_slots = r.slots[prefix_length:] - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + postfix_length, - dtype=torch.int64, - ) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, postfix_length - sliding_window), - cumulative_length + postfix_length, - dtype=torch.int64, - ) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append( - request_position_ids + cumulative_length - ) - prefill_next_token_indices.append( - prefill_out_cumulative_length + postfix_length - 1 - ) - prefill_cu_outlens.append( - prefill_out_cumulative_length + postfix_length - ) - prefill_out_cumulative_length += postfix_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + postfix_length - 1], - dtype=torch.int32, - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 - else: # This request is done prefilling, the new id is the one selected the sampling method postfix_ids = [next_token_id] - # Position_ids - position_ids.append( - torch.tensor( - (prefix_length + postfix_length,), dtype=torch.int32 - ) - ) - - # Add this request token - cu_seqlen_prefill.append(cumulative_length + 1) - - request_slots = r.slots[prefix_length:] - request_slot_indices = torch.tensor( - (cumulative_slot_tokens + postfix_length,), dtype=torch.int64 - ) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.tensor( - [cumulative_length], dtype=torch.int64 - ) - - prefill_head_indices.append( - torch.tensor([cumulative_length], dtype=torch.int32) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 - all_postfix_ids.extend(postfix_ids) - start_slots.append(cumulative_slot_tokens) - slots.extend(request_slots) - slot_indices.append(request_slot_indices) - - if sliding_window is not None: - prefill_cache_indices.append(request_prefill_cache_indices) - - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append( - torch.full((next_chunk_length,), adapter_index) - ) - - # Update - cumulative_length += next_chunk_length - cumulative_slot_tokens += len(request_slots) - - device = batch.input_ids.device - batch.start_slots = torch.tensor(start_slots, dtype=torch.int64) - - if len(batch) > 1: - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - batch.cu_seqlen_prefill = cu_seqlen_prefill - batch.position_ids = position_ids.to(device) - batch.slot_indices = slot_indices.to(device) - batch.prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) - batch.input_ids = torch.tensor( - all_postfix_ids, dtype=torch.int64, device=device - ) - batch.postfix_lengths_tensor = torch.tensor( - batch.postfix_lengths, dtype=torch.int32, device=device - ) - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) - - batch.prefill_head_indices = prefill_head_indices - batch.prefill_next_token_indices = prefill_next_token_indices - batch.slots = torch.tensor(slots, dtype=torch.int64, device=device) - batch.prefix_lengths_tensor = torch.tensor( - batch.prefix_lengths, dtype=torch.int32, device=device - ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - batch.adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=batch.adapter_meta.adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, + batch.input_ids = batch.input_ids.new_tensor( + all_postfix_ids, dtype=torch.int64 ) start_decode = time.time_ns() + # Results + generations: List[Generation] = [] + stopped = True + # Zipped iterator iterator = zip( batch.requests, @@ -2072,11 +1955,14 @@ def generate_token( batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, + batch.prefilling_mask, accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) + # Reset max_postfix_length + batch.max_postfix_length = 0 # For each member of the batch index = 0 for i, ( @@ -2092,6 +1978,7 @@ def generate_token( do_sample, seed, top_n_tokens, + request_prefilling, n_accepted_ids, top_token_ids, top_token_logprobs, @@ -2139,134 +2026,133 @@ def generate_token( else: batch.prefill_tokens[i] = None - # Represent whether this request is still prefilling # If it is, the tokens we decoded should be ignored - skip_token = skip_tokens.get(request.id, False) - - if skip_token: + if request_prefilling: # Make sure that we do not stop as even though this request did not create a token, it is still # processing stopped = False - # Skip the rest of the decoding - # Values were updated before this for loop - continue - - # Append next token to all tokens - next_token_texts = [] - left = 0 - - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - index += n_accepted_ids - - # Shard generations - # All generations will be appended in the rust sharded client - if request.id % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( + new_postfix_length = next_chunk_lengths[i] + else: + new_postfix_length = n_accepted_ids + # Append next token to all tokens + next_token_texts = [] + left = 0 + + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") + + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, + prefix_offset, + read_offset, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + next_token_texts.append(next_token_text) + + stop, reason = stopping_criteria( + next_token_id, + next_token_text, ) - else: - generated_text = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, + + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + index += n_accepted_ids + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None - - generation = Generation( - request.id, - batch.prefill_tokens[i], - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) + else: + generated_text = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + batch.prefill_tokens[i], + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) - generations.append(generation) + generations.append(generation) - # accept each new token for this specific request since we may - # have more than one new token per request with speculative decoding - for next_token_id in _next_token_ids: - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) - ) + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single(i, next_token_id) + ) # Update values - current_postfix_length = postfix_length + n_accepted_ids + current_prefix_length = prefix_length + postfix_length + batch.prefix_lengths[i] = current_prefix_length + current_postfix_length = new_postfix_length batch.max_postfix_length = max( batch.max_postfix_length, current_postfix_length ) batch.postfix_lengths[i] = current_postfix_length - current_length = prefix_length + current_postfix_length + current_length = current_prefix_length + current_postfix_length batch.max_current_length = max(batch.max_current_length, current_length) + batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids diff --git a/server/text_generation_server/utils/segments.py b/server/text_generation_server/utils/segments.py index f5961102108..b3f923694e3 100644 --- a/server/text_generation_server/utils/segments.py +++ b/server/text_generation_server/utils/segments.py @@ -7,6 +7,7 @@ import torch +# FIXME: this should be optimized def find_segments( adapter_indices: Union[torch.Tensor, List[int]] ) -> Tuple[List[int], List[int]]: From 34f5dc525ea351d9f30b14b6169fde4ffb7c27ee Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 1 Oct 2024 09:51:34 +0200 Subject: [PATCH 09/29] working --- .../models/flash_causal_lm.py | 234 ++++++++++-------- 1 file changed, 137 insertions(+), 97 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2836dcc8c2e..8ee9d184cde 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -16,7 +16,7 @@ AutoTokenizer, GenerationConfig, ) -from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict +from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -119,7 +119,9 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: torch.Tensor + # Can be a list for easy filtering + # If `input_ids` is a list, it needs to be materialized to a tensor first + input_ids: Union[torch.Tensor, List[List[int]]] # Will be set by `generate_token` and reset after each prefill forward before staying set in decode position_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor] @@ -178,7 +180,7 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward before staying set in decode postfix_lengths_tensor: Optional[torch.Tensor] prefix_lengths_tensor: Optional[torch.Tensor] - prompt_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: torch.Tensor prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -350,12 +352,6 @@ def from_tokenized( all_input_ids_tensor, dtype=torch.int64, device=device ) - if len(pb.requests) > 1: - input_ids = np.concatenate(all_postfix_ids, dtype=np.int64) - else: - input_ids = all_postfix_ids[0] - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) @@ -366,12 +362,15 @@ def from_tokenized( for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device + ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, + input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, @@ -395,6 +394,7 @@ def from_tokenized( num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=None, + prompt_lengths_tensor=prompt_lengths_tensor, # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids=None, @@ -408,7 +408,6 @@ def from_tokenized( prefill_cu_outlens=None, prefix_lengths_tensor=None, postfix_lengths_tensor=None, - prompt_lengths_tensor=None, adapter_meta=None, ) @@ -455,6 +454,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": block_tables = [] all_input_ids = [] prefix_ids = [] + input_ids = [] prompt_lengths = [] postfix_lengths = [] @@ -473,7 +473,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": max_blocks = 0 # Cumulative length cumulative_max_length = 0 - prefilling=False for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -484,9 +483,13 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Prefilling request_prefilling = self.prefilling_mask[idx] - prefilling = prefilling or request_prefilling prefilling_mask.append(request_prefilling) + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + # Get length request_postfix_length = self.postfix_lengths[idx] request_prefix_length = self.prefix_lengths[idx] @@ -538,32 +541,48 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": max_blocks = max(max_blocks, len(request_block_table)) - # Index into tensors - input_ids = self.input_ids[indices] - position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - postfix_lengths_tensor = self.postfix_lengths_tensor[indices] - prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] - prefix_lengths_tensor = self.prefix_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - start_slots = torch.tensor(start_slots, dtype=torch.int64) - - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + if self.prefilling: + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None + start_slots=None + slot_indices=None + slots=None + prefix_lengths_tensor=None + postfix_lengths_tensor=None + adapter_meta=None + else: + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] + postfix_lengths_tensor = self.postfix_lengths_tensor[indices] + slots = self.slots[slot_filtering_indices] + prefix_lengths_tensor = self.prefix_lengths_tensor[indices] + + start_slots = torch.tensor(start_slots, dtype=torch.int64) + + # Move to GPU now that we have the whole tensor + slot_indices = slot_indices.to(device) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return type(self)( batch_id=self.batch_id, @@ -580,7 +599,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": slots=slots, max_postfix_length=max_postfix_length, max_current_length=max_current_length, - prefilling=prefilling, + prefilling=self.prefilling, prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, @@ -604,12 +623,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, ) @classmethod @@ -652,38 +666,51 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) prefilling = prefilling or b.prefilling - input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) - slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + if prefilling: + input_ids = [] + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None + start_slots=None + slots=None + slot_indices=None + prefix_lengths_tensor=None + postfix_lengths_tensor=None + adapter_meta=None + adapter_segment_builder=None + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) + position_ids = batches[0].position_ids.new_empty(total_batch_size) + start_slots = [] + slots = batches[0].slots.new_empty(total_slots) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( + total_batch_size + ) + prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( + total_batch_size + ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size ) - postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( - total_batch_size - ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( - total_batch_size - ) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_set = set() - adapter_segment_builder = SegmentConcatBuilder() - start_slots = [] block_tables = [] prefix_lengths = [] all_input_ids = [] @@ -723,29 +750,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots[slots_start_index:slots_end_index] = batch.slots - - # Copy over adapter indices - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices - ) - all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -753,12 +758,38 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor + if not prefilling: + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor + slots[slots_start_index:slots_end_index] = batch.slots + + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) + prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor - start_slots.append(batch.start_slots + cumulative_slots) + start_slots.append(batch.start_slots + cumulative_slots) + else: + if isinstance(batch.input_ids, torch.Tensor): + batch.input_ids = batch.input_ids.view(-1, 1).tolist() + input_ids.extend(batch.input_ids) - prefilling_mask = prefilling_mask.extend(batch.prefilling_mask) + prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) prefix_lengths.extend(batch.prefix_lengths) all_input_ids.extend(batch.all_input_ids) @@ -781,7 +812,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) - start_slots = torch.concat(start_slots) + if start_slots is not None: + start_slots = torch.concat(start_slots) # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() @@ -799,7 +831,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch else None ) - adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + if adapter_segment_builder is not None: + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return cls( batch_id=batches[0].batch_id, @@ -840,12 +879,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, ) def prepare_for_prefill(self): @@ -973,9 +1007,16 @@ def prepare_for_prefill(self): cumulative_length += next_chunk_length cumulative_slot_tokens += len(request_slots) - device = self.input_ids.device + device = self.block_tables_tensor.device self.start_slots = torch.tensor(start_slots, dtype=torch.int64) + if isinstance(self.input_ids, list): + if len(self) > 1: + input_ids = np.concatenate(self.input_ids, dtype=np.int64) + else: + input_ids = self.input_ids[0] + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + if len(self) > 1: position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) @@ -1865,7 +1906,8 @@ def generate_token( batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids - batch.postfix_lengths_tensor += accepted_ids + batch.prefix_lengths_tensor += batch.postfix_lengths_tensor + batch.postfix_lengths_tensor = accepted_ids batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices @@ -1929,11 +1971,9 @@ def generate_token( # This request is done prefilling, the new id is the one selected the sampling method postfix_ids = [next_token_id] - all_postfix_ids.extend(postfix_ids) + all_postfix_ids.append(postfix_ids) - batch.input_ids = batch.input_ids.new_tensor( - all_postfix_ids, dtype=torch.int64 - ) + batch.input_ids = all_postfix_ids start_decode = time.time_ns() @@ -2014,7 +2054,7 @@ def generate_token( ) prefill_tokens = Tokens( - prefix_ids + prefill_token_ids, + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], From 7f9abde3f8a769acf1fb61c824b807e13bde80c1 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 2 Oct 2024 12:59:44 +0200 Subject: [PATCH 10/29] load tested --- backends/client/src/v3/client.rs | 1 + backends/client/src/v3/sharded_client.rs | 1 + backends/v3/src/backend.rs | 53 +++-- backends/v3/src/client/grpc_client.rs | 1 + backends/v3/src/client/mod.rs | 9 - backends/v3/src/client/sharded_client.rs | 21 +- backends/v3/src/lib.rs | 4 + backends/v3/src/main.rs | 28 +-- backends/v3/src/queue.rs | 141 +++++++------- benchmark/src/generation.rs | 1 + proto/v3/generate.proto | 5 + server/tests/conftest.py | 2 +- .../models/causal_lm.py | 1 + .../models/flash_causal_lm.py | 181 +++++++++++------- .../text_generation_server/models/globals.py | 2 +- .../models/idefics_causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + server/text_generation_server/models/model.py | 16 ++ .../models/seq2seq_lm.py | 1 + .../models/vlm_causal_lm.py | 3 +- server/text_generation_server/server.py | 3 + .../utils/prefill_chunking.py | 24 +++ 22 files changed, 306 insertions(+), 194 deletions(-) create mode 100644 server/text_generation_server/utils/prefill_chunking.py diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 479d31bf290..61d1ea1b595 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -159,6 +159,7 @@ impl Client { blocks: vec![], slots: vec![], prefix_len: 0, + postfix_len: truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 645c076a26b..8872f8bdf8f 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -246,6 +246,7 @@ impl Health for ShardedClient { blocks: vec![0], slots: (0..16).collect(), prefix_len: 0, + postfix_len: 1, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 77fdb0419bc..bfe7932fc99 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -34,9 +34,13 @@ impl BackendV3 { requires_padding: bool, window_size: Option, speculate: u32, + support_chunking: bool, ) -> Self { - let prefix_caching = - std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); + if support_chunking { + tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored."); + } + + let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string()); @@ -52,6 +56,7 @@ impl BackendV3 { window_size, speculate, max_batch_total_tokens, + support_chunking, ); let batching_task_notifier = Arc::new(Notify::new()); @@ -63,6 +68,7 @@ impl BackendV3 { max_batch_total_tokens, max_waiting_tokens, max_batch_size, + support_chunking, queue.clone(), batching_task_notifier.clone(), )); @@ -127,6 +133,7 @@ pub(crate) async fn batching_task( max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, + support_chunking: bool, queue: Queue, notifier: Arc, ) { @@ -158,28 +165,44 @@ pub(crate) async fn batching_task( // Get current batch info let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; + let current_tokens = batch.current_tokens; let mut batches = vec![batch]; metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); - let min_size = if waiting_tokens >= max_waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + + let (min_size, max_size, prefill_token_budget) = if support_chunking { + // Since the next batch will be concatenated with the current batch, + // the current batch tokens must be subtracted to the prefill budget + // In the future, we could concatenate beforehand + let prefill_token_budget = max_batch_prefill_tokens - current_tokens; + // We can ignore min_size and max_size + // Models than rely on max_size cannot support chunking + // Regarding min_size, chunking allow us to consistently run at the compute + // bound, making min_size useless. + (None, None, prefill_token_budget) } else { - // Minimum batch size - // TODO: temporarily disable to avoid incorrect deallocation + - // reallocation when using prefix caching. - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - }; + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; - let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = - max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + + (min_size, max_size, max_batch_prefill_tokens) + }; // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .next_batch(min_size, max_size, prefill_token_budget, token_budget) .await { // Tracking metrics diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 648662db39b..3b4432a7cdc 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -159,6 +159,7 @@ impl Client { blocks: vec![], slots: vec![], prefix_len: 0, + postfix_len: truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs index 755431f4633..d4ac50c9c46 100644 --- a/backends/v3/src/client/mod.rs +++ b/backends/v3/src/client/mod.rs @@ -29,15 +29,6 @@ pub trait Health { async fn model_health(&self) -> Result<()>; } -#[derive(Debug)] -pub struct ShardInfo { - pub requires_padding: bool, - pub dtype: String, - pub device_type: String, - pub window_size: Option, - pub speculate: u32, -} - #[derive(Error, Debug, Clone)] pub enum ClientError { #[error("Could not connect to Text Generation server: {0}")] diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index ea77a696648..97a1eab6dab 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -1,6 +1,6 @@ -use crate::client::{ClientError, Result}; +use crate::client::Health; /// Multi shard Client -use crate::client::{Health, ShardInfo}; +use crate::client::{ClientError, Result}; use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; use crate::client::{ @@ -49,13 +49,13 @@ impl ShardedClient { /// Get the model info #[instrument(skip(self))] - pub async fn info(&mut self) -> Result { + pub async fn info(&mut self) -> Result { let futures: Vec<_> = self .clients .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap().map(ShardInfo::from) + join_all(futures).await.pop().unwrap() } /// GRPC health check @@ -194,18 +194,6 @@ impl ShardedClient { } } -impl From for ShardInfo { - fn from(value: InfoResponse) -> Self { - Self { - requires_padding: value.requires_padding, - dtype: value.dtype, - device_type: value.device_type, - window_size: value.window_size, - speculate: value.speculate, - } - } -} - #[async_trait] impl Health for ShardedClient { async fn device_health(&self) -> Result<()> { @@ -248,6 +236,7 @@ impl Health for ShardedClient { slots: (0..16).collect(), prefix_len: 0, adapter_id: None, + postfix_len: 1, }; let batch = Batch { id: u64::MAX, diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index af66b21eb22..0a7ef2239f4 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -29,6 +29,8 @@ pub struct BackendInfo { pub max_waiting_tokens: usize, #[schema(nullable = true, example = "null")] pub max_batch_size: Option, + #[schema(example = "false")] + pub support_chunking: bool, } #[allow(clippy::too_many_arguments)] @@ -110,6 +112,7 @@ pub async fn connect_backend( model_device_type: shard_info.device_type.clone(), model_dtype: shard_info.dtype.clone(), speculate: shard_info.speculate as usize, + support_chunking: shard_info.support_chunking, }; let backend = BackendV3::new( @@ -122,6 +125,7 @@ pub async fn connect_backend( shard_info.requires_padding, shard_info.window_size, shard_info.speculate, + shard_info.support_chunking, ); tracing::info!("Using backend V3"); diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 471ddb5a70a..b4751bd5373 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> { "`max_input_tokens` must be < `max_total_tokens`".to_string(), )); } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); - } if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), )); } - - if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } - } - if let Some(max_batch_size) = max_batch_size { if max_batch_size == 0 { return Err(RouterError::ArgumentValidation( @@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> { } } - let (backend, _backend_info) = connect_backend( + let (backend, backend_info) = connect_backend( max_input_tokens, max_total_tokens, master_shard_uds_path, @@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> { ) .await?; + // Validate remaining args now that the backend is known + let support_chunking = backend_info.support_chunking; + let max_batch_total_tokens = backend_info.max_batch_total_tokens; + if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + if max_batch_prefill_tokens > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + // Run server server::run( backend, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index f8123b57aa2..7db0aba3be8 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -4,7 +4,7 @@ use crate::client::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; +use std::cmp::max; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; @@ -50,6 +50,7 @@ impl Queue { window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -62,6 +63,7 @@ impl Queue { window_size, speculate, max_batch_total_tokens, + support_chunking, queue_receiver, )); @@ -108,6 +110,7 @@ impl Queue { } // Background task responsible of the queue state +#[allow(clippy::too_many_arguments)] async fn queue_task( requires_padding: bool, block_size: u32, @@ -115,6 +118,7 @@ async fn queue_task( window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, mut receiver: mpsc::UnboundedReceiver, ) { let mut state = State::new( @@ -124,6 +128,7 @@ async fn queue_task( window_size, speculate, max_batch_total_tokens, + support_chunking, ); while let Some(cmd) = receiver.recv().await { @@ -166,12 +171,14 @@ struct State { /// Paged Attention block size block_size: u32, - /// Sliding window - window_size: Option, - /// Speculation amount speculate: u32, + /// Whether the model allow the prefill chunking + /// If it does, the last request in the batch will be split to exactly match the prefill + /// token budget + support_chunking: bool, + /// Paged Attention Block Allocation block_allocator: Option, } @@ -184,6 +191,7 @@ impl State { window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, ) -> Self { let block_allocator = (!requires_padding).then(|| { BlockAllocator::new( @@ -199,8 +207,8 @@ impl State { next_id: 0, next_batch_id: 0, block_size, - window_size, speculate, + support_chunking, block_allocator, } } @@ -268,7 +276,7 @@ impl State { continue; } - let block_allocation = match &self.block_allocator { + let (block_allocation, postfix_len) = match &self.block_allocator { None => { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation @@ -285,34 +293,9 @@ impl State { self.entries.push_front((id, entry)); break 'entry_loop; } - None + (None, entry.request.input_length) } - Some(_block_allocator) => { - prefill_tokens += entry.request.input_length; - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; - decode_tokens += max_new_tokens; - - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break; - } - - let tokens = entry.request.input_length - + entry.request.stopping_parameters.max_new_tokens - + self.speculate - - 1; - + Some(block_allocator) => { // If users wants the prefill logprobs, we cannot reuse the cache. // So no input_ids for the radix tree. let input_ids = if entry.request.decoder_input_details { @@ -321,10 +304,65 @@ impl State { entry.request.input_ids.clone() }; - Some((tokens, input_ids)) + let tokens = entry.request.input_length + + entry.request.stopping_parameters.max_new_tokens + + self.speculate + - 1; + tracing::debug!("Allocating {tokens} with {input_ids:?}"); + + let block_allocation = match block_allocator.allocate(tokens, input_ids).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + Some(mut block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + + if block_allocation.prefix_len == entry.request.input_length { + // The whole request was found in the radix trie + // However, for the transformer forward to work, we need to + // have at least one token of postfix. + block_allocation.prefix_len -= 1; + } + + block_allocation + } + }; + + let mut postfix_len = entry.request.input_length - block_allocation.prefix_len; + + // Check equality too as if we don't we might end up with a postfix_len = 0 + // in the next iteration of the loop + if prefill_tokens + postfix_len >= prefill_token_budget { + // Entry is over budget + if self.support_chunking { + // We support chunking, just set postfix_len to exactly match prefill_token_budget + postfix_len = prefill_token_budget - prefill_tokens; + // Push this entry inside the batch + batch.push((id, entry, Some(block_allocation), postfix_len)); + break 'entry_loop; + } else { + // We don't support chunking, this entry needs to go back to the buffer + // Add it back to the front + tracing::debug!( + "Over budget: prefill_tokens={} > {prefill_token_budget}", + prefill_tokens + postfix_len + ); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + } + + prefill_tokens += postfix_len; + + (Some(block_allocation), postfix_len) } }; - batch.push((id, entry, block_allocation)); + batch.push((id, entry, block_allocation, postfix_len)); if Some(batch.len()) == max_size { break; } @@ -342,7 +380,7 @@ impl State { // Batch is too small if batch.len() < min_size { // Add back entries to the queue in the correct order - for (id, entry, _) in batch.into_iter().rev() { + for (id, entry, _, _) in batch.into_iter().rev() { self.entries.push_front((id, entry)); } return None; @@ -353,29 +391,7 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - for (id, mut entry, block_allocation) in batch { - let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = - (block_allocation, &self.block_allocator) - { - tracing::debug!("Allocating {tokens} with {input_ids:?}"); - match block_allocator.allocate(tokens, input_ids).await { - None => { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: not enough free blocks"); - self.entries.push_front((id, entry)); - continue; - } - Some(block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) - } - } - } else { - None - }; - tracing::debug!("Accepting entry"); + for (id, mut entry, block_allocation, postfix_len) in batch { // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships @@ -429,6 +445,7 @@ impl State { slots, prefix_len, adapter_id: entry.request.adapter_id.clone(), + postfix_len, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -436,12 +453,6 @@ impl State { batch_entries.insert(id, entry); } - // Empty batch - if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); - return None; - } - // Final batch size let size = batch_requests.len() as u32; next_batch_span.record("batch_size", size); diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 789c7b514fc..fff221ef582 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -159,6 +159,7 @@ async fn prefill( blocks: vec![], slots: vec![], prefix_len: 0, + postfix_len: sequence_length, adapter_id: None, }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 34894bdaba4..cfb92ba8fa5 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -34,6 +34,7 @@ message InfoResponse { string device_type = 3; optional uint32 window_size = 4; uint32 speculate = 5; + bool support_chunking = 6; } /// Empty request @@ -139,6 +140,8 @@ message Request { uint32 prefix_len = 12; /// Context truncation bool add_special_tokens = 13; + /// Postfix length for prefill chunking + uint32 postfix_len = 14; } message Batch { @@ -163,6 +166,8 @@ message CachedBatch { uint32 size = 3; /// Maximum number of tokens this batch will grow to uint32 max_tokens = 4; + /// Number of tokens in the next forward + uint32 current_tokens = 5; } enum FinishReason { diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 1efeba5864e..b1a30e02002 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,7 +1,7 @@ import pytest -import os from text_generation_server.pb import generate_pb2 + @pytest.fixture def default_pb_parameters(): return generate_pb2.NextTokenChooserParameters( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 28534d0f73b..1378f59055c 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -76,6 +76,7 @@ def to_pb(self) -> generate_pb2.CachedBatch: request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8ee9d184cde..b39fe0ff6f1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -16,7 +16,17 @@ AutoTokenizer, GenerationConfig, ) -from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union +from typing import ( + Any, + ContextManager, + Iterable, + Optional, + Tuple, + List, + Type, + Dict, + Union, +) from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -24,6 +34,10 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.log import log_master +from text_generation_server.utils.prefill_chunking import ( + get_support_chunking, + get_max_prefill_tokens, +) from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( @@ -60,12 +74,9 @@ tracer = trace.get_tracer(__name__) - # Will be set in init SLIDING_WINDOW: Optional[int] = None -TOKEN_BUDGET = 8 - def set_sliding_window(sliding_window: int): global SLIDING_WINDOW @@ -206,6 +217,11 @@ def to_pb(self) -> generate_pb2.CachedBatch: request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.num_blocks * BLOCK_SIZE, + current_tokens=( + sum([len(i) for i in self.input_ids]) + if isinstance(self.input_ids, list) + else len(self.input_ids) + ), ) @classmethod @@ -272,7 +288,7 @@ def from_tokenized( prompt_lengths.append(prompt_length) prefix_length = r.prefix_len - postfix_length = prefix_length + 10 + postfix_length = r.postfix_len assert ( prefix_length <= prompt_length ), f"Prefix {prefix_length} vs input {prompt_length}" @@ -282,10 +298,13 @@ def from_tokenized( if prefix_length + postfix_length < prompt_length: # FIXME: speculate is not supported for context chunking at the moment assert speculate == 0 + assert get_support_chunking() + assert postfix_length > 0 prefix_ids.append(tokenized_input[:prefix_length]) - postfix_ids = tokenized_input[prefix_length : postfix_length] - # postfix_ids = tokenized_input[prefix_length:] + postfix_ids = tokenized_input[ + prefix_length : prefix_length + postfix_length + ] postfix_length = len(postfix_ids) postfix_lengths.append(postfix_length) @@ -371,7 +390,6 @@ def from_tokenized( requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=all_postfix_ids, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, prefix_lengths=prefix_lengths, @@ -395,7 +413,6 @@ def from_tokenized( max_blocks=max_blocks, speculative_ids=None, prompt_lengths_tensor=prompt_lengths_tensor, - # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids=None, cu_seqlen_prefill=None, @@ -431,7 +448,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if len(request_ids) == len(self): return self - device = self.input_ids.device + device = self.block_tables_tensor.device # New values after filtering requests_idx_mapping = {} @@ -552,13 +569,13 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` - position_ids=None - start_slots=None - slot_indices=None - slots=None - prefix_lengths_tensor=None - postfix_lengths_tensor=None - adapter_meta=None + position_ids = None + start_slots = None + slot_indices = None + slots = None + prefix_lengths_tensor = None + postfix_lengths_tensor = None + adapter_meta = None else: # Index into tensors input_ids = self.input_ids[indices] @@ -643,24 +660,24 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max_current_length = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots) + max_blocks = max(max_blocks, b.max_blocks) + # If `b` is prefilling and was just filtered, `b.slots` is None + # `total_slots` is not used if any of the batches is prefilling + total_slots += len(b.slots) if not b.prefilling else 0 num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) - max_blocks = max(max_blocks, b.max_blocks) max_postfix_length = max(max_postfix_length, b.max_postfix_length) max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( - prefix_length - + postfix_length + prompt_length + stopping_criteria.max_new_tokens + speculative_length - - stopping_criteria.current_tokens - for prefix_length, postfix_length, stopping_criteria in zip( - b.prefix_lengths, b.postfix_lengths, b.stopping_criterias + for prompt_length, stopping_criteria in zip( + b.prompt_lengths, b.stopping_criterias ) ), ) @@ -669,14 +686,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch if prefilling: input_ids = [] # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` - position_ids=None - start_slots=None - slots=None - slot_indices=None - prefix_lengths_tensor=None - postfix_lengths_tensor=None - adapter_meta=None - adapter_segment_builder=None + position_ids = None + start_slots = None + slots = None + slot_indices = None + prefix_lengths_tensor = None + postfix_lengths_tensor = None + adapter_meta = None + adapter_segment_builder = None else: input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) @@ -746,8 +763,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor @@ -761,10 +776,17 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor if not prefilling: + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor + slot_indices[start_index:end_index] = ( + batch.slot_indices + cumulative_slots + ) + postfix_lengths_tensor[start_index:end_index] = ( + batch.postfix_lengths_tensor + ) slots[slots_start_index:slots_end_index] = batch.slots # Copy over adapter indices @@ -779,11 +801,17 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch cumulative_adapter_indices_size = adapter_end_index adapter_set.update(batch.adapter_meta.adapter_set) adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) + prefix_lengths_tensor[start_index:end_index] = ( + batch.prefix_lengths_tensor ) - prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor start_slots.append(batch.start_slots + cumulative_slots) + + # Update + cumulative_slots += len(batch.slots) else: if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() @@ -810,7 +838,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots) if start_slots is not None: start_slots = torch.concat(start_slots) @@ -915,7 +942,7 @@ def prepare_for_prefill(self): postfix_length, prompt_length, request_prefilling, - blocks + blocks, ) in enumerate( zip( self.requests, @@ -923,7 +950,7 @@ def prepare_for_prefill(self): self.postfix_lengths, self.prompt_lengths, self.prefilling_mask, - self.block_tables + self.block_tables, ) ): next_chunk_length = postfix_length @@ -967,9 +994,7 @@ def prepare_for_prefill(self): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append( - request_position_ids + cumulative_length - ) + prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + postfix_length - 1 ) @@ -988,7 +1013,6 @@ def prepare_for_prefill(self): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - start_slots.append(cumulative_slot_tokens) slots.extend(request_slots) slot_indices.append(request_slot_indices) @@ -998,9 +1022,7 @@ def prepare_for_prefill(self): ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append( - torch.full((next_chunk_length,), adapter_index) - ) + adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index)) adapter_set.add(adapter_index) # Update @@ -1240,6 +1262,7 @@ def __init__( rank=rank, world_size=world_size, sliding_window=config.sliding_window, + support_chunking=True, ) @property @@ -1764,29 +1787,43 @@ def generate_token( finished_prefilling = True next_chunk_lengths = [] if prefill: - next_prefilling_mask = [] - # Budget in tokens for the next batch - # We remove next input ids to always have enough space for at least a single decode - # for the remaining requests - batch_budget = TOKEN_BUDGET - len(batch) - for prefix_length, postfix_length, prompt_length in zip( - batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths - ): - remaining_prefill_tokens = max( - prompt_length - prefix_length - postfix_length, 0 - ) - if remaining_prefill_tokens > 0: - next_chunk_length = max( - min(remaining_prefill_tokens, batch_budget), 1 + if get_support_chunking(): + next_prefilling_mask = [] + # Budget in tokens for the next batch + # We remove len(batch) to always have enough space for at least a single decode + # for the remaining requests + batch_budget = get_max_prefill_tokens() - len(batch) + # We reverse to prioritize older requests + # zip() is not reversible so reverse the underlying lists instead + for prefix_length, postfix_length, prompt_length in zip( + reversed(batch.prefix_lengths), + reversed(batch.postfix_lengths), + reversed(batch.prompt_lengths), + ): + remaining_prefill_tokens = max( + prompt_length - prefix_length - postfix_length, 0 ) - batch_budget -= next_chunk_length - finished_prefilling = False - next_prefilling_mask.append(True) - else: - # Since speculation will be turned off, this is always true - next_chunk_length = 1 - next_prefilling_mask.append(False) - next_chunk_lengths.append(next_chunk_length) + if remaining_prefill_tokens > 0: + next_chunk_length = max( + min(remaining_prefill_tokens, batch_budget), 1 + ) + batch_budget -= next_chunk_length + finished_prefilling = False + next_prefilling_mask.append(True) + else: + # Since speculation will be turned off, this is always true + next_chunk_length = 1 + next_prefilling_mask.append(False) + next_chunk_lengths.append(next_chunk_length) + + # Reverse back the obtained values² + next_chunk_lengths.reverse() + next_prefilling_mask.reverse() + else: + # The model does not support chunking + # We know we only do a single prefill + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) batch.prefilling = not finished_prefilling batch.prefilling_mask = next_prefilling_mask @@ -2179,7 +2216,9 @@ def generate_token( # have more than one new token per request with speculative decoding for next_token_id in _next_token_ids: batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) + batch.next_token_chooser.advance_grammar_single( + i, next_token_id + ) ) # Update values diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 1830dc42f67..6bf8d3ffcea 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -18,7 +18,7 @@ raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None -TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) +TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 9a7a6fe156a..34b74ba8856 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -83,6 +83,7 @@ def to_pb(self) -> generate_pb2.CachedBatch: request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index f6dcde68ab6..dfc61fb8875 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -116,6 +116,7 @@ def to_pb(self) -> generate_pb2.CachedBatch: request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 20402e07afa..02f3dbf9329 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -5,8 +5,11 @@ from typing import List, Tuple, Optional, TypeVar, Type, Dict from collections import defaultdict from transformers import PreTrainedTokenizerBase +from loguru import logger from text_generation_server.models.types import Batch, Generation +from text_generation_server.utils.log import log_master +from text_generation_server.utils.prefill_chunking import set_support_chunking from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights @@ -31,6 +34,7 @@ def __init__( sliding_window: Optional[int] = None, speculate: Optional[int] = None, adapter_id: str = BASE_MODEL_ADAPTER_ID, + support_chunking: bool = False, ): self.model_id = model_id self.model = model.eval() @@ -60,6 +64,17 @@ def __init__( speculate = get_speculate() self.speculate = speculate + if speculate != 0 and support_chunking: + log_master( + logger.warning, + "Prefill chunking does not support speculation yet. " + "Prefill chunking will be turned off", + ) + support_chunking = False + + self.support_chunking = support_chunking + set_support_chunking(support_chunking) + self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) is not None @@ -78,6 +93,7 @@ def info(self) -> InfoResponse: device_type=self.device.type, window_size=self.sliding_window, speculate=self.speculate, + support_chunking=self.support_chunking, ) @property diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 04d4c28ba3e..e2d7aa4d5ba 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -80,6 +80,7 @@ def to_pb(self) -> generate_pb2.CachedBatch: request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 937811d7d44..1a578d7b1d9 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -357,7 +357,6 @@ def forward( else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = postfix_lengths + prefix_lengths_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -424,7 +423,7 @@ def forward( cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths cuda_graph["prefix_lengths"].zero_() cuda_graph["prefix_lengths"][ - : prefix_lengths_tensor.shape[0] + : prefix_lengths_tensor.shape[0] ] = prefix_lengths_tensor with self._forward_context( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 46e342a4fbf..bd4b3a535f9 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -15,6 +15,7 @@ from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.utils.adapter import AdapterInfo +from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -96,6 +97,8 @@ async def FilterBatch(self, request, context): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): + set_max_prefill_tokens(request.max_prefill_tokens) + if self.quantize in {"exl2", "gptq"}: try: # When using GPTQ, Exllama kernels need some global kernels diff --git a/server/text_generation_server/utils/prefill_chunking.py b/server/text_generation_server/utils/prefill_chunking.py new file mode 100644 index 00000000000..c227d30f512 --- /dev/null +++ b/server/text_generation_server/utils/prefill_chunking.py @@ -0,0 +1,24 @@ +from typing import Optional + +SUPPORT_CHUNKING: Optional[bool] = None +MAX_PREFILL_TOKENS: Optional[int] = None + + +def set_support_chunking(support_chunking: bool): + global SUPPORT_CHUNKING + SUPPORT_CHUNKING = support_chunking + + +def get_support_chunking() -> bool: + global SUPPORT_CHUNKING + return SUPPORT_CHUNKING + + +def set_max_prefill_tokens(max_prefill_tokens: int): + global MAX_PREFILL_TOKENS + MAX_PREFILL_TOKENS = max_prefill_tokens + + +def get_max_prefill_tokens() -> int: + global MAX_PREFILL_TOKENS + return MAX_PREFILL_TOKENS From 4db5e7dde6a91f0e8182011368afaa05798885ed Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:10:33 +0200 Subject: [PATCH 11/29] re-create slots --- .../models/flash_causal_lm.py | 64 +++++++++++++------ 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b39fe0ff6f1..cf2b6ea7e36 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -462,12 +462,12 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": ) # Create on CPU to only move to GPU once instead of at every copy - slot_indices = torch.empty(len(request_ids), dtype=torch.int64) + # slot_indices = torch.empty(len(request_ids), dtype=torch.int64) max_postfix_length = 0 max_current_length = 0 requests = [] - start_slots = [] + # start_slots = [] block_tables = [] all_input_ids = [] prefix_ids = [] @@ -491,12 +491,18 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Cumulative length cumulative_max_length = 0 + start_slots = [] + slots = [] + slot_indices = [] + cumulative_slot_tokens = 0 + for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i - requests.append(self.requests[idx]) + request = self.requests[idx] + requests.append(request) # Prefilling request_prefilling = self.prefilling_mask[idx] @@ -508,6 +514,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": input_ids.append(self.input_ids[idx]) # Get length + request_prompt_length = self.prompt_lengths[idx] request_postfix_length = self.postfix_lengths[idx] request_prefix_length = self.prefix_lengths[idx] max_postfix_length = max(max_postfix_length, request_postfix_length) @@ -518,7 +525,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) - prompt_lengths.append(self.prompt_lengths[idx]) + prompt_lengths.append(request_prompt_length) postfix_lengths.append(request_postfix_length) prefix_lengths.append(request_prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) @@ -534,27 +541,45 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) + # remaining_tokens = ( + # stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + # ) request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) + # start_slots.append(cumulative_max_length) # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_postfix_length - 1 + # slot_indices[i] = cumulative_max_length + request_postfix_length - 1 # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_postfix_length - + remaining_tokens - - 1 - ] = True + #FIXME + # slot_filtering_indices[ + # self.start_slots[idx] : self.start_slots[idx] + # + request_postfix_length + # + remaining_tokens + # - 1 + # ] = True + + if not self.prefilling: + if not request.slots: + request_slots = [ + s + for b in request_block_table + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = request.slots - cumulative_max_length += request_postfix_length + remaining_tokens - 1 + request_slots = request_slots[request_prefix_length:] + start_slots.append(cumulative_slot_tokens) + slots.extend(request_slots) + slot_indices.append(cumulative_slot_tokens) + + cumulative_slot_tokens += len(request_slots) + + # cumulative_max_length += request_postfix_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) @@ -577,18 +602,21 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": postfix_lengths_tensor = None adapter_meta = None else: + slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) + slots = torch.tensor(slots, dtype=torch.int64, device=device) + # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] + # slots = self.slots[slot_filtering_indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices] start_slots = torch.tensor(start_slots, dtype=torch.int64) # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + # slot_indices = slot_indices.to(device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor( From b49978ff67cb171223562efa98f9a717aa215859 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:17:26 +0200 Subject: [PATCH 12/29] re-create slots --- backends/client/src/v3/client.rs | 6 ++- backends/v3/src/backend.rs | 60 ++++++++++++++---------- backends/v3/src/client/grpc_client.rs | 3 +- backends/v3/src/client/sharded_client.rs | 3 +- proto/v3/generate.proto | 2 + server/text_generation_server/server.py | 10 ++++ 6 files changed, 57 insertions(+), 27 deletions(-) diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 61d1ea1b595..5191f8dd7ec 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -237,7 +237,11 @@ impl Client { &mut self, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let request = tonic::Request::new(DecodeRequest { + batch: None, + batches, + }) + .inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok(( response.generations, diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index bfe7932fc99..183f4e5252e 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -200,6 +200,8 @@ pub(crate) async fn batching_task( (min_size, max_size, max_batch_prefill_tokens) }; + let mut additional_batch = None; + // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue .next_batch(min_size, max_size, prefill_token_budget, token_budget) @@ -210,31 +212,40 @@ pub(crate) async fn batching_task( metrics::counter!("tgi_batch_concat", "reason" => "backpressure") .increment(1); } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") - .increment(1); + let counter = if support_chunking { + metrics::counter!("tgi_batch_concat", "reason" => "chunking") + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + }; + counter.increment(1); } - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); - - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) - .instrument(span) - .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { + if support_chunking { entries.extend(new_entries); - batches.push(new_cached_batch); + additional_batch = Some(new_batch); + } else { + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } } } @@ -252,7 +263,7 @@ pub(crate) async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries) + cached_batch = decode(&mut client, additional_batch, batches, &mut entries) .instrument(next_batch_span) .await; waiting_tokens += 1; @@ -306,6 +317,7 @@ async fn prefill( #[instrument(skip_all)] async fn decode( client: &mut ShardedClient, + batch: Option, batches: Vec, entries: &mut IntMap, ) -> Option { @@ -313,7 +325,7 @@ async fn decode( let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); - match client.decode(batches).await { + match client.decode(batch, batches).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 3b4432a7cdc..ab93db9b488 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -235,9 +235,10 @@ impl Client { #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] pub async fn decode( &mut self, + batch: Option, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok(( response.generations, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 97a1eab6dab..8af4b26fee3 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -167,12 +167,13 @@ impl ShardedClient { #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] pub async fn decode( &mut self, + batch: Option, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.decode(batches.clone()))) + .map(|client| Box::pin(client.decode(batch.clone(), batches.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, DecodeTimings)>> = diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index cfb92ba8fa5..15a93ac950e 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -243,6 +243,8 @@ message PrefillResponse { message DecodeRequest { /// Cached batches repeated CachedBatch batches = 1; + /// Optional Batch + optional Batch batch = 2; } message DecodeResponse { diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index bd4b3a535f9..d89df966ec8 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -179,6 +179,16 @@ async def Decode(self, request, context): if len(batches) == 0: raise ValueError("All batches are empty") + if self.model.support_chunking: + if request.HasField("batch"): + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.dtype, + self.model.device, + ) + batches.append(batch) + if len(batches) > 1: start_concat = time.time_ns() batch = self.model.batch_type.concatenate(batches) From ff4155dfea4e194e3db8c53839df5542fb181dae Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 2 Oct 2024 19:16:36 +0200 Subject: [PATCH 13/29] fix slot_filtering_indices --- backends/v3/src/backend.rs | 1 - .../models/flash_causal_lm.py | 101 +++++------------- 2 files changed, 26 insertions(+), 76 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 183f4e5252e..84152ff8030 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -175,7 +175,6 @@ pub(crate) async fn batching_task( let (min_size, max_size, prefill_token_budget) = if support_chunking { // Since the next batch will be concatenated with the current batch, // the current batch tokens must be subtracted to the prefill budget - // In the future, we could concatenate beforehand let prefill_token_budget = max_batch_prefill_tokens - current_tokens; // We can ignore min_size and max_size // Models than rely on max_size cannot support chunking diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cf2b6ea7e36..b283a5fb43f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -138,9 +138,6 @@ class FlashCausalLMBatch(Batch): speculative_ids: Optional[torch.Tensor] # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - # Will be set by `generate_token` and reset after each prefill forward before staying set in decode - start_slots: Optional[torch.Tensor] # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode # Will be set by `generate_token` and reset after each prefill forward before staying set in decode slot_indices: Optional[torch.Tensor] @@ -417,7 +414,6 @@ def from_tokenized( position_ids=None, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=None, slot_indices=None, slots=None, prefill_head_indices=None, @@ -462,12 +458,11 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": ) # Create on CPU to only move to GPU once instead of at every copy - # slot_indices = torch.empty(len(request_ids), dtype=torch.int64) + slot_indices = torch.empty(len(request_ids), dtype=torch.int64) max_postfix_length = 0 max_current_length = 0 requests = [] - # start_slots = [] block_tables = [] all_input_ids = [] prefix_ids = [] @@ -491,30 +486,18 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Cumulative length cumulative_max_length = 0 - start_slots = [] - slots = [] - slot_indices = [] - cumulative_slot_tokens = 0 - for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i - request = self.requests[idx] - requests.append(request) + requests.append(self.requests[idx]) # Prefilling request_prefilling = self.prefilling_mask[idx] prefilling_mask.append(request_prefilling) - # Input ids if the request was part of a prefilling batch - # If the batch was decoding we can index into the tensor directly later - if self.prefilling: - input_ids.append(self.input_ids[idx]) - # Get length - request_prompt_length = self.prompt_lengths[idx] request_postfix_length = self.postfix_lengths[idx] request_prefix_length = self.prefix_lengths[idx] max_postfix_length = max(max_postfix_length, request_postfix_length) @@ -525,7 +508,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) - prompt_lengths.append(request_prompt_length) + prompt_lengths.append(self.prompt_lengths[idx]) postfix_lengths.append(request_postfix_length) prefix_lengths.append(request_prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) @@ -541,45 +524,31 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) - # remaining_tokens = ( - # stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - # ) - request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - # start_slots.append(cumulative_max_length) - - # Copy to tensor (CPU) - # slot_indices[i] = cumulative_max_length + request_postfix_length - 1 - - # Set slice - #FIXME - # slot_filtering_indices[ - # self.start_slots[idx] : self.start_slots[idx] - # + request_postfix_length - # + remaining_tokens - # - 1 - # ] = True - - if not self.prefilling: - if not request.slots: - request_slots = [ - s - for b in request_block_table - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] - else: - request_slots = request.slots - request_slots = request_slots[request_prefix_length:] - start_slots.append(cumulative_slot_tokens) - slots.extend(request_slots) - slot_indices.append(cumulative_slot_tokens) + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + else: + # Copy to tensor (CPU) + slot_indices[i] = cumulative_max_length + + remaining_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) - cumulative_slot_tokens += len(request_slots) + # Set slice + slot_filtering_indices[ + self.slot_indices[idx] : self.slot_indices[idx] + + request_postfix_length + + remaining_tokens + - 1 + ] = True - # cumulative_max_length += request_postfix_length + remaining_tokens - 1 + cumulative_max_length += request_postfix_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) @@ -595,28 +564,22 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None - start_slots = None slot_indices = None slots = None prefix_lengths_tensor = None postfix_lengths_tensor = None adapter_meta = None else: - slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) - slots = torch.tensor(slots, dtype=torch.int64, device=device) - # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices] - # slots = self.slots[slot_filtering_indices] + slots = self.slots[slot_filtering_indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices] - start_slots = torch.tensor(start_slots, dtype=torch.int64) - # Move to GPU now that we have the whole tensor - # slot_indices = slot_indices.to(device) + slot_indices = slot_indices.to(device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor( @@ -637,7 +600,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, @@ -715,7 +677,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids = [] # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None - start_slots = None slots = None slot_indices = None prefix_lengths_tensor = None @@ -725,7 +686,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch else: input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - start_slots = [] slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( @@ -836,8 +796,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch batch.prefix_lengths_tensor ) - start_slots.append(batch.start_slots + cumulative_slots) - # Update cumulative_slots += len(batch.slots) else: @@ -867,11 +825,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # Update cumulative_batch_size += len(batch) - if start_slots is not None: - start_slots = torch.concat(start_slots) - - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, @@ -903,7 +856,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, @@ -946,7 +898,6 @@ def prepare_for_prefill(self): sliding_window = get_sliding_windows() position_ids = [] cu_seqlen_prefill = [0] - start_slots = [] slot_indices = [] prefill_cache_indices = [] all_prefill_logprobs = True @@ -1041,7 +992,6 @@ def prepare_for_prefill(self): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - start_slots.append(cumulative_slot_tokens) slots.extend(request_slots) slot_indices.append(request_slot_indices) @@ -1058,7 +1008,6 @@ def prepare_for_prefill(self): cumulative_slot_tokens += len(request_slots) device = self.block_tables_tensor.device - self.start_slots = torch.tensor(start_slots, dtype=torch.int64) if isinstance(self.input_ids, list): if len(self) > 1: @@ -1762,6 +1711,8 @@ def generate_token( if prefill: batch.prepare_for_prefill() + log_master(logger.info, f"Tokens in this forward: {len(batch.input_ids)}") + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) From c8a033b6366740af3b39a36d55f3dd1616c919a8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 12:02:25 +0200 Subject: [PATCH 14/29] feedback loop --- backends/client/src/v3/client.rs | 13 +-- backends/client/src/v3/sharded_client.rs | 5 +- backends/v2/src/backend.rs | 20 ++--- backends/v3/src/backend.rs | 86 +++++++++---------- backends/v3/src/client/grpc_client.rs | 21 +++-- backends/v3/src/client/sharded_client.rs | 8 +- backends/v3/src/lib.rs | 14 ++- backends/v3/src/queue.rs | 4 + benchmark/src/generation.rs | 2 +- proto/v3/generate.proto | 9 +- router/src/lib.rs | 39 --------- .../models/causal_lm.py | 2 +- .../models/flash_causal_lm.py | 64 +++++++------- server/text_generation_server/models/model.py | 4 + .../models/seq2seq_lm.py | 2 +- server/text_generation_server/server.py | 23 ++--- 16 files changed, 153 insertions(+), 163 deletions(-) diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 5191f8dd7ec..8280795daf0 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -218,8 +218,13 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok(( response.generations, @@ -237,11 +242,7 @@ impl Client { &mut self, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let request = tonic::Request::new(DecodeRequest { - batch: None, - batches, - }) - .inject_context(); + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok(( response.generations, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 8872f8bdf8f..39e99776107 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -134,11 +134,12 @@ impl ShardedClient { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, PrefillTimings)>> = @@ -256,7 +257,7 @@ impl Health for ShardedClient { max_tokens: 2, max_blocks: 1, }; - self.clone().prefill(batch).await?; + self.clone().prefill(batch, None).await?; Ok(()) } } diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index 086fc6dc4a6..bc264138d28 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; +use text_generation_router::{FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -36,18 +36,14 @@ impl BackendV2 { speculate: u32, ) -> Self { // Infer shared state - let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) - } else { - Attention::Paged - }; - let block_size = if attention == Attention::FlashDecoding { - 256 - } else { - 16 + let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string()); + let block_size = match attention.as_str() { + "flashinfer" => 1, + "flashdecoding" => 256, + "paged" => 16, + _ => unreachable!(), }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 84152ff8030..a5c0f5125b2 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -1,12 +1,14 @@ -use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; /// Batching and inference logic +use crate::client::{ + Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient, +}; use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; +use text_generation_router::{FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -31,32 +33,22 @@ impl BackendV3 { max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, - requires_padding: bool, - window_size: Option, - speculate: u32, - support_chunking: bool, + shard_info: InfoResponse, ) -> Self { - if support_chunking { + if shard_info.support_chunking { tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored."); } - let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); - let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); - let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string()); - - let attention: Attention = attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); - let block_size = attention.block_size(); + let block_size = shard_info.block_size; let queue = Queue::new( - requires_padding, + shard_info.requires_padding, block_size, - prefix_caching, - window_size, - speculate, + shard_info.use_prefix_caching, + shard_info.window_size, + shard_info.speculate, max_batch_total_tokens, - support_chunking, + shard_info.support_chunking, ); let batching_task_notifier = Arc::new(Notify::new()); @@ -68,7 +60,7 @@ impl BackendV3 { max_batch_total_tokens, max_waiting_tokens, max_batch_size, - support_chunking, + shard_info.support_chunking, queue.clone(), batching_task_notifier.clone(), )); @@ -154,7 +146,7 @@ pub(crate) async fn batching_task( ) .await { - let mut cached_batch = prefill(&mut client, batch, &mut entries) + let mut cached_batch = prefill(&mut client, batch, None, &mut entries) .instrument(span) .await; let mut waiting_tokens = 1; @@ -175,7 +167,8 @@ pub(crate) async fn batching_task( let (min_size, max_size, prefill_token_budget) = if support_chunking { // Since the next batch will be concatenated with the current batch, // the current batch tokens must be subtracted to the prefill budget - let prefill_token_budget = max_batch_prefill_tokens - current_tokens; + let prefill_token_budget = + max_batch_prefill_tokens.saturating_sub(current_tokens); // We can ignore min_size and max_size // Models than rely on max_size cannot support chunking // Regarding min_size, chunking allow us to consistently run at the compute @@ -199,10 +192,8 @@ pub(crate) async fn batching_task( (min_size, max_size, max_batch_prefill_tokens) }; - let mut additional_batch = None; - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue + if let Some((new_entries, new_batch, span)) = queue .next_batch(min_size, max_size, prefill_token_budget, token_budget) .await { @@ -218,11 +209,11 @@ pub(crate) async fn batching_task( }; counter.increment(1); } - - if support_chunking { - entries.extend(new_entries); - additional_batch = Some(new_batch); + let cached_batch = if support_chunking { + // Concat current batch to the new one + batches.pop() } else { + // Request are waiting only if we don't support chunking entries.iter_mut().for_each(|(_, entry)| { // Create a new span to add the info that this entry is waiting // because a new batch is being computed @@ -233,18 +224,23 @@ pub(crate) async fn batching_task( // Update entry entry.temp_span = Some(entry_waiting_span); }); + None + }; + entries.extend(new_entries); - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = + prefill(&mut client, new_batch, cached_batch, &mut entries) .instrument(span) .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + batches.push(new_cached_batch); + } else if support_chunking { + // New cached batch is empty, no work left + break; } } @@ -262,7 +258,7 @@ pub(crate) async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, additional_batch, batches, &mut entries) + cached_batch = decode(&mut client, batches, &mut entries) .instrument(next_batch_span) .await; waiting_tokens += 1; @@ -277,13 +273,14 @@ pub(crate) async fn batching_task( async fn prefill( client: &mut ShardedClient, batch: Batch, + cached_batch: Option, entries: &mut IntMap, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); - match client.prefill(batch).await { + match client.prefill(batch, cached_batch).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries @@ -292,6 +289,10 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") .record(timings.forward.as_secs_f64()); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") @@ -316,7 +317,6 @@ async fn prefill( #[instrument(skip_all)] async fn decode( client: &mut ShardedClient, - batch: Option, batches: Vec, entries: &mut IntMap, ) -> Option { @@ -324,7 +324,7 @@ async fn decode( let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); - match client.decode(batch, batches).await { + match client.decode(batches).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index ab93db9b488..804c77d4163 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -218,13 +218,23 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok(( response.generations, response.batch, - PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + PrefillTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), )) } @@ -235,10 +245,9 @@ impl Client { #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] pub async fn decode( &mut self, - batch: Option, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context(); + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok(( response.generations, @@ -254,14 +263,16 @@ impl Client { } pub struct PrefillTimings { + pub concat: Option, pub forward: Duration, pub decode: Duration, pub total: Duration, } impl PrefillTimings { - fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { Self { + concat: concat_ns.map(Duration::from_nanos), forward: Duration::from_nanos(forward_ns), decode: Duration::from_nanos(decode_ns), total: Duration::from_nanos(total_ns), diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 8af4b26fee3..e25bf71e55d 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -135,11 +135,12 @@ impl ShardedClient { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, PrefillTimings)>> = @@ -167,13 +168,12 @@ impl ShardedClient { #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] pub async fn decode( &mut self, - batch: Option, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.decode(batch.clone(), batches.clone()))) + .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, DecodeTimings)>> = @@ -246,7 +246,7 @@ impl Health for ShardedClient { max_tokens: 2, max_blocks: 1, }; - self.clone().prefill(batch).await?; + self.clone().prefill(batch, None).await?; Ok(()) } } diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 0a7ef2239f4..7daf9eaeca7 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -31,6 +31,12 @@ pub struct BackendInfo { pub max_batch_size: Option, #[schema(example = "false")] pub support_chunking: bool, + #[schema(example = "false")] + pub prefix_caching: bool, + #[schema(example = "flashinfer")] + pub attention_impl: String, + #[schema(example = "1")] + pub block_size: u32, } #[allow(clippy::too_many_arguments)] @@ -113,6 +119,9 @@ pub async fn connect_backend( model_dtype: shard_info.dtype.clone(), speculate: shard_info.speculate as usize, support_chunking: shard_info.support_chunking, + prefix_caching: shard_info.use_prefix_caching, + attention_impl: shard_info.attention_impl.clone(), + block_size: shard_info.block_size, }; let backend = BackendV3::new( @@ -122,10 +131,7 @@ pub async fn connect_backend( max_batch_total_tokens, max_waiting_tokens, max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - shard_info.support_chunking, + shard_info, ); tracing::info!("Using backend V3"); diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 7db0aba3be8..a07c725cbc7 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -89,6 +89,10 @@ impl Queue { prefill_token_budget: u32, token_budget: u32, ) -> Option { + if prefill_token_budget == 0 || token_budget == 0 { + return None; + }; + // Create response channel let (response_sender, response_receiver) = oneshot::channel(); // Send next batch command to the background task managing the state diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index fff221ef582..43a84e7023a 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -174,7 +174,7 @@ async fn prefill( // Run prefill let start_time = Instant::now(); - let (_, decode_batch, _) = client.prefill(batch.clone()).await?; + let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?; // Get latency let latency = start_time.elapsed(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 15a93ac950e..e4dfefefefd 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -35,6 +35,9 @@ message InfoResponse { optional uint32 window_size = 4; uint32 speculate = 5; bool support_chunking = 6; + bool use_prefix_caching = 7; + string attention_impl = 8; + uint32 block_size = 9; } /// Empty request @@ -225,6 +228,8 @@ message FilterBatchResponse { message PrefillRequest { /// Batch Batch batch = 1; + /// Optional cached batch + CachedBatch cached_batch = 2; } message PrefillResponse { @@ -238,13 +243,13 @@ message PrefillResponse { uint64 decode_ns = 4; /// Total elapsed time in nanoseconds uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; } message DecodeRequest { /// Cached batches repeated CachedBatch batches = 1; - /// Optional Batch - optional Batch batch = 2; } message DecodeResponse { diff --git a/router/src/lib.rs b/router/src/lib.rs index b29c9395d91..fdbd931eafb 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -18,45 +18,6 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; -#[derive(PartialEq)] -pub enum Attention { - Paged, - FlashDecoding, - FlashInfer, -} - -impl Attention { - pub fn block_size(&self) -> u32 { - match self { - Attention::FlashDecoding => 256, - Attention::FlashInfer => 1, - Attention::Paged => 16, - } - } -} - -#[derive(Debug)] -pub struct ParseError; - -impl std::fmt::Display for ParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Cannot parse attention value") - } -} -impl std::error::Error for ParseError {} - -impl std::str::FromStr for Attention { - type Err = ParseError; - fn from_str(s: &str) -> Result { - match s { - "paged" => Ok(Attention::Paged), - "flashdecoding" => Ok(Attention::FlashDecoding), - "flashinfer" => Ok(Attention::FlashInfer), - _ => Err(ParseError), - } - } -} - /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 1378f59055c..0e5c016359b 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -76,7 +76,7 @@ def to_pb(self) -> generate_pb2.CachedBatch: request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, - current_tokens=len(self), + current_tokens=len(self.input_ids), ) @classmethod diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b283a5fb43f..65552ff7064 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -171,7 +171,7 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] # Will be set by `generate_token` and reset after each prefill forward - prefill_tokens: List[Optional[Tokens]] + prefill_logprob_tokens: List[Optional[Tokens]] # Prefixes prefix_ids: List[List[int]] @@ -290,8 +290,7 @@ def from_tokenized( prefix_length <= prompt_length ), f"Prefix {prefix_length} vs input {prompt_length}" if prefix_length == prompt_length: - assert prefix_length > 0 - prefix_length -= 1 + assert False, "unreachable" if prefix_length + postfix_length < prompt_length: # FIXME: speculate is not supported for context chunking at the moment assert speculate == 0 @@ -303,7 +302,9 @@ def from_tokenized( prefix_length : prefix_length + postfix_length ] - postfix_length = len(postfix_ids) + assert ( + len(postfix_ids) == postfix_length + ), "Rust and Python tokenizers are not aligned" postfix_lengths.append(postfix_length) prefix_offsets.append(prompt_length - 5) @@ -394,7 +395,7 @@ def from_tokenized( max_current_length=max_current_length, prefilling=True, prefilling_mask=[True] * len(pb.requests), - prefill_tokens=[None] * len(pb.requests), + prefill_logprob_tokens=[None] * len(pb.requests), postfix_lengths=postfix_lengths, prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, @@ -475,7 +476,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": read_offsets = [] prefilling_mask = [] - prefill_tokens = [] + prefill_logprob_tokens = [] stopping_criterias = [] top_n_tokens = [] @@ -518,7 +519,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": stopping_criterias.append(stopping_criteria) top_n_tokens.append(self.top_n_tokens[idx]) - prefill_tokens.append(self.prefill_tokens[idx]) + prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) @@ -611,7 +612,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - prefill_tokens=prefill_tokens, + prefill_logprob_tokens=prefill_logprob_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, @@ -726,7 +727,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefix_offsets = [] read_offsets = [] - prefill_tokens = [] + prefill_logprob_tokens = [] next_token_chooser_parameters = [] fsm_grammar_states = [] @@ -814,7 +815,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) - prefill_tokens.extend(batch.prefill_tokens) + prefill_logprob_tokens.extend(batch.prefill_logprob_tokens) next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) @@ -869,7 +870,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - prefill_tokens=prefill_tokens, + prefill_logprob_tokens=prefill_logprob_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, @@ -1769,9 +1770,10 @@ def generate_token( if get_support_chunking(): next_prefilling_mask = [] # Budget in tokens for the next batch - # We remove len(batch) to always have enough space for at least a single decode - # for the remaining requests - batch_budget = get_max_prefill_tokens() - len(batch) + # We remove (len(batch) - 1) to always have enough space for at least a single decode + # for the remaining requests -1 because the first request does not need to be removed from the budget + # (ex: you have one request in the batch, you want it to take the full budget not budget -1) + batch_budget = get_max_prefill_tokens() - (len(batch) - 1) # We reverse to prioritize older requests # zip() is not reversible so reverse the underlying lists instead for prefix_length, postfix_length, prompt_length in zip( @@ -1790,6 +1792,7 @@ def generate_token( finished_prefilling = False next_prefilling_mask.append(True) else: + # FIXME: use true number of accepted tokens instead of 1 # Since speculation will be turned off, this is always true next_chunk_length = 1 next_prefilling_mask.append(False) @@ -1807,14 +1810,7 @@ def generate_token( batch.prefilling = not finished_prefilling batch.prefilling_mask = next_prefilling_mask - # Turn off speculative if some requests are still prefilling - # It makes the logic easier to follow - if prefill and not finished_prefilling: - speculate = 0 - speculative_logits = None - else: - speculate = get_speculate() - + speculate = get_speculate() ( next_input_ids, next_token_logprobs, @@ -1914,7 +1910,7 @@ def generate_token( ] = next_input_ids[index] index += 1 - cumulative_length += postfix_length + cumulative_length += postfix_length # Update values # These values can be updated without a GPU -> CPU sync @@ -2045,18 +2041,18 @@ def generate_token( # this state to be stable if request.id % self.world_size == self.rank: # Prefill - if prefill and request.prefill_logprobs: + if request_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - request_prefill_tokens = batch.prefill_tokens[i] - request_prefill_logprobs = prefill_logprobs[ out_start_index : out_end_index - 1 ] prefill_token_ids = all_input_ids[:-1] - if request_prefill_tokens is None: + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] * ( len(prefix_ids) + 1 @@ -2069,18 +2065,20 @@ def generate_token( skip_special_tokens=False, ) - prefill_tokens = Tokens( + prefill_logprob_tokens = Tokens( prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], ) - if request_prefill_tokens is not None: - prefill_tokens = request_prefill_tokens + prefill_tokens + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = ( + past_prefill_logprob_tokens + prefill_logprob_tokens + ) - batch.prefill_tokens[i] = prefill_tokens + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - batch.prefill_tokens[i] = None + batch.prefill_logprob_tokens[i] = None # If it is, the tokens we decoded should be ignored if request_prefilling: @@ -2178,7 +2176,7 @@ def generate_token( generation = Generation( request.id, - batch.prefill_tokens[i], + batch.prefill_logprob_tokens[i], Tokens( _next_token_ids, _next_token_logprobs, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 02f3dbf9329..05d36ba315b 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -7,6 +7,7 @@ from transformers import PreTrainedTokenizerBase from loguru import logger +from text_generation_server.models.globals import ATTENTION, PREFIX_CACHING, BLOCK_SIZE from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.log import log_master from text_generation_server.utils.prefill_chunking import set_support_chunking @@ -94,6 +95,9 @@ def info(self) -> InfoResponse: window_size=self.sliding_window, speculate=self.speculate, support_chunking=self.support_chunking, + use_prefix_caching=PREFIX_CACHING, + attention_impl=ATTENTION, + block_size=BLOCK_SIZE, ) @property diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index e2d7aa4d5ba..0a1d0824f55 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -80,7 +80,7 @@ def to_pb(self) -> generate_pb2.CachedBatch: request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, - current_tokens=len(self), + current_tokens=len(self.input_ids), ) @classmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d89df966ec8..cc7979d4785 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -153,6 +153,18 @@ async def Prefill(self, request, context): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) + concat_ns = None + if self.model.support_chunking: + if request.HasField("cached_batch"): + cached_batch = self.cache.pop(request.cached_batch.id) + if cached_batch is None: + raise ValueError( + f"Batch ID {request.cached_batch.id} not found in cache." + ) + start_concat = time.time_ns() + batch = self.model.batch_type.concatenate([batch, cached_batch]) + concat_ns = time.time_ns() - start_concat + generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) @@ -162,6 +174,7 @@ async def Prefill(self, request, context): forward_ns=timings[0], decode_ns=timings[1], total_ns=time.time_ns() - start, + concat_ns=concat_ns, ) async def Decode(self, request, context): @@ -179,16 +192,6 @@ async def Decode(self, request, context): if len(batches) == 0: raise ValueError("All batches are empty") - if self.model.support_chunking: - if request.HasField("batch"): - batch = self.model.batch_type.from_pb( - request.batch, - self.model.tokenizer, - self.model.dtype, - self.model.device, - ) - batches.append(batch) - if len(batches) > 1: start_concat = time.time_ns() batch = self.model.batch_type.concatenate(batches) From 4ddea01c6e4f057caf2233fe1e9e82d57732bb21 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 12:11:50 +0200 Subject: [PATCH 15/29] remove log --- server/text_generation_server/models/flash_causal_lm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 65552ff7064..4c3024f4a37 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1712,8 +1712,6 @@ def generate_token( if prefill: batch.prepare_for_prefill() - log_master(logger.info, f"Tokens in this forward: {len(batch.input_ids)}") - prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) From 460e83044493f5efed7112b10fa0bbef3dff02e9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:45:52 +0200 Subject: [PATCH 16/29] fix benchmarker --- benchmark/src/main.rs | 13 +++++++++++++ server/text_generation_server/interceptor.py | 5 ++++- .../models/flash_causal_lm.py | 3 ++- server/text_generation_server/server.py | 10 +++++++--- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2ee3d7c551a..0bb5dc0cf33 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -178,6 +178,19 @@ fn main() -> Result<(), Box> { .clear_cache(None) .await .expect("Unable to clear cache"); + + // Warmup shard + let max_batch_size = batch_size.iter().max().unwrap(); + sharded_client + .warmup( + sequence_length, + sequence_length * max_batch_size, + (sequence_length + decode_length) * max_batch_size, + Some(*max_batch_size as usize), + ) + .await + .expect("Unable to warmup"); + tracing::info!("Connected"); // Run app diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 57df172575a..a5c023e4ecf 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -9,6 +9,9 @@ class ExceptionInterceptor(AsyncServerInterceptor): + def __init__(self, shutdown_callback): + self.shutdown_callback = shutdown_callback + async def intercept( self, method: Callable, @@ -25,7 +28,7 @@ async def intercept( # Runtime Error cannot be recovered from if isinstance(err, RuntimeError): - exit(1) + self.shutdown_callback() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4c3024f4a37..6e8a0097746 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1383,6 +1383,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive + self.kv_cache = [] empty_cache() try: @@ -1402,7 +1403,7 @@ def warmup(self, batch: FlashCausalLMBatch): _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) from e diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index cc7979d4785..da85d19d891 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -47,9 +47,12 @@ def __init__(self): signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully) + def set_keep_processing(self, value: bool): + self.KEEP_PROCESSING = value + def exit_gracefully(self, signum, frame): print(f"Exiting gracefully: Signal {signum}") - self.KEEP_PROCESSING = False + self.set_keep_processing(False) class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): @@ -268,10 +271,12 @@ async def serve_inner( logger.exception("Error when initializing model") raise + signal_handler = SignalHandler() + set_adapter_to_index(adapter_to_index) server = aio.server( interceptors=[ - ExceptionInterceptor(), + ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)), UDSOpenTelemetryAioServerInterceptor(), ], options=[ @@ -292,7 +297,6 @@ async def serve_inner( await server.start() logger.info("Server started at {}".format(local_url)) - signal_handler = SignalHandler() while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) From 8188deac224ffed2d057d429d0f28d0a3a6f5744 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:08:30 +0200 Subject: [PATCH 17/29] fix vlm and seq2seq --- .../test_grammar_response_format_llama.py | 1 - .../layers/attention/flash_attn_triton.py | 1 - .../models/custom_modeling/mllama.py | 11 +++++++--- .../models/seq2seq_lm.py | 2 +- .../models/vlm_causal_lm.py | 5 ++--- .../text_generation_server/utils/adapter.py | 21 +++++++++++-------- 6 files changed, 23 insertions(+), 18 deletions(-) diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index 25bf9d98eff..eb3268cea4f 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle): @pytest.mark.release @pytest.mark.asyncio async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): - class Weather(BaseModel): unit: str temperature: List[int] diff --git a/server/text_generation_server/layers/attention/flash_attn_triton.py b/server/text_generation_server/layers/attention/flash_attn_triton.py index 3a6f9a730ce..fd180f0f19a 100644 --- a/server/text_generation_server/layers/attention/flash_attn_triton.py +++ b/server/text_generation_server/layers/attention/flash_attn_triton.py @@ -699,7 +699,6 @@ def check_args( class _attention(torch.autograd.Function): - @staticmethod def forward( ctx, diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 6e091a74fe8..be0a4b5d7c3 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -493,9 +493,14 @@ def forward( aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( - pixel_values.shape - ) + ( + batch_size, + num_concurrent_media, + num_tiles, + num_channels, + height, + width, + ) = pixel_values.shape pixel_values = pixel_values.reshape( batch_size * num_concurrent_media * num_tiles, num_channels, height, width diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 0a1d0824f55..42cd572a332 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -80,7 +80,7 @@ def to_pb(self) -> generate_pb2.CachedBatch: request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, - current_tokens=len(self.input_ids), + current_tokens=len(self.decoder_input_ids), ) @classmethod diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1a578d7b1d9..7484e44885c 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -295,7 +295,7 @@ def forward( block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] postfix_lengths = batch.postfix_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -338,7 +338,7 @@ def forward( slots = batch.slots[batch.slot_indices] postfix_lengths = batch.postfix_lengths_tensor prefix_lengths_tensor = batch.prefix_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -347,7 +347,6 @@ def forward( # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 2b61f9bb448..09254b68a43 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -120,15 +120,18 @@ def _load_and_merge( if adapter.id == BASE_MODEL_ADAPTER_ID: raise ValueError("Base model adapter cannot be merged.") - module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( - load_module_map( - model_id, - adapter.revision, - adapter.id, - adapter.path, - weight_names, - trust_remote_code, - ) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_module_map( + model_id, + adapter.revision, + adapter.id, + adapter.path, + weight_names, + trust_remote_code, ) adapters_to_merge.append((module_map, adapter_config)) From 3924b87a046c59a552007152051b59cc42a41ac8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:14:03 +0200 Subject: [PATCH 18/29] rename to cache and input lengths --- .../layers/attention/common.py | 22 +- .../models/flash_causal_lm.py | 360 +++++++++--------- .../models/mllama_causal_lm.py | 2 +- .../models/vlm_causal_lm.py | 48 +-- 4 files changed, 210 insertions(+), 222 deletions(-) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index 648b010a75a..8f9d93a1b46 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -9,8 +9,8 @@ @dataclass class Seqlen: - postfix_lengths: torch.Tensor - prefix_lengths: torch.Tensor + input_lengths: torch.Tensor + cache_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor] max_q: int @@ -18,16 +18,16 @@ class Seqlen: def __init__( self, - postfix_lengths, - prefix_lengths, + input_lengths, + cache_lengths, cu_seqlen_q=None, max_q=None, max_k=None, ): - self.postfix_lengths = postfix_lengths - self.prefix_lengths = prefix_lengths - device = self.postfix_lengths.device - shape = self.postfix_lengths.shape + self.input_lengths = input_lengths + self.cache_lengths = cache_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape if cu_seqlen_q is None: cu_seqlen_q = torch.arange( shape[0] + 1, @@ -43,7 +43,7 @@ def __init__( # cuda graphs don't like this and this is necessary to clamp within mistral # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 - total = self.postfix_lengths + self.prefix_lengths + total = self.input_lengths + self.cache_lengths torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q @@ -59,8 +59,8 @@ def clamp(self, max): @dataclass class Seqlen: - postfix_lengths: torch.Tensor - prefix_lengths: torch.Tensor + input_lengths: torch.Tensor + cache_lengths: torch.Tensor cu_seqlen_q: torch.Tensor max_q: int max_k: int diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6e8a0097746..9a34dfc55fc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -150,7 +150,7 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward before staying set in decode slots: Optional[torch.Tensor] - max_postfix_length: int + max_input_length: int max_current_length: int # Whether this batch contains at least one request that is prefilling @@ -181,13 +181,13 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor: torch.Tensor # Lengths of all generations present in the batch - postfix_lengths: List[int] + input_lengths: List[int] # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lengths: List[int] + cache_lengths: List[int] prompt_lengths: List[int] # Will be set by `generate_token` and reset after each prefill forward before staying set in decode - postfix_lengths_tensor: Optional[torch.Tensor] - prefix_lengths_tensor: Optional[torch.Tensor] + input_lengths_tensor: Optional[torch.Tensor] + cache_lengths_tensor: Optional[torch.Tensor] prompt_lengths_tensor: torch.Tensor prefix_offsets: List[Optional[int]] @@ -252,8 +252,8 @@ def from_tokenized( ) -> "FlashCausalLMBatch": speculate = get_speculate() - prefix_lengths = [] - postfix_lengths = [] + cache_lengths = [] + input_lengths = [] prompt_lengths = [] prefix_offsets = [] read_offsets = [] @@ -267,7 +267,7 @@ def from_tokenized( top_n_tokens = [] num_blocks = 0 - max_postfix_length = 0 + max_input_length = 0 max_current_length = 0 max_length = 0 max_blocks = 0 @@ -284,28 +284,26 @@ def from_tokenized( prompt_length = len(tokenized_input) prompt_lengths.append(prompt_length) - prefix_length = r.prefix_len - postfix_length = r.postfix_len + cache_length = r.prefix_len + input_length = r.postfix_len assert ( - prefix_length <= prompt_length - ), f"Prefix {prefix_length} vs input {prompt_length}" - if prefix_length == prompt_length: + cache_length <= prompt_length + ), f"Prefix {cache_length} vs input {prompt_length}" + if cache_length == prompt_length: assert False, "unreachable" - if prefix_length + postfix_length < prompt_length: + if cache_length + input_length < prompt_length: # FIXME: speculate is not supported for context chunking at the moment assert speculate == 0 assert get_support_chunking() - assert postfix_length > 0 + assert input_length > 0 - prefix_ids.append(tokenized_input[:prefix_length]) - postfix_ids = tokenized_input[ - prefix_length : prefix_length + postfix_length - ] + prefix_ids.append(tokenized_input[:cache_length]) + postfix_ids = tokenized_input[cache_length : cache_length + input_length] assert ( - len(postfix_ids) == postfix_length + len(postfix_ids) == input_length ), "Rust and Python tokenizers are not aligned" - postfix_lengths.append(postfix_length) + input_lengths.append(input_length) prefix_offsets.append(prompt_length - 5) read_offsets.append(prompt_length) @@ -341,13 +339,13 @@ def from_tokenized( block_tables.append(request_blocks) - prefix_lengths.append(prefix_length) + cache_lengths.append(cache_length) num_blocks += len(request_blocks) # Update max_blocks = max(max_blocks, len(request_blocks)) - max_postfix_length = max(max_postfix_length, postfix_length) - max_current_length = max(max_current_length, prefix_length + postfix_length) + max_input_length = max(max_input_length, input_length) + max_current_length = max(max_current_length, cache_length + input_length) max_length = max( max_length, prompt_length + max_new_tokens + speculative_length, @@ -390,13 +388,13 @@ def from_tokenized( input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lengths=prefix_lengths, - max_postfix_length=max_postfix_length, + cache_lengths=cache_lengths, + max_input_length=max_input_length, max_current_length=max_current_length, prefilling=True, prefilling_mask=[True] * len(pb.requests), prefill_logprob_tokens=[None] * len(pb.requests), - postfix_lengths=postfix_lengths, + input_lengths=input_lengths, prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, read_offsets=read_offsets, @@ -420,8 +418,8 @@ def from_tokenized( prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - prefix_lengths_tensor=None, - postfix_lengths_tensor=None, + cache_lengths_tensor=None, + input_lengths_tensor=None, adapter_meta=None, ) @@ -460,7 +458,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_postfix_length = 0 + max_input_length = 0 max_current_length = 0 requests = [] @@ -470,8 +468,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": input_ids = [] prompt_lengths = [] - postfix_lengths = [] - prefix_lengths = [] + input_lengths = [] + cache_lengths = [] prefix_offsets = [] read_offsets = [] @@ -499,19 +497,19 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefilling_mask.append(request_prefilling) # Get length - request_postfix_length = self.postfix_lengths[idx] - request_prefix_length = self.prefix_lengths[idx] - max_postfix_length = max(max_postfix_length, request_postfix_length) + request_input_length = self.input_lengths[idx] + request_cache_length = self.cache_lengths[idx] + max_input_length = max(max_input_length, request_input_length) max_current_length = max( - max_current_length, request_prefix_length + request_postfix_length + max_current_length, request_cache_length + request_input_length ) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) prompt_lengths.append(self.prompt_lengths[idx]) - postfix_lengths.append(request_postfix_length) - prefix_lengths.append(request_prefix_length) + input_lengths.append(request_input_length) + cache_lengths.append(request_cache_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -544,12 +542,12 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Set slice slot_filtering_indices[ self.slot_indices[idx] : self.slot_indices[idx] - + request_postfix_length + + request_input_length + remaining_tokens - 1 ] = True - cumulative_max_length += request_postfix_length + remaining_tokens - 1 + cumulative_max_length += request_input_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) @@ -567,17 +565,17 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": position_ids = None slot_indices = None slots = None - prefix_lengths_tensor = None - postfix_lengths_tensor = None + cache_lengths_tensor = None + input_lengths_tensor = None adapter_meta = None else: # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] - postfix_lengths_tensor = self.postfix_lengths_tensor[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] - prefix_lengths_tensor = self.prefix_lengths_tensor[indices] + cache_lengths_tensor = self.cache_lengths_tensor[indices] # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) @@ -605,7 +603,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_postfix_length=max_postfix_length, + max_input_length=max_input_length, max_current_length=max_current_length, prefilling=self.prefilling, prefilling_mask=prefilling_mask, @@ -615,10 +613,10 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefill_logprob_tokens=prefill_logprob_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, - postfix_lengths=postfix_lengths, - postfix_lengths_tensor=postfix_lengths_tensor, - prefix_lengths=prefix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -647,7 +645,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_slots = 0 max_blocks = 0 max_length = 0 - max_postfix_length = 0 + max_input_length = 0 max_current_length = 0 for b in batches: total_batch_size += len(b) @@ -659,7 +657,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) - max_postfix_length = max(max_postfix_length, b.max_postfix_length) + max_input_length = max(max_input_length, b.max_input_length) max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, @@ -680,8 +678,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids = None slots = None slot_indices = None - prefix_lengths_tensor = None - postfix_lengths_tensor = None + cache_lengths_tensor = None + input_lengths_tensor = None adapter_meta = None adapter_segment_builder = None else: @@ -689,10 +687,10 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size ) - prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( + cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( total_batch_size ) total_indices_size = sum( @@ -718,12 +716,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) block_tables = [] - prefix_lengths = [] + cache_lengths = [] all_input_ids = [] prefix_ids = [] prompt_lengths = [] - postfix_lengths = [] + input_lengths = [] prefix_offsets = [] read_offsets = [] @@ -773,9 +771,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices[start_index:end_index] = ( batch.slot_indices + cumulative_slots ) - postfix_lengths_tensor[start_index:end_index] = ( - batch.postfix_lengths_tensor - ) + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor slots[slots_start_index:slots_end_index] = batch.slots # Copy over adapter indices @@ -793,9 +789,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices, ) - prefix_lengths_tensor[start_index:end_index] = ( - batch.prefix_lengths_tensor - ) + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor # Update cumulative_slots += len(batch.slots) @@ -806,12 +800,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) - prefix_lengths.extend(batch.prefix_lengths) + cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) prefix_ids.extend(batch.prefix_ids) prompt_lengths.extend(batch.prompt_lengths) - postfix_lengths.extend(batch.postfix_lengths) + input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) @@ -860,10 +854,10 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lengths=prefix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, slots=slots, - max_postfix_length=max_postfix_length, + max_input_length=max_input_length, max_current_length=max_current_length, prefilling=prefilling, prefilling_mask=prefilling_mask, @@ -873,8 +867,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prefill_logprob_tokens=prefill_logprob_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, - postfix_lengths=postfix_lengths, - postfix_lengths_tensor=postfix_lengths_tensor, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -918,30 +912,30 @@ def prepare_for_prefill(self): for i, ( r, - prefix_length, - postfix_length, + cache_length, + input_length, prompt_length, request_prefilling, blocks, ) in enumerate( zip( self.requests, - self.prefix_lengths, - self.postfix_lengths, + self.cache_lengths, + self.input_lengths, self.prompt_lengths, self.prefilling_mask, self.block_tables, ) ): - next_chunk_length = postfix_length + next_chunk_length = input_length # Position ids request_position_ids = torch.arange( - prefix_length, prefix_length + postfix_length, dtype=torch.int32 + cache_length, cache_length + input_length, dtype=torch.int32 ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + postfix_length) + cu_seqlen_prefill.append(cumulative_length + input_length) if not r.slots: request_slots = [ @@ -952,18 +946,18 @@ def prepare_for_prefill(self): else: request_slots = r.slots - request_slots = request_slots[prefix_length:] + request_slots = request_slots[cache_length:] request_slot_indices = torch.arange( cumulative_slot_tokens, - cumulative_slot_tokens + postfix_length, + cumulative_slot_tokens + input_length, dtype=torch.int64, ) # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, postfix_length - sliding_window), - cumulative_length + postfix_length, + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, dtype=torch.int64, ) @@ -976,16 +970,14 @@ def prepare_for_prefill(self): if prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( - prefill_out_cumulative_length + postfix_length - 1 - ) - prefill_cu_outlens.append( - prefill_out_cumulative_length + postfix_length + prefill_out_cumulative_length + input_length - 1 ) - prefill_out_cumulative_length += postfix_length + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length else: prefill_head_indices.append( torch.tensor( - [cumulative_length + postfix_length - 1], + [cumulative_length + input_length - 1], dtype=torch.int32, ) ) @@ -1038,8 +1030,8 @@ def prepare_for_prefill(self): self.prefill_cache_indices = ( prefill_cache_indices.to(device) if sliding_window is not None else None ) - self.postfix_lengths_tensor = torch.tensor( - self.postfix_lengths, dtype=torch.int32, device=device + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device ) if all_prefill_logprobs: @@ -1059,8 +1051,8 @@ def prepare_for_prefill(self): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices self.slots = torch.tensor(slots, dtype=torch.int64, device=device) - self.prefix_lengths_tensor = torch.tensor( - self.prefix_lengths, dtype=torch.int32, device=device + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device ) adapter_indices = torch.cat(adapter_indices_list).to( dtype=torch.int64, device=device @@ -1276,12 +1268,12 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - postfix_lengths = [max_s] * bs - prefix_lengths = [0] * bs - postfix_lengths_tensor = ( + input_lengths = [max_s] * bs + cache_lengths = [0] * bs + input_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) - prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device ).repeat(bs) @@ -1290,8 +1282,8 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=postfix_lengths, - prefix_lengths=prefix_lengths, + input_lengths=input_lengths, + cache_lengths=cache_lengths, ) from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, @@ -1319,8 +1311,8 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "postfix_lengths": postfix_lengths_tensor, - "prefix_lengths": prefix_lengths_tensor, + "input_lengths": input_lengths_tensor, + "cache_lengths": cache_lengths_tensor, "state": state, "graph": graph, } @@ -1330,13 +1322,13 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=None, - postfix_lengths_tensor=postfix_lengths_tensor, + input_lengths_tensor=input_lengths_tensor, state=state, - prefix_lengths_tensor=prefix_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, ): seqlen = Seqlen( - postfix_lengths=postfix_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=None, max_q=1, max_k=max_s, @@ -1359,8 +1351,8 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): with torch.cuda.graph(graph, pool=MEM_POOL): seqlen = Seqlen( - postfix_lengths=postfix_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=None, max_q=1, max_k=max_s, @@ -1517,8 +1509,8 @@ def tunableop_warmup(self, seqlen: int): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) # Dummy value, some models (starcoder2) don't accept `None`. - postfix_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - prefix_lengths_tensor = torch.zeros( + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros( seqlen, dtype=torch.int32, device=self.device ) cu_seqlen_prefill = torch.tensor( @@ -1526,8 +1518,8 @@ def tunableop_warmup(self, seqlen: int): ) max_s = seqlen seqlen = Seqlen( - postfix_lengths=postfix_lengths, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=1, max_k=seqlen, @@ -1558,7 +1550,7 @@ def forward( kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - postfix_lengths = batch.postfix_lengths_tensor + input_lengths = batch.input_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -1575,11 +1567,11 @@ def forward( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - postfix_lengths = ( - postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lengths_tensor = ( - batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -1600,8 +1592,8 @@ def forward( kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - postfix_lengths = batch.postfix_lengths_tensor - prefix_lengths_tensor = batch.prefix_lengths_tensor + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -1623,19 +1615,19 @@ def forward( if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=batch.postfix_lengths, - prefix_lengths=batch.prefix_lengths, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - postfix_lengths_tensor=postfix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (postfix_lengths + prefix_lengths_tensor).max().item() + max_k = (input_lengths + cache_lengths_tensor).max().item() seqlen = Seqlen( - postfix_lengths=postfix_lengths, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, @@ -1664,8 +1656,8 @@ def forward( if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=batch.postfix_lengths, - prefix_lengths=batch.prefix_lengths, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, ) # assert block_tables.shape[0] >= slots.shape[0] cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables @@ -1678,18 +1670,18 @@ def forward( # so it doesn't matter if we override it with bogus values. cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["postfix_lengths"].zero_() - cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths - cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][ - : prefix_lengths_tensor.shape[0] - ] = prefix_lengths_tensor + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - postfix_lengths_tensor=cuda_graph["postfix_lengths"], - prefix_lengths_tensor=cuda_graph["prefix_lengths"], + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], state=cuda_graph["state"], ): # Replay the graph @@ -1775,13 +1767,13 @@ def generate_token( batch_budget = get_max_prefill_tokens() - (len(batch) - 1) # We reverse to prioritize older requests # zip() is not reversible so reverse the underlying lists instead - for prefix_length, postfix_length, prompt_length in zip( - reversed(batch.prefix_lengths), - reversed(batch.postfix_lengths), + for cache_length, input_length, prompt_length in zip( + reversed(batch.cache_lengths), + reversed(batch.input_lengths), reversed(batch.prompt_lengths), ): remaining_prefill_tokens = max( - prompt_length - prefix_length - postfix_length, 0 + prompt_length - cache_length - input_length, 0 ) if remaining_prefill_tokens > 0: next_chunk_length = max( @@ -1842,8 +1834,8 @@ def generate_token( # Zipped iterator iterator = zip( batch.prompt_lengths, - batch.prefix_lengths, - batch.postfix_lengths, + batch.cache_lengths, + batch.input_lengths, batch.all_input_ids, accepted_ids, ) @@ -1858,14 +1850,14 @@ def generate_token( cumulative_length = 0 for i, ( prompt_length, - prefix_length, - postfix_length, + cache_length, + input_length, all_input_ids, n_accepted_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length - end_index = cumulative_length + postfix_length + end_index = cumulative_length + input_length if prefill: # Indexing metadata @@ -1899,17 +1891,17 @@ def generate_token( # Represent whether this request is still prefilling # If it is, the tokens we decoded should be ignored - accept_tokens = prefix_length + postfix_length >= prompt_length + accept_tokens = cache_length + input_length >= prompt_length if accept_tokens: # Only save tokens if we are done prefilling for this request for j in range(n_accepted_ids): - batch.all_input_ids_tensor[ - i, prefix_length + postfix_length + j - ] = next_input_ids[index] + batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( + next_input_ids[index] + ) index += 1 - cumulative_length += postfix_length + cumulative_length += input_length # Update values # These values can be updated without a GPU -> CPU sync @@ -1917,8 +1909,8 @@ def generate_token( batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids - batch.prefix_lengths_tensor += batch.postfix_lengths_tensor - batch.postfix_lengths_tensor = accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = accepted_ids batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices @@ -1959,24 +1951,24 @@ def generate_token( request_prefilling, next_token_id, all_input_ids, - prefix_length, - postfix_length, + cache_length, + input_length, next_chunk_length, ) in enumerate( zip( batch.prefilling_mask, next_token_ids, batch.all_input_ids, - batch.prefix_lengths, - batch.postfix_lengths, + batch.cache_lengths, + batch.input_lengths, next_chunk_lengths, ) ): if request_prefilling: - next_prefix_length = prefix_length + postfix_length + next_cache_length = cache_length + input_length # Get new prompt IDs to prefill postfix_ids = all_input_ids[ - next_prefix_length : next_prefix_length + next_chunk_length + next_cache_length : next_cache_length + next_chunk_length ] else: # This request is done prefilling, the new id is the one selected the sampling method @@ -1996,8 +1988,8 @@ def generate_token( iterator = zip( batch.requests, batch.prompt_lengths, - batch.prefix_lengths, - batch.postfix_lengths, + batch.cache_lengths, + batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, @@ -2012,15 +2004,15 @@ def generate_token( batch_top_token_logprobs, ) - # Reset max_postfix_length - batch.max_postfix_length = 0 + # Reset max_input_length + batch.max_input_length = 0 # For each member of the batch index = 0 for i, ( request, prompt_length, - prefix_length, - postfix_length, + cache_length, + input_length, prefix_offset, read_offset, stopping_criteria, @@ -2084,9 +2076,9 @@ def generate_token( # Make sure that we do not stop as even though this request did not create a token, it is still # processing stopped = False - new_postfix_length = next_chunk_lengths[i] + new_input_length = next_chunk_lengths[i] else: - new_postfix_length = n_accepted_ids + new_input_length = n_accepted_ids # Append next token to all tokens next_token_texts = [] left = 0 @@ -2198,14 +2190,12 @@ def generate_token( ) # Update values - current_prefix_length = prefix_length + postfix_length - batch.prefix_lengths[i] = current_prefix_length - current_postfix_length = new_postfix_length - batch.max_postfix_length = max( - batch.max_postfix_length, current_postfix_length - ) - batch.postfix_lengths[i] = current_postfix_length - current_length = current_prefix_length + current_postfix_length + current_cache_length = cache_length + input_length + batch.cache_lengths[i] = current_cache_length + current_input_length = new_input_length + batch.max_input_length = max(batch.max_input_length, current_input_length) + batch.input_lengths[i] = current_input_length + current_length = current_cache_length + current_input_length batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset @@ -2235,8 +2225,8 @@ def _forward_context( *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - postfix_lengths_tensor: torch.Tensor, - prefix_lengths_tensor: torch.Tensor, + input_lengths_tensor: torch.Tensor, + cache_lengths_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": @@ -2247,7 +2237,7 @@ def _forward_context( use_prefill_with_paged_kv_state, ) - # has_prefix_lengths = any(prefix_length > 0 for prefix_length in prefix_lengths) + # has_cache_lengths = any(cache_length > 0 for cache_length in cache_lengths) if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( @@ -2256,12 +2246,12 @@ def _forward_context( ), # block_tables=block_tables_to_ragged( # block_tables=block_tables, - # postfix_lengths=postfix_lengths, - # prefix_lengths=prefix_lengths, + # input_lengths=input_lengths, + # cache_lengths=cache_lengths, # ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, + input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -2270,10 +2260,10 @@ def _forward_context( window_left=self.sliding_window, ) else: - assert postfix_lengths_tensor is not None + assert input_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, + input_lengths=input_lengths_tensor + cache_lengths_tensor, block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, @@ -2285,21 +2275,19 @@ def _forward_context( def block_tables_to_ragged( - *, block_tables: torch.Tensor, postfix_lengths: List[int], prefix_lengths: List[int] + *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int] ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" - assert len(postfix_lengths) == len(prefix_lengths) + assert len(input_lengths) == len(cache_lengths) - total_len = sum(postfix_lengths) + sum(prefix_lengths) + total_len = sum(input_lengths) + sum(cache_lengths) block_tables_ragged = torch.empty( total_len, dtype=torch.int32, device=block_tables.device ) offset = 0 - for i, (input_length, prefix_length) in enumerate( - zip(postfix_lengths, prefix_lengths) - ): - seq_len = prefix_length + input_length + for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): + seq_len = cache_length + input_length block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] offset += seq_len diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 9e19e171503..3aa475c3179 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -285,7 +285,7 @@ def forward( max_k = (input_lengths + prefix_lens_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=prefix_lens_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7484e44885c..a06add13f6e 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -294,7 +294,7 @@ def forward( kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - postfix_lengths = batch.postfix_lengths_tensor + input_lengths = batch.input_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -311,11 +311,11 @@ def forward( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - postfix_lengths = ( - postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lengths_tensor = ( - batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -336,8 +336,8 @@ def forward( kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - postfix_lengths = batch.postfix_lengths_tensor - prefix_lengths_tensor = batch.prefix_lengths_tensor + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -359,19 +359,19 @@ def forward( if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=batch.postfix_lengths, - prefix_lengths=batch.prefix_lengths, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - postfix_lengths_tensor=postfix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (postfix_lengths + prefix_lengths_tensor).max().item() + max_k = (input_lengths + cache_lengths_tensor).max().item() seqlen = Seqlen( - postfix_lengths=postfix_lengths, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, @@ -408,8 +408,8 @@ def forward( if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=batch.postfix_lengths, - prefix_lengths=batch.prefix_lengths, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: @@ -418,18 +418,18 @@ def forward( ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["postfix_lengths"].zero_() - cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths - cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][ - : prefix_lengths_tensor.shape[0] - ] = prefix_lengths_tensor + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - postfix_lengths_tensor=cuda_graph["postfix_lengths"], - prefix_lengths_tensor=cuda_graph["prefix_lengths"], + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], state=cuda_graph["state"], ): # Replay the graph From ea4b739a9f9f903a13200ba97c073f2bbda11e56 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 17:12:31 +0200 Subject: [PATCH 19/29] fix prefill logprobs --- server/text_generation_server/models/flash_causal_lm.py | 9 ++++++--- server/text_generation_server/server.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 9a34dfc55fc..4e9f9c668e4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1757,6 +1757,7 @@ def generate_token( finished_prefilling = True next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask if prefill: if get_support_chunking(): next_prefilling_mask = [] @@ -1998,6 +1999,7 @@ def generate_token( batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, + current_prefilling_mask, batch.prefilling_mask, accepted_ids, batch_top_token_ids, @@ -2021,7 +2023,8 @@ def generate_token( do_sample, seed, top_n_tokens, - request_prefilling, + request_was_prefilling, + request_is_prefilling, n_accepted_ids, top_token_ids, top_token_logprobs, @@ -2032,7 +2035,7 @@ def generate_token( # this state to be stable if request.id % self.world_size == self.rank: # Prefill - if request_prefilling and request.prefill_logprobs: + if request_was_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] @@ -2072,7 +2075,7 @@ def generate_token( batch.prefill_logprob_tokens[i] = None # If it is, the tokens we decoded should be ignored - if request_prefilling: + if request_is_prefilling: # Make sure that we do not stop as even though this request did not create a token, it is still # processing stopped = False diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index da85d19d891..aef00fb5f5d 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -165,7 +165,7 @@ async def Prefill(self, request, context): f"Batch ID {request.cached_batch.id} not found in cache." ) start_concat = time.time_ns() - batch = self.model.batch_type.concatenate([batch, cached_batch]) + batch = self.model.batch_type.concatenate([cached_batch, batch]) concat_ns = time.time_ns() - start_concat generations, next_batch, timings = self.model.generate_token(batch) From 08953c5975745c868cf6383dac84ed9d6d59229f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:23:45 +0200 Subject: [PATCH 20/29] fix launcher --- benchmark/src/main.rs | 12 ---- launcher/src/main.rs | 12 ---- .../models/flash_causal_lm.py | 66 +++++++++++-------- 3 files changed, 37 insertions(+), 53 deletions(-) diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 0bb5dc0cf33..2e2e9a11c12 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -179,18 +179,6 @@ fn main() -> Result<(), Box> { .await .expect("Unable to clear cache"); - // Warmup shard - let max_batch_size = batch_size.iter().max().unwrap(); - sharded_client - .warmup( - sequence_length, - sequence_length * max_batch_size, - (sequence_length + decode_length) * max_batch_size, - Some(*max_batch_size as usize), - ) - .await - .expect("Unable to warmup"); - tracing::info!("Connected"); // Run app diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 214adcdc909..2389ac6338f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1727,12 +1727,6 @@ fn main() -> Result<(), LauncherError> { "`max_input_tokens must be < `max_total_tokens`".to_string(), )); } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_input_tokens - ))); - } if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); @@ -1786,12 +1780,6 @@ fn main() -> Result<(), LauncherError> { } if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_batch_total_tokens - ))); - } if max_total_tokens as u32 > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4e9f9c668e4..8d33a2b3f54 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -173,9 +173,6 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward prefill_logprob_tokens: List[Optional[Tokens]] - # Prefixes - prefix_ids: List[List[int]] - # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -259,7 +256,6 @@ def from_tokenized( read_offsets = [] all_input_ids = [] all_postfix_ids = [] - prefix_ids = [] requests_idx_mapping = {} next_token_chooser_parameters = [] @@ -297,7 +293,6 @@ def from_tokenized( assert get_support_chunking() assert input_length > 0 - prefix_ids.append(tokenized_input[:cache_length]) postfix_ids = tokenized_input[cache_length : cache_length + input_length] assert ( @@ -400,7 +395,6 @@ def from_tokenized( read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -464,7 +458,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": requests = [] block_tables = [] all_input_ids = [] - prefix_ids = [] input_ids = [] prompt_lengths = [] @@ -505,7 +498,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": ) all_input_ids.append(self.all_input_ids[idx]) - prefix_ids.append(self.prefix_ids[idx]) prompt_lengths.append(self.prompt_lengths[idx]) input_lengths.append(request_input_length) @@ -621,7 +613,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -718,7 +709,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch block_tables = [] cache_lengths = [] all_input_ids = [] - prefix_ids = [] prompt_lengths = [] input_lengths = [] @@ -802,7 +792,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch block_tables.extend(batch.block_tables) cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) - prefix_ids.extend(batch.prefix_ids) prompt_lengths.extend(batch.prompt_lengths) input_lengths.extend(batch.input_lengths) @@ -873,7 +862,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -1839,6 +1827,8 @@ def generate_token( batch.input_lengths, batch.all_input_ids, accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second @@ -1855,6 +1845,8 @@ def generate_token( input_length, all_input_ids, n_accepted_ids, + request_was_prefilling, + request_is_prefilling, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -1864,7 +1856,6 @@ def generate_token( # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - out_length = out_end_index - out_start_index if finished_prefilling: # Initialize position_ids @@ -1880,21 +1871,25 @@ def generate_token( # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: + # If the request was prefilling and cache_length == 0, the first token is a bogus token + # and needs to be removed. We do so by incrementing the start_index + if request_was_prefilling and cache_length == 0: + start_index += 1 + + # If the request was prefilling, and it is done prefilling, the last token was generated and is + # therefore not part of the prefill. We remove it by decrementing out_end_index + if request_was_prefilling and not request_is_prefilling: + out_end_index -= 1 + if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index - 1] = ( - batch.input_ids[start_index + 1 : start_index + out_length] + prefill_tokens_indices[out_start_index:out_end_index] = ( + batch.input_ids[start_index:end_index] ) else: # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : start_index + out_length - ] + prefill_tokens_indices = batch.input_ids[start_index:end_index] - # Represent whether this request is still prefilling - # If it is, the tokens we decoded should be ignored - accept_tokens = cache_length + input_length >= prompt_length - - if accept_tokens: + if not request_is_prefilling: # Only save tokens if we are done prefilling for this request for j in range(n_accepted_ids): batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( @@ -1995,7 +1990,6 @@ def generate_token( batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, @@ -2019,7 +2013,6 @@ def generate_token( read_offset, stopping_criteria, all_input_ids, - prefix_ids, do_sample, seed, top_n_tokens, @@ -2039,19 +2032,30 @@ def generate_token( out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] + # log_master(logger.info, f"{prefill_logprobs}") + + if not request_is_prefilling: + # If the request is done prefilling, then the last logprob is a generated token + # We need to remove it + out_end_index -= 1 + request_prefill_logprobs = prefill_logprobs[ - out_start_index : out_end_index - 1 + out_start_index:out_end_index + ] + prefill_token_ids = all_input_ids[ + cache_length : cache_length + input_length ] - prefill_token_ids = all_input_ids[:-1] past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] if past_prefill_logprob_tokens is None: # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] * ( - len(prefix_ids) + 1 + cache_length + 1 ) + request_prefill_logprobs - prefill_token_ids = prefix_ids + prefill_token_ids + prefill_token_ids = ( + all_input_ids[:cache_length] + prefill_token_ids + ) prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, @@ -2059,6 +2063,10 @@ def generate_token( skip_special_tokens=False, ) + # log_master(logger.info, f"{prefill_token_ids}") + # log_master(logger.info, f"{request_prefill_logprobs}") + # log_master(logger.info, f"{prefill_texts}") + prefill_logprob_tokens = Tokens( prefill_token_ids, request_prefill_logprobs, From 3ace1b2f8d32672b27af36f881a218cb1b0f34e2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:33:15 +0200 Subject: [PATCH 21/29] fix logprobs? --- .../models/flash_causal_lm.py | 57 ++++++++----------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8d33a2b3f54..b7202c04ff7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -956,7 +956,11 @@ def prepare_for_prefill(self): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) + prefill_head_indices.append(torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64 + )) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 ) @@ -966,7 +970,7 @@ def prepare_for_prefill(self): prefill_head_indices.append( torch.tensor( [cumulative_length + input_length - 1], - dtype=torch.int32, + dtype=torch.int64, ) ) prefill_next_token_indices.append(prefill_out_cumulative_length) @@ -1029,9 +1033,7 @@ def prepare_for_prefill(self): prefill_head_indices = cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) + prefill_head_indices = torch.cat(prefill_head_indices).to(device) prefill_next_token_indices = torch.tensor( prefill_next_token_indices, dtype=torch.int64, device=device ) @@ -1822,6 +1824,7 @@ def generate_token( # Zipped iterator iterator = zip( + batch.requests, batch.prompt_lengths, batch.cache_lengths, batch.input_lengths, @@ -1840,6 +1843,7 @@ def generate_token( # Cumulative length cumulative_length = 0 for i, ( + request, prompt_length, cache_length, input_length, @@ -1849,7 +1853,7 @@ def generate_token( request_is_prefilling, ) in enumerate(iterator): # Indexing metadata - start_index = cumulative_length + _start_index = cumulative_length end_index = cumulative_length + input_length if prefill: @@ -1869,25 +1873,16 @@ def generate_token( ] # Used to gather prefill logprobs - # Copy batch.input_ids to prefill_token_indices - if prefill_logprobs: - # If the request was prefilling and cache_length == 0, the first token is a bogus token - # and needs to be removed. We do so by incrementing the start_index - if request_was_prefilling and cache_length == 0: - start_index += 1 - - # If the request was prefilling, and it is done prefilling, the last token was generated and is - # therefore not part of the prefill. We remove it by decrementing out_end_index - if request_was_prefilling and not request_is_prefilling: - out_end_index -= 1 - + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ( - batch.input_ids[start_index:end_index] - ) + prefill_tokens_indices[out_start_index : out_end_index] = ids else: # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[start_index:end_index] + prefill_tokens_indices = ids if not request_is_prefilling: # Only save tokens if we are done prefilling for this request @@ -2031,30 +2026,30 @@ def generate_token( if request_was_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - - # log_master(logger.info, f"{prefill_logprobs}") - if not request_is_prefilling: - # If the request is done prefilling, then the last logprob is a generated token + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt # We need to remove it out_end_index -= 1 request_prefill_logprobs = prefill_logprobs[ out_start_index:out_end_index ] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 prefill_token_ids = all_input_ids[ - cache_length : cache_length + input_length + cache_length + 1 : cache_length + input_length + 1 ] past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] if past_prefill_logprob_tokens is None: - # Remove generated token to only have prefill and add nan for first prompt token + # add nan for cached prompt tokens/first token request_prefill_logprobs = [float("nan")] * ( cache_length + 1 ) + request_prefill_logprobs prefill_token_ids = ( - all_input_ids[:cache_length] + prefill_token_ids + all_input_ids[:cache_length + 1] + prefill_token_ids ) prefill_texts = self.tokenizer.batch_decode( @@ -2063,10 +2058,6 @@ def generate_token( skip_special_tokens=False, ) - # log_master(logger.info, f"{prefill_token_ids}") - # log_master(logger.info, f"{request_prefill_logprobs}") - # log_master(logger.info, f"{prefill_texts}") - prefill_logprob_tokens = Tokens( prefill_token_ids, request_prefill_logprobs, From 57f55fe8346bf9dbcf2c598eb610b2966249cd07 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 19:17:18 +0200 Subject: [PATCH 22/29] idk at this point --- .../models/flash_causal_lm.py | 63 +++++++++---------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b7202c04ff7..05bad9243e0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -758,11 +758,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids + slots[slots_start_index:slots_end_index] = batch.slots slot_indices[start_index:end_index] = ( batch.slot_indices + cumulative_slots ) input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor - slots[slots_start_index:slots_end_index] = batch.slots + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor # Copy over adapter indices adapter_start_index = cumulative_adapter_indices_size @@ -779,7 +780,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices, ) - cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor # Update cumulative_slots += len(batch.slots) @@ -1614,13 +1614,12 @@ def forward( input_lengths_tensor=input_lengths, cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + cache_lengths_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, - max_k=max_k, + max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -1852,46 +1851,44 @@ def generate_token( request_was_prefilling, request_is_prefilling, ) in enumerate(iterator): - # Indexing metadata - _start_index = cumulative_length - end_index = cumulative_length + input_length + if prefill and finished_prefilling: + # Indexing metadata + _start_index = cumulative_length + end_index = cumulative_length + input_length + + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] + + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] - if prefill: + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - if finished_prefilling: - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ - end_index - 1 - ] - - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] - if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] + if len(batch) > 1: + prefill_tokens_indices[out_start_index : out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids if not request_is_prefilling: # Only save tokens if we are done prefilling for this request for j in range(n_accepted_ids): batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( - next_input_ids[index] + next_input_ids[index + j] ) - index += 1 - + index += n_accepted_ids cumulative_length += input_length # Update values From d73c5c634dea446ca53b4a9ca3e1e6d28961c2a6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 19:39:14 +0200 Subject: [PATCH 23/29] max input length --- server/text_generation_server/models/flash_causal_lm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 05bad9243e0..7ebe3deab9b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1618,7 +1618,7 @@ def forward( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, + max_q=batch.max_input_length, max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( @@ -2236,8 +2236,6 @@ def _forward_context( use_prefill_with_paged_kv_state, ) - # has_cache_lengths = any(cache_length > 0 for cache_length in cache_lengths) - if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( state=( From d361197aab814aaa7e24b469b607b1c78090be72 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 20:04:06 +0200 Subject: [PATCH 24/29] omfg --- .../models/flash_causal_lm.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7ebe3deab9b..98de8c79817 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -956,11 +956,13 @@ def prepare_for_prefill(self): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append(torch.arange( - cumulative_length, - cumulative_length + input_length, - dtype=torch.int64 - )) + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 ) @@ -1875,9 +1877,11 @@ def generate_token( # Logprobs generated by the model are for the next token # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] + ids = batch.all_input_ids_tensor[ + i, cache_length + 1 : cache_length + input_length + 1 + ] if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index] = ids + prefill_tokens_indices[out_start_index:out_end_index] = ids else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = ids @@ -2014,6 +2018,12 @@ def generate_token( top_token_ids, top_token_logprobs, ) in enumerate(iterator): + if all_input_ids[:2] == [1986, 374] and not request_is_prefilling: + log_master( + logger.info, + f"{request.id} {next_token_ids} {self.tokenizer.batch_decode(next_token_ids)}", + ) + # Compute logprobs first as, even though we might skip the token, # it can still be required to compute the logprobs # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need @@ -2046,7 +2056,7 @@ def generate_token( cache_length + 1 ) + request_prefill_logprobs prefill_token_ids = ( - all_input_ids[:cache_length + 1] + prefill_token_ids + all_input_ids[: cache_length + 1] + prefill_token_ids ) prefill_texts = self.tokenizer.batch_decode( @@ -2114,7 +2124,6 @@ def generate_token( _next_token_logprobs = next_token_logprobs[ index : index + n_accepted_ids - left ] - index += n_accepted_ids # Shard generations # All generations will be appended in the rust sharded client @@ -2189,6 +2198,7 @@ def generate_token( ) # Update values + index += n_accepted_ids current_cache_length = cache_length + input_length batch.cache_lengths[i] = current_cache_length current_input_length = new_input_length From f85a308ef1e88fa65b7778c6a60b525774beac28 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 20:05:39 +0200 Subject: [PATCH 25/29] remove debugging lines --- server/text_generation_server/models/flash_causal_lm.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 98de8c79817..7e256dcfc7a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2018,12 +2018,6 @@ def generate_token( top_token_ids, top_token_logprobs, ) in enumerate(iterator): - if all_input_ids[:2] == [1986, 374] and not request_is_prefilling: - log_master( - logger.info, - f"{request.id} {next_token_ids} {self.tokenizer.batch_decode(next_token_ids)}", - ) - # Compute logprobs first as, even though we might skip the token, # it can still be required to compute the logprobs # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need From b7a1280f2530b0ffce95d8ff54d565082ccc6f15 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:52:09 +0200 Subject: [PATCH 26/29] fix tests --- backends/client/src/v3/client.rs | 4 +- backends/client/src/v3/sharded_client.rs | 4 +- backends/v3/src/client/grpc_client.rs | 4 +- backends/v3/src/client/sharded_client.rs | 4 +- backends/v3/src/queue.rs | 44 +++++----- benchmark/src/generation.rs | 4 +- integration-tests/conftest.py | 26 +++++- .../models/test_flash_pali_gemma.py | 24 ++---- integration-tests/models/test_idefics.py | 24 +----- integration-tests/models/test_idefics2.py | 28 ++---- integration-tests/models/test_llava_next.py | 14 +-- integration-tests/models/test_mllama.py | 15 ---- proto/v3/generate.proto | 10 ++- .../models/flash_causal_lm.py | 37 +++++--- .../models/mllama_causal_lm.py | 86 +++++++++++-------- .../models/vlm_causal_lm.py | 14 +-- 16 files changed, 162 insertions(+), 180 deletions(-) diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 8280795daf0..d43f789e7ca 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -158,8 +158,8 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], - prefix_len: 0, - postfix_len: truncate, + cache_len: 0, + chunk_len: None, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 39e99776107..854a5895eba 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -246,8 +246,8 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), - prefix_len: 0, - postfix_len: 1, + cache_len: 0, + chunk_len: None, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 804c77d4163..fe810f24742 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -158,8 +158,8 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], - prefix_len: 0, - postfix_len: truncate, + cache_len: 0, + chunk_len: None, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index e25bf71e55d..e181cd28d2f 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -235,9 +235,9 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), - prefix_len: 0, + cache_len: 0, adapter_id: None, - postfix_len: 1, + chunk_len: None, }; let batch = Batch { id: u64::MAX, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index a07c725cbc7..36fbed87c18 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -280,7 +280,7 @@ impl State { continue; } - let (block_allocation, postfix_len) = match &self.block_allocator { + let block_allocation = match &self.block_allocator { None => { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation @@ -297,7 +297,7 @@ impl State { self.entries.push_front((id, entry)); break 'entry_loop; } - (None, entry.request.input_length) + None } Some(block_allocator) => { // If users wants the prefill logprobs, we cannot reuse the cache. @@ -337,7 +337,7 @@ impl State { } }; - let mut postfix_len = entry.request.input_length - block_allocation.prefix_len; + let postfix_len = entry.request.input_length - block_allocation.prefix_len; // Check equality too as if we don't we might end up with a postfix_len = 0 // in the next iteration of the loop @@ -345,9 +345,9 @@ impl State { // Entry is over budget if self.support_chunking { // We support chunking, just set postfix_len to exactly match prefill_token_budget - postfix_len = prefill_token_budget - prefill_tokens; + let chunk_len = prefill_token_budget - prefill_tokens; // Push this entry inside the batch - batch.push((id, entry, Some(block_allocation), postfix_len)); + batch.push((id, entry, Some(block_allocation), Some(chunk_len))); break 'entry_loop; } else { // We don't support chunking, this entry needs to go back to the buffer @@ -363,10 +363,10 @@ impl State { prefill_tokens += postfix_len; - (Some(block_allocation), postfix_len) + Some(block_allocation) } }; - batch.push((id, entry, block_allocation, postfix_len)); + batch.push((id, entry, block_allocation, None)); if Some(batch.len()) == max_size { break; } @@ -395,7 +395,7 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - for (id, mut entry, block_allocation, postfix_len) in batch { + for (id, mut entry, block_allocation, chunk_len) in batch { // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships @@ -447,9 +447,9 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, - prefix_len, + cache_len: prefix_len, adapter_id: entry.request.adapter_id.clone(), - postfix_len, + chunk_len, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -582,7 +582,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -598,7 +598,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -606,7 +606,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -638,7 +638,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -658,7 +658,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, false, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 2, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -691,14 +691,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -706,7 +706,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -739,7 +739,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -755,7 +755,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -780,7 +780,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, false, None, 2, 16); + let queue = Queue::new(false, 1, false, None, 2, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -799,7 +799,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry, _) = default_entry(); queue.append(entry); diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 43a84e7023a..63fc780818b 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -158,8 +158,8 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], - prefix_len: 0, - postfix_len: sequence_length, + cache_len: 0, + chunk_len: None, adapter_id: None, }) .collect(); diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index dbe69244fed..dfbff7e5cb3 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -9,13 +9,16 @@ import sys import tempfile import time -from typing import Dict, List, Optional - import docker import pytest +import base64 + +from pathlib import Path +from typing import Dict, List, Optional from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from docker.errors import NotFound from syrupy.extensions.json import JSONSnapshotExtension + from text_generation import AsyncClient from text_generation.types import ( BestOfSequence, @@ -639,3 +642,22 @@ async def generate_load_inner( return responses return generate_load_inner + + +# TODO fix the server parsser to count inline image tokens correctly +@pytest.fixture +def chicken(): + path = Path(__file__).parent / "images" / "chicken_on_money.png" + + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.fixture +def cow_beach(): + path = Path(__file__).parent / "images" / "cow_beach.png" + + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index 52ecaed4612..93962eb3e8f 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -1,5 +1,4 @@ import pytest -import base64 @pytest.fixture(scope="module") @@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle): return flash_pali_gemma_handle.client -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): - cow = get_cow_beach() - inputs = f"![]({cow})Where is the cow standing?\n" +async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach): + inputs = f"![]({cow_beach})Where is the cow standing?\n" response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) assert response.generated_text == "beach" @@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_flash_pali_gemma_two_images( + flash_pali_gemma, response_snapshot, chicken, cow_beach +): response = await flash_pali_gemma.generate( f"caption![]({chicken})![]({cow_beach})\n", max_new_tokens=20, diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index eb573385b7c..e5d08bb74c6 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -1,5 +1,4 @@ import pytest -import base64 @pytest.fixture(scope="module") @@ -16,22 +15,8 @@ async def idefics(idefics_handle): return idefics_handle.client -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.asyncio -async def test_idefics(idefics, response_snapshot): - chicken = get_chicken() +async def test_idefics(idefics, response_snapshot, chicken): response = await idefics.generate( f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, @@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_idefics_two_images(idefics, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach): response = await idefics.generate( f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", max_new_tokens=20, @@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot): @pytest.mark.release @pytest.mark.asyncio -async def test_idefics_load(idefics, generate_load, response_snapshot): - chicken = get_chicken() +async def test_idefics_load(idefics, generate_load, response_snapshot, chicken): responses = await generate_load( idefics, f"User:![]({chicken})Can you tell me a very short story based on the image?", diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index c5f48da3525..881e37f9b95 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -1,18 +1,4 @@ import pytest -import base64 - - -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" @pytest.fixture(scope="module") @@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot): - chicken = get_chicken() +async def test_flash_idefics2_next_simple( + flash_idefics2_next, response_snapshot, chicken +): response = await flash_idefics2_next.generate( f"User:![]({chicken})Write me a short story \nAssistant:", max_new_tokens=10, @@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_flash_idefics2_two_images( + flash_idefics2_next, response_snapshot, chicken, cow_beach +): response = await flash_idefics2_next.generate( f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", max_new_tokens=20, @@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap @pytest.mark.asyncio @pytest.mark.private async def test_flash_idefics2_next_load( - flash_idefics2_next, generate_load, response_snapshot + flash_idefics2_next, generate_load, response_snapshot, chicken ): - chicken = get_chicken() responses = await generate_load( flash_idefics2_next, f"User:![]({chicken})Write me a short story \nAssistant:", diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py index ea277d713e0..1ac8f172db7 100644 --- a/integration-tests/models/test_llava_next.py +++ b/integration-tests/models/test_llava_next.py @@ -1,12 +1,4 @@ import pytest -import base64 - - -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" @pytest.fixture(scope="module") @@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): - chicken = get_chicken() +async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken): response = await flash_llava_next.generate( f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, @@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_load( - flash_llava_next, generate_load, response_snapshot + flash_llava_next, generate_load, response_snapshot, chicken ): - chicken = get_chicken() responses = await generate_load( flash_llava_next, f"User:![]({chicken})Can you tell me a very short story based on the image?", diff --git a/integration-tests/models/test_mllama.py b/integration-tests/models/test_mllama.py index 1b4264aacd1..02781707e05 100644 --- a/integration-tests/models/test_mllama.py +++ b/integration-tests/models/test_mllama.py @@ -1,5 +1,4 @@ import pytest -import base64 import asyncio @@ -15,22 +14,8 @@ async def mllama(mllama_handle): return mllama_handle.client -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.asyncio async def test_mllama_simpl(mllama, response_snapshot): - # chicken = get_chicken() response = await mllama.chat( max_tokens=10, temperature=0.0, diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index e4dfefefefd..c91e7cc43b2 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -139,12 +139,14 @@ message Request { repeated uint32 slots = 10; /// LORA adapter index optional string adapter_id = 11; - /// Prefix length that can be retrieved from the KV cache. - uint32 prefix_len = 12; + /// Tokens that can be retrieved from the KV cache. + /// This value is set for the first prefill and never reset + uint32 cache_len = 12; /// Context truncation bool add_special_tokens = 13; - /// Postfix length for prefill chunking - uint32 postfix_len = 14; + /// Chunk of tokens that must be computed for the first prefill + /// This value is set for the first prefill and never reset + optional uint32 chunk_len = 14; } message Batch { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7e256dcfc7a..8222722aafe 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -280,24 +280,36 @@ def from_tokenized( prompt_length = len(tokenized_input) prompt_lengths.append(prompt_length) - cache_length = r.prefix_len - input_length = r.postfix_len + cache_length = r.cache_len + assert ( cache_length <= prompt_length ), f"Prefix {cache_length} vs input {prompt_length}" if cache_length == prompt_length: assert False, "unreachable" - if cache_length + input_length < prompt_length: - # FIXME: speculate is not supported for context chunking at the moment - assert speculate == 0 - assert get_support_chunking() - assert input_length > 0 - postfix_ids = tokenized_input[cache_length : cache_length + input_length] + # `chunk_len` is an optional field in the protobuf + # It is only set if the model support chunking + if r.HasField("chunk_len"): + input_length = r.chunk_len + + if cache_length + input_length < prompt_length: + # FIXME: speculate is not supported for context chunking at the moment + assert speculate == 0 + assert get_support_chunking() + assert input_length > 0 + + postfix_ids = tokenized_input[ + cache_length : cache_length + input_length + ] + assert ( + len(postfix_ids) == input_length + ), "Rust and Python tokenizers are not aligned" + else: + # Use all the remaining ids + postfix_ids = tokenized_input[cache_length:] + input_length = len(postfix_ids) - assert ( - len(postfix_ids) == input_length - ), "Rust and Python tokenizers are not aligned" input_lengths.append(input_length) prefix_offsets.append(prompt_length - 5) @@ -1097,6 +1109,7 @@ def __init__( head_size: Optional[int] = None, skip_special_tokens: bool = True, kv_cache_dtype: Optional[torch.dtype] = None, + support_chunking: bool = True, ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() @@ -1224,7 +1237,7 @@ def __init__( rank=rank, world_size=world_size, sliding_window=config.sliding_window, - support_chunking=True, + support_chunking=support_chunking, ) @property diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 3aa475c3179..83e44039e12 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -1,14 +1,17 @@ -from io import BytesIO -from PIL import Image import torch + +import numpy as np + from typing import Iterable, Optional, Tuple, List, Dict from text_generation_server.pb.generate_pb2 import Request - +from io import BytesIO +from PIL import Image from dataclasses import dataclass from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) + from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( @@ -167,6 +170,13 @@ def from_pb_processor( batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( max=config.text_config.vocab_size - 1 ) + if isinstance(batch.input_ids, list): + if len(batch) > 1: + input_ids = np.concatenate(batch.input_ids, dtype=np.int64) + else: + input_ids = batch.input_ids[0] + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) if image_inputs is not None: @@ -190,7 +200,7 @@ def from_pb_processor( class MllamaCausalLM(VlmCausalLM): def forward( self, - batch: VlmCausalLMBatch, + batch: MllamaCausalLMBatch, adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward @@ -202,7 +212,7 @@ def forward( block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -221,8 +231,8 @@ def forward( input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -244,8 +254,8 @@ def forward( block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -254,7 +264,6 @@ def forward( # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -269,38 +278,25 @@ def forward( # Only run cuda graphs when there's no images. or batch.cross_attention_states is not None ): - input_lengths = input_lengths + prefix_lens_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, - cache_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, + max_q=batch.max_input_length, + max_k=batch.max_current_length, ) - - if batch.pixel_values is not None: - cross_attention_states = self.model.vision_forward( - pixel_values=batch.pixel_values, - aspect_ratio_ids=batch.aspect_ratio_ids, - aspect_ratio_mask=batch.aspect_ratio_mask, - ) - batch.cross_attention_states = cross_attention_states - - cross_attention_states = batch.cross_attention_states - logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -312,14 +308,18 @@ def forward( max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, - adapter_data=adapter_data, - image_indices=batch.image_indices[:], + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph @@ -330,22 +330,34 @@ def forward( block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor - # Replay the graph - cuda_graph["graph"].replay() + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() # Slice output to the correct shape speculative_logits = ( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index a06add13f6e..150cf0d07d7 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -271,6 +271,8 @@ def __init__( model_id=model_id, revision=revision, trust_remote_code=trust_remote_code, + # FIXME: VLM do not work with context chunking yet + support_chunking=False, **kwargs, ) @@ -356,7 +358,7 @@ def forward( else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - if PREFIX_CACHING: + if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, @@ -368,13 +370,12 @@ def forward( input_lengths_tensor=input_lengths, cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + cache_lengths_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, + max_q=batch.max_input_length, + max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -416,7 +417,10 @@ def forward( cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables - cuda_graph["slots"].fill_(-1) + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. + cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths From f923a3fb6823069a984eda07f1e086f83363ce07 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:01:18 +0200 Subject: [PATCH 27/29] fix mllama --- backends/v3/src/queue.rs | 2 +- .../models/mllama_causal_lm.py | 21 ++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 36fbed87c18..414045a1a8c 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -546,7 +546,7 @@ mod tests { request: ValidGenerateRequest { inputs: vec![], input_ids: Some(Arc::new(vec![])), - input_length: 0, + input_length: 1, add_special_tokens: true, truncate: 0, decoder_input_details: false, diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 83e44039e12..6399f92c14c 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -297,6 +297,17 @@ def forward( max_q=batch.max_input_length, max_k=batch.max_current_length, ) + + if batch.pixel_values is not None: + cross_attention_states = self.model.vision_forward( + pixel_values=batch.pixel_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, + ) + batch.cross_attention_states = cross_attention_states + + cross_attention_states = batch.cross_attention_states + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -308,18 +319,14 @@ def forward( max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, + cross_attention_states=cross_attention_states, + adapter_data=adapter_data, + image_indices=batch.image_indices[:], ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph From df982999198991ad972f415d0e3d073c0f5d0fd3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:54:42 +0200 Subject: [PATCH 28/29] fix cargo tests --- backends/v3/src/queue.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 414045a1a8c..6662b8de1f9 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -339,15 +339,23 @@ impl State { let postfix_len = entry.request.input_length - block_allocation.prefix_len; - // Check equality too as if we don't we might end up with a postfix_len = 0 - // in the next iteration of the loop - if prefill_tokens + postfix_len >= prefill_token_budget { + if prefill_tokens + postfix_len > prefill_token_budget { // Entry is over budget if self.support_chunking { // We support chunking, just set postfix_len to exactly match prefill_token_budget - let chunk_len = prefill_token_budget - prefill_tokens; - // Push this entry inside the batch - batch.push((id, entry, Some(block_allocation), Some(chunk_len))); + let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens); + if chunk_len > 0 { + // Push this entry inside the batch + batch.push((id, entry, Some(block_allocation), Some(chunk_len))); + } else { + // We cannot prefill even one token for this entry + // Add it back to the queue + self.entries.push_front((id, entry)); + } + tracing::debug!( + "Matched budget: prefill_tokens={} == {prefill_token_budget}", + prefill_tokens + postfix_len + ); break 'entry_loop; } else { // We don't support chunking, this entry needs to go back to the buffer @@ -658,7 +666,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, false, None, 0, 2, false); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -780,7 +788,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, false, None, 2, 16, false); + let queue = Queue::new(true, 1, false, None, 2, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); From 5e70158b2c7bef801933cca487424aa3d174d3a5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:19:14 +0200 Subject: [PATCH 29/29] remove support chunking for paged --- server/text_generation_server/models/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 05d36ba315b..1da6e3e34b4 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -72,6 +72,12 @@ def __init__( "Prefill chunking will be turned off", ) support_chunking = False + if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: + log_master( + logger.warning, + "Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", + ) + support_chunking = False self.support_chunking = support_chunking set_support_chunking(support_chunking)