Skip to content

Draft: Deepseek: Start Eagle work #6210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

IzzyPutterman
Copy link
Collaborator

@IzzyPutterman IzzyPutterman commented Jul 21, 2025

Summary by CodeRabbit

  • New Features

    • Added support for tracking speculative decoding metadata in DeepseekV3 and Eagle3 models.
    • Introduced a configuration option to set the number of capture layers for Eagle decoding.
  • Improvements

    • Enhanced flexibility in model behavior based on speculative decoding modes and capture layer settings.
    • Improved configuration handling for draft length synchronization in MTP decoding.
  • Bug Fixes

    • Adjusted model weight loading to skip draft model parameters, preventing unintended loading.

@IzzyPutterman IzzyPutterman requested review from a team as code owners July 21, 2025 06:13
Copy link

coderabbitai bot commented Jul 21, 2025

Walkthrough

This update integrates speculative decoding metadata tracking into DeepseekV3 and Eagle3 models, modifies fusion and capture logic based on new configuration parameters, and updates configuration classes to support flexible capture layer counts. It also adjusts model initialization, forward passes, and weight loading to accommodate these changes.

Changes

File(s) Change Summary
tensorrt_llm/_torch/models/modeling_deepseekv3.py Integrated speculative decoding metadata handling, changed base class to SpecDecOneEngineForCausalLM, updated method signatures to accept and propagate spec_metadata, modified fusion logic based on capture state, adjusted MTP layer initialization conditionally, and modified weight loading to skip draft model modules.
tensorrt_llm/_torch/models/modeling_speculative.py Updated Eagle3DraftModel constructor to conditionally initialize the fc layer based on num_capture_layers, adjusting input size accordingly.
tensorrt_llm/_torch/speculative/eagle3.py Changed hidden state tensor allocation to use num_capture_layers instead of fixed multiplier, updated capture layer logic in Eagle3SpecMetadata and Eagle3OneModelSpecMetadata initializations, and added num_capture_layers field to Eagle3OneModelSpecMetadata.
tensorrt_llm/_torch/speculative/interface.py Modified SpeculativeDecodingMode.use_one_engine to remove is_mtp() condition, added is_layer_capture method to SpecMetadata returning False, and refactored extend_ctx method with explicit conditional logic.
tensorrt_llm/_torch/speculative/utils.py Updated get_spec_metadata to pass num_capture_layers from config to Eagle3SpecMetadata and Eagle3OneModelSpecMetadata constructors.
tensorrt_llm/llmapi/llm_args.py Added num_capture_layers field to EagleDecodingConfig, updated MTPDecodingConfig.from_dict to set max_draft_len equal to num_nextn_predict_layers after instantiation.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant DeepseekV3ForCausalLM
    participant DeepseekV3Model
    participant DeepseekV3DecoderLayer
    participant SpecMetadata

    User->>DeepseekV3ForCausalLM: forward(input, spec_metadata)
    DeepseekV3ForCausalLM->>DeepseekV3Model: forward(input, spec_metadata)
    loop for each layer
        DeepseekV3Model->>DeepseekV3DecoderLayer: forward(hidden_states, spec_metadata)
        alt spec_metadata and capture
            DeepseekV3DecoderLayer->>SpecMetadata: capture(hidden_states, residual)
        end
    end
    DeepseekV3Model-->>DeepseekV3ForCausalLM: output
    DeepseekV3ForCausalLM-->>User: output
Loading

Estimated code review effort

4 (90–120 minutes)

Possibly related PRs

Suggested labels

Community want to contribute

Suggested reviewers

  • byshiue
  • mikeiovine
  • HuiGao-NV

Poem

In the warren where models leap and bound,
New layers to capture, new configs are found.
Speculative hops through metadata’s maze,
Drafts and fusions in a speculative haze.
With every carrot of code we chew,
The engine grows smarter—thanks to you! 🥕


📜 Recent review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1f65f9b and 008c8aa.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/speculative/interface.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/speculative/interface.py
✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tensorrt_llm/llmapi/llm_args.py (1)

320-330: Add validation for the new num_capture_layers field

num_capture_layers must be a positive integer; otherwise back-end code that allocates per-layer buffers will blow up at run-time. Consider guarding the field with a validator so that bad configs fail fast.

 class EagleDecodingConfig(DecodingBaseConfig):
@@
     max_non_leaves_per_layer: Optional[int] = None
     eagle3_one_model: Optional[bool] = True
     num_capture_layers: Optional[int] = 3
+
+    # Ensure the value is positive
+    @field_validator("num_capture_layers")
+    @classmethod
+    def _check_num_capture_layers(cls, v: int):
+        if v is not None and v <= 0:
+            raise ValueError("num_capture_layers must be a positive integer")
+        return v
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

219-219: Consider breaking the long line for better readability.

The logic is correct, but the line exceeds the 120-character limit.

-        predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
+        predicted_tokens_per_seq = (model_config.spec_config.max_draft_len + 1 
+                                    if model_config.spec_config is not None else 1)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 88076ee and 913313c.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (12 hunks)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (2 hunks)
  • tensorrt_llm/_torch/speculative/utils.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py

