Skip to content

Commit

Permalink
fix prefill logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Oct 7, 2024
1 parent 09f5cb1 commit 9c6aed1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,6 +1757,7 @@ def generate_token(

finished_prefilling = True
next_chunk_lengths = []
current_prefilling_mask = batch.prefilling_mask
if prefill:
if get_support_chunking():
next_prefilling_mask = []
Expand Down Expand Up @@ -1998,6 +1999,7 @@ def generate_token(
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
current_prefilling_mask,
batch.prefilling_mask,
accepted_ids,
batch_top_token_ids,
Expand All @@ -2021,7 +2023,8 @@ def generate_token(
do_sample,
seed,
top_n_tokens,
request_prefilling,
request_was_prefilling,
request_is_prefilling,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
Expand All @@ -2032,7 +2035,7 @@ def generate_token(
# this state to be stable
if request.id % self.world_size == self.rank:
# Prefill
if request_prefilling and request.prefill_logprobs:
if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]

Expand Down Expand Up @@ -2072,7 +2075,7 @@ def generate_token(
batch.prefill_logprob_tokens[i] = None

# If it is, the tokens we decoded should be ignored
if request_prefilling:
if request_is_prefilling:
# Make sure that we do not stop as even though this request did not create a token, it is still
# processing
stopped = False
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def Prefill(self, request, context):
f"Batch ID {request.cached_batch.id} not found in cache."
)
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([batch, cached_batch])
batch = self.model.batch_type.concatenate([cached_batch, batch])
concat_ns = time.time_ns() - start_concat

generations, next_batch, timings = self.model.generate_token(batch)
Expand Down

0 comments on commit 9c6aed1

Please sign in to comment.