@@ -758,11 +758,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
758
758
759
759
input_ids [start_index :end_index ] = batch .input_ids
760
760
position_ids [start_index :end_index ] = batch .position_ids
761
+ slots [slots_start_index :slots_end_index ] = batch .slots
761
762
slot_indices [start_index :end_index ] = (
762
763
batch .slot_indices + cumulative_slots
763
764
)
764
765
input_lengths_tensor [start_index :end_index ] = batch .input_lengths_tensor
765
- slots [ slots_start_index : slots_end_index ] = batch .slots
766
+ cache_lengths_tensor [ start_index : end_index ] = batch .cache_lengths_tensor
766
767
767
768
# Copy over adapter indices
768
769
adapter_start_index = cumulative_adapter_indices_size
@@ -779,7 +780,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
779
780
batch .adapter_meta .adapter_segments ,
780
781
batch .adapter_meta .segment_indices ,
781
782
)
782
- cache_lengths_tensor [start_index :end_index ] = batch .cache_lengths_tensor
783
783
784
784
# Update
785
785
cumulative_slots += len (batch .slots )
@@ -1614,13 +1614,12 @@ def forward(
1614
1614
input_lengths_tensor = input_lengths ,
1615
1615
cache_lengths_tensor = cache_lengths_tensor ,
1616
1616
):
1617
- max_k = (input_lengths + cache_lengths_tensor ).max ().item ()
1618
1617
seqlen = Seqlen (
1619
1618
input_lengths = input_lengths ,
1620
1619
cache_lengths = cache_lengths_tensor ,
1621
1620
cu_seqlen_q = cu_seqlen_prefill ,
1622
1621
max_q = max_s ,
1623
- max_k = max_k ,
1622
+ max_k = batch . max_current_length ,
1624
1623
)
1625
1624
logits , speculative_logits = self .model .forward (
1626
1625
input_ids = input_ids ,
@@ -1852,46 +1851,44 @@ def generate_token(
1852
1851
request_was_prefilling ,
1853
1852
request_is_prefilling ,
1854
1853
) in enumerate (iterator ):
1855
- # Indexing metadata
1856
- _start_index = cumulative_length
1857
- end_index = cumulative_length + input_length
1854
+ if prefill and finished_prefilling :
1855
+ # Indexing metadata
1856
+ _start_index = cumulative_length
1857
+ end_index = cumulative_length + input_length
1858
+
1859
+ # Initialize position_ids
1860
+ # In decode, we do not need this as we can just increment position ids
1861
+ next_position_ids [i ] = batch .position_ids [end_index - 1 ]
1862
+
1863
+ # Initialize adapter indices
1864
+ # In decode, we only have one token per row in the batch, so grab last index
1865
+ next_adapter_indices [i ] = batch .adapter_meta .adapter_indices [
1866
+ end_index - 1
1867
+ ]
1858
1868
1859
- if prefill :
1869
+ # Used to gather prefill logprobs
1870
+ # Copy batch.all_input_ids_tensor to prefill_token_indices
1871
+ if request .prefill_logprobs and request_was_prefilling :
1860
1872
# Indexing metadata
1861
1873
out_start_index = batch .prefill_cu_outlens [i ]
1862
1874
out_end_index = batch .prefill_cu_outlens [i + 1 ]
1863
1875
1864
- if finished_prefilling :
1865
- # Initialize position_ids
1866
- # In decode, we do not need this as we can just increment position ids
1867
- next_position_ids [i ] = batch .position_ids [end_index - 1 ]
1868
-
1869
- # Initialize adapter indices
1870
- # In decode, we only have one token per row in the batch, so grab last index
1871
- next_adapter_indices [i ] = batch .adapter_meta .adapter_indices [
1872
- end_index - 1
1873
- ]
1874
-
1875
- # Used to gather prefill logprobs
1876
- # Copy batch.all_input_ids_tensor to prefill_token_indices
1877
- if request .prefill_logprobs and request_was_prefilling :
1878
- # Logprobs generated by the model are for the next token
1879
- # So we need to translate the id tensor by 1
1880
- ids = batch .all_input_ids_tensor [i , cache_length + 1 : cache_length + input_length + 1 ]
1881
- if len (batch ) > 1 :
1882
- prefill_tokens_indices [out_start_index : out_end_index ] = ids
1883
- else :
1884
- # Set prefill_tokens_indices to the correct slice
1885
- prefill_tokens_indices = ids
1876
+ # Logprobs generated by the model are for the next token
1877
+ # So we need to translate the id tensor by 1
1878
+ ids = batch .all_input_ids_tensor [i , cache_length + 1 : cache_length + input_length + 1 ]
1879
+ if len (batch ) > 1 :
1880
+ prefill_tokens_indices [out_start_index : out_end_index ] = ids
1881
+ else :
1882
+ # Set prefill_tokens_indices to the correct slice
1883
+ prefill_tokens_indices = ids
1886
1884
1887
1885
if not request_is_prefilling :
1888
1886
# Only save tokens if we are done prefilling for this request
1889
1887
for j in range (n_accepted_ids ):
1890
1888
batch .all_input_ids_tensor [i , cache_length + input_length + j ] = (
1891
- next_input_ids [index ]
1889
+ next_input_ids [index + j ]
1892
1890
)
1893
- index += 1
1894
-
1891
+ index += n_accepted_ids
1895
1892
cumulative_length += input_length
1896
1893
1897
1894
# Update values
0 commit comments