-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/add dos finetune UT #3876
Feat/add dos finetune UT #3876
Conversation
WalkthroughThe overall changes involve modifying the Changes
Sequence Diagram(s)sequenceDiagram
participant TestSuite
participant TrainingModule
participant StatisticsModule
TestSuite->>+TrainingModule: call single_model_finetune()
TrainingModule-->>-TrainingModule: Simplified _model assignment
TrainingModule-->>TestSuite: Return result
TestSuite->>+StatisticsModule: call compute_output_stats_global(model_pred)
StatisticsModule-->>StatisticsModule: Check if model_pred is empty
alt model_pred is empty
StatisticsModule-->>TestSuite: Return {}, {}
else model_pred is not empty
StatisticsModule->>StatisticsModule: Proceed with computation
StatisticsModule-->>TestSuite: Return computed stats
end
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Outside diff range and nitpick comments (5)
deepmd/pt/train/training.py (5)
Line range hint
368-368
: Refactor the usage ofconfig.get("learning_rate_dict")
to remove unnecessaryNone
default.- config.get("learning_rate_dict", None) + config.get("learning_rate_dict")Also applies to: 462-462
Line range hint
560-560
: Optimize dictionary key checks by removing.keys()
for better performance.- key in dict.keys() + key in dictAlso applies to: 1132-1132
Line range hint
705-705
: Use context handlers for file operations to ensure files are properly closed after operations, avoiding resource leaks.- fout = open(self.disp_file, mode="w", buffering=1) + with open(self.disp_file, mode="w", buffering=1) as fout:- fout1 = open(record_file, mode="w", buffering=1) + with open(record_file, mode="w", buffering=1) as fout1:Also applies to: 709-709
Line range hint
751-754
: Use a ternary operator for cleaner and more concise logic when settingpref_lr
.- if _step_id < self.warmup_steps: - pref_lr = _lr.start_lr - else: - pref_lr = cur_lr + pref_lr = _lr.start_lr if _step_id < self.warmup_steps else cur_lr
Line range hint
861-861
: Remove unused loop control variableii
or replace it with_
to indicate it's intentionally unused.- for ii in range(valid_numb_batch): + for _ in range(valid_numb_batch):
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #3876 +/- ##
=======================================
Coverage 82.70% 82.71%
=======================================
Files 517 517
Lines 50137 50138 +1
Branches 2984 2987 +3
=======================================
+ Hits 41467 41472 +5
+ Misses 7760 7756 -4
Partials 910 910 ☔ View full report in Codecov by Sentry. |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Outside diff range and nitpick comments (3)
deepmd/pt/utils/stat.py (3)
Line range hint
88-88
: Avoid using mutable default arguments.- def compute_output_stats(merged: Union[Callable[[], List[dict]], List[dict]] = [], ntypes: int, keys: Union[str, List[str]] = ["energy"], ... + def compute_output_stats(merged: Union[Callable[[], List[dict]], List[dict]] = None, ntypes: int, keys: Union[str, List[str]] = None, ... + if merged is None: + merged = [] + if keys is None: + keys = ["energy"]Also applies to: 238-238
Line range hint
154-154
: Remove unused variable.- natoms = system["natoms"]
Line range hint
166-166
: Bind loop variables in function definition.- for nframes in range(len(sampled)): + for nframes, system in enumerate(sampled):Also applies to: 167-167
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Outside diff range and nitpick comments (4)
deepmd/pt/utils/stat.py (4)
Line range hint
88-88
: Avoid using mutable default arguments to prevent unexpected behaviors.- def make_stat_input(datasets, dataloaders, nbatches=[]): + def make_stat_input(datasets, dataloaders, nbatches=None): + if nbatches is None: + nbatches = []
Line range hint
154-154
: Remove the unused variablenatoms
to clean up the code.- natoms = system["natoms"]
Line range hint
238-238
: Avoid using mutable default arguments to prevent unexpected behaviors.- def compute_output_stats(merged, ntypes, keys=[], ...): + def compute_output_stats(merged, ntypes, keys=None, ...): + if keys is None: + keys = []
Line range hint
291-291
: Simplify dictionary key checks by removing.keys()
.- if kk in preset_bias.keys(): + if kk in preset_bias:Also applies to: 334-334, 343-343, 450-450, 487-487, 493-493, 495-495, 500-500
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Outside diff range and nitpick comments (4)
deepmd/pt/utils/stat.py (4)
Line range hint
238-238
: Refactor to avoid using.keys()
when checking for key existence in dictionaries.- if kk in preset_bias.keys() + if kk in preset_biasThis change should be applied to all similar instances in the file.
Also applies to: 334-334, 343-343, 450-450, 487-487, 493-493, 495-495, 500-500
Line range hint
88-88
: Avoid using mutable data structures as default arguments to prevent potential bugs related to mutable default values.- def compute_output_stats_global(sampled: List[dict], ntypes: int, keys: List[str] = ["energy"], rcond: Optional[float] = None, preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None, model_pred: Optional[Dict[str, np.ndarray]] = None): + def compute_output_stats_global(sampled: List[dict], ntypes: int, keys: Optional[List[str]] = None, rcond: Optional[float] = None, preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None, model_pred: Optional[Dict[str, np.ndarray]] = None): + if keys is None: + keys = ["energy"]Also applies to: 238-238
Line range hint
154-154
: Remove the unused variablenatoms
as it is defined but never used.- natoms = { - kk: [ - system["atype"] - for system in sampled - if ("atom_" + kk) in system and system.get(f"find_atom_{kk}", 0) > 0 - ] - for kk in keys - }
Line range hint
166-166
: Bind the loop variablesnframes
andsystem
to their respective functions to avoid potential scoping issues.- nf = {kk: merged_natoms[kk].shape[0] for kk in keys if kk in merged_natoms} + nf = {kk: merged_natoms[kk].shape[0] for kk, system in enumerate(sampled) if kk in merged_natoms}Also applies to: 167-167
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved the `compute_output_stats_global` function to handle empty model predictions gracefully. - **Tests** - Enhanced finetuning tests with new model support and additional checks for "dos" model setup. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Summary by CodeRabbit
Bug Fixes
compute_output_stats_global
function to handle empty model predictions gracefully.Tests