Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables trt-llm config for lookahead decoding #1391

Merged
merged 9 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from huggingface_hub.utils import validate_repo_id
from pydantic import BaseModel, PydanticDeprecatedSince20, model_validator, validator

from truss.base.constants import BEI_REQUIRED_MAX_NUM_TOKENS

logger = logging.getLogger(__name__)
# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
Expand Down Expand Up @@ -84,6 +82,7 @@ class TrussTRTLLMBatchSchedulerPolicy(str, Enum):

class TrussSpecDecMode(str, Enum):
DRAFT_EXTERNAL = "DRAFT_TOKENS_EXTERNAL"
LOOKAHEAD_DECODING = "LOOKAHEAD_DECODING"


class TrussTRTLLMRuntimeConfiguration(BaseModel):
Expand Down Expand Up @@ -144,6 +143,7 @@ def _bei_specfic_migration(self):
f"Your setting of `build.max_seq_len={self.max_seq_len}` is not used and "
"automatically inferred from the model repo config.json -> `max_position_embeddings`"
)
from truss.base.constants import BEI_REQUIRED_MAX_NUM_TOKENS
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved

if self.max_num_tokens < BEI_REQUIRED_MAX_NUM_TOKENS:
logger.warning(
Expand Down Expand Up @@ -210,17 +210,96 @@ def max_draft_len(self) -> Optional[int]:

class TrussSpeculatorConfiguration(BaseModel):
speculative_decoding_mode: TrussSpecDecMode = TrussSpecDecMode.DRAFT_EXTERNAL
num_draft_tokens: int
num_draft_tokens: Optional[int] = None
checkpoint_repository: Optional[CheckpointRepository] = None
runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
build: Optional[TrussTRTLLMBuildConfiguration] = None
lookahead_windows_size: Optional[int] = None
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
lookahead_ngram_size: Optional[int] = None
lookahead_verification_set_size: Optional[int] = None

def __init__(self, **data):
super().__init__(**data)
self._validate_checkpoint()
self._validate_spec_dec_mode()

def _assert_draft_tokens(self):
if self.num_draft_tokens > 512 or self.num_draft_tokens < 0:
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
if self.speculative_decoding_mode == TrussSpecDecMode.LOOKAHEAD_DECODING:
reason = (
f"This is automatically calculated value of lookahead_windows_size={self.lookahead_windows_size}, "
f" lookahead_ngram_size={self.lookahead_ngram_size}, lookahead_verification_set_size={self.lookahead_verification_set_size}. "
f"Please any of them."
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
)
else:
reason = "You set this value under speculator.num_draft_tokens"
raise ValueError(
f"num_draft_tokens must be less than or equal to 512. But you requested num_draft_tokens={self.num_draft_tokens}. {reason}"
)

@staticmethod
def lade_max_draft_len(
windows_size: int, ngram_size: int, verification_set_size: int
) -> int:
"""calculate the maximum number of tokens with baseten lookahead algorithm"""
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
return (0 if (ngram_size == 1) else ngram_size - 2) + (
windows_size - 1 + verification_set_size
) * (ngram_size - 1)

def _validate_spec_dec_mode(self):
if self.speculative_decoding_mode == TrussSpecDecMode.DRAFT_EXTERNAL:
if not self.num_draft_tokens:
raise ValueError(
"Draft external mode requires num_draft_tokens to be set."
)
elif self.speculative_decoding_mode == TrussSpecDecMode.LOOKAHEAD_DECODING:
if not all(
[
self.lookahead_windows_size,
self.lookahead_ngram_size,
self.lookahead_verification_set_size,
]
):
raise ValueError(
"Lookahead decoding mode requires lookahead_windows_size, lookahead_ngram_size, lookahead_verification_set_size to be set."
)
lade_num_draft_tokens = self.lade_max_draft_len(
self.lookahead_windows_size,
self.lookahead_ngram_size,
self.lookahead_verification_set_size,
)
if not ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION:
if (
self.num_draft_tokens
and self.num_draft_tokens != lade_num_draft_tokens
):
raise ValueError(
f"num_draft_tokens is automatically calculated based on lookahead_windows_size, lookahead_ngram_size, lookahead_verification_set_size. "
f"Please remove num_draft_tokens or set it to exactly {lade_num_draft_tokens}. You set it to {self.num_draft_tokens}."
)
self.num_draft_tokens = lade_num_draft_tokens
if self.num_draft_tokens > 128:
logger.warning(
f"Lookahead decoding mode generates up to {self.num_draft_tokens} speculative tokens per step and may have performance implications. "
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
"We recommend a simpler config, e.g. lookahead_windows_size=7, lookahead_ngram_size=5, lookahead_verification_set_size=3."
)
else:
# server side on engine-builder
if not self.num_draft_tokens:
raise ValueError(
"num_draft_tokens is required in lookahead decoding mode but not set"
)
if self.num_draft_tokens < lade_num_draft_tokens:
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"num_draft_tokens is less than the calculated value based on lookahead_windows_size, lookahead_ngram_size, lookahead_verification_set_size"
)

