Skip to content

Commit

Permalink
fix logprobs?
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Oct 9, 2024
1 parent 9290c6e commit 9c11f1e
Showing 1 changed file with 24 additions and 33 deletions.
57 changes: 24 additions & 33 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,11 @@ def prepare_for_prefill(self):
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs

if prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_head_indices.append(torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64
))
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
Expand All @@ -966,7 +970,7 @@ def prepare_for_prefill(self):
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int32,
dtype=torch.int64,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
Expand Down Expand Up @@ -1029,9 +1033,7 @@ def prepare_for_prefill(self):
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
Expand Down Expand Up @@ -1822,6 +1824,7 @@ def generate_token(

# Zipped iterator
iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths,
Expand All @@ -1840,6 +1843,7 @@ def generate_token(
# Cumulative length
cumulative_length = 0
for i, (
request,
prompt_length,
cache_length,
input_length,
Expand All @@ -1849,7 +1853,7 @@ def generate_token(
request_is_prefilling,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
_start_index = cumulative_length
end_index = cumulative_length + input_length

if prefill:
Expand All @@ -1869,25 +1873,16 @@ def generate_token(
]

# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
# If the request was prefilling and cache_length == 0, the first token is a bogus token
# and needs to be removed. We do so by incrementing the start_index
if request_was_prefilling and cache_length == 0:
start_index += 1

# If the request was prefilling, and it is done prefilling, the last token was generated and is
# therefore not part of the prefill. We remove it by decrementing out_end_index
if request_was_prefilling and not request_is_prefilling:
out_end_index -= 1

# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
if len(batch) > 1:
prefill_tokens_indices[out_start_index:out_end_index] = (
batch.input_ids[start_index:end_index]
)
prefill_tokens_indices[out_start_index : out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[start_index:end_index]
prefill_tokens_indices = ids

if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
Expand Down Expand Up @@ -2031,30 +2026,30 @@ def generate_token(
if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]

# log_master(logger.info, f"{prefill_logprobs}")

if not request_is_prefilling:
# If the request is done prefilling, then the last logprob is a generated token
# The request is dones prefilling, meaning that we started generating new tokens
# The last logprob is a logprob for a generated token that was not part of the prompt
# We need to remove it
out_end_index -= 1

request_prefill_logprobs = prefill_logprobs[
out_start_index:out_end_index
]
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
prefill_token_ids = all_input_ids[
cache_length : cache_length + input_length
cache_length + 1 : cache_length + input_length + 1
]

past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]

if past_prefill_logprob_tokens is None:
# Remove generated token to only have prefill and add nan for first prompt token
# add nan for cached prompt tokens/first token
request_prefill_logprobs = [float("nan")] * (
cache_length + 1
) + request_prefill_logprobs
prefill_token_ids = (
all_input_ids[:cache_length] + prefill_token_ids
all_input_ids[:cache_length + 1] + prefill_token_ids
)

prefill_texts = self.tokenizer.batch_decode(
Expand All @@ -2063,10 +2058,6 @@ def generate_token(
skip_special_tokens=False,
)

# log_master(logger.info, f"{prefill_token_ids}")
# log_master(logger.info, f"{request_prefill_logprobs}")
# log_master(logger.info, f"{prefill_texts}")

prefill_logprob_tokens = Tokens(
prefill_token_ids,
request_prefill_logprobs,
Expand Down

0 comments on commit 9c11f1e

Please sign in to comment.