-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(pt): fix not used sys_probs
#4353
Conversation
📝 WalkthroughWalkthroughThe changes involve refactoring the data sampling logic in the Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant DataLoader
participant Sampler
Trainer->>DataLoader: get_data_loader()
DataLoader->>Sampler: get_sampler_from_params(training_params)
Sampler-->>DataLoader: return sampler
DataLoader-->>Trainer: return dataloader
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (4)
source/tests/pt/test_sampler.py (2)
17-17
: Consider documenting deprecation ofget_weighted_sampler
Since both
get_sampler_from_params
andget_weighted_sampler
are imported, and the new tests use the former, it appears thatget_weighted_sampler
might be deprecated. Consider adding a deprecation notice if that's the case.
109-119
: LGTM! Consider enhancing the test documentationThe test effectively verifies that
sys_probs
takes precedence when bothsys_probs
andauto_prob
are provided, which aligns with fixing the "not used sys_probs" issue mentioned in the PR title.Consider adding a docstring to better document the test's purpose:
def test_sys_probs_end2end(self): + """ + Verifies that sys_probs takes precedence over auto_prob when both are provided, + ensuring that manually specified system probabilities are not ignored. + """ sys_probs = [0.1, 0.4, 0.5]deepmd/pt/utils/dataloader.py (2)
311-324
: Add documentation and type hints for better code clarity.The function logic looks good, but could benefit from improved documentation and type safety:
Consider applying these improvements:
-def get_sampler_from_params(_data, _params): +def get_sampler_from_params(_data: DpLoaderSet, _params: dict) -> WeightedRandomSampler: + """Create a weighted sampler based on provided parameters. + + Parameters + ---------- + _data : DpLoaderSet + The dataset for which to create the sampler + _params : dict + Dictionary containing sampling parameters: + - sys_probs: Optional[str] - System probabilities string + - auto_prob: Optional[str] - Automatic probability configuration + + Returns + ------- + WeightedRandomSampler + The configured sampler instance + """ if ( "sys_probs" in _params and _params["sys_probs"] is not None ): # use sys_probs first
309-310
: Remove extra blank line.Keep only one blank line between function definitions for consistent spacing.
def get_weighted_sampler(training_data, prob_style, sys_prob=False): # ... existing implementation ... - def get_sampler_from_params(_data, _params):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/pt/train/training.py
(3 hunks)deepmd/pt/utils/dataloader.py
(1 hunks)source/tests/pt/test_sampler.py
(2 hunks)
🔇 Additional comments (3)
deepmd/pt/train/training.py (3)
50-50
: LGTM: Import statement updated correctly.
The import statement has been updated to use the renamed function get_sampler_from_params
, which aligns with the refactoring changes.
163-163
: LGTM: Simplified sampler creation logic.
The code now uses the new get_sampler_from_params
function with appropriate warning for missing sampler cases.
184-184
: Verify parameter structure assumptions.
The code now accesses specific data subsets using "training_data" and "validation_data" keys. While this improves clarity, we should verify that these keys are guaranteed to exist in the parameters.
Also applies to: 191-193
✅ Verification successful
Parameter structure validation is confirmed.
The codebase has robust validation for both parameters:
- "training_data" is a required parameter with strict validation of its structure
- "validation_data" is an optional parameter (can be None) with equivalent validation when present
The parameter structure assumptions are safe and guaranteed by the built-in argument validation system.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the parameter structure in training configuration files
# Look for training parameter definitions to confirm the expected structure
# Search for training parameter definitions
rg -A 10 "training_data.*=|training_data.*:"
# Search for validation parameter definitions
rg -A 10 "validation_data.*=|validation_data.*:"
Length of output: 403111
sys_probs
was not used in pt, because its priority was lower than that ofauto_prob
, whileauto_prob
always has its default values.See #4346 (reply in thread) .
Summary by CodeRabbit