self._assert_draft_tokens()

def _validate_checkpoint(self):
if not (bool(self.checkpoint_repository) ^ bool(self.build)):
if self.speculative_decoding_mode == TrussSpecDecMode.DRAFT_EXTERNAL and not (
bool(self.checkpoint_repository) ^ bool(self.build)
):
raise ValueError(
"Speculative decoding requires exactly one of checkpoint_repository or build to be configured."
)
Expand Down
75 changes: 75 additions & 0 deletions truss/tests/trt_llm/test_trt_llm_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import copy

import pytest
from truss.base.trt_llm_config import (
TRTLLMConfiguration,
TrussSpecDecMode,
TrussSpeculatorConfiguration,
TrussTRTLLMBatchSchedulerPolicy,
TrussTRTLLMBuildConfiguration,
TrussTRTLLMRuntimeConfiguration,
Expand Down Expand Up @@ -39,3 +44,73 @@ def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields_existi
"request_default_max_tokens": 10,
"total_token_limit": 100,
}


def test_trt_llm_chunked_prefill_fix(trtllm_config):
"""make sure that the chunked prefill validation is working"""
trt_llm_config = TRTLLMConfiguration(**trtllm_config["trt_llm"])

assert trt_llm_config.build.plugin_configuration.paged_kv_cache is True
assert trt_llm_config.build.plugin_configuration.use_paged_context_fmha is True
assert trt_llm_config.runtime.enable_chunked_context is True

with pytest.raises(ValueError):
trt_llm2 = copy.deepcopy(trt_llm_config)
trt_llm2.build.plugin_configuration.paged_kv_cache = False
TRTLLMConfiguration(**trt_llm2.model_dump())

with pytest.raises(
ValueError
): # verify you cant disable paged context fmha without disabling enable_chunked_context
trt_llm2 = copy.deepcopy(trt_llm_config)
trt_llm2.build.plugin_configuration.use_paged_context_fmha = False
TRTLLMConfiguration(**trt_llm2.model_dump())

trt_llm2 = copy.deepcopy(trt_llm_config)
trt_llm2.runtime.enable_chunked_context = False
trt_llm2.build.plugin_configuration.use_paged_context_fmha = False
TRTLLMConfiguration(**trt_llm2.model_dump())


def test_trt_llm_lookahead_decoding(trtllm_config):
trt_llm_config = TRTLLMConfiguration(**trtllm_config["trt_llm"])

with pytest.raises(ValueError):
trt_llm_config.build.speculator = TrussSpeculatorConfiguration(
speculative_decoding_mode=TrussSpecDecMode.LOOKAHEAD_DECODING,
num_draft_tokens=None,
lookahead_windows_size=None,
lookahead_ngram_size=None,
lookahead_verification_set_size=None,
)
# need to specify lookahead_windows_size and lookahead_ngram_size and lookahead_verification_set_size
TRTLLMConfiguration(**trt_llm_config.model_dump())

trt_llm_config.build.speculator = TrussSpeculatorConfiguration(
speculative_decoding_mode=TrussSpecDecMode.LOOKAHEAD_DECODING,
num_draft_tokens=None, # will be overwriten
lookahead_windows_size=10,
lookahead_ngram_size=10,
lookahead_verification_set_size=10,
)
with_spec = TRTLLMConfiguration(**trt_llm_config.model_dump())
assert with_spec.build.speculator.lookahead_windows_size == 10
assert with_spec.build.speculator.lookahead_ngram_size == 10
assert with_spec.build.speculator.lookahead_verification_set_size == 10
assert (
with_spec.build.speculator.speculative_decoding_mode
== TrussSpecDecMode.LOOKAHEAD_DECODING
)
assert with_spec.build.speculator.num_draft_tokens == 179

with pytest.raises(ValueError):
trt_llm_config.build.speculator = TrussSpeculatorConfiguration(
speculative_decoding_mode=TrussSpecDecMode.LOOKAHEAD_DECODING,
num_draft_tokens=None,
lookahead_windows_size=100,
lookahead_ngram_size=100,
lookahead_verification_set_size=100,
)
# need to specify num_draft_tokens
TRTLLMConfiguration(**trt_llm_config.model_dump())
# will lead to ValueError -> too many draft tokens are generated with 100 lookahead windows
Loading