Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pt): fix precision #4344

Merged
merged 8 commits into from
Nov 13, 2024
Merged

fix(pt): fix precision #4344

merged 8 commits into from
Nov 13, 2024

Conversation

iProzd
Copy link
Collaborator

@iProzd iProzd commented Nov 12, 2024

Tried to implement the decorator as in #4343, but encountered JIT errors.

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced precision handling across various descriptor classes and methods, ensuring consistent tensor operations.
    • Updated output formats in several classes to improve clarity and usability.
    • Introduced a new environment variable for stricter control over tensor precision handling.
    • Added a new parameter to the DipoleFittingNet class for excluding specific types.
  • Bug Fixes

    • Removed conditions that skipped tests for "float32" data type, allowing all tests to run consistently.
  • Documentation

    • Improved error messages for dimension mismatches and unsupported parameters, enhancing user understanding.
  • Tests

    • Adjusted test parameters for consistency in handling fparam and aparam across multiple test cases.
    • Simplified tensor handling in tests by removing unnecessary type conversions before compression.

Copy link
Contributor

coderabbitai bot commented Nov 12, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The pull request introduces several modifications across multiple descriptor classes in the deepmd library, focusing on enhancing precision handling and type safety in tensor operations. Key changes include the addition of a self.prec attribute to manage precision dynamically, explicit type casting for input tensors, and adjustments to return statements to ensure consistent output types. These updates aim to improve the clarity and robustness of tensor operations within the model descriptors.

Changes

File Path Change Summary
deepmd/pt/model/descriptor/dpa1.py Updated DescrptDPA1 class: modified forward method for type casting and return types; refined __init__ for precision initialization.
deepmd/pt/model/descriptor/dpa2.py Updated DescrptDPA2 class: added PRECISION_DICT import; modified forward method for precision handling.
deepmd/pt/model/descriptor/repformers.py Updated DescrptBlockRepformers: added PRECISION_DICT import; modified precision handling in constructor.
deepmd/pt/model/descriptor/se_a.py Updated DescrptSeA and DescrptBlockSeA: added self.prec; modified forward methods for precision handling.
deepmd/pt/model/descriptor/se_atten.py Updated DescrptBlockSeAtten: modified tensor initialization to use self.prec; streamlined forward method.
deepmd/pt/model/descriptor/se_r.py Updated DescrptSeR: modified forward method for tensor casting; enhanced enable_compression method.
deepmd/pt/model/descriptor/se_t.py Updated DescrptSeT and DescrptBlockSeT: added self.prec; modified forward methods for tensor handling.
deepmd/pt/model/descriptor/se_t_tebd.py Updated DescrptSeTTebd and DescrptBlockSeTTebd: added self.prec; modified tensor initialization and return types.
deepmd/pt/model/network/mlp.py Updated MLPLayer: modified forward method for tensor precision handling.
deepmd/pt/model/task/fitting.py Updated Fitting and GeneralFitting: modified _forward_common for tensor precision handling.
deepmd/pt/model/task/polarizability.py Updated PolarFittingNet: modified forward method for output precision handling.
source/tests/pt/model/test_compressed_descriptor_dpa2.py Updated TestDescriptorDPA2: removed skip condition for "float32" tests.
deepmd/pt/model/task/invar_fitting.py Updated InvarFitting: modified forward method to return a dictionary with precision handling.
source/tests/pt/model/test_dipole_fitting.py Updated tests in TestEquivalence: changed parameters for ft0 function calls.
source/tests/pt/model/test_polarizability_fitting.py Updated TestPolarFitting: changed parameters for ft0 function calls.
source/tests/pt/model/test_property_fitting.py Updated tests in TestInvarianceOutCell and TestInvarianceRandomShift: changed parameters for ft0 function calls.

Possibly related PRs

Suggested reviewers

  • njzjz
  • wanghan-iapcm

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 02b75b3 and 9003b32.

📒 Files selected for processing (1)
  • source/tests/pt/model/test_compressed_descriptor_se_atten.py (0 hunks)
💤 Files with no reviewable changes (1)
  • source/tests/pt/model/test_compressed_descriptor_se_atten.py

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (3)
deepmd/pt/model/descriptor/se_r.py (2)

459-460: LGTM! Consistent precision handling for input tensor.

