diff --git a/truss/base/trt_llm_config.py b/truss/base/trt_llm_config.py index 1f508bc31..1677e3e4f 100644 --- a/truss/base/trt_llm_config.py +++ b/truss/base/trt_llm_config.py @@ -11,8 +11,6 @@ from huggingface_hub.utils import validate_repo_id from pydantic import BaseModel, PydanticDeprecatedSince20, model_validator, validator -from truss.base.constants import BEI_REQUIRED_MAX_NUM_TOKENS - logger = logging.getLogger(__name__) # Suppress Pydantic V1 warnings, because we have to use it for backwards compat. warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) @@ -84,6 +82,7 @@ class TrussTRTLLMBatchSchedulerPolicy(str, Enum): class TrussSpecDecMode(str, Enum): DRAFT_EXTERNAL = "DRAFT_TOKENS_EXTERNAL" + LOOKAHEAD_DECODING = "LOOKAHEAD_DECODING" class TrussTRTLLMRuntimeConfiguration(BaseModel): @@ -144,6 +143,7 @@ def _bei_specfic_migration(self): f"Your setting of `build.max_seq_len={self.max_seq_len}` is not used and " "automatically inferred from the model repo config.json -> `max_position_embeddings`" ) + from truss.base.constants import BEI_REQUIRED_MAX_NUM_TOKENS if self.max_num_tokens < BEI_REQUIRED_MAX_NUM_TOKENS: logger.warning( @@ -210,17 +210,98 @@ def max_draft_len(self) -> Optional[int]: class TrussSpeculatorConfiguration(BaseModel): speculative_decoding_mode: TrussSpecDecMode = TrussSpecDecMode.DRAFT_EXTERNAL - num_draft_tokens: int + num_draft_tokens: Optional[int] = None checkpoint_repository: Optional[CheckpointRepository] = None runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration() build: Optional[TrussTRTLLMBuildConfiguration] = None + lookahead_windows_size: Optional[int] = None + lookahead_ngram_size: Optional[int] = None + lookahead_verification_set_size: Optional[int] = None def __init__(self, **data): super().__init__(**data) self._validate_checkpoint() + self._validate_spec_dec_mode() + + def _assert_draft_tokens(self): + if self.num_draft_tokens > 2048 or self.num_draft_tokens < 0: + if self.speculative_decoding_mode == TrussSpecDecMode.LOOKAHEAD_DECODING: + reason = ( + f"This is automatically calculated value of lookahead_windows_size={self.lookahead_windows_size}, " + f" lookahead_ngram_size={self.lookahead_ngram_size}, lookahead_verification_set_size={self.lookahead_verification_set_size}. " + f"Please lower any of them." + ) + else: + reason = "You set this value under speculator.num_draft_tokens" + raise ValueError( + f"num_draft_tokens must be less than or equal to 2048. But you requested num_draft_tokens={self.num_draft_tokens}. {reason}" + ) + + @staticmethod + def lade_max_draft_len( + windows_size: int, ngram_size: int, verification_set_size: int + ) -> int: + """calculate the maximum number of tokens with baseten lookahead algorithm: https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/lookahead#overview""" + return (0 if (ngram_size == 1) else ngram_size - 2) + ( + windows_size - 1 + verification_set_size + ) * (ngram_size - 1) + + def _validate_spec_dec_mode(self): + if self.speculative_decoding_mode == TrussSpecDecMode.DRAFT_EXTERNAL: + if not self.num_draft_tokens: + raise ValueError( + "Draft external mode requires num_draft_tokens to be set." + ) + elif self.speculative_decoding_mode == TrussSpecDecMode.LOOKAHEAD_DECODING: + if not all( + [ + self.lookahead_windows_size, + self.lookahead_ngram_size, + self.lookahead_verification_set_size, + ] + ): + raise ValueError( + "Lookahead decoding mode requires lookahead_windows_size, lookahead_ngram_size, lookahead_verification_set_size to be set." + ) + lade_num_draft_tokens = self.lade_max_draft_len( + self.lookahead_windows_size, + self.lookahead_ngram_size, + self.lookahead_verification_set_size, + ) + if not ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION: + if ( + self.num_draft_tokens + and self.num_draft_tokens != lade_num_draft_tokens + ): + raise ValueError( + f"num_draft_tokens is automatically calculated based on lookahead_windows_size, lookahead_ngram_size, lookahead_verification_set_size. " + f"Please remove num_draft_tokens or set it to exactly {lade_num_draft_tokens}. You set it to {self.num_draft_tokens}." + ) + self.num_draft_tokens = lade_num_draft_tokens + if self.num_draft_tokens > 512: + logger.warning( + f"Lookahead decoding mode generates up to {self.num_draft_tokens} speculative tokens per step and may have performance implications. " + "We recommend a simpler config, e.g. lookahead_windows_size=7, lookahead_ngram_size=5, lookahead_verification_set_size=3." + ) + else: + # server side on engine-builder + if not self.num_draft_tokens: + raise ValueError( + "num_draft_tokens is required in lookahead decoding mode but not set" + ) + if ( + self.num_draft_tokens >= lade_num_draft_tokens + ): # check that it has at least the required tokens. That way, it could have even higher at request time. + raise ValueError( + "num_draft_tokens is less than the calculated value based on lookahead_windows_size, lookahead_ngram_size, lookahead_verification_set_size" + ) + + self._assert_draft_tokens() def _validate_checkpoint(self): - if not (bool(self.checkpoint_repository) ^ bool(self.build)): + if self.speculative_decoding_mode == TrussSpecDecMode.DRAFT_EXTERNAL and not ( + bool(self.checkpoint_repository) ^ bool(self.build) + ): raise ValueError( "Speculative decoding requires exactly one of checkpoint_repository or build to be configured." ) diff --git a/truss/tests/trt_llm/test_trt_llm_config.py b/truss/tests/trt_llm/test_trt_llm_config.py index 91c565c00..963c7e80c 100644 --- a/truss/tests/trt_llm/test_trt_llm_config.py +++ b/truss/tests/trt_llm/test_trt_llm_config.py @@ -1,5 +1,10 @@ +import copy + +import pytest from truss.base.trt_llm_config import ( TRTLLMConfiguration, + TrussSpecDecMode, + TrussSpeculatorConfiguration, TrussTRTLLMBatchSchedulerPolicy, TrussTRTLLMBuildConfiguration, TrussTRTLLMRuntimeConfiguration, @@ -39,3 +44,73 @@ def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields_existi "request_default_max_tokens": 10, "total_token_limit": 100, } + + +def test_trt_llm_chunked_prefill_fix(trtllm_config): + """make sure that the chunked prefill validation is working""" + trt_llm_config = TRTLLMConfiguration(**trtllm_config["trt_llm"]) + + assert trt_llm_config.build.plugin_configuration.paged_kv_cache is True + assert trt_llm_config.build.plugin_configuration.use_paged_context_fmha is True + assert trt_llm_config.runtime.enable_chunked_context is True + + with pytest.raises(ValueError): + trt_llm2 = copy.deepcopy(trt_llm_config) + trt_llm2.build.plugin_configuration.paged_kv_cache = False + TRTLLMConfiguration(**trt_llm2.model_dump()) + + with pytest.raises( + ValueError + ): # verify you cant disable paged context fmha without disabling enable_chunked_context + trt_llm2 = copy.deepcopy(trt_llm_config) + trt_llm2.build.plugin_configuration.use_paged_context_fmha = False + TRTLLMConfiguration(**trt_llm2.model_dump()) + + trt_llm2 = copy.deepcopy(trt_llm_config) + trt_llm2.runtime.enable_chunked_context = False + trt_llm2.build.plugin_configuration.use_paged_context_fmha = False + TRTLLMConfiguration(**trt_llm2.model_dump()) + + +def test_trt_llm_lookahead_decoding(trtllm_config): + trt_llm_config = TRTLLMConfiguration(**trtllm_config["trt_llm"]) + + with pytest.raises(ValueError): + trt_llm_config.build.speculator = TrussSpeculatorConfiguration( + speculative_decoding_mode=TrussSpecDecMode.LOOKAHEAD_DECODING, + num_draft_tokens=None, + lookahead_windows_size=None, + lookahead_ngram_size=None, + lookahead_verification_set_size=None, + ) + # need to specify lookahead_windows_size and lookahead_ngram_size and lookahead_verification_set_size + TRTLLMConfiguration(**trt_llm_config.model_dump()) + + trt_llm_config.build.speculator = TrussSpeculatorConfiguration( + speculative_decoding_mode=TrussSpecDecMode.LOOKAHEAD_DECODING, + num_draft_tokens=None, # will be overwriten + lookahead_windows_size=10, + lookahead_ngram_size=10, + lookahead_verification_set_size=10, + ) + with_spec = TRTLLMConfiguration(**trt_llm_config.model_dump()) + assert with_spec.build.speculator.lookahead_windows_size == 10 + assert with_spec.build.speculator.lookahead_ngram_size == 10 + assert with_spec.build.speculator.lookahead_verification_set_size == 10 + assert ( + with_spec.build.speculator.speculative_decoding_mode + == TrussSpecDecMode.LOOKAHEAD_DECODING + ) + assert with_spec.build.speculator.num_draft_tokens == 179 + + with pytest.raises(ValueError): + trt_llm_config.build.speculator = TrussSpeculatorConfiguration( + speculative_decoding_mode=TrussSpecDecMode.LOOKAHEAD_DECODING, + num_draft_tokens=None, + lookahead_windows_size=100, + lookahead_ngram_size=100, + lookahead_verification_set_size=100, + ) + # need to specify num_draft_tokens + TRTLLMConfiguration(**trt_llm_config.model_dump()) + # will lead to ValueError -> too many draft tokens are generated with 100 lookahead windows