From 9c11f1eb2fa369fb89064af9485d7219fd761a31 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:33:15 +0200 Subject: [PATCH] fix logprobs? --- .../models/flash_causal_lm.py | 57 ++++++++----------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8d33a2b3f54..b7202c04ff7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -956,7 +956,11 @@ def prepare_for_prefill(self): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) + prefill_head_indices.append(torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64 + )) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 ) @@ -966,7 +970,7 @@ def prepare_for_prefill(self): prefill_head_indices.append( torch.tensor( [cumulative_length + input_length - 1], - dtype=torch.int32, + dtype=torch.int64, ) ) prefill_next_token_indices.append(prefill_out_cumulative_length) @@ -1029,9 +1033,7 @@ def prepare_for_prefill(self): prefill_head_indices = cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) + prefill_head_indices = torch.cat(prefill_head_indices).to(device) prefill_next_token_indices = torch.tensor( prefill_next_token_indices, dtype=torch.int64, device=device ) @@ -1822,6 +1824,7 @@ def generate_token( # Zipped iterator iterator = zip( + batch.requests, batch.prompt_lengths, batch.cache_lengths, batch.input_lengths, @@ -1840,6 +1843,7 @@ def generate_token( # Cumulative length cumulative_length = 0 for i, ( + request, prompt_length, cache_length, input_length, @@ -1849,7 +1853,7 @@ def generate_token( request_is_prefilling, ) in enumerate(iterator): # Indexing metadata - start_index = cumulative_length + _start_index = cumulative_length end_index = cumulative_length + input_length if prefill: @@ -1869,25 +1873,16 @@ def generate_token( ] # Used to gather prefill logprobs - # Copy batch.input_ids to prefill_token_indices - if prefill_logprobs: - # If the request was prefilling and cache_length == 0, the first token is a bogus token - # and needs to be removed. We do so by incrementing the start_index - if request_was_prefilling and cache_length == 0: - start_index += 1 - - # If the request was prefilling, and it is done prefilling, the last token was generated and is - # therefore not part of the prefill. We remove it by decrementing out_end_index - if request_was_prefilling and not request_is_prefilling: - out_end_index -= 1 - + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ( - batch.input_ids[start_index:end_index] - ) + prefill_tokens_indices[out_start_index : out_end_index] = ids else: # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[start_index:end_index] + prefill_tokens_indices = ids if not request_is_prefilling: # Only save tokens if we are done prefilling for this request @@ -2031,30 +2026,30 @@ def generate_token( if request_was_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - - # log_master(logger.info, f"{prefill_logprobs}") - if not request_is_prefilling: - # If the request is done prefilling, then the last logprob is a generated token + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt # We need to remove it out_end_index -= 1 request_prefill_logprobs = prefill_logprobs[ out_start_index:out_end_index ] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 prefill_token_ids = all_input_ids[ - cache_length : cache_length + input_length + cache_length + 1 : cache_length + input_length + 1 ] past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] if past_prefill_logprob_tokens is None: - # Remove generated token to only have prefill and add nan for first prompt token + # add nan for cached prompt tokens/first token request_prefill_logprobs = [float("nan")] * ( cache_length + 1 ) + request_prefill_logprobs prefill_token_ids = ( - all_input_ids[:cache_length] + prefill_token_ids + all_input_ids[:cache_length + 1] + prefill_token_ids ) prefill_texts = self.tokenizer.batch_decode( @@ -2063,10 +2058,6 @@ def generate_token( skip_special_tokens=False, ) - # log_master(logger.info, f"{prefill_token_ids}") - # log_master(logger.info, f"{request_prefill_logprobs}") - # log_master(logger.info, f"{prefill_texts}") - prefill_logprob_tokens = Tokens( prefill_token_ids, request_prefill_logprobs,