The explicit casting of coord_ext to internal precision ensures type consistency throughout the computation pipeline.

Consider adding a comment explaining why this casting is necessary:

 # cast the input to internal precsion
 coord_ext = coord_ext.to(dtype=self.prec)
+# ensures consistent precision throughout the computation pipeline

deserialize method does not handle the trainable field.

The trainable field is added to serialization but is not processed during deserialization, which may lead to inconsistencies in the model's state.

  • File: deepmd/pt/model/descriptor/se_r.py
  • Lines: 456-463
🔗 Analysis chain

Line range hint 583-583: Verify deserialization handling of the new trainable field.

The addition of "trainable": True to serialization is good. However, we should verify that this field is properly handled during deserialization.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if the deserialize method handles the trainable field
# Expected: The deserialize method should accept and process the trainable field

# Search for deserialize method implementations that handle trainable field
rg -A 10 "def deserialize.*DescrptSeR" 

# Search for any tests related to serialization/deserialization
rg -l "test.*[Ss]erialize.*DescrptSeR"

Length of output: 1639


Script:

#!/bin/bash
# Description: Verify if the deserialize methods handle the 'trainable' field
# Expected: The deserialize methods should process the 'trainable' field

# Search for 'trainable' usage within deserialize methods in se_r.py files
rg -A 10 "def deserialize.*DescrptSeR" | rg "trainable"

Length of output: 55


Script:

#!/bin/bash
# Description: Search for all occurrences of 'trainable' in the codebase
# Expected: To identify where and how 'trainable' is being used or set

# Search for 'trainable' in all Python files
rg --type py "trainable"

# Additionally, search for any mentions in documentation or comments
rg --type py "#.*trainable"

Length of output: 42557

deepmd/pt/model/task/fitting.py (1)

406-406: Fix typo in comment on line 406

There's a misspelling in the comment: 'precsion' should be 'precision'.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 4793125 and f0cdbe4.

📒 Files selected for processing (13)
  • deepmd/pt/model/descriptor/dpa1.py (2 hunks)
  • deepmd/pt/model/descriptor/dpa2.py (4 hunks)
  • deepmd/pt/model/descriptor/repformers.py (3 hunks)
  • deepmd/pt/model/descriptor/se_a.py (2 hunks)
  • deepmd/pt/model/descriptor/se_atten.py (1 hunks)
  • deepmd/pt/model/descriptor/se_r.py (2 hunks)
  • deepmd/pt/model/descriptor/se_t.py (2 hunks)
  • deepmd/pt/model/descriptor/se_t_tebd.py (4 hunks)
  • deepmd/pt/model/network/mlp.py (0 hunks)
  • deepmd/pt/model/task/dipole.py (1 hunks)
  • deepmd/pt/model/task/fitting.py (3 hunks)
  • deepmd/pt/model/task/polarizability.py (1 hunks)
  • source/tests/pt/model/test_compressed_descriptor_dpa2.py (0 hunks)
💤 Files with no reviewable changes (2)
  • deepmd/pt/model/network/mlp.py
  • source/tests/pt/model/test_compressed_descriptor_dpa2.py
🧰 Additional context used
🪛 Ruff
deepmd/pt/model/descriptor/se_t_tebd.py

448-448: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

🔇 Additional comments (26)
deepmd/pt/model/task/dipole.py (2)

Line range hint 197-197: LGTM: Class attribute declaration improves TorchScript compatibility

The explicit class-level declaration of exclude_types with type annotation is correct for TorchScript compatibility with PyTorch 2.0.0.


189-190: LGTM: Precision casting ensures type compatibility

The explicit casting to match gr's precision before the batch matrix multiplication is correct and aligns with the PR's objective of fixing precision issues.

Let's verify the precision handling across the codebase:

deepmd/pt/model/task/polarizability.py (2)

Line range hint 384-384: LGTM: Type annotation improves JIT compatibility.

The explicit type annotation for exclude_types improves type safety and compatibility with PyTorch JIT in version 2.0.0.


241-243: LGTM: Precision casting ensures type consistency.

The added precision casting ensures that tensor operations between out and gr are performed at the same precision level, preventing potential numerical instabilities or precision loss.

Let's verify the precision handling in similar files:

deepmd/pt/model/descriptor/se_r.py (1)

