Skip to content

Commit

Permalink
idk at this point
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Oct 9, 2024
1 parent 9c11f1e commit 1b43d2d
Showing 1 changed file with 30 additions and 33 deletions.
63 changes: 30 additions & 33 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,11 +758,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch

input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slots[slots_start_index:slots_end_index] = batch.slots
slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots
)
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor

# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
Expand All @@ -779,7 +780,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor

# Update
cumulative_slots += len(batch.slots)
Expand Down Expand Up @@ -1614,13 +1614,12 @@ def forward(
input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor,
):
max_k = (input_lengths + cache_lengths_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
max_k=batch.max_current_length,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
Expand Down Expand Up @@ -1852,46 +1851,44 @@ def generate_token(
request_was_prefilling,
request_is_prefilling,
) in enumerate(iterator):
# Indexing metadata
_start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill and finished_prefilling:
# Indexing metadata
_start_index = cumulative_length
end_index = cumulative_length + input_length

# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]

# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]

if prefill:
# Used to gather prefill logprobs
# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]

if finished_prefilling:
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]

# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]

# Used to gather prefill logprobs
# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids

if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
next_input_ids[index]
next_input_ids[index + j]
)
index += 1

index += n_accepted_ids
cumulative_length += input_length

# Update values
Expand Down

0 comments on commit 1b43d2d

Please sign in to comment.