Skip to content

Commit 58d22a7

Browse files
authored
[TRTLLM-6352][feat] Migrate EAGLE3 and draft/target speculation to Drafter (#6007)
Signed-off-by: ziyixiong-nv <fxiong@nvidia.com>
1 parent 9518e14 commit 58d22a7

File tree

6 files changed

+388
-200
lines changed

6 files changed

+388
-200
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import weakref
1212
from collections import deque, namedtuple
1313
from contextlib import contextmanager
14-
from typing import Dict, List, Optional, Tuple, Union
14+
from typing import Dict, List, Optional, Union
1515

1616
import torch
1717

@@ -308,7 +308,7 @@ def __init__(self,
308308
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
309309
self.event_loop = trace_func(self.event_loop)
310310

311-
if self.draft_model_engine is not None:
311+
if self.drafter is not None:
312312
if self.event_loop.__name__ != self._executor_loop.__name__:
313313
raise NotImplementedError(
314314
"Drafting is not supported for selected executor loop. "
@@ -905,10 +905,6 @@ def _executor_loop_pp(self):
905905

906906
def _executor_loop(self):
907907
torch.cuda.set_device(self.device_id)
908-
is_ngram = hasattr(
909-
self.model_engine, "spec_config"
910-
) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram(
911-
)
912908
with self._profiler() as profile_step:
913909
sample_state = None
914910
iter_start_time = time.time()
@@ -931,7 +927,7 @@ def _executor_loop(self):
931927

932928
self._pad_attention_dp_dummy_request()
933929

934-
if self.draft_model_engine is not None or is_ngram or self.drafter is not None:
930+
if self.drafter is not None:
935931
self._prepare_draft_requests(self.active_requests)
936932

937933
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
@@ -971,11 +967,9 @@ def _executor_loop(self):
971967
scheduled_batch)
972968

973969
self.resource_manager.prepare_resources(scheduled_batch)
974-
if self.draft_model_engine is not None:
975-
self._prepare_draft_tokens(scheduled_batch)
976-
977970
if self.drafter is not None:
978-
self.drafter.prepare_draft_tokens(scheduled_batch)
971+
self.drafter.prepare_draft_tokens(
972+
scheduled_batch, self.resource_manager)
979973

980974
if self.kv_cache_transceiver:
981975
# For generation requests which have completed KV cache transfer
@@ -1798,188 +1792,6 @@ def _update_requests(self, sample_state: SampleState):
17981792
logger.error(f"Encountered an error in sampling: {error_msg}")
17991793
self._handle_errors(error_msg)
18001794

1801-
@nvtx_range("_prepare_draft_batch")
1802-
def _prepare_draft_batch(
1803-
self, scheduled_requests: ScheduledRequests
1804-
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
1805-
"""
1806-
Prepares a batch for the draft model engine. Draft tokens are only produced
1807-
for generation requests.
1808-
1809-
The requests are prepared as follows:
1810-
1. The first time the draft engine sees a request, it's a context request.
1811-
2. Otherwise, if draft tokens were accepted on the last target model decoding
1812-
step, it's a chunked context request (we process all the accepted tokens together).
1813-
3. Otherwise, it's a generation request.
1814-
"""
1815-
try:
1816-
draft_batch = ScheduledRequests()
1817-
1818-
for request in scheduled_requests.generation_requests:
1819-
if request.py_draft_pages_allocated == 0:
1820-
# No space for draft tokens.
1821-
continue
1822-
1823-
# Stop drafting when we hit the max seqlen. We still need dummy draft
1824-
# tokens attached to the requests to make sure everything works properly
1825-
# with CUDA graph. These dummy tokens are already added by
1826-
# _prepare_draft_requests to make the KV cache/scheduler aware of the fact
1827-
# that we want to do spec decoding, so no need to do anything else here.
1828-
# This makes the perf for this case suboptimal, but that's OK - this is
1829-
# a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
1830-
if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len:
1831-
continue
1832-
1833-
num_draft_tokens = len(
1834-
request.py_last_draft_tokens
1835-
) if request.py_last_draft_tokens is not None else 0
1836-
request.py_draft_tokens = []
1837-
1838-
num_accepted_tokens = request.py_num_accepted_draft_tokens
1839-
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
1840-
assert num_rejected_tokens >= 0
1841-
1842-
spec_config = self.model_engine.spec_config
1843-
beam_idx = 0
1844-
input_tokens = spec_config.get_draft_model_prompt(
1845-
request.get_tokens()[beam_idx])
1846-
1847-
def create_new_request(input_tokens):
1848-
return LlmRequest(
1849-
request_id=request.py_request_id,
1850-
max_new_tokens=request.py_max_new_tokens,
1851-
input_tokens=input_tokens,
1852-
sampling_config=request.sampling_config,
1853-
return_perf_metrics=request.return_perf_metrics,
1854-
is_streaming=False,
1855-
is_draft=True)
1856-
1857-
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
1858-
# This is the first time the draft model is seeing this request.
1859-
# Prepare a context request. We discard the first token and take
1860-
# the newly decoded one - this is the convention for EAGLE 2 and 3.
1861-
new_request = create_new_request(input_tokens)
1862-
draft_batch.context_requests.append(new_request)
1863-
elif num_accepted_tokens == 0:
1864-
new_request = create_new_request(input_tokens[:-1])
1865-
# Explicitly add the last token so get_last_tokens() returns
1866-
# the right value
1867-
new_request.add_new_token(input_tokens[-1], beam_idx)
1868-
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
1869-
draft_batch.generation_requests.append(new_request)
1870-
else:
1871-
new_request = create_new_request(input_tokens)
1872-
new_request.context_chunk_size = num_accepted_tokens + 1
1873-
new_request.context_current_position = len(
1874-
input_tokens) - num_accepted_tokens - 1
1875-
new_request.context_chunk_size = num_accepted_tokens + 1
1876-
new_request.context_current_position = len(
1877-
input_tokens) - num_accepted_tokens - 1
1878-
1879-
draft_batch.context_requests.append(new_request)
1880-
1881-
new_request.py_stop_words_list = request.py_stop_words_list
1882-
1883-
return draft_batch
1884-
1885-
except Exception as e:
1886-
traceback.print_exc()
1887-
error_msg = str(e)
1888-
logger.error(f"Encountered an error in decode: {error_msg}")
1889-
self._handle_errors(error_msg)
1890-
1891-
@nvtx_range("_prepare_draft_tokens")
1892-
def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests):
1893-
if not self.draft_model_engine:
1894-
raise ValueError("Draft model engine is not set")
1895-
1896-
try:
1897-
draft_batch = self._prepare_draft_batch(scheduled_requests)
1898-
1899-
if draft_batch.batch_size == 0:
1900-
return
1901-
self.draft_seq_slot_manager.prepare_resources(draft_batch)
1902-
1903-
req_id_to_old_request = {
1904-
req.py_request_id: req
1905-
for req in scheduled_requests.all_requests()
1906-
}
1907-
1908-
# Disable cuda graph for the 1st draft model forward
1909-
if self.model_engine.spec_config.spec_dec_mode.needs_kv_cache_recompute(
1910-
):
1911-
with self.draft_model_engine.no_cuda_graph():
1912-
outputs = self.draft_model_engine.forward(
1913-
draft_batch, self.resource_manager)
1914-
else:
1915-
outputs = self.draft_model_engine.forward(
1916-
draft_batch, self.resource_manager)
1917-
if hasattr(self.draft_model_engine.model.model, 'd2t'):
1918-
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
1919-
1920-
sample_state = self._sample_async(draft_batch, outputs)
1921-
previous_batch = sample_state
1922-
1923-
self._update_request_states(draft_batch)
1924-
1925-
def _process_decoded_tokens(draft_batch):
1926-
new_requests = []
1927-
for req in draft_batch.all_requests():
1928-
target_model_req = req_id_to_old_request[req.py_request_id]
1929-
target_model_req.py_draft_tokens.append(
1930-
req.get_last_tokens(0))
1931-
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
1932-
target_model_req.py_draft_tokens
1933-
) < target_model_req.py_draft_pages_allocated:
1934-
new_requests.append(req)
1935-
else:
1936-
self.draft_seq_slot_manager.free_resources(req)
1937-
1938-
return new_requests
1939-
1940-
# The TRTLLM attention kernels cannot handle generation requests with
1941-
# different seqlens. No issues with flashinfer, should we look into removing
1942-
# this? Just needs proper kernel support.
1943-
def _pad_to_max_draft_tokens():
1944-
for req in scheduled_requests.generation_requests:
1945-
max_draft_len = self.max_draft_len
1946-
num_draft_tokens = len(req.py_draft_tokens)
1947-
req.py_draft_tokens.extend(
1948-
0 for _ in range(max_draft_len - num_draft_tokens))
1949-
1950-
draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests
1951-
draft_batch.context_requests = []
1952-
1953-
for i in range(self.max_draft_len - 1):
1954-
if len(draft_batch.generation_requests) == 0:
1955-
break
1956-
1957-
outputs = self.draft_model_engine.forward(
1958-
draft_batch,
1959-
self.resource_manager,
1960-
new_tensors_device=previous_batch.device)
1961-
1962-
if hasattr(self.draft_model_engine.model.model, 'd2t'):
1963-
outputs[
1964-
'd2t'] = self.draft_model_engine.model.model.d2t.data
1965-
sample_state = self._sample_async(draft_batch, outputs)
1966-
self._update_request_states(draft_batch)
1967-
self._update_requests(previous_batch)
1968-
new_requests = _process_decoded_tokens(
1969-
previous_batch.scheduled_requests)
1970-
draft_batch.generation_requests = new_requests
1971-
previous_batch = sample_state
1972-
self._update_requests(previous_batch)
1973-
new_requests = _process_decoded_tokens(
1974-
previous_batch.scheduled_requests)
1975-
_pad_to_max_draft_tokens()
1976-
1977-
except Exception as e:
1978-
traceback.print_exc()
1979-
error_msg = str(e)
1980-
logger.error(f"Encountered an error in decode: {error_msg}")
1981-
self._handle_errors(error_msg)
1982-
19831795
def _handle_errors(self, error_msg: Optional[str] = None):
19841796
error_responses = {}
19851797
error_msg = error_msg or "error"

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ def create_py_executor(
382382

383383
# Drafter for speculative decoding
384384
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
385-
drafter = get_spec_drafter(model_engine, spec_resource_manager)
385+
drafter = get_spec_drafter(model_engine, draft_model_engine, sampler,
386+
spec_resource_manager)
386387

387388
with mem_monitor.observe_creation_stage(
388389
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
from abc import ABC, abstractmethod
2+
from typing import Optional
23

4+
from ..pyexecutor.resource_manager import ResourceManager
35
from ..pyexecutor.scheduler import ScheduledRequests
46

57

68
class Drafter(ABC):
9+
"""Abstract base class for all drafter implementations."""
710

811
@abstractmethod
912
def prepare_draft_tokens(
1013
self,
1114
scheduled_requests: ScheduledRequests,
15+
resource_manager: Optional[ResourceManager] = None,
1216
) -> None:
1317
"""
1418
Prepare the drafter tokens for the forward computation this step.
19+
20+
Args:
21+
scheduled_requests: The scheduled requests for this iteration
1522
"""
1623
raise NotImplementedError

0 commit comments

Comments
 (0)