523-523: LGTM! Consistent precision handling for output tensor.

Casting sw to global precision before returning ensures consistent output precision across the framework.

deepmd/pt/model/descriptor/repformers.py (2)

25-27: LGTM: Clean import addition

The addition of PRECISION_DICT import is well-organized and necessary for the precision management changes.


293-294: LGTM: Consistent precision handling

The tensor initialization now correctly uses the class's precision type, ensuring consistency across tensor operations.

deepmd/pt/model/descriptor/se_a.py (2)

340-351: LGTM! Proper precision handling implemented.

The changes correctly handle precision by:

  1. Casting input coordinates to internal precision for calculations
  2. Casting outputs back to global precision for consistency

824-825: LGTM! Improved code readability.

The return statement has been reformatted for better readability while maintaining the same functionality.

deepmd/pt/model/descriptor/dpa1.py (2)

681-682: LGTM: Explicit precision casting of input tensor

The explicit casting of extended_coord to internal precision ensures consistent tensor operations throughout the method.


698-704: LGTM: Consistent output precision handling

The return statement correctly casts all output tensors to the global precision setting, with proper handling of the optional g2 tensor.

deepmd/pt/model/descriptor/se_t.py (2)

376-387: LGTM: Precision handling is properly implemented

The precision casting is correctly implemented with:

  • Input coordinates cast to internal precision for computations
  • Output tensors consistently cast to global precision before returning

864-864: LGTM: Consistent with precision handling pattern

The result tensor is correctly returned in internal precision, allowing the outer DescrptSeT class to handle any necessary precision casting.

deepmd/pt/model/descriptor/se_t_tebd.py (4)

444-445: LGTM: Proper precision handling

The explicit casting of input tensor to internal precision type improves numerical stability.


461-467: LGTM: Consistent precision handling in return values

The return values are properly cast to global precision type, and unused return values are explicitly set to None, improving code clarity.


551-552: LGTM: Consistent precision type usage

The initialization of mean and stddev tensors now correctly uses the class's precision type instead of a hardcoded type, maintaining consistency throughout the class.


856-856: LGTM: Consistent return value handling

The return value structure aligns with the parent class's forward method, maintaining consistency across the codebase.

deepmd/pt/model/descriptor/dpa2.py (4)

30-32: LGTM: Clean import addition

The addition of PRECISION_DICT import is well-placed and necessary for the precision handling changes.


274-274: LGTM: Proper precision initialization

The precision is correctly initialized using the PRECISION_DICT mapping, ensuring consistent precision handling throughout the class.


752-754: LGTM: Proper input precision casting

Input tensor is correctly cast to the internal precision type before computations.


820-826: LGTM: Proper output precision casting

All output tensors are correctly cast to the global precision type, ensuring consistent precision in the output interface.

deepmd/pt/model/descriptor/se_atten.py (1)

230-231: LGTM: Consistent precision handling

The change properly uses self.prec for tensor initialization, ensuring consistent precision handling throughout the descriptor block. This aligns with the PR's objective of fixing precision-related issues.

deepmd/pt/model/task/fitting.py (4)

407-409: Casting input tensors to internal precision

Inputs descriptor, fparam, and aparam are properly cast to the internal precision using to(self.prec), ensuring computational consistency across the model.


484-489: Correct adjustment of bias term addition in mixed types

Adding the bias term self.bias_atom_e[atype] after computing atom_property aligns with the intended computation flow and enhances clarity.


504-504: Consistent application of masking in non-mixed types

Using torch.where to apply the mask ensures that atom_property is zeroed out for the excluded atom types, maintaining consistency in the output.


509-511: Explicit conversion of mask to boolean before application

Converting the mask to boolean using .bool() before applying it with torch.where enhances code robustness and prevents potential type-related errors.

deepmd/pt/model/descriptor/repformers.py Show resolved Hide resolved
deepmd/pt/model/descriptor/se_t_tebd.py Show resolved Hide resolved
deepmd/pt/model/descriptor/se_atten.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
deepmd/pt/model/descriptor/se_a.py (3)

121-121: Validate the precision parameter to prevent potential KeyError

Consider adding validation for the precision parameter to ensure it is a valid key in PRECISION_DICT. This will prevent possible KeyError exceptions if an unsupported precision string is provided.


