-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt): support loss plugin for external package #4248
feat(pt): support loss plugin for external package #4248
Conversation
📝 Walkthrough📝 WalkthroughWalkthroughThe pull request introduces modifications to the Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4248 +/- ##
==========================================
- Coverage 84.23% 84.22% -0.02%
==========================================
Files 548 548
Lines 51425 51431 +6
Branches 3051 3053 +2
==========================================
- Hits 43317 43316 -1
- Misses 7150 7155 +5
- Partials 958 960 +2 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/loss/loss.py (2)
45-63
: LGTM: Well-structured factory method with room for enhancementThe
get_loss
classmethod provides a clean factory pattern for instantiating loss modules. The docstring follows NumPy style and clearly explains the default behavior.Consider adding a note in the docstring about parameter validation:
Parameters ---------- loss_params : dict - The loss parameters + The loss parameters. Note: Parameter validation should be implemented + by concrete subclasses that override this method.
Line range hint
17-63
: Excellent plugin architecture designThe integration of the plugin registry with
TaskLoss
creates a flexible extension point for custom loss implementations. This architecture:
- Enables external packages to register custom losses
- Maintains backward compatibility through the base implementation
- Provides a standardized way to instantiate loss modules
This is a solid foundation for the external package support objective.
Consider documenting the plugin system usage in the module docstring to help external package developers understand how to register their custom losses.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- deepmd/pt/loss/loss.py (2 hunks)
🔇 Additional comments (2)
deepmd/pt/loss/loss.py (2)
12-14
: LGTM: Clean import additionThe import of
make_plugin_registry
is properly placed and follows the project's import style.
17-17
: LGTM: Proper inheritance order for plugin supportThe inheritance order is correct with
torch.nn.Module
first (for PyTorch functionality), followed byABC
(for abstract methods), and the plugin registry mixin last.Let's verify if any external loss plugins are already registered:
✅ Verification successful
Plugin registry implementation is consistent with codebase patterns
The inheritance change is consistent with the established plugin architecture in the codebase:
- The
make_plugin_registry("loss")
pattern is used consistently across multiple core components (models, descriptors, fitting, etc.)- The inheritance order matches other plugin-enabled classes like
DescriptorBlock
andModel
- The loss plugin system is properly integrated with argument checking via
loss_args_plugin
indeepmd/utils/argcheck.py
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for other files that might register loss plugins rg -l "TaskLoss.register" --type pyLength of output: 35
Script:
#!/bin/bash # Let's try a broader search to find loss plugin registrations and implementations # First, find all Python files that might contain loss-related code echo "=== Files containing 'loss' in their path ===" fd "loss" -e py echo -e "\n=== Looking for plugin registration patterns ===" # Search for register_plugin or similar patterns rg -i "register.*loss|loss.*register" --type py echo -e "\n=== Looking for TaskLoss implementations ===" # Search for classes that inherit from TaskLoss ast-grep --pattern 'class $_ (TaskLoss)'Length of output: 850
Script:
#!/bin/bash # Let's check the actual implementation of loss plugins and registration mechanism echo "=== Checking loss_args_plugin implementation ===" rg -A 5 "loss_args_plugin\s*=\s*" --type py echo -e "\n=== Looking for actual loss implementations ===" # Search for any class that might be a loss implementation ast-grep --pattern 'class $_ ($$$Loss$$$)' echo -e "\n=== Checking make_plugin_registry usage ===" rg "make_plugin_registry" --type py -A 2Length of output: 5439
support loss plugin for external package in pt backend
Summary by CodeRabbit
New Features
TaskLoss
, enhancing loss functionality in training.Bug Fixes
TaskLoss
class for better registration within the system.