-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(pt): optimize graph memory usage #4006
Conversation
WalkthroughWalkthroughThe recent changes introduce a new Changes
Possibly related issues
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Outside diff range, codebase verification and nitpick comments (8)
deepmd/pt/model/model/dp_zbl_model.py (1)
114-115
: Ensure all calls toforward_lower
include theinference
parameter.The
forward_lower
function indeepmd/pt/model/model/dp_zbl_model.py
has been updated to include theinference
parameter. However, many calls to this function in the codebase do not include this parameter, which could lead to runtime errors or unexpected behavior.
- Files and lines to update:
source/tests/pt/model/test_dp_model.py
- Lines: 0, 1, 2
source/tests/pt/model/test_forward_lower.py
- Lines: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
source/tests/pt/model/test_ener_spin_model.py
- Lines: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
source/tests/universal/common/cases/model/utils.py
- Lines: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
source/tests/universal/pt/backend.py
- Lines: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
Please update these calls to include the
inference
parameter to ensure consistency and correct functionality.Analysis chain
Verify the function usage in the codebase.
The code changes are approved.
However, ensure that all function calls to
forward_lower
match the new signature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `forward_lower` match the new signature. # Test: Search for the function usage. Expect: Only occurances of the new signature. rg --type python -A 5 $'forward_lower'Length of output: 68
Script:
#!/bin/bash # Description: Verify all function calls to `forward_lower` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg --type py -A 5 $'forward_lower'Length of output: 8436
Script:
#!/bin/bash # Description: Verify the function definition of `forward_lower` in `deepmd/pt/model/model/dp_zbl_model.py` includes the `inference` parameter # and check if all calls to this function in the codebase include this parameter. # Test: Search for the function definition of `forward_lower` in the specified file. ast-grep --lang python --pattern $'def forward_lower($$$):\n $$$' deepmd/pt/model/model/dp_zbl_model.py # Test: Search for all calls to `forward_lower` in the codebase and check if they include the `inference` parameter. rg --type py 'forward_lower' -A 5Length of output: 11331
deepmd/pt/model/model/transform_output.py (3)
125-126
: Update function call to match new signatureThe function call to
take_deriv
indeepmd/pt/model/model/transform_output.py
does not include the newinference
parameter. Please update the function call to match the new signature.
deepmd/pt/model/model/transform_output.py
:dr, dc = take_deriv( vv, model_ret[kk_redu], vdef, coord_ext, do_virial=vdef.c_differentiable, inference=False # Add this line )Analysis chain
Verify the function usage in the codebase.
The code changes are approved.
However, ensure that all function calls to
take_deriv
match the new signature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `take_deriv` match the new signature. # Test: Search for the function usage. Expect: Only occurances of the new signature. rg --type python -A 5 $'take_deriv'Length of output: 65
Script:
#!/bin/bash # Description: Verify all function calls to `take_deriv` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg --type py -A 5 $'take_deriv'Length of output: 896
167-168
: Update function calls to match new signatureThe function
fit_output_to_model_output
has a new parameterinference
set toFalse
by default. However, the following function calls do not include this new parameter and need to be updated:
deepmd/pt/model/model/make_model.py
deepmd/dpmodel/model/make_model.py
Please update these function calls to include the
inference
parameter to match the new signature.Analysis chain
Verify the function usage in the codebase.
The code changes are approved.
However, ensure that all function calls to
fit_output_to_model_output
match the new signature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `fit_output_to_model_output` match the new signature. # Test: Search for the function usage. Expect: Only occurances of the new signature. rg --type python -A 5 $'fit_output_to_model_output'Length of output: 81
Script:
#!/bin/bash # Description: Verify all function calls to `fit_output_to_model_output` match the new signature. # Search for the function definition to understand its context. ast-grep --lang python --pattern $'def fit_output_to_model_output($_, $_, $_, $_, $_, $_, $_, $_, $_, inference=False):\n $$$' # Search for the function usage in the codebase. rg 'fit_output_to_model_output'Length of output: 626
Script:
#!/bin/bash # Description: Verify all function calls to `fit_output_to_model_output` match the new signature. # Check the function calls in the identified files. rg 'fit_output_to_model_output' deepmd/pt/model/model/make_model.py -A 3 rg 'fit_output_to_model_output' deepmd/dpmodel/model/make_model.py -A 3Length of output: 590
75-76
: Update all calls totask_deriv_one
to match the new signature.The
task_deriv_one
function has been updated to include a newinference
parameter, but not all calls to this function reflect this change. Ensure that all calls totask_deriv_one
include theinference
parameter to maintain consistency and avoid potential issues.
deepmd/pt/model/model/transform_output.py
: Update the call totask_deriv_one
to include theinference
parameter.Analysis chain
Verify the function usage in the codebase.
The code changes are approved.
However, ensure that all function calls to
task_deriv_one
match the new signature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `task_deriv_one` match the new signature. # Test: Search for the function usage. Expect: Only occurances of the new signature. rg --type python -A 5 $'task_deriv_one'Length of output: 69
Script:
#!/bin/bash # Description: Verify all function calls to `task_deriv_one` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg 'task_deriv_one' --glob '*.py' -A 5Length of output: 895
Script:
#!/bin/bash # Description: Verify all function calls and definitions of `task_deriv_one` match the new signature. # Test: Search for the function definition and usage. Expect: Only occurrences of the new signature. rg 'def task_deriv_one' --glob '*.py' -A 5 rg 'task_deriv_one(' --glob '*.py' -A 5Length of output: 594
Script:
#!/bin/bash # Description: Verify all function calls to `task_deriv_one` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg 'task_deriv_one\(' --glob '*.py' -A 5Length of output: 899
deepmd/pt/model/model/make_model.py (2)
243-244
: Update the docstring for clarity.The docstring should be updated to match the new parameter name and its purpose.
- inference - Whether only perform inference rather than undergoing training. + inference + Whether to perform inference only, rather than undergoing training.
218-218
: Ensure all calls toforward_common_lower
include the newinference
parameter.The following function calls to
forward_common_lower
need to be updated to match the new signature:
source/tests/pt/model/test_ener_spin_model.py
source/tests/pt/model/test_dp_model.py
deepmd/pt/model/model/spin_model.py
deepmd/pt/model/model/polar_model.py
deepmd/pt/model/model/make_model.py
deepmd/pt/model/model/ener_model.py
deepmd/pt/model/model/dp_zbl_model.py
deepmd/pt/model/model/dos_model.py
deepmd/pt/model/model/dipole_model.py
Analysis chain
LGTM! But verify the function usage in the codebase.
The addition of the
inference
parameter is approved.However, ensure that all function calls to
forward_common_lower
match the new signature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `forward_common_lower` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg --type python -A 5 $'forward_common_lower'Length of output: 75
Script:
#!/bin/bash # Description: Verify all function calls to `forward_common_lower` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg -A 5 'forward_common_lower'Length of output: 8451
deepmd/pt/model/model/spin_model.py (2)
608-608
: Verify and updateforward_lower
function calls to match the new signature.The addition of the
inference
parameter toforward_lower
is approved. However, some function calls do not include this parameter. Please update the following instances to ensure they match the new signature:
- source/tests/universal/common/cases/model/utils.py
- Line:
ret_lower.append(module.forward_lower(**input_dict_lower))
- source/tests/pt/model/test_dp_model.py
- Line:
ret2 = md0.forward_lower(*args, do_atomic_virial=True)
- Line:
ret0 = md0.forward_lower(*args)
- Line:
ret1 = md1.forward_lower(*args)
- Line:
ret0 = md0.forward_lower(*args, do_atomic_virial=True)
- Line:
ret1 = md1.forward_lower(*args, do_atomic_virial=True)
- source/api_cc/src/DeepPotPT.cc
- Line:
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, firstneigh_tensor, optional_tensor, fparam_tensor, aparam_tensor, do_atom_virial_tensor, comm_dict)
- Line:
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, firstneigh_tensor, optional_tensor, fparam_tensor, aparam_tensor, do_atom_virial_tensor)
Please ensure that all calls to
forward_lower
include theinference
parameter set toTrue
.Analysis chain
LGTM! But verify the function usage in the codebase.
The addition of the
inference
parameter is approved.However, ensure that all function calls to
forward_lower
match the new signature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `forward_lower` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg --type python -A 5 $'forward_lower'Length of output: 68
Script:
#!/bin/bash # Description: Verify all function calls to `forward_lower` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg -A 5 'forward_lower'Length of output: 9418
462-462
: Ensure all calls toforward_common_lower
include theinference
parameter.The
inference
parameter is not found in the function calls toforward_common_lower
in the following files:
source/tests/pt/model/test_ener_spin_model.py
source/tests/pt/model/test_dp_model.py
deepmd/pt/model/model/spin_model.py
deepmd/pt/model/model/polar_model.py
deepmd/pt/model/model/make_model.py
deepmd/pt/model/model/ener_model.py
deepmd/pt/model/model/dp_zbl_model.py
deepmd/pt/model/model/dos_model.py
deepmd/pt/model/model/dipole_model.py
Please update these calls to match the new function signature.
Analysis chain
LGTM! But verify the function usage in the codebase.
The addition of the
inference
parameter is approved.However, ensure that all function calls to
forward_common_lower
match the new signature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `forward_common_lower` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg --type python -A 5 $'forward_common_lower'Length of output: 75
Script:
#!/bin/bash # Description: Verify all function calls to `forward_common_lower` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg --type py -A 5 $'forward_common_lower'Length of output: 8461
Script:
#!/bin/bash # Description: Check if the `inference` parameter is included in the function calls to `forward_common_lower`. # Test: Search for the function calls and check for the `inference` parameter. rg --type py 'forward_common_lower' -A 5 | grep -B 5 'inference'Length of output: 62
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4006 +/- ##
=======================================
Coverage 82.84% 82.85%
=======================================
Files 522 522
Lines 50920 50922 +2
Branches 3015 3015
=======================================
+ Hits 42186 42189 +3
Misses 7796 7796
+ Partials 938 937 -1 ☔ View full report in Codecov by Sentry. |
- Remove atomic virial graph. - Remove force graph during inference. After this, the lammps memory saves **50% for dpa1** (attn_layer=0) and **80% for dpa2** (layer=12). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new `inference` parameter to key model functions, enhancing flexibility for inference scenarios during model execution. - Added functionality to output a mapping array to a CSV file, improving data handling capabilities. - **Bug Fixes** - Improved the behavior of the model during inference versus training, potentially impacting downstream processing based on the output. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
After this, the lammps memory saves 50% for dpa1 (attn_layer=0) and 80% for dpa2 (layer=12).
Summary by CodeRabbit
New Features
inference
parameter to key model functions, enhancing flexibility for inference scenarios during model execution.Bug Fixes