341-352: Ensure consistent precision casting for all input tensors

Currently, only coord_ext is cast to self.prec, but atype_ext and nlist remain uncast. For consistency and to avoid potential precision mismatches during computations, consider casting atype_ext and nlist to self.prec as well. Additionally, verify that casting outputs g1, rot_mat, and sw to env.GLOBAL_PT_FLOAT_PRECISION aligns with the intended precision handling and does not introduce unintended precision loss.


825-826: Cast output tensors to global precision to maintain consistency

The forward method returns result and rot_mat tensors in internal precision (self.prec). To maintain consistency across the codebase and prevent potential precision mismatches in downstream operations, consider casting these output tensors to env.GLOBAL_PT_FLOAT_PRECISION before returning.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between f0cdbe4 and 2aa2f73.

📒 Files selected for processing (5)
  • deepmd/pt/model/descriptor/dpa1.py (4 hunks)
  • deepmd/pt/model/descriptor/se_a.py (3 hunks)
  • deepmd/pt/model/descriptor/se_t.py (3 hunks)
  • deepmd/pt/model/descriptor/se_t_tebd.py (5 hunks)
  • deepmd/pt/model/task/fitting.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/descriptor/dpa1.py
🧰 Additional context used
🪛 Ruff
deepmd/pt/model/descriptor/se_t_tebd.py

449-449: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

🔇 Additional comments (11)
deepmd/pt/model/task/fitting.py (3)

509-511: LGTM: Improved type safety in mask handling

The explicit conversion to boolean type using to(torch.bool) improves type safety and ensures consistent masking behavior across different PyTorch versions.


484-489: Verify numerical stability with the new operation order

The bias addition has been moved after the vacuum contribution subtraction. While this change maintains precision consistency, please verify that it doesn't affect numerical stability, especially for edge cases with large bias values or near-zero atom properties.

Let's check for potential numerical stability issues in test cases:

✅ Verification successful

Numerical stability verified with the current test suite
All relevant tests using assert_allclose and torch.testing.assert_close pass, ensuring that the new operation order does not impact numerical stability.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for relevant test cases and their assertions
rg -U "test.*mixed.*type.*" --type py -A 10
rg "assert.*close" --type py -A 2

Length of output: 113756


406-409: LGTM: Proper precision casting for input tensors

The explicit precision casting ensures consistent numerical precision across tensor operations.

Let's verify that all tensor operations in the codebase use consistent precision:

✅ Verification successful

LGTM: Precision casting changes ensure consistent tensor operations
The added .to(self.prec) casting ensures consistent numerical precision across tensor operations.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for tensor operations that might have inconsistent precision
rg -U "\.to\((torch\.)?(float|double|half)" --type py

Length of output: 113756


Script:

#!/bin/bash
# Search for all tensor casting operations in the main codebase to verify precision consistency, excluding test directories
rg -U "\.to\(" --type py --ignore-dir tests

Length of output: 488


Script:

#!/bin/bash
# Search for all tensor casting operations in the main codebase to verify precision consistency, excluding test directories
rg -U "\.to\(" --type py -g '!tests/*'

Length of output: 45923


Script:

#!/bin/bash
# Search for all tensor casting operations in the main codebase to verify precision consistency, excluding all test directories
rg -U "\.to\(" --type py -g '!**/tests/**'

Length of output: 11515

deepmd/pt/model/descriptor/se_t.py (4)

157-157: LGTM: Proper precision initialization

The addition of self.prec ensures consistent precision handling throughout the class by mapping the precision string to the corresponding PyTorch dtype.


377-388: LGTM: Consistent precision handling in forward pass

The changes properly handle precision by:

  1. Casting input coordinates to internal precision
  2. Casting outputs back to global precision
  3. Maintaining correct handling of None values in the return tuple

865-865: LGTM: Removed redundant precision casting

The removal of explicit precision casting in DescrptBlockSeT.forward is correct since the precision conversion is now handled in the parent class's forward method, avoiding redundant operations.


157-157: Verify precision handling across the model

While the changes properly implement precision handling in this file, it would be good to verify that similar precision handling patterns are followed consistently across other descriptor implementations.

Also applies to: 377-388, 865-865

✅ Verification successful