219-219: Line too long (124 > 120)

(E501)

🔇 Additional comments (14)
tensorrt_llm/_torch/speculative/utils.py (1)

39-39: LGTM! Correctly integrates the new capture layers configuration.

The addition of num_capture_layers=spec_config.num_capture_layers to both Eagle3SpecMetadata and Eagle3OneModelSpecMetadata constructors properly propagates the configurable parameter from the spec config to the metadata objects.

Also applies to: 49-49

tensorrt_llm/_torch/speculative/interface.py (2)

32-32: LGTM! Correctly narrows engine usage to Eagle3 one-model mode.

The change to use_one_engine() removes the MTP check, now focusing exclusively on the EAGLE3_ONE_MODEL mode. This aligns with the architectural changes where only Eagle3 one-model mode uses a single engine.


181-182: LGTM! Provides appropriate base implementation.

The new is_layer_capture method provides a sensible default implementation that returns False. This will be overridden in concrete metadata classes that actually perform layer capture (like Eagle3 metadata classes).

tensorrt_llm/_torch/models/modeling_speculative.py (1)

158-163: Conditional fc initialization is safe.

The apply_eagle3_fc method first checks whether the last dimension of hidden_states matches self.model.hidden_size before calling self.model.fc. Since the constructor only creates self.fc when num_capture_layers > 1 (i.e., when concatenated hidden states exceed hidden_size), the layer will always exist exactly when it’s needed. No other code paths reference self.model.fc unguarded.

No changes required.

tensorrt_llm/_torch/speculative/eagle3.py (2)

38-41: LGTM! Correctly uses configurable capture layers.

The hidden states tensor initialization now properly uses config.num_capture_layers instead of the hardcoded value of 3, making the tensor size consistent with the configurable number of capture layers.


94-95: Confirm single-layer capture logic in Eagle3SpecMetadata

The __post_init__ in Eagle3SpecMetadata now treats num_capture_layers == 1 as a special case and sets

self.layers_to_capture = (self.num_layers - 1,)

(i.e. the last layer) instead of capturing the first layer. I didn’t find any existing documentation or comments that specify this design decision.

Please verify that targeting the last hidden layer for single‐layer capture is the intended behavior, and consider adding or updating inline comments or spec documentation to clearly describe this change.

• File: tensorrt_llm/_torch/speculative/eagle3.py, lines 94–95

tensorrt_llm/_torch/models/modeling_deepseekv3.py (8)

52-52: LGTM! Base class change aligns with speculative decoding architecture.

The change from DecoderModelForCausalLM to SpecDecOneEngineForCausalLM correctly integrates the model with the speculative decoding framework.


65-65: LGTM! Import changes support speculative decoding integration.

The imports for MTPEagleWorker, MTPWorker, and SpecMetadata are properly added to support the new speculative decoding functionality.


730-767: LGTM! Fusion control logic properly handles layer capture.

The conditional disabling of fusion optimizations during layer capture is correctly implemented. This ensures that hidden states can be properly captured for speculative decoding without interference from fusion optimizations.


849-853: LGTM! Hidden state capture is correctly positioned.

The capture logic is properly placed after MoE computation when fusion is disabled, ensuring the captured states include the complete MoE output.


901-904: LGTM! Consistent hidden state capture implementation.

The MLP forward method correctly implements the same capture pattern as the MoE method, maintaining consistency across the codebase.


1051-1071: LGTM! Proper propagation of speculative metadata.

The spec_metadata parameter is correctly passed through all decoder layers, enabling consistent speculative decoding behavior.


1077-1082: LGTM! Correct base class and type parameters.

The inheritance from SpecDecOneEngineForCausalLM with proper type parameters correctly integrates the model with the speculative decoding framework.


1313-1315: LGTM! Correctly skips draft model weights.

The weight loading logic properly skips modules prefixed with "draft_model", which aligns with the speculative decoding architecture where draft models are handled separately.

return_context_logits: bool = False,
**kwargs,
) -> torch.Tensor:
#TODO Izzy check this
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Address the TODO comment.

The TODO comment lacks context about what needs to be checked. Please clarify or resolve this before merging.

Could you clarify what needs to be checked regarding num_generations_per_batch? Is this related to the speculative decoding batch size calculation?


🏁 Script executed:

#!/bin/bash
# Show context around the TODO comment
rg -n "TODO Izzy" -C5 tensorrt_llm/_torch/models/modeling_deepseekv3.py

echo "----"

# Show usage and definition context of num_generations_per_batch
rg -n "num_generations_per_batch" -C5 tensorrt_llm/_torch/models/modeling_deepseekv3.py

Length of output: 1168


Clarify and remove the TODO for num_generations_per_batch assignment

The # TODO Izzy check this at line 1137 needs to be resolved before merging:

  • Confirm that setting
    attn_metadata.num_generations_per_batch = self.model_nextn + 1
    correctly reflects the intended number of speculative generations (i.e., should it be model_nextn, model_nextn + 1, or another value?).
  • Once confirmed, remove the # TODO and, if helpful, replace it with a concise comment explaining why this formula is used.

