-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt/tf): add bias changing param/interface #3933
Conversation
WalkthroughWalkthroughThe changes introduce a new feature to the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CLI as CLI (deepmd/main.py)
participant EntryPoint as EntryPoint (deepmd/pt/entrypoints/main.py)
participant Training as Training (deepmd/pt/train/training.py)
participant ModelWrapper as ModelWrapper (deepmd/pt/train/wrapper.py)
participant StatUtil as StatUtil (deepmd/pt/utils/stat.py)
User->>CLI: dp --pt change-bias --config config.json
CLI->>EntryPoint: Parse and hand over control to `change_bias` function
EntryPoint->>Training: Initialize training with change_bias_after_training
Training->>ModelWrapper: Load model and data
ModelWrapper->>StatUtil: Compute data statistics
StatUtil->>ModelWrapper: Returned statistics
ModelWrapper->>Training: Adjust bias based on statistics
Training->>EntryPoint: Save updated model
EntryPoint->>User: Model bias updated and saved
In this sequence diagram, the process of updating the model's output bias using the newly introduced subcommand Recent review detailsConfiguration used: CodeRabbit UI Files selected for processing (1)
Files skipped from review as they are similar to previous changes (1)
Tip Early access features: enabledWe are currently testing the following features in early access:
Note:
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #3933 +/- ##
==========================================
- Coverage 82.88% 82.84% -0.04%
==========================================
Files 520 520
Lines 50679 50787 +108
Branches 3015 3015
==========================================
+ Hits 42003 42074 +71
- Misses 7740 7775 +35
- Partials 936 938 +2 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Outside diff range and nitpick comments (16)
deepmd/pt/entrypoints/main.py (4)
390-391
: Ensure correct file extension handling.Consider using a more robust method to handle file extensions to avoid potential errors.
file_extension = FLAGS.INPUT.split(".")[-1] if file_extension not in ["pt", "pth"]: raise RuntimeError("The model provided must be a checkpoint file with a .pt extension or a frozen model with a .pth extension")
406-409
: Use a more descriptive error message.Provide more context in the error message to help users understand the issue.
raise RuntimeError( "The model provided must be a checkpoint file with a .pt extension or a frozen model with a .pth extension. Provided file: {FLAGS.INPUT}" )
416-421
: Improve assertion error messages.Provide more informative error messages for assertions to aid debugging.
assert model_branch is not None, "For multitask model, the model branch must be set!" assert model_branch in model_params["model_dict"], ( f"For multitask model, the model branch must be in the 'model_dict'! Available options are: {list(model_params['model_dict'].keys())}." )
437-441
: Combine data requirements for clarity.Combine the data requirement assignment into a single line for better readability.
data_requirement = mock_loss.label_requirement + training.get_additional_data_requirement(model_to_change)deepmd/pt/utils/stat.py (1)
Line range hint
127-127
: Remove unused variablef
.The variable
f
is assigned but never used.- with h5py.File(stat_file_path_single, "w") as f: + with h5py.File(stat_file_path_single, "w"):deepmd/main.py (1)
Line range hint
114-114
: Usekey not in dict
instead ofkey not in dict.keys()
.Simplify the condition to improve readability and performance.
- if default_backend not in BACKEND_TABLE.keys(): + if default_backend not in BACKEND_TABLE:deepmd/tf/train/trainer.py (3)
Line range hint
151-155
: Simplify nested if-statements.Use a single
if
statement instead of nestedif
statements.- if self.mixed_prec is not None: - if ( - self.mixed_prec["compute_prec"] not in ("float16", "bfloat16") - or self.mixed_prec["output_prec"] != "float32" - ): + if self.mixed_prec is not None and ( + self.mixed_prec["compute_prec"] not in ("float16", "bfloat16") + or self.mixed_prec["output_prec"] != "float32" + ):
Line range hint
371-371
: Use context handler for opening files.Use
with open(...) as ...:
to ensure files are properly closed after their suite finishes.- fp = open(self.disp_file, "w") - fp = open(self.disp_file, "w") - fp = open(self.disp_file, "w") - fp = open(self.disp_file, "w") - fp = open(self.disp_file, "a") + with open(self.disp_file, "w") as fp: + with open(self.disp_file, "w") as fp: + with open(self.disp_file, "w") as fp: + with open(self.disp_file, "w") as fp: + with open(self.disp_file, "a") as fp:Also applies to: 378-378, 387-387, 392-392, 420-420
Line range hint
464-464
: Remove unused variables.Remove the assignments to unused variables to clean up the code.
- tb_valid_writer = None - fitting_key = None - test_time = 0 - except FileNotFoundError as e:Also applies to: 490-490, 550-550, 822-822
deepmd/pt/train/training.py (6)
Line range hint
294-294
: Useconfig.get("learning_rate_dict")
instead ofconfig.get("learning_rate_dict", None)
.Replace
config.get("learning_rate_dict", None)
withconfig.get("learning_rate_dict")
.- if self.multi_task and config.get("learning_rate_dict", None) is not None: + if self.multi_task and config.get("learning_rate_dict") is not None:Also applies to: 388-388
Line range hint
486-486
: Usekey in dict
instead ofkey in dict.keys()
.Remove
.keys()
.- missing_keys = [key for key in self.model_keys if key not in self.optim_dict.keys()] + missing_keys = [key for key in self.model_keys if key not in self.optim_dict]
Line range hint
633-633
: Use context handler for opening files.Replace with
with open(...) as ...:
.- fout = open(self.disp_file, mode="w", buffering=1) if self.rank == 0 else None + fout = open(self.disp_file, mode="w", buffering=1) if self.rank == 0 else None if fout: with fout: # your code - fout1 = open(record_file, mode="w", buffering=1) + fout1 = open(record_file, mode="w", buffering=1) if fout1: with fout1: # your codeAlso applies to: 637-637
Line range hint
681-684
: Use ternary operator forpref_lr
.Replace
if
-else
-block withpref_lr = _lr.start_lr if _step_id < self.warmup_steps else cur_lr
.- if _step_id < self.warmup_steps: - pref_lr = _lr.start_lr - else: - pref_lr = cur_lr + pref_lr = _lr.start_lr if _step_id < self.warmup_steps else cur_lr
Line range hint
791-791
: Rename unused loop control variableii
to_ii
.Rename
ii
to_ii
.- for ii, model_key in enumerate(self.model_keys): + for _ii, model_key in enumerate(self.model_keys):
Line range hint
1089-1089
: Usekey in dict
instead ofkey in dict.keys()
.Remove
.keys()
.- target_keys = [i for i in _random_state_dict.keys() if i != "_extra_state" and f".{_model_key}." in i] + target_keys = [i for i in _random_state_dict if i != "_extra_state" and f".{_model_key}." in i]deepmd/utils/argcheck.py (1)
Line range hint
1167-1174
: Remove unused local variables.The variables
link_lf
,link_se_e2_a
,link_se_e2_r
,link_se_e3
,link_se_a_tpe
,link_hybrid
,link_se_atten
, andlink_se_atten_v2
are assigned but never used.- link_lf = make_link("loc_frame", "model/descriptor[loc_frame]") - link_se_e2_a = make_link("se_e2_a", "model/descriptor[se_e2_a]") - link_se_e2_r = make_link("se_e2_r", "model/descriptor[se_e2_r]") - link_se_e3 = make_link("se_e3", "model/descriptor[se_e3]") - link_se_a_tpe = make_link("se_a_tpe", "model/descriptor[se_a_tpe]") - link_hybrid = make_link("hybrid", "model/descriptor[hybrid]") - link_se_atten = make_link("se_atten", "model/descriptor[se_atten]") - link_se_atten_v2 = make_link("se_atten_v2", "model/descriptor[se_atten_v2]")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- please add unit test.
- along with stat from data, shall we allow the user explicitly setting the bias number?
@coderabbitai do you have any idea how to add a ut for command |
Tip For best results, initiate chat on the files or code changes. @iProzd, to add a unit test for the
import unittest
from deepmd.pt.entrypoints.main import change_bias
class TestChangeBias(unittest.TestCase):
def setUp(self):
# Setup any necessary preconditions here
self.flags = argparse.Namespace(
model_path="path/to/model.pt",
data_path="path/to/data",
num_batches=10,
mode="change"
)
def test_change_bias(self):
# Call the change_bias function with the setup flags
result = change_bias(self.flags)
# Add assertions to verify the expected outcomes
self.assertIsNotNone(result)
# Add more assertions as needed to validate the bias change
if __name__ == '__main__':
unittest.main() This example assumes you have a suitable model and data for testing. You may need to adjust paths and parameters according to your specific setup. [ADD_LEARNING] Analysis chainScripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Find the implementation of the change_bias command
ast-grep --lang python --pattern $'def change_bias($_) {
$$$
}'
# Locate existing unit tests for similar commands
fd --type f --name "test_*.py"
Length of output: 802 Script: #!/bin/bash
# Correctly find the implementation of the change_bias command using a simpler pattern
ast-grep --lang python --pattern $'def change_bias($_) {$$$}'
# Correctly locate existing unit tests for similar commands
fd "test_*.py" --type f
Length of output: 685 |
Tip For best results, initiate chat on the files or code changes. # Let's review the relevant parts of deepmd/pt/entrypoints/main.py to understand the implementation of the change-bias command.
# Excerpt from deepmd/pt/entrypoints/main.py:
import argparse
import copy
from deepmd.__version__ import __version__
from deepmd.common import expand_sys_str
from deepmd.pt.train import training
from deepmd.pt.train.wrapper import ModelWrapper
from deepmd.pt.train.entry import preprocess_shared_params
from deepmd.pt.utils.stat import make_stat_input
def show(FLAGS):
# existing show function implementation
def change_bias(FLAGS):
# The function handling the change-bias command
# Load model parameters, process data systems, adjust biases, and save the updated model.
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deep Potential Model Training")
subparsers = parser.add_subparsers(dest="command")
# Other subcommands
# Adding the change-bias subcommand
change_bias_parser = subparsers.add_parser("change-bias", help="Change model output bias based on input data")
change_bias_parser.add_argument("model", type=str, help="Path to the model file")
change_bias_parser.add_argument("-s", "--system", type=str, required=True, help="System data")
change_bias_parser.add_argument("-n", "--num", type=int, default=10, help="Number of samples")
change_bias_parser.add_argument("-m", "--mode", type=str, choices=["change", "reset"], default="change", help="Mode of bias adjustment")
change_bias_parser.set_defaults(func=change_bias)
FLAGS = parser.parse_args()
if FLAGS.command == "change-bias":
change_bias(FLAGS) # Now, let's create a unit test for the change-bias command using pytest.
import pytest
from deepmd.pt.entrypoints.main import change_bias
@pytest.fixture
def mock_flags():
class MockFlags:
def __init__(self):
self.command = "change-bias"
self.model = "model.pt"
self.system = "data"
self.num = 10
self.mode = "change"
return MockFlags()
def test_change_bias(mock_flags, monkeypatch):
# Mock the functions and methods used within change_bias
monkeypatch.setattr("deepmd.pt.train.wrapper.ModelWrapper", lambda *args, **kwargs: None)
monkeypatch.setattr("deepmd.pt.utils.stat.make_stat_input", lambda *args, **kwargs: None)
monkeypatch.setattr("deepmd.pt.train.training", lambda *args, **kwargs: None)
monkeypatch.setattr("deepmd.common.expand_sys_str", lambda *args, **kwargs: None)
# Call the change_bias function with the mock flags
change_bias(mock_flags)
# Add assertions to verify expected behavior
assert True # Replace with actual assertions based on expected outcomes
# Note: The actual assertions will depend on the specific behavior and outcomes of the change_bias function. @iProzd, I've created a unit test for the import pytest
from deepmd.pt.entrypoints.main import change_bias
@pytest.fixture
def mock_flags():
class MockFlags:
def __init__(self):
self.command = "change-bias"
self.model = "model.pt"
self.system = "data"
self.num = 10
self.mode = "change"
return MockFlags()
def test_change_bias(mock_flags, monkeypatch):
# Mock the functions and methods used within change_bias
monkeypatch.setattr("deepmd.pt.train.wrapper.ModelWrapper", lambda *args, **kwargs: None)
monkeypatch.setattr("deepmd.pt.utils.stat.make_stat_input", lambda *args, **kwargs: None)
monkeypatch.setattr("deepmd.pt.train.training", lambda *args, **kwargs: None)
monkeypatch.setattr("deepmd.common.expand_sys_str", lambda *args, **kwargs: None)
# Call the change_bias function with the mock flags
change_bias(mock_flags)
# Add assertions to verify expected behavior
assert True # Replace with actual assertions based on expected outcomes You can adjust the assertions based on the specific behavior and outcomes of the [ADD_LEARNING] Analysis chainScripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Locate the implementation of the `change-bias` command.
ast-grep --lang python --pattern $'def change_bias($_):' --json
# Identify the testing framework by looking for test files and their imports.
fd -e py test | xargs grep -E 'import (unittest|pytest)'
Length of output: 11928 Script: #!/bin/bash
# Search for the `change-bias` command within the deepmd directory.
rg 'change-bias' deepmd
Length of output: 245 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Outside diff range and nitpick comments (7)
deepmd/pt/entrypoints/main.py (1)
Line range hint
517-520
: Use a ternary operator for cleaner code.Replace the
if-else
block with a ternary operator.- if not isinstance(args, argparse.Namespace): - FLAGS = parse_args(args=args) - else: - FLAGS = args + FLAGS = parse_args(args=args) if not isinstance(args, argparse.Namespace) else argsTools
GitHub Check: codecov/patch
[warning] 393-403: deepmd/pt/entrypoints/main.py#L393-L403
Added lines #L393 - L403 were not covered by tests
[warning] 405-405: deepmd/pt/entrypoints/main.py#L405
Added line #L405 was not covered by tests
[warning] 409-411: deepmd/pt/entrypoints/main.py#L409-L411
Added lines #L409 - L411 were not covered by tests
[warning] 414-415: deepmd/pt/entrypoints/main.py#L414-L415
Added lines #L414 - L415 were not covered by tests
[warning] 418-418: deepmd/pt/entrypoints/main.py#L418
Added line #L418 was not covered by tests
[warning] 422-424: deepmd/pt/entrypoints/main.py#L422-L424
Added lines #L422 - L424 were not covered by tests
[warning] 429-432: deepmd/pt/entrypoints/main.py#L429-L432
Added lines #L429 - L432 were not covered by tests
[warning] 435-435: deepmd/pt/entrypoints/main.py#L435
Added line #L435 was not covered by tests
[warning] 437-437: deepmd/pt/entrypoints/main.py#L437
Added line #L437 was not covered by tests
[warning] 439-439: deepmd/pt/entrypoints/main.py#L439
Added line #L439 was not covered by tests
[warning] 442-442: deepmd/pt/entrypoints/main.py#L442
Added line #L442 was not covered by tests
[warning] 445-446: deepmd/pt/entrypoints/main.py#L445-L446
Added lines #L445 - L446 were not covered by tests
[warning] 449-450: deepmd/pt/entrypoints/main.py#L449-L450
Added lines #L449 - L450 were not covered by tests
[warning] 455-455: deepmd/pt/entrypoints/main.py#L455
Added line #L455 was not covered by tests
[warning] 458-459: deepmd/pt/entrypoints/main.py#L458-L459
Added lines #L458 - L459 were not covered by testsdeepmd/pt/model/atomic_model/base_atomic_model.py (4)
Line range hint
80-81
: Avoid using mutable data structures for argument defaults.Replace with
None
and initialize within the function.- atom_exclude_types: List[int] = [], - pair_exclude_types: List[Tuple[int, int]] = [], + atom_exclude_types: Optional[List[int]] = None, + pair_exclude_types: Optional[List[Tuple[int, int]]] = None,
Line range hint
95-95
: Remove unused local variable.The local variable
ntypes
is assigned but never used.- ntypes = self.get_ntypes()
Line range hint
257-257
: Usekey in dict
instead ofkey in dict.keys()
.Remove
.keys()
for better performance and readability.- if key not in dict.keys(): + if key not in dict:
Line range hint
545-546
: Usekey in dict
instead ofkey in dict.keys()
.Remove
.keys()
for better performance and readability.- if key in dict.keys(): - if key not in dict.keys(): + if key in dict: + if key not in dict:deepmd/main.py (2)
Line range hint
95-95
: Addstacklevel
keyword argument towarnings.warn
.To ensure the warning points to the correct line in the user's code, add the
stacklevel
argument.- warnings.warn( - f"The option `{option_string}` is deprecated. It will be ignored.", - FutureWarning, - ) + warnings.warn( + f"The option `{option_string}` is deprecated. It will be ignored.", + FutureWarning, + stacklevel=2 + )Tools
Ruff
60-63: Within an
except
clause, raise exceptions withraise ... from err
orraise ... from None
to distinguish them from errors in exception handling(B904)
GitHub Check: codecov/patch
[warning] 54-54: deepmd/main.py#L54
Added line #L54 was not covered by tests
[warning] 56-60: deepmd/main.py#L56-L60
Added lines #L56 - L60 were not covered by tests
Line range hint
126-126
: Usekey not in dict
instead ofkey not in dict.keys()
.Remove
.keys()
for better performance and readability.- if default_backend not in BACKEND_TABLE.keys(): + if default_backend not in BACKEND_TABLE:Tools
Ruff
60-63: Within an
except
clause, raise exceptions withraise ... from err
orraise ... from None
to distinguish them from errors in exception handling(B904)
GitHub Check: codecov/patch
[warning] 54-54: deepmd/main.py#L54
Added line #L54 was not covered by tests
[warning] 56-60: deepmd/main.py#L56-L60
Added lines #L56 - L60 were not covered by tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Outside diff range and nitpick comments (2)
deepmd/main.py (2)
Line range hint
83-83
: Addstacklevel
argument to the warning.Specify the
stacklevel
argument in the warning to provide more context about where the warning was triggered.- warnings.warn( + warnings.warn( + f"The option `{option_string}` is deprecated. It will be ignored.", + FutureWarning, + stacklevel=2,
Line range hint
114-114
: Usekey not in dict
instead ofkey not in dict.keys()
.Remove
.keys()
for better readability and performance.- if default_backend not in BACKEND_TABLE.keys(): + if default_backend not in BACKEND_TABLE:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
As planned in deepmodeling#3933 (comment) Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
As planned in #3933 (comment) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved exception handling across multiple modules to provide more context in error messages. - **Chores** - Added rule `B904` for `raise-without-from-inside-except` to `pyproject.toml`. These changes enhance debugging by offering clearer error messages and context, thus improving the overall user experience when dealing with exceptions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Add bias changing param/interface For pt/tf, add `training/change_bias_after_training` to change out bias once after training. For pt, add a separate command `change-bias` to change trained model(pt/pth, multi/single) out bias for specific data: ``` dp change-bias model.pt -s data -n 10 -m change ``` UTs for this feature are still in consideration. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added a new subcommand `change-bias` to adjust model output bias in the PyTorch backend. - Introduced test cases for changing model biases via new test suite. - **Documentation** - Added documentation for the new `change-bias` command, including usage and options. - Updated `index.rst` to include a new entry for `change-bias` under the `Model` section. - **Bug Fixes** - Adjusted data handling in `make_stat_input` to limit processing to a specified number of batches. - **Refactor** - Restructured training configuration to include the parameter `change_bias_after_training`. - Modularized data requirement handling and bias adjustment functions. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
As planned in deepmodeling#3933 (comment) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved exception handling across multiple modules to provide more context in error messages. - **Chores** - Added rule `B904` for `raise-without-from-inside-except` to `pyproject.toml`. These changes enhance debugging by offering clearer error messages and context, thus improving the overall user experience when dealing with exceptions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Add bias changing param/interface
For pt/tf, add
training/change_bias_after_training
to change out bias once after training.For pt, add a separate command
change-bias
to change trained model(pt/pth, multi/single) out bias for specific data:UTs for this feature are still in consideration.
Summary by CodeRabbit
New Features
change-bias
to adjust model output bias in the PyTorch backend.Documentation
change-bias
command, including usage and options.index.rst
to include a new entry forchange-bias
under theModel
section.Bug Fixes
make_stat_input
to limit processing to a specified number of batches.Refactor
change_bias_after_training
.