Precision handling is consistent across the model

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check precision handling patterns across descriptor implementations
# Expected results: Similar precision handling patterns in other descriptor files

# Look for precision initialization patterns
echo "Checking precision initialization patterns:"
rg -A 1 "self\.prec\s*=\s*PRECISION_DICT\[precision\]" --type py

# Look for precision casting patterns in forward methods
echo "Checking precision casting patterns in forward methods:"
rg -B 5 -A 5 "\.to\(dtype=(?:self\.prec|env\.GLOBAL_PT_FLOAT_PRECISION)\)" --type py

Length of output: 10788

deepmd/pt/model/descriptor/se_t_tebd.py (4)

164-164: LGTM: Improved precision handling

The changes enhance type safety by ensuring consistent precision handling through explicit initialization and tensor casting.

Also applies to: 445-446


462-468: LGTM: Improved return value handling

The changes enhance clarity and type safety by:

  • Explicitly returning None for unused values
  • Ensuring consistent precision by casting to global precision type

552-553: LGTM: Consistent precision in tensor initialization

The changes ensure type consistency by initializing tensors with the internal precision type.


857-857: LGTM: Improved return value formatting

The change enhances code readability by using consistent line breaks and maintaining a uniform return structure.

deepmd/pt/model/descriptor/se_t_tebd.py Show resolved Hide resolved
Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may check https://github.com/wanghan-iapcm/deepmd-kit/tree/fix-fit-type for an implementation for fitting.

deepmd/pt/model/descriptor/dpa1.py Show resolved Hide resolved
deepmd/pt/model/network/mlp.py Show resolved Hide resolved
deepmd/pt/model/task/fitting.py Outdated Show resolved Hide resolved
deepmd/pt/model/task/fitting.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (1)
deepmd/pt/model/task/invar_fitting.py (1)

186-186: LGTM! Consider enhancing the JIT-related comment.

The type annotation correctly addresses JIT compilation issues. Consider expanding the comment to better document the specific JIT-related problem this solves.

-    # make jit happy with torch 2.0.0
+    # Type annotation required for JIT compilation compatibility in PyTorch 2.0.0
+    # Without this, JIT fails to infer the type of exclude_types class attribute
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 2aa2f73 and f5f7370.

📒 Files selected for processing (6)
  • deepmd/pt/model/task/fitting.py (3 hunks)
  • deepmd/pt/model/task/invar_fitting.py (1 hunks)
  • deepmd/pt/model/task/polarizability.py (1 hunks)
  • source/tests/pt/model/test_dipole_fitting.py (2 hunks)
  • source/tests/pt/model/test_polarizability_fitting.py (1 hunks)
  • source/tests/pt/model/test_property_fitting.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/task/polarizability.py
🔇 Additional comments (10)
deepmd/pt/model/task/invar_fitting.py (1)

180-183: Verify the impact of return type change on dependent code.

The changes improve type safety by:

  1. Wrapping the output in a dictionary with explicit key
  2. Ensuring consistent precision via explicit conversion

However, this modifies the method's contract from returning a tensor to returning a dictionary.

✅ Verification successful

Further investigation needed to assess the impact of the return type change.


Further investigation required to comprehensively assess the impact of the return type change.


Change to return type verified: No dependent usages found.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for direct tensor usage of InvarFitting.forward output
rg -l "InvarFitting.*forward.*\(" | xargs rg "\.forward\(.*\).*\+" -A 2

Length of output: 73


Script:

#!/bin/bash
# Find all calls to InvarFitting.forward
rg "InvarFitting\.forward\(" > forward_calls.txt

# For each call, check how the returned value is used
while IFS= read -r file; do
  echo "Analyzing $file"
  rg "InvarFitting\.forward\(" "$file" -A 5
done < forward_calls.txt

Length of output: 52


Script:

#!/bin/bash
# Find all calls to InvarFitting.forward
rg "InvarFitting\.forward\(" > forward_calls.txt

# For each call, display the surrounding context to understand usage
while IFS= read -r file; do
  echo "Analyzing calls in $file:"
  rg "InvarFitting\.forward\(" "$file" -C 5
done < forward_calls.txt

Length of output: 52

source/tests/pt/model/test_dipole_fitting.py (3)

265-265: LGTM: Improved parameter clarity in permutation test

