-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(pt): fix precision #4344
fix(pt): fix precision #4344
Conversation
📝 Walkthrough📝 WalkthroughWalkthroughThe pull request introduces several modifications across multiple descriptor classes in the Changes
Possibly related PRs
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
💤 Files with no reviewable changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (3)
deepmd/pt/model/descriptor/se_r.py (2)
459-460
: LGTM! Consistent precision handling for input tensor.The explicit casting of
coord_ext
to internal precision ensures type consistency throughout the computation pipeline.Consider adding a comment explaining why this casting is necessary:
# cast the input to internal precsion coord_ext = coord_ext.to(dtype=self.prec) +# ensures consistent precision throughout the computation pipeline
deserialize
method does not handle thetrainable
field.The
trainable
field is added to serialization but is not processed during deserialization, which may lead to inconsistencies in the model's state.
- File:
deepmd/pt/model/descriptor/se_r.py
- Lines: 456-463
🔗 Analysis chain
Line range hint
583-583
: Verify deserialization handling of the new trainable field.The addition of
"trainable": True
to serialization is good. However, we should verify that this field is properly handled during deserialization.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check if the deserialize method handles the trainable field # Expected: The deserialize method should accept and process the trainable field # Search for deserialize method implementations that handle trainable field rg -A 10 "def deserialize.*DescrptSeR" # Search for any tests related to serialization/deserialization rg -l "test.*[Ss]erialize.*DescrptSeR"Length of output: 1639
Script:
#!/bin/bash # Description: Verify if the deserialize methods handle the 'trainable' field # Expected: The deserialize methods should process the 'trainable' field # Search for 'trainable' usage within deserialize methods in se_r.py files rg -A 10 "def deserialize.*DescrptSeR" | rg "trainable"Length of output: 55
Script:
#!/bin/bash # Description: Search for all occurrences of 'trainable' in the codebase # Expected: To identify where and how 'trainable' is being used or set # Search for 'trainable' in all Python files rg --type py "trainable" # Additionally, search for any mentions in documentation or comments rg --type py "#.*trainable"Length of output: 42557
deepmd/pt/model/task/fitting.py (1)
406-406
: Fix typo in comment on line 406There's a misspelling in the comment: 'precsion' should be 'precision'.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (13)
deepmd/pt/model/descriptor/dpa1.py
(2 hunks)deepmd/pt/model/descriptor/dpa2.py
(4 hunks)deepmd/pt/model/descriptor/repformers.py
(3 hunks)deepmd/pt/model/descriptor/se_a.py
(2 hunks)deepmd/pt/model/descriptor/se_atten.py
(1 hunks)deepmd/pt/model/descriptor/se_r.py
(2 hunks)deepmd/pt/model/descriptor/se_t.py
(2 hunks)deepmd/pt/model/descriptor/se_t_tebd.py
(4 hunks)deepmd/pt/model/network/mlp.py
(0 hunks)deepmd/pt/model/task/dipole.py
(1 hunks)deepmd/pt/model/task/fitting.py
(3 hunks)deepmd/pt/model/task/polarizability.py
(1 hunks)source/tests/pt/model/test_compressed_descriptor_dpa2.py
(0 hunks)
💤 Files with no reviewable changes (2)
- deepmd/pt/model/network/mlp.py
- source/tests/pt/model/test_compressed_descriptor_dpa2.py
🧰 Additional context used
🪛 Ruff
deepmd/pt/model/descriptor/se_t_tebd.py
448-448: Local variable nall
is assigned to but never used
Remove assignment to unused variable nall
(F841)
🔇 Additional comments (26)
deepmd/pt/model/task/dipole.py (2)
Line range hint 197-197
: LGTM: Class attribute declaration improves TorchScript compatibility
The explicit class-level declaration of exclude_types
with type annotation is correct for TorchScript compatibility with PyTorch 2.0.0.
189-190
: LGTM: Precision casting ensures type compatibility
The explicit casting to match gr
's precision before the batch matrix multiplication is correct and aligns with the PR's objective of fixing precision issues.
Let's verify the precision handling across the codebase:
deepmd/pt/model/task/polarizability.py (2)
Line range hint 384-384
: LGTM: Type annotation improves JIT compatibility.
The explicit type annotation for exclude_types
improves type safety and compatibility with PyTorch JIT in version 2.0.0.
241-243
: LGTM: Precision casting ensures type consistency.
The added precision casting ensures that tensor operations between out
and gr
are performed at the same precision level, preventing potential numerical instabilities or precision loss.
Let's verify the precision handling in similar files:
deepmd/pt/model/descriptor/se_r.py (1)
523-523
: LGTM! Consistent precision handling for output tensor.
Casting sw
to global precision before returning ensures consistent output precision across the framework.
deepmd/pt/model/descriptor/repformers.py (2)
25-27
: LGTM: Clean import addition
The addition of PRECISION_DICT
import is well-organized and necessary for the precision management changes.
293-294
: LGTM: Consistent precision handling
The tensor initialization now correctly uses the class's precision type, ensuring consistency across tensor operations.
deepmd/pt/model/descriptor/se_a.py (2)
340-351
: LGTM! Proper precision handling implemented.
The changes correctly handle precision by:
- Casting input coordinates to internal precision for calculations
- Casting outputs back to global precision for consistency
824-825
: LGTM! Improved code readability.
The return statement has been reformatted for better readability while maintaining the same functionality.
deepmd/pt/model/descriptor/dpa1.py (2)
681-682
: LGTM: Explicit precision casting of input tensor
The explicit casting of extended_coord
to internal precision ensures consistent tensor operations throughout the method.
698-704
: LGTM: Consistent output precision handling
The return statement correctly casts all output tensors to the global precision setting, with proper handling of the optional g2
tensor.
deepmd/pt/model/descriptor/se_t.py (2)
376-387
: LGTM: Precision handling is properly implemented
The precision casting is correctly implemented with:
- Input coordinates cast to internal precision for computations
- Output tensors consistently cast to global precision before returning
864-864
: LGTM: Consistent with precision handling pattern
The result tensor is correctly returned in internal precision, allowing the outer DescrptSeT class to handle any necessary precision casting.
deepmd/pt/model/descriptor/se_t_tebd.py (4)
444-445
: LGTM: Proper precision handling
The explicit casting of input tensor to internal precision type improves numerical stability.
461-467
: LGTM: Consistent precision handling in return values
The return values are properly cast to global precision type, and unused return values are explicitly set to None, improving code clarity.
551-552
: LGTM: Consistent precision type usage
The initialization of mean and stddev tensors now correctly uses the class's precision type instead of a hardcoded type, maintaining consistency throughout the class.
856-856
: LGTM: Consistent return value handling
The return value structure aligns with the parent class's forward method, maintaining consistency across the codebase.
deepmd/pt/model/descriptor/dpa2.py (4)
30-32
: LGTM: Clean import addition
The addition of PRECISION_DICT
import is well-placed and necessary for the precision handling changes.
274-274
: LGTM: Proper precision initialization
The precision is correctly initialized using the PRECISION_DICT mapping, ensuring consistent precision handling throughout the class.
752-754
: LGTM: Proper input precision casting
Input tensor is correctly cast to the internal precision type before computations.
820-826
: LGTM: Proper output precision casting
All output tensors are correctly cast to the global precision type, ensuring consistent precision in the output interface.
deepmd/pt/model/descriptor/se_atten.py (1)
230-231
: LGTM: Consistent precision handling
The change properly uses self.prec
for tensor initialization, ensuring consistent precision handling throughout the descriptor block. This aligns with the PR's objective of fixing precision-related issues.
deepmd/pt/model/task/fitting.py (4)
407-409
: Casting input tensors to internal precision
Inputs descriptor
, fparam
, and aparam
are properly cast to the internal precision using to(self.prec)
, ensuring computational consistency across the model.
484-489
: Correct adjustment of bias term addition in mixed types
Adding the bias term self.bias_atom_e[atype]
after computing atom_property
aligns with the intended computation flow and enhances clarity.
504-504
: Consistent application of masking in non-mixed types
Using torch.where
to apply the mask ensures that atom_property
is zeroed out for the excluded atom types, maintaining consistency in the output.
509-511
: Explicit conversion of mask to boolean before application
Converting the mask to boolean using .bool()
before applying it with torch.where
enhances code robustness and prevents potential type-related errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
deepmd/pt/model/descriptor/se_a.py (3)
121-121
: Validate theprecision
parameter to prevent potentialKeyError
Consider adding validation for the
precision
parameter to ensure it is a valid key inPRECISION_DICT
. This will prevent possibleKeyError
exceptions if an unsupported precision string is provided.
341-352
: Ensure consistent precision casting for all input tensorsCurrently, only
coord_ext
is cast toself.prec
, butatype_ext
andnlist
remain uncast. For consistency and to avoid potential precision mismatches during computations, consider castingatype_ext
andnlist
toself.prec
as well. Additionally, verify that casting outputsg1
,rot_mat
, andsw
toenv.GLOBAL_PT_FLOAT_PRECISION
aligns with the intended precision handling and does not introduce unintended precision loss.
825-826
: Cast output tensors to global precision to maintain consistencyThe
forward
method returnsresult
androt_mat
tensors in internal precision (self.prec
). To maintain consistency across the codebase and prevent potential precision mismatches in downstream operations, consider casting these output tensors toenv.GLOBAL_PT_FLOAT_PRECISION
before returning.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
deepmd/pt/model/descriptor/dpa1.py
(4 hunks)deepmd/pt/model/descriptor/se_a.py
(3 hunks)deepmd/pt/model/descriptor/se_t.py
(3 hunks)deepmd/pt/model/descriptor/se_t_tebd.py
(5 hunks)deepmd/pt/model/task/fitting.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/descriptor/dpa1.py
🧰 Additional context used
🪛 Ruff
deepmd/pt/model/descriptor/se_t_tebd.py
449-449: Local variable nall
is assigned to but never used
Remove assignment to unused variable nall
(F841)
🔇 Additional comments (11)
deepmd/pt/model/task/fitting.py (3)
509-511
: LGTM: Improved type safety in mask handling
The explicit conversion to boolean type using to(torch.bool)
improves type safety and ensures consistent masking behavior across different PyTorch versions.
484-489
: Verify numerical stability with the new operation order
The bias addition has been moved after the vacuum contribution subtraction. While this change maintains precision consistency, please verify that it doesn't affect numerical stability, especially for edge cases with large bias values or near-zero atom properties.
Let's check for potential numerical stability issues in test cases:
✅ Verification successful
Numerical stability verified with the current test suite
All relevant tests using assert_allclose
and torch.testing.assert_close
pass, ensuring that the new operation order does not impact numerical stability.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for relevant test cases and their assertions
rg -U "test.*mixed.*type.*" --type py -A 10
rg "assert.*close" --type py -A 2
Length of output: 113756
406-409
: LGTM: Proper precision casting for input tensors
The explicit precision casting ensures consistent numerical precision across tensor operations.
Let's verify that all tensor operations in the codebase use consistent precision:
✅ Verification successful
LGTM: Precision casting changes ensure consistent tensor operations
The added .to(self.prec)
casting ensures consistent numerical precision across tensor operations.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for tensor operations that might have inconsistent precision
rg -U "\.to\((torch\.)?(float|double|half)" --type py
Length of output: 113756
Script:
#!/bin/bash
# Search for all tensor casting operations in the main codebase to verify precision consistency, excluding test directories
rg -U "\.to\(" --type py --ignore-dir tests
Length of output: 488
Script:
#!/bin/bash
# Search for all tensor casting operations in the main codebase to verify precision consistency, excluding test directories
rg -U "\.to\(" --type py -g '!tests/*'
Length of output: 45923
Script:
#!/bin/bash
# Search for all tensor casting operations in the main codebase to verify precision consistency, excluding all test directories
rg -U "\.to\(" --type py -g '!**/tests/**'
Length of output: 11515
deepmd/pt/model/descriptor/se_t.py (4)
157-157
: LGTM: Proper precision initialization
The addition of self.prec
ensures consistent precision handling throughout the class by mapping the precision string to the corresponding PyTorch dtype.
377-388
: LGTM: Consistent precision handling in forward pass
The changes properly handle precision by:
- Casting input coordinates to internal precision
- Casting outputs back to global precision
- Maintaining correct handling of None values in the return tuple
865-865
: LGTM: Removed redundant precision casting
The removal of explicit precision casting in DescrptBlockSeT.forward
is correct since the precision conversion is now handled in the parent class's forward method, avoiding redundant operations.
157-157
: Verify precision handling across the model
While the changes properly implement precision handling in this file, it would be good to verify that similar precision handling patterns are followed consistently across other descriptor implementations.
Also applies to: 377-388, 865-865
✅ Verification successful
Precision handling is consistent across the model
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check precision handling patterns across descriptor implementations
# Expected results: Similar precision handling patterns in other descriptor files
# Look for precision initialization patterns
echo "Checking precision initialization patterns:"
rg -A 1 "self\.prec\s*=\s*PRECISION_DICT\[precision\]" --type py
# Look for precision casting patterns in forward methods
echo "Checking precision casting patterns in forward methods:"
rg -B 5 -A 5 "\.to\(dtype=(?:self\.prec|env\.GLOBAL_PT_FLOAT_PRECISION)\)" --type py
Length of output: 10788
deepmd/pt/model/descriptor/se_t_tebd.py (4)
164-164
: LGTM: Improved precision handling
The changes enhance type safety by ensuring consistent precision handling through explicit initialization and tensor casting.
Also applies to: 445-446
462-468
: LGTM: Improved return value handling
The changes enhance clarity and type safety by:
- Explicitly returning None for unused values
- Ensuring consistent precision by casting to global precision type
552-553
: LGTM: Consistent precision in tensor initialization
The changes ensure type consistency by initializing tensors with the internal precision type.
857-857
: LGTM: Improved return value formatting
The change enhances code readability by using consistent line breaks and maintaining a uniform return structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you may check https://github.com/wanghan-iapcm/deepmd-kit/tree/fix-fit-type for an implementation for fitting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/model/task/invar_fitting.py (1)
186-186
: LGTM! Consider enhancing the JIT-related comment.The type annotation correctly addresses JIT compilation issues. Consider expanding the comment to better document the specific JIT-related problem this solves.
- # make jit happy with torch 2.0.0 + # Type annotation required for JIT compilation compatibility in PyTorch 2.0.0 + # Without this, JIT fails to infer the type of exclude_types class attribute
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (6)
deepmd/pt/model/task/fitting.py
(3 hunks)deepmd/pt/model/task/invar_fitting.py
(1 hunks)deepmd/pt/model/task/polarizability.py
(1 hunks)source/tests/pt/model/test_dipole_fitting.py
(2 hunks)source/tests/pt/model/test_polarizability_fitting.py
(1 hunks)source/tests/pt/model/test_property_fitting.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/task/polarizability.py
🔇 Additional comments (10)
deepmd/pt/model/task/invar_fitting.py (1)
180-183
: Verify the impact of return type change on dependent code.
The changes improve type safety by:
- Wrapping the output in a dictionary with explicit key
- Ensuring consistent precision via explicit conversion
However, this modifies the method's contract from returning a tensor to returning a dictionary.
✅ Verification successful
Further investigation needed to assess the impact of the return type change.
Further investigation required to comprehensively assess the impact of the return type change.
Change to return type verified: No dependent usages found.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for direct tensor usage of InvarFitting.forward output
rg -l "InvarFitting.*forward.*\(" | xargs rg "\.forward\(.*\).*\+" -A 2
Length of output: 73
Script:
#!/bin/bash
# Find all calls to InvarFitting.forward
rg "InvarFitting\.forward\(" > forward_calls.txt
# For each call, check how the returned value is used
while IFS= read -r file; do
echo "Analyzing $file"
rg "InvarFitting\.forward\(" "$file" -A 5
done < forward_calls.txt
Length of output: 52
Script:
#!/bin/bash
# Find all calls to InvarFitting.forward
rg "InvarFitting\.forward\(" > forward_calls.txt
# For each call, display the surrounding context to understand usage
while IFS= read -r file; do
echo "Analyzing calls in $file:"
rg "InvarFitting\.forward\(" "$file" -C 5
done < forward_calls.txt
Length of output: 52
source/tests/pt/model/test_dipole_fitting.py (3)
265-265
: LGTM: Improved parameter clarity in permutation test
The change from 0
to None
for fparam
and aparam
better reflects that these parameters are intentionally unused in the permutation invariance test.
306-306
: LGTM: Improved parameter clarity in translation test
The change from 0
to None
for fparam
and aparam
better reflects that these parameters are intentionally unused in the translation invariance test.
265-265
: Verify consistent parameter usage across test files
Let's verify that similar test methods in related files follow the same pattern of using None
for unused parameters.
Also applies to: 306-306
source/tests/pt/model/test_polarizability_fitting.py (1)
329-329
: LGTM! Improved type consistency for JIT compilation.
The change from fparam=0, aparam=0
to fparam=None, aparam=None
is correct and consistent with the PolarFittingNet
initialization where numb_fparam=0, numb_aparam=0
. When these numbers are 0, passing None
is more type-consistent than passing integers.
Let's verify similar patterns in other test methods:
✅ Verification successful
LGTM! Improved type consistency for JIT compilation.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for similar parameter patterns in test methods
# Expected: All test methods should use None instead of 0 for fparam/aparam when numb_fparam/numb_aparam=0
# Search for PolarFittingNet initialization with zero parameters
echo "Searching for PolarFittingNet with zero parameters..."
rg "PolarFittingNet\([^)]*numb_fparam=0[^)]*numb_aparam=0"
# Search for corresponding function calls
echo "Searching for corresponding function calls..."
rg "ft0\([^)]*fparam=[^)]*aparam=[^)]*\)"
Length of output: 1914
deepmd/pt/model/task/fitting.py (5)
406-409
: LGTM: Explicit precision casting improves type safety
The explicit casting of input tensors to internal precision helps prevent mixed precision issues and ensures consistent numerical handling throughout the network.
480-481
: LGTM: Consistent precision in output tensor initialization
The output tensor initialization now correctly uses the internal precision type, maintaining consistency with input tensors.
484-489
: LGTM: Improved precision handling in mixed types case
The bias addition now properly handles precision casting, maintaining numerical consistency throughout the computation.
503-503
: LGTM: Consistent precision handling in non-mixed types case
The bias casting to internal precision maintains consistency with the mixed types implementation.
509-511
: LGTM: Explicit boolean mask conversion prevents type promotion issues
The explicit conversion of mask to boolean type is a necessary improvement that prevents implicit type promotion issues, as previously noted in the reviews. This change enhances type safety and makes the code's behavior more predictable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/utils/env.py (1)
Line range hint
53-61
: Fix typo in variable nameRESERVED_PRECISON_DICT
The variable name has a typo:
PRECISON
should bePRECISION
.-RESERVED_PRECISON_DICT = { +RESERVED_PRECISION_DICT = {Remember to update the variable name in
__all__
as well:- "RESERVED_PRECISON_DICT", + "RESERVED_PRECISION_DICT",🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 18-18: Unused global variable
The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.deepmd/pt/model/network/mlp.py (1)
204-205
: Consider adding error handling for type conversionsThe precision handling logic is well-implemented, preserving the original tensor precision. However, consider adding error handling for potential type conversion failures.
if not DP_DTYPE_PROMOTION_STRICT: - xx = xx.to(self.prec) + try: + xx = xx.to(self.prec) + except RuntimeError as e: + raise RuntimeError(f"Failed to convert input tensor to {self.prec}: {e}")Also applies to: 220-221
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
deepmd/pt/model/network/mlp.py
(3 hunks)deepmd/pt/utils/env.py
(1 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
deepmd/pt/utils/env.py
[notice] 18-18: Unused global variable
The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.
🔇 Additional comments (4)
deepmd/pt/utils/env.py (2)
Line range hint 37-52
: LGTM! Robust precision handling implementation
The precision dictionary implementation is well-structured with:
- Comprehensive coverage of precision types
- Strong validation through assertions
- Proper bidirectional mapping
🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 18-18: Unused global variable
The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.
18-18
: Verify the usage of DP_DTYPE_PROMOTION_STRICT
The new environment variable appears to be unused in the codebase. Please ensure that:
- The variable is actually used in the type promotion logic
- Documentation is added explaining its purpose and impact
Let's verify its usage across the codebase:
✅ Verification successful
DP_DTYPE_PROMOTION_STRICT
is used in the codebase
- Found usage in:
deepmd/pt/model/network/mlp.py
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any usage of the new environment variable
rg "DP_DTYPE_PROMOTION_STRICT" --type python
# Search for potential type promotion related code that might need this flag
ast-grep --pattern 'torch.promote_types($$$)'
Length of output: 123
Script:
#!/bin/bash
# Search for any usage of the new environment variable in Python files
rg "DP_DTYPE_PROMOTION_STRICT" -g "*.py"
# Search for potential type promotion related code that might need this flag
ast-grep --pattern 'torch.promote_types($$$)'
Length of output: 403
🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 18-18: Unused global variable
The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.
deepmd/pt/model/network/mlp.py (2)
35-37
: LGTM: Clean import addition for precision control
The addition of DP_DTYPE_PROMOTION_STRICT
import is well-organized with related precision-handling imports.
204-205
: Verify performance impact of conditional type conversions
The conditional type conversions may impact performance. Let's verify the overhead.
Also applies to: 220-221
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/model/task/dipole.py (2)
Line range hint
199-199
: Document the JIT workaround more explicitlyWhile the comment indicates this is a JIT compatibility fix, it would be helpful to:
- Document the specific JIT issue being addressed
- Reference any relevant PyTorch issues/documentation
Consider updating the comment to be more descriptive:
-# make jit happy with torch 2.0.0 +# Explicitly declare class variable for TorchScript compatibility in PyTorch 2.0.0 +# See: [reference to PyTorch issue/documentation]
Based on the gathered information, I can now provide a final response:
The empty
compute_output_stats
implementation needs attentionThe empty implementation in
DipoleFittingNet
is problematic because:
- The method is required for computing output statistics (e.g., dipole bias) from packed data
- Other fitting networks in the codebase (energy, polarizability, DOS) have proper implementations
- There's a utility function
compute_output_stats
indeepmd/pt/utils/stat.py
that should be utilized🔗 Analysis chain
Line range hint
134-171
: Verify the empty implementation ofcompute_output_stats
The method is defined with comprehensive documentation but lacks implementation. This could impact the model's ability to compute and utilize output statistics.
Let's check if this is implemented in other similar files:
Would you like assistance in implementing this method based on other fitting networks' implementations?
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for compute_output_stats implementations in other fitting networks ast-grep --pattern 'def compute_output_stats($$$) { $$$ }' 'deepmd/pt/model/task/*.py'Length of output: 161
Script:
#!/bin/bash # First, let's find all Python files in the task directory fd ".py$" deepmd/pt/model/task/ # Then, let's search for compute_output_stats implementations rg "def compute_output_stats" -A 10Length of output: 6785
Script:
#!/bin/bash # Let's check the base class to understand if this is an abstract method rg "class GeneralFitting" -A 20 deepmd/pt/model/task/ # Also check if there's an implementation in the utils that this class should use rg "compute_output_stats.*dipole" -A 10 deepmd/pt/utils/stat.pyLength of output: 1459
Script:
#!/bin/bash # Let's check if compute_output_stats is defined in the base Fitting class rg "class Fitting" -A 30 deepmd/pt/model/task/ # Also check if there are any implementations of compute_output_stats in other fitting networks rg "class.*FittingNet.*:.*\n.*compute_output_stats" -A 30 deepmd/pt/model/task/Length of output: 2764
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
deepmd/pt/model/task/dipole.py
(1 hunks)deepmd/pt/model/task/polarizability.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/task/polarizability.py
🔇 Additional comments (1)
deepmd/pt/model/task/dipole.py (1)
183-184
: LGTM: Proper precision handling added
The addition of explicit precision casting for the rotation matrix gr
ensures consistent numerical precision throughout the computation pipeline.
Tried to implement the decorator as in #4343, but encountered JIT errors.
Summary by CodeRabbit
Release Notes
New Features
DipoleFittingNet
class for excluding specific types.Bug Fixes
Documentation
Tests
fparam
andaparam
across multiple test cases.