11
11
import weakref
12
12
from collections import deque , namedtuple
13
13
from contextlib import contextmanager
14
- from typing import Dict , List , Optional , Tuple , Union
14
+ from typing import Dict , List , Optional , Union
15
15
16
16
import torch
17
17
@@ -308,7 +308,7 @@ def __init__(self,
308
308
if is_trace_enabled ("TLLM_TRACE_EXECUTOR_LOOP" ):
309
309
self .event_loop = trace_func (self .event_loop )
310
310
311
- if self .draft_model_engine is not None :
311
+ if self .drafter is not None :
312
312
if self .event_loop .__name__ != self ._executor_loop .__name__ :
313
313
raise NotImplementedError (
314
314
"Drafting is not supported for selected executor loop. "
@@ -905,10 +905,6 @@ def _executor_loop_pp(self):
905
905
906
906
def _executor_loop (self ):
907
907
torch .cuda .set_device (self .device_id )
908
- is_ngram = hasattr (
909
- self .model_engine , "spec_config"
910
- ) and self .model_engine .spec_config is not None and self .model_engine .spec_config .spec_dec_mode .is_ngram (
911
- )
912
908
with self ._profiler () as profile_step :
913
909
sample_state = None
914
910
iter_start_time = time .time ()
@@ -931,7 +927,7 @@ def _executor_loop(self):
931
927
932
928
self ._pad_attention_dp_dummy_request ()
933
929
934
- if self .draft_model_engine is not None or is_ngram or self . drafter is not None :
930
+ if self .drafter is not None :
935
931
self ._prepare_draft_requests (self .active_requests )
936
932
937
933
scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
@@ -971,11 +967,9 @@ def _executor_loop(self):
971
967
scheduled_batch )
972
968
973
969
self .resource_manager .prepare_resources (scheduled_batch )
974
- if self .draft_model_engine is not None :
975
- self ._prepare_draft_tokens (scheduled_batch )
976
-
977
970
if self .drafter is not None :
978
- self .drafter .prepare_draft_tokens (scheduled_batch )
971
+ self .drafter .prepare_draft_tokens (
972
+ scheduled_batch , self .resource_manager )
979
973
980
974
if self .kv_cache_transceiver :
981
975
# For generation requests which have completed KV cache transfer
@@ -1798,188 +1792,6 @@ def _update_requests(self, sample_state: SampleState):
1798
1792
logger .error (f"Encountered an error in sampling: { error_msg } " )
1799
1793
self ._handle_errors (error_msg )
1800
1794
1801
- @nvtx_range ("_prepare_draft_batch" )
1802
- def _prepare_draft_batch (
1803
- self , scheduled_requests : ScheduledRequests
1804
- ) -> Tuple [ScheduledRequests , Dict [int , LlmRequest ]]:
1805
- """
1806
- Prepares a batch for the draft model engine. Draft tokens are only produced
1807
- for generation requests.
1808
-
1809
- The requests are prepared as follows:
1810
- 1. The first time the draft engine sees a request, it's a context request.
1811
- 2. Otherwise, if draft tokens were accepted on the last target model decoding
1812
- step, it's a chunked context request (we process all the accepted tokens together).
1813
- 3. Otherwise, it's a generation request.
1814
- """
1815
- try :
1816
- draft_batch = ScheduledRequests ()
1817
-
1818
- for request in scheduled_requests .generation_requests :
1819
- if request .py_draft_pages_allocated == 0 :
1820
- # No space for draft tokens.
1821
- continue
1822
-
1823
- # Stop drafting when we hit the max seqlen. We still need dummy draft
1824
- # tokens attached to the requests to make sure everything works properly
1825
- # with CUDA graph. These dummy tokens are already added by
1826
- # _prepare_draft_requests to make the KV cache/scheduler aware of the fact
1827
- # that we want to do spec decoding, so no need to do anything else here.
1828
- # This makes the perf for this case suboptimal, but that's OK - this is
1829
- # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
1830
- if request .max_beam_num_tokens - 1 >= self .draft_model_engine .max_seq_len :
1831
- continue
1832
-
1833
- num_draft_tokens = len (
1834
- request .py_last_draft_tokens
1835
- ) if request .py_last_draft_tokens is not None else 0
1836
- request .py_draft_tokens = []
1837
-
1838
- num_accepted_tokens = request .py_num_accepted_draft_tokens
1839
- num_rejected_tokens = num_draft_tokens - num_accepted_tokens
1840
- assert num_rejected_tokens >= 0
1841
-
1842
- spec_config = self .model_engine .spec_config
1843
- beam_idx = 0
1844
- input_tokens = spec_config .get_draft_model_prompt (
1845
- request .get_tokens ()[beam_idx ])
1846
-
1847
- def create_new_request (input_tokens ):
1848
- return LlmRequest (
1849
- request_id = request .py_request_id ,
1850
- max_new_tokens = request .py_max_new_tokens ,
1851
- input_tokens = input_tokens ,
1852
- sampling_config = request .sampling_config ,
1853
- return_perf_metrics = request .return_perf_metrics ,
1854
- is_streaming = False ,
1855
- is_draft = True )
1856
-
1857
- if request .max_beam_num_tokens - 1 == request .py_prompt_len :
1858
- # This is the first time the draft model is seeing this request.
1859
- # Prepare a context request. We discard the first token and take
1860
- # the newly decoded one - this is the convention for EAGLE 2 and 3.
1861
- new_request = create_new_request (input_tokens )
1862
- draft_batch .context_requests .append (new_request )
1863
- elif num_accepted_tokens == 0 :
1864
- new_request = create_new_request (input_tokens [:- 1 ])
1865
- # Explicitly add the last token so get_last_tokens() returns
1866
- # the right value
1867
- new_request .add_new_token (input_tokens [- 1 ], beam_idx )
1868
- new_request .state = LlmRequestState .GENERATION_IN_PROGRESS
1869
- draft_batch .generation_requests .append (new_request )
1870
- else :
1871
- new_request = create_new_request (input_tokens )
1872
- new_request .context_chunk_size = num_accepted_tokens + 1
1873
- new_request .context_current_position = len (
1874
- input_tokens ) - num_accepted_tokens - 1
1875
- new_request .context_chunk_size = num_accepted_tokens + 1
1876
- new_request .context_current_position = len (
1877
- input_tokens ) - num_accepted_tokens - 1
1878
-
1879
- draft_batch .context_requests .append (new_request )
1880
-
1881
- new_request .py_stop_words_list = request .py_stop_words_list
1882
-
1883
- return draft_batch
1884
-
1885
- except Exception as e :
1886
- traceback .print_exc ()
1887
- error_msg = str (e )
1888
- logger .error (f"Encountered an error in decode: { error_msg } " )
1889
- self ._handle_errors (error_msg )
1890
-
1891
- @nvtx_range ("_prepare_draft_tokens" )
1892
- def _prepare_draft_tokens (self , scheduled_requests : ScheduledRequests ):
1893
- if not self .draft_model_engine :
1894
- raise ValueError ("Draft model engine is not set" )
1895
-
1896
- try :
1897
- draft_batch = self ._prepare_draft_batch (scheduled_requests )
1898
-
1899
- if draft_batch .batch_size == 0 :
1900
- return
1901
- self .draft_seq_slot_manager .prepare_resources (draft_batch )
1902
-
1903
- req_id_to_old_request = {
1904
- req .py_request_id : req
1905
- for req in scheduled_requests .all_requests ()
1906
- }
1907
-
1908
- # Disable cuda graph for the 1st draft model forward
1909
- if self .model_engine .spec_config .spec_dec_mode .needs_kv_cache_recompute (
1910
- ):
1911
- with self .draft_model_engine .no_cuda_graph ():
1912
- outputs = self .draft_model_engine .forward (
1913
- draft_batch , self .resource_manager )
1914
- else :
1915
- outputs = self .draft_model_engine .forward (
1916
- draft_batch , self .resource_manager )
1917
- if hasattr (self .draft_model_engine .model .model , 'd2t' ):
1918
- outputs ['d2t' ] = self .draft_model_engine .model .model .d2t .data
1919
-
1920
- sample_state = self ._sample_async (draft_batch , outputs )
1921
- previous_batch = sample_state
1922
-
1923
- self ._update_request_states (draft_batch )
1924
-
1925
- def _process_decoded_tokens (draft_batch ):
1926
- new_requests = []
1927
- for req in draft_batch .all_requests ():
1928
- target_model_req = req_id_to_old_request [req .py_request_id ]
1929
- target_model_req .py_draft_tokens .append (
1930
- req .get_last_tokens (0 ))
1931
- if req .state != LlmRequestState .GENERATION_COMPLETE and len (
1932
- target_model_req .py_draft_tokens
1933
- ) < target_model_req .py_draft_pages_allocated :
1934
- new_requests .append (req )
1935
- else :
1936
- self .draft_seq_slot_manager .free_resources (req )
1937
-
1938
- return new_requests
1939
-
1940
- # The TRTLLM attention kernels cannot handle generation requests with
1941
- # different seqlens. No issues with flashinfer, should we look into removing
1942
- # this? Just needs proper kernel support.
1943
- def _pad_to_max_draft_tokens ():
1944
- for req in scheduled_requests .generation_requests :
1945
- max_draft_len = self .max_draft_len
1946
- num_draft_tokens = len (req .py_draft_tokens )
1947
- req .py_draft_tokens .extend (
1948
- 0 for _ in range (max_draft_len - num_draft_tokens ))
1949
-
1950
- draft_batch .generation_requests = draft_batch .context_requests + draft_batch .generation_requests
1951
- draft_batch .context_requests = []
1952
-
1953
- for i in range (self .max_draft_len - 1 ):
1954
- if len (draft_batch .generation_requests ) == 0 :
1955
- break
1956
-
1957
- outputs = self .draft_model_engine .forward (
1958
- draft_batch ,
1959
- self .resource_manager ,
1960
- new_tensors_device = previous_batch .device )
1961
-
1962
- if hasattr (self .draft_model_engine .model .model , 'd2t' ):
1963
- outputs [
1964
- 'd2t' ] = self .draft_model_engine .model .model .d2t .data
1965
- sample_state = self ._sample_async (draft_batch , outputs )
1966
- self ._update_request_states (draft_batch )
1967
- self ._update_requests (previous_batch )
1968
- new_requests = _process_decoded_tokens (
1969
- previous_batch .scheduled_requests )
1970
- draft_batch .generation_requests = new_requests
1971
- previous_batch = sample_state
1972
- self ._update_requests (previous_batch )
1973
- new_requests = _process_decoded_tokens (
1974
- previous_batch .scheduled_requests )
1975
- _pad_to_max_draft_tokens ()
1976
-
1977
- except Exception as e :
1978
- traceback .print_exc ()
1979
- error_msg = str (e )
1980
- logger .error (f"Encountered an error in decode: { error_msg } " )
1981
- self ._handle_errors (error_msg )
1982
-
1983
1795
def _handle_errors (self , error_msg : Optional [str ] = None ):
1984
1796
error_responses = {}
1985
1797
error_msg = error_msg or "error"
0 commit comments