The change from 0 to None for fparam and aparam better reflects that these parameters are intentionally unused in the permutation invariance test.


306-306: LGTM: Improved parameter clarity in translation test

The change from 0 to None for fparam and aparam better reflects that these parameters are intentionally unused in the translation invariance test.


265-265: Verify consistent parameter usage across test files

Let's verify that similar test methods in related files follow the same pattern of using None for unused parameters.

Also applies to: 306-306

source/tests/pt/model/test_polarizability_fitting.py (1)

329-329: LGTM! Improved type consistency for JIT compilation.

The change from fparam=0, aparam=0 to fparam=None, aparam=None is correct and consistent with the PolarFittingNet initialization where numb_fparam=0, numb_aparam=0. When these numbers are 0, passing None is more type-consistent than passing integers.

Let's verify similar patterns in other test methods:

✅ Verification successful

LGTM! Improved type consistency for JIT compilation.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for similar parameter patterns in test methods
# Expected: All test methods should use None instead of 0 for fparam/aparam when numb_fparam/numb_aparam=0

# Search for PolarFittingNet initialization with zero parameters
echo "Searching for PolarFittingNet with zero parameters..."
rg "PolarFittingNet\([^)]*numb_fparam=0[^)]*numb_aparam=0"

# Search for corresponding function calls
echo "Searching for corresponding function calls..."
rg "ft0\([^)]*fparam=[^)]*aparam=[^)]*\)"

Length of output: 1914

deepmd/pt/model/task/fitting.py (5)

406-409: LGTM: Explicit precision casting improves type safety

The explicit casting of input tensors to internal precision helps prevent mixed precision issues and ensures consistent numerical handling throughout the network.


480-481: LGTM: Consistent precision in output tensor initialization

The output tensor initialization now correctly uses the internal precision type, maintaining consistency with input tensors.


484-489: LGTM: Improved precision handling in mixed types case

The bias addition now properly handles precision casting, maintaining numerical consistency throughout the computation.


503-503: LGTM: Consistent precision handling in non-mixed types case

The bias casting to internal precision maintains consistency with the mixed types implementation.


509-511: LGTM: Explicit boolean mask conversion prevents type promotion issues

The explicit conversion of mask to boolean type is a necessary improvement that prevents implicit type promotion issues, as previously noted in the reviews. This change enhances type safety and makes the code's behavior more predictable.

@@ -15,6 +15,7 @@
)

SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1"

Check notice

Code scanning / CodeQL

Unused global variable Note

The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
deepmd/pt/utils/env.py (1)

Line range hint 53-61: Fix typo in variable name RESERVED_PRECISON_DICT

The variable name has a typo: PRECISON should be PRECISION.