No further changes are required beyond clarifying and documenting this assignment.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_deepseekv3.py at line 1137, confirm
whether setting attn_metadata.num_generations_per_batch to self.model_nextn + 1
correctly represents the intended number of speculative generations. After
verification, remove the TODO comment and add a brief explanatory comment
clarifying why this value is used, ensuring the code is clear and documented.

Comment on lines 435 to 440
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
out = cls(**data)
out.max_draft_len = out.num_nextn_predict_layers
return out

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

from_dict silently overwrites user-provided max_draft_len

The helper now forces max_draft_len = num_nextn_predict_layers, even when the input JSON explicitly sets a different max_draft_len. This can mask mis-configurations and produces surprising behaviour.

Recommend only filling the value when it is missing, and erroring out on mismatches:

-    @classmethod
-    def from_dict(cls, data: dict):
-        out = cls(**data)
-        out.max_draft_len = out.num_nextn_predict_layers
-        return out
+    @classmethod
+    def from_dict(cls, data: dict):
+        out = cls(**data)
+        if out.max_draft_len is None:
+            # Back-compat: derive from num_nextn_predict_layers
+            out.max_draft_len = out.num_nextn_predict_layers
+        elif out.max_draft_len != out.num_nextn_predict_layers:
+            raise ValueError(
+                "max_draft_len and num_nextn_predict_layers must match "
+                f"(got {out.max_draft_len} vs {out.num_nextn_predict_layers})"
+            )
+        return out
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
out = cls(**data)
out.max_draft_len = out.num_nextn_predict_layers
return out
@classmethod
def from_dict(cls, data: dict):
out = cls(**data)
if out.max_draft_len is None:
# Back-compat: derive from num_nextn_predict_layers
out.max_draft_len = out.num_nextn_predict_layers
elif out.max_draft_len != out.num_nextn_predict_layers:
raise ValueError(
"max_draft_len and num_nextn_predict_layers must match "
f"(got {out.max_draft_len} vs {out.num_nextn_predict_layers})"
)
return out
🤖 Prompt for AI Agents
In tensorrt_llm/llmapi/llm_args.py around lines 435 to 440, the from_dict method
overwrites any user-provided max_draft_len by setting it to
num_nextn_predict_layers unconditionally. To fix this, modify the method to only
set max_draft_len to num_nextn_predict_layers if max_draft_len is not already
provided in the input data. Additionally, add a check to raise an error if
max_draft_len is provided but does not match num_nextn_predict_layers, to
prevent silent misconfigurations.

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
@IzzyPutterman IzzyPutterman force-pushed the iputterman/deepseek-eagle branch from 913313c to 1f65f9b Compare July 21, 2025 22:09
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

1138-1138: Address the TODO comment before merging.

This TODO comment was previously flagged and still needs to be resolved. Please clarify the correctness of the num_generations_per_batch calculation or remove the TODO.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

219-219: Fix line length violation.

The line exceeds the 120-character limit. Consider breaking it across multiple lines for better readability.

-        predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
+        predicted_tokens_per_seq = (
+            model_config.spec_config.max_draft_len + 1 
+            if model_config.spec_config is not None 
+            else 1
+        )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 913313c and 1f65f9b.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (12 hunks)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (4 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (2 hunks)
  • tensorrt_llm/_torch/speculative/utils.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/speculative/eagle3.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py

219-219: Line too long (124 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (9)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (9)

52-52: Import changes look correct for speculative decoding integration.

The updated imports properly reflect the transition from MTP-specific to general speculative decoding metadata handling, and the base class change aligns with the new architecture.

Also applies to: 65-65, 67-67


730-730: Method signature updates are well-designed.

The addition of the optional spec_metadata parameter maintains backward compatibility while enabling speculative decoding functionality.

Also applies to: 1051-1052


748-750: Fusion configuration logic is correct.

Properly disables fusion when layer capture is required for speculative decoding, ensuring hidden states remain accessible for capture.

Also applies to: 759-761


755-755: Parameter passing is consistent and correct.

Proper propagation of spec_metadata through the call chain enables speculative decoding functionality at all relevant layers.

Also applies to: 766-766, 774-774


849-852: Hidden state capture logic is well-implemented.

Correctly captures hidden states and residuals when required for speculative decoding, with appropriate conditional checks.

Also applies to: 901-904


1078-1079: Base class change aligns with speculative decoding architecture.

The transition to SpecDecOneEngineForCausalLM is consistent with the overall speculative decoding integration.


1083-1083: Constructor changes improve precision.

The more specific check for MTP mode and proper parameter passing to the parent constructor enhance the initialization logic.

Also applies to: 1086-1087


1134-1134: Speculative metadata integration is correct.

The addition and propagation of spec_metadata parameter enables proper speculative decoding functionality in the forward pass.

Also applies to: 1145-1145


1315-1317: Draft model weight exclusion is appropriate.

Correctly skips loading draft model weights, maintaining proper separation between draft and main models in speculative decoding.

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant