Skip to content

Commit 1b43d2d

Browse files
idk at this point
1 parent 9c11f1e commit 1b43d2d

File tree

1 file changed

+30
-33
lines changed

1 file changed

+30
-33
lines changed

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -758,11 +758,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
758758

759759
input_ids[start_index:end_index] = batch.input_ids
760760
position_ids[start_index:end_index] = batch.position_ids
761+
slots[slots_start_index:slots_end_index] = batch.slots
761762
slot_indices[start_index:end_index] = (
762763
batch.slot_indices + cumulative_slots
763764
)
764765
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
765-
slots[slots_start_index:slots_end_index] = batch.slots
766+
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
766767

767768
# Copy over adapter indices
768769
adapter_start_index = cumulative_adapter_indices_size
@@ -779,7 +780,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
779780
batch.adapter_meta.adapter_segments,
780781
batch.adapter_meta.segment_indices,
781782
)
782-
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
783783

784784
# Update
785785
cumulative_slots += len(batch.slots)
@@ -1614,13 +1614,12 @@ def forward(
16141614
input_lengths_tensor=input_lengths,
16151615
cache_lengths_tensor=cache_lengths_tensor,
16161616
):
1617-
max_k = (input_lengths + cache_lengths_tensor).max().item()
16181617
seqlen = Seqlen(
16191618
input_lengths=input_lengths,
16201619
cache_lengths=cache_lengths_tensor,
16211620
cu_seqlen_q=cu_seqlen_prefill,
16221621
max_q=max_s,
1623-
max_k=max_k,
1622+
max_k=batch.max_current_length,
16241623
)
16251624
logits, speculative_logits = self.model.forward(
16261625
input_ids=input_ids,
@@ -1852,46 +1851,44 @@ def generate_token(
18521851
request_was_prefilling,
18531852
request_is_prefilling,
18541853
) in enumerate(iterator):
1855-
# Indexing metadata
1856-
_start_index = cumulative_length
1857-
end_index = cumulative_length + input_length
1854+
if prefill and finished_prefilling:
1855+
# Indexing metadata
1856+
_start_index = cumulative_length
1857+
end_index = cumulative_length + input_length
1858+
1859+
# Initialize position_ids
1860+
# In decode, we do not need this as we can just increment position ids
1861+
next_position_ids[i] = batch.position_ids[end_index - 1]
1862+
1863+
# Initialize adapter indices
1864+
# In decode, we only have one token per row in the batch, so grab last index
1865+
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
1866+
end_index - 1
1867+
]
18581868

1859-
if prefill:
1869+
# Used to gather prefill logprobs
1870+
# Copy batch.all_input_ids_tensor to prefill_token_indices
1871+
if request.prefill_logprobs and request_was_prefilling:
18601872
# Indexing metadata
18611873
out_start_index = batch.prefill_cu_outlens[i]
18621874
out_end_index = batch.prefill_cu_outlens[i + 1]
18631875

1864-
if finished_prefilling:
1865-
# Initialize position_ids
1866-
# In decode, we do not need this as we can just increment position ids
1867-
next_position_ids[i] = batch.position_ids[end_index - 1]
1868-
1869-
# Initialize adapter indices
1870-
# In decode, we only have one token per row in the batch, so grab last index
1871-
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
1872-
end_index - 1
1873-
]
1874-
1875-
# Used to gather prefill logprobs
1876-
# Copy batch.all_input_ids_tensor to prefill_token_indices
1877-
if request.prefill_logprobs and request_was_prefilling:
1878-
# Logprobs generated by the model are for the next token
1879-
# So we need to translate the id tensor by 1
1880-
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
1881-
if len(batch) > 1:
1882-
prefill_tokens_indices[out_start_index : out_end_index] = ids
1883-
else:
1884-
# Set prefill_tokens_indices to the correct slice
1885-
prefill_tokens_indices = ids
1876+
# Logprobs generated by the model are for the next token
1877+
# So we need to translate the id tensor by 1
1878+
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
1879+
if len(batch) > 1:
1880+
prefill_tokens_indices[out_start_index : out_end_index] = ids
1881+
else:
1882+
# Set prefill_tokens_indices to the correct slice
1883+
prefill_tokens_indices = ids
18861884

18871885
if not request_is_prefilling:
18881886
# Only save tokens if we are done prefilling for this request
18891887
for j in range(n_accepted_ids):
18901888
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
1891-
next_input_ids[index]
1889+
next_input_ids[index + j]
18921890
)
1893-
index += 1
1894-
1891+
index += n_accepted_ids
18951892
cumulative_length += input_length
18961893

18971894
# Update values

0 commit comments

Comments
 (0)