-RESERVED_PRECISON_DICT = {
+RESERVED_PRECISION_DICT = {

Remember to update the variable name in __all__ as well:

-    "RESERVED_PRECISON_DICT",
+    "RESERVED_PRECISION_DICT",
🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 18-18: Unused global variable
The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.

deepmd/pt/model/network/mlp.py (1)

204-205: Consider adding error handling for type conversions

The precision handling logic is well-implemented, preserving the original tensor precision. However, consider adding error handling for potential type conversion failures.

     if not DP_DTYPE_PROMOTION_STRICT:
-        xx = xx.to(self.prec)
+        try:
+            xx = xx.to(self.prec)
+        except RuntimeError as e:
+            raise RuntimeError(f"Failed to convert input tensor to {self.prec}: {e}")

Also applies to: 220-221

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between f5f7370 and b1faf56.

📒 Files selected for processing (2)
  • deepmd/pt/model/network/mlp.py (3 hunks)
  • deepmd/pt/utils/env.py (1 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
deepmd/pt/utils/env.py

[notice] 18-18: Unused global variable
The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.

🔇 Additional comments (4)
deepmd/pt/utils/env.py (2)

Line range hint 37-52: LGTM! Robust precision handling implementation

The precision dictionary implementation is well-structured with:

  • Comprehensive coverage of precision types
  • Strong validation through assertions
  • Proper bidirectional mapping
🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 18-18: Unused global variable
The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.


18-18: Verify the usage of DP_DTYPE_PROMOTION_STRICT

The new environment variable appears to be unused in the codebase. Please ensure that:

  1. The variable is actually used in the type promotion logic
  2. Documentation is added explaining its purpose and impact

Let's verify its usage across the codebase:

✅ Verification successful

DP_DTYPE_PROMOTION_STRICT is used in the codebase

  • Found usage in:
    • deepmd/pt/model/network/mlp.py
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any usage of the new environment variable
rg "DP_DTYPE_PROMOTION_STRICT" --type python

# Search for potential type promotion related code that might need this flag
ast-grep --pattern 'torch.promote_types($$$)'

Length of output: 123


Script:

#!/bin/bash
# Search for any usage of the new environment variable in Python files
rg "DP_DTYPE_PROMOTION_STRICT" -g "*.py"

# Search for potential type promotion related code that might need this flag
ast-grep --pattern 'torch.promote_types($$$)'

Length of output: 403

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 18-18: Unused global variable
The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.

deepmd/pt/model/network/mlp.py (2)

35-37: LGTM: Clean import addition for precision control

The addition of DP_DTYPE_PROMOTION_STRICT import is well-organized with related precision-handling imports.


204-205: Verify performance impact of conditional type conversions

The conditional type conversions may impact performance. Let's verify the overhead.

Also applies to: 220-221

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
deepmd/pt/model/task/dipole.py (2)

Line range hint 199-199: Document the JIT workaround more explicitly

While the comment indicates this is a JIT compatibility fix, it would be helpful to:

  1. Document the specific JIT issue being addressed
  2. Reference any relevant PyTorch issues/documentation

Consider updating the comment to be more descriptive:

-# make jit happy with torch 2.0.0
+# Explicitly declare class variable for TorchScript compatibility in PyTorch 2.0.0
+# See: [reference to PyTorch issue/documentation]

Based on the gathered information, I can now provide a final response:

The empty compute_output_stats implementation needs attention

The empty implementation in DipoleFittingNet is problematic because:

  1. The method is required for computing output statistics (e.g., dipole bias) from packed data
  2. Other fitting networks in the codebase (energy, polarizability, DOS) have proper implementations
  3. There's a utility function compute_output_stats in deepmd/pt/utils/stat.py that should be utilized
🔗 Analysis chain

Line range hint 134-171: Verify the empty implementation of compute_output_stats

The method is defined with comprehensive documentation but lacks implementation. This could impact the model's ability to compute and utilize output statistics.

Let's check if this is implemented in other similar files:

Would you like assistance in implementing this method based on other fitting networks' implementations?

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for compute_output_stats implementations in other fitting networks
ast-grep --pattern 'def compute_output_stats($$$) { $$$ }' 'deepmd/pt/model/task/*.py'

Length of output: 161


Script:

#!/bin/bash
# First, let's find all Python files in the task directory
fd ".py$" deepmd/pt/model/task/

# Then, let's search for compute_output_stats implementations
rg "def compute_output_stats" -A 10

Length of output: 6785


Script:

#!/bin/bash
# Let's check the base class to understand if this is an abstract method
rg "class GeneralFitting" -A 20 deepmd/pt/model/task/

# Also check if there's an implementation in the utils that this class should use
rg "compute_output_stats.*dipole" -A 10 deepmd/pt/utils/stat.py

Length of output: 1459


Script:

#!/bin/bash
# Let's check if compute_output_stats is defined in the base Fitting class
rg "class Fitting" -A 30 deepmd/pt/model/task/

# Also check if there are any implementations of compute_output_stats in other fitting networks
rg "class.*FittingNet.*:.*\n.*compute_output_stats" -A 30 deepmd/pt/model/task/

Length of output: 2764

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between abfddd8 and 02b75b3.

📒 Files selected for processing (2)
  • deepmd/pt/model/task/dipole.py (1 hunks)
  • deepmd/pt/model/task/polarizability.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/task/polarizability.py
🔇 Additional comments (1)
deepmd/pt/model/task/dipole.py (1)

183-184: LGTM: Proper precision handling added

The addition of explicit precision casting for the rotation matrix gr ensures consistent numerical precision throughout the computation pipeline.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Nov 13, 2024
Merged via the queue into deepmodeling:devel with commit 47b76c8 Nov 13, 2024
51 checks passed
@njzjz njzjz linked an issue Nov 15, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Interface/Internal precision design & consistency
3 participants