-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Draft: Deepseek: Start Eagle work #6210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Draft: Deepseek: Start Eagle work #6210
Conversation
WalkthroughThis update integrates speculative decoding metadata tracking into DeepseekV3 and Eagle3 models, modifies fusion and capture logic based on new configuration parameters, and updates configuration classes to support flexible capture layer counts. It also adjusts model initialization, forward passes, and weight loading to accommodate these changes. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant DeepseekV3ForCausalLM
participant DeepseekV3Model
participant DeepseekV3DecoderLayer
participant SpecMetadata
User->>DeepseekV3ForCausalLM: forward(input, spec_metadata)
DeepseekV3ForCausalLM->>DeepseekV3Model: forward(input, spec_metadata)
loop for each layer
DeepseekV3Model->>DeepseekV3DecoderLayer: forward(hidden_states, spec_metadata)
alt spec_metadata and capture
DeepseekV3DecoderLayer->>SpecMetadata: capture(hidden_states, residual)
end
end
DeepseekV3Model-->>DeepseekV3ForCausalLM: output
DeepseekV3ForCausalLM-->>User: output
Estimated code review effort4 (90–120 minutes) Possibly related PRs
Suggested labels
Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tensorrt_llm/llmapi/llm_args.py (1)
320-330
: Add validation for the newnum_capture_layers
field
num_capture_layers
must be a positive integer; otherwise back-end code that allocates per-layer buffers will blow up at run-time. Consider guarding the field with a validator so that bad configs fail fast.class EagleDecodingConfig(DecodingBaseConfig): @@ max_non_leaves_per_layer: Optional[int] = None eagle3_one_model: Optional[bool] = True num_capture_layers: Optional[int] = 3 + + # Ensure the value is positive + @field_validator("num_capture_layers") + @classmethod + def _check_num_capture_layers(cls, v: int): + if v is not None and v <= 0: + raise ValueError("num_capture_layers must be a positive integer") + return vtensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
219-219
: Consider breaking the long line for better readability.The logic is correct, but the line exceeds the 120-character limit.
- predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 + predicted_tokens_per_seq = (model_config.spec_config.max_draft_len + 1 + if model_config.spec_config is not None else 1)
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
tensorrt_llm/_torch/models/modeling_deepseekv3.py
(12 hunks)tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(2 hunks)tensorrt_llm/_torch/speculative/interface.py
(2 hunks)tensorrt_llm/_torch/speculative/utils.py
(2 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py
219-219: Line too long (124 > 120)
(E501)
🔇 Additional comments (14)
tensorrt_llm/_torch/speculative/utils.py (1)
39-39
: LGTM! Correctly integrates the new capture layers configuration.The addition of
num_capture_layers=spec_config.num_capture_layers
to bothEagle3SpecMetadata
andEagle3OneModelSpecMetadata
constructors properly propagates the configurable parameter from the spec config to the metadata objects.Also applies to: 49-49
tensorrt_llm/_torch/speculative/interface.py (2)
32-32
: LGTM! Correctly narrows engine usage to Eagle3 one-model mode.The change to
use_one_engine()
removes the MTP check, now focusing exclusively on the EAGLE3_ONE_MODEL mode. This aligns with the architectural changes where only Eagle3 one-model mode uses a single engine.
181-182
: LGTM! Provides appropriate base implementation.The new
is_layer_capture
method provides a sensible default implementation that returnsFalse
. This will be overridden in concrete metadata classes that actually perform layer capture (like Eagle3 metadata classes).tensorrt_llm/_torch/models/modeling_speculative.py (1)
158-163
: Conditionalfc
initialization is safe.The
apply_eagle3_fc
method first checks whether the last dimension ofhidden_states
matchesself.model.hidden_size
before callingself.model.fc
. Since the constructor only createsself.fc
whennum_capture_layers > 1
(i.e., when concatenated hidden states exceedhidden_size
), the layer will always exist exactly when it’s needed. No other code paths referenceself.model.fc
unguarded.No changes required.
tensorrt_llm/_torch/speculative/eagle3.py (2)
38-41
: LGTM! Correctly uses configurable capture layers.The hidden states tensor initialization now properly uses
config.num_capture_layers
instead of the hardcoded value of 3, making the tensor size consistent with the configurable number of capture layers.
94-95
: Confirm single-layer capture logic in Eagle3SpecMetadataThe
__post_init__
inEagle3SpecMetadata
now treatsnum_capture_layers == 1
as a special case and setsself.layers_to_capture = (self.num_layers - 1,)(i.e. the last layer) instead of capturing the first layer. I didn’t find any existing documentation or comments that specify this design decision.
Please verify that targeting the last hidden layer for single‐layer capture is the intended behavior, and consider adding or updating inline comments or spec documentation to clearly describe this change.
• File:
tensorrt_llm/_torch/speculative/eagle3.py
, lines 94–95tensorrt_llm/_torch/models/modeling_deepseekv3.py (8)
52-52
: LGTM! Base class change aligns with speculative decoding architecture.The change from
DecoderModelForCausalLM
toSpecDecOneEngineForCausalLM
correctly integrates the model with the speculative decoding framework.
65-65
: LGTM! Import changes support speculative decoding integration.The imports for
MTPEagleWorker
,MTPWorker
, andSpecMetadata
are properly added to support the new speculative decoding functionality.
730-767
: LGTM! Fusion control logic properly handles layer capture.The conditional disabling of fusion optimizations during layer capture is correctly implemented. This ensures that hidden states can be properly captured for speculative decoding without interference from fusion optimizations.
849-853
: LGTM! Hidden state capture is correctly positioned.The capture logic is properly placed after MoE computation when fusion is disabled, ensuring the captured states include the complete MoE output.
901-904
: LGTM! Consistent hidden state capture implementation.The MLP forward method correctly implements the same capture pattern as the MoE method, maintaining consistency across the codebase.
1051-1071
: LGTM! Proper propagation of speculative metadata.The
spec_metadata
parameter is correctly passed through all decoder layers, enabling consistent speculative decoding behavior.
1077-1082
: LGTM! Correct base class and type parameters.The inheritance from
SpecDecOneEngineForCausalLM
with proper type parameters correctly integrates the model with the speculative decoding framework.
1313-1315
: LGTM! Correctly skips draft model weights.The weight loading logic properly skips modules prefixed with "draft_model", which aligns with the speculative decoding architecture where draft models are handled separately.
return_context_logits: bool = False, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
#TODO Izzy check this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Address the TODO comment.
The TODO comment lacks context about what needs to be checked. Please clarify or resolve this before merging.
Could you clarify what needs to be checked regarding num_generations_per_batch
? Is this related to the speculative decoding batch size calculation?
🏁 Script executed:
#!/bin/bash
# Show context around the TODO comment
rg -n "TODO Izzy" -C5 tensorrt_llm/_torch/models/modeling_deepseekv3.py
echo "----"
# Show usage and definition context of num_generations_per_batch
rg -n "num_generations_per_batch" -C5 tensorrt_llm/_torch/models/modeling_deepseekv3.py
Length of output: 1168
Clarify and remove the TODO for num_generations_per_batch
assignment
The # TODO Izzy check this
at line 1137 needs to be resolved before merging:
- Confirm that setting
correctly reflects the intended number of speculative generations (i.e., should it be
attn_metadata.num_generations_per_batch = self.model_nextn + 1
model_nextn
,model_nextn + 1
, or another value?). - Once confirmed, remove the
# TODO
and, if helpful, replace it with a concise comment explaining why this formula is used.
No further changes are required beyond clarifying and documenting this assignment.
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_deepseekv3.py at line 1137, confirm
whether setting attn_metadata.num_generations_per_batch to self.model_nextn + 1
correctly represents the intended number of speculative generations. After
verification, remove the TODO comment and add a brief explanatory comment
clarifying why this value is used, ensuring the code is clear and documented.
@classmethod | ||
def from_dict(cls, data: dict): | ||
return cls(**data) | ||
out = cls(**data) | ||
out.max_draft_len = out.num_nextn_predict_layers | ||
return out | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from_dict
silently overwrites user-provided max_draft_len
The helper now forces max_draft_len = num_nextn_predict_layers
, even when the input JSON explicitly sets a different max_draft_len
. This can mask mis-configurations and produces surprising behaviour.
Recommend only filling the value when it is missing, and erroring out on mismatches:
- @classmethod
- def from_dict(cls, data: dict):
- out = cls(**data)
- out.max_draft_len = out.num_nextn_predict_layers
- return out
+ @classmethod
+ def from_dict(cls, data: dict):
+ out = cls(**data)
+ if out.max_draft_len is None:
+ # Back-compat: derive from num_nextn_predict_layers
+ out.max_draft_len = out.num_nextn_predict_layers
+ elif out.max_draft_len != out.num_nextn_predict_layers:
+ raise ValueError(
+ "max_draft_len and num_nextn_predict_layers must match "
+ f"(got {out.max_draft_len} vs {out.num_nextn_predict_layers})"
+ )
+ return out
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@classmethod | |
def from_dict(cls, data: dict): | |
return cls(**data) | |
out = cls(**data) | |
out.max_draft_len = out.num_nextn_predict_layers | |
return out | |
@classmethod | |
def from_dict(cls, data: dict): | |
out = cls(**data) | |
if out.max_draft_len is None: | |
# Back-compat: derive from num_nextn_predict_layers | |
out.max_draft_len = out.num_nextn_predict_layers | |
elif out.max_draft_len != out.num_nextn_predict_layers: | |
raise ValueError( | |
"max_draft_len and num_nextn_predict_layers must match " | |
f"(got {out.max_draft_len} vs {out.num_nextn_predict_layers})" | |
) | |
return out |
🤖 Prompt for AI Agents
In tensorrt_llm/llmapi/llm_args.py around lines 435 to 440, the from_dict method
overwrites any user-provided max_draft_len by setting it to
num_nextn_predict_layers unconditionally. To fix this, modify the method to only
set max_draft_len to num_nextn_predict_layers if max_draft_len is not already
provided in the input data. Additionally, add a check to raise an error if
max_draft_len is provided but does not match num_nextn_predict_layers, to
prevent silent misconfigurations.
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
913313c
to
1f65f9b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
1138-1138
: Address the TODO comment before merging.This TODO comment was previously flagged and still needs to be resolved. Please clarify the correctness of the
num_generations_per_batch
calculation or remove the TODO.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
219-219
: Fix line length violation.The line exceeds the 120-character limit. Consider breaking it across multiple lines for better readability.
- predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 + predicted_tokens_per_seq = ( + model_config.spec_config.max_draft_len + 1 + if model_config.spec_config is not None + else 1 + )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
tensorrt_llm/_torch/models/modeling_deepseekv3.py
(12 hunks)tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(4 hunks)tensorrt_llm/_torch/speculative/interface.py
(2 hunks)tensorrt_llm/_torch/speculative/utils.py
(2 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- tensorrt_llm/_torch/speculative/utils.py
- tensorrt_llm/_torch/speculative/interface.py
- tensorrt_llm/_torch/models/modeling_speculative.py
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/speculative/eagle3.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py
219-219: Line too long (124 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (9)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (9)
52-52
: Import changes look correct for speculative decoding integration.The updated imports properly reflect the transition from MTP-specific to general speculative decoding metadata handling, and the base class change aligns with the new architecture.
Also applies to: 65-65, 67-67
730-730
: Method signature updates are well-designed.The addition of the optional
spec_metadata
parameter maintains backward compatibility while enabling speculative decoding functionality.Also applies to: 1051-1052
748-750
: Fusion configuration logic is correct.Properly disables fusion when layer capture is required for speculative decoding, ensuring hidden states remain accessible for capture.
Also applies to: 759-761
755-755
: Parameter passing is consistent and correct.Proper propagation of
spec_metadata
through the call chain enables speculative decoding functionality at all relevant layers.Also applies to: 766-766, 774-774
849-852
: Hidden state capture logic is well-implemented.Correctly captures hidden states and residuals when required for speculative decoding, with appropriate conditional checks.
Also applies to: 901-904
1078-1079
: Base class change aligns with speculative decoding architecture.The transition to
SpecDecOneEngineForCausalLM
is consistent with the overall speculative decoding integration.
1083-1083
: Constructor changes improve precision.The more specific check for MTP mode and proper parameter passing to the parent constructor enhance the initialization logic.
Also applies to: 1086-1087
1134-1134
: Speculative metadata integration is correct.The addition and propagation of
spec_metadata
parameter enables proper speculative decoding functionality in the forward pass.Also applies to: 1145-1145
1315-1317
: Draft model weight exclusion is appropriate.Correctly skips loading draft model weights, maintaining proper separation between draft and main models in speculative decoding.
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
Summary by CodeRabbit
New Features
Improvements
Bug Fixes