Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dpmodel/jax): fix fparam and aparam support in DeepEval #4285

Merged
merged 2 commits into from
Oct 31, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Oct 30, 2024

Summary by CodeRabbit

  • New Features

    • Enhanced error messages for improved clarity when input dimensions are incorrect.
    • Added support for optional fitting and atomic parameters in model evaluations.
  • Bug Fixes

    • Removed restrictions on providing fitting and atomic parameters, allowing for more flexible evaluations.
  • Tests

    • Introduced a new test class to validate the handling of fitting and atomic parameters in model evaluations.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy link
Contributor

coderabbitai bot commented Oct 30, 2024

📝 Walkthrough

Walkthrough

The pull request introduces modifications across several files related to the DeepEval and GeneralFitting classes. Key changes include enhancements to error messaging in GeneralFitting, the addition of optional parameters fparam and aparam in the eval and _eval_model methods of DeepEval, and updates to the deserialize_to_file function in the serialization module to simplify symbolic shape handling. A new test class is also added to validate the functionality of the model with these parameters.

Changes

File Path Change Summary
deepmd/dpmodel/fitting/general_fitting.py Updated error messages in _call_common method to use f-string formatting for better clarity.
deepmd/dpmodel/infer/deep_eval.py Added optional parameters fparam and aparam to eval and _eval_model methods; updated method signatures.
deepmd/jax/infer/deep_eval.py Similar updates as in deepmd/dpmodel/infer/deep_eval.py, adding fparam and aparam to methods.
deepmd/jax/utils/serialization.py Simplified unpacking in deserialize_to_file function; updated handling of model dimensions.
source/tests/consistent/io/test_io.py Introduced TestDeepPotFparamAparam class for testing new fitting and atomic parameters; updated test logic.

Possibly related PRs

  • feat pt : Support property fitting #3867: This PR introduces support for property fitting, which is relevant to the changes in the GeneralFitting class in the main PR, as both involve modifications to fitting parameters and handling of fparam and aparam.
  • feat(jax/array-api): energy fitting #4204: This PR modifies the GeneralFitting class, which is directly related to the changes made in the main PR regarding the _call_common method and its handling of input parameters, enhancing the overall fitting functionality.

Suggested labels

Examples

Suggested reviewers

  • wanghan-iapcm
  • anyangml

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (2)
deepmd/dpmodel/infer/deep_eval.py (1)

326-333: Consider adding dimension validation.

While the reshaping logic is correct, consider adding validation to ensure the input dimensions match the model's expectations:

  • fparam.shape[-1] == self.get_dim_fparam()
  • aparam.shape[-1] == self.get_dim_aparam()

Example implementation:

 if fparam is not None:
+    if fparam.shape[-1] != self.get_dim_fparam():
+        raise ValueError(f"fparam dimension {fparam.shape[-1]} does not match model's dim_fparam {self.get_dim_fparam()}")
     fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
 else:
     fparam_input = None
 if aparam is not None:
+    if aparam.shape[-1] != self.get_dim_aparam():
+        raise ValueError(f"aparam dimension {aparam.shape[-1]} does not match model's dim_aparam {self.get_dim_aparam()}")
     aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
 else:
     aparam_input = None
deepmd/jax/infer/deep_eval.py (1)

336-343: Add dimension validation for fparam and aparam

While the reshaping logic is correct, consider adding validation for input dimensions to provide clearer error messages. This would help users identify dimension mismatches earlier.

Add validation before reshaping:

         if fparam is not None:
+            if fparam.size % self.get_dim_fparam() != 0:
+                raise ValueError(f"fparam size {fparam.size} is not divisible by dim_fparam {self.get_dim_fparam()}")
             fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
         else:
             fparam_input = None
         if aparam is not None:
+            expected_size = nframes * natoms * self.get_dim_aparam()
+            if aparam.size != expected_size:
+                raise ValueError(f"aparam size {aparam.size} does not match expected size {expected_size}")
             aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
         else:
             aparam_input = None
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between d165fee and 78e9980.

📒 Files selected for processing (5)
  • deepmd/dpmodel/fitting/general_fitting.py (2 hunks)
  • deepmd/dpmodel/infer/deep_eval.py (3 hunks)
  • deepmd/jax/infer/deep_eval.py (4 hunks)
  • deepmd/jax/utils/serialization.py (1 hunks)
  • source/tests/consistent/io/test_io.py (3 hunks)
🔇 Additional comments (12)
deepmd/jax/utils/serialization.py (2)

54-54: LGTM: Improved flexibility by removing hardcoded shape parameters.

The change to use only necessary shape parameters (nf, nloc, nghost) is more maintainable, as fparam and aparam dimensions are now dynamically obtained from the model.


60-65: LGTM: Proper handling of optional fparam and aparam parameters.

The implementation correctly handles the shape structures for both fparam and aparam, with appropriate null checks and dimensions.

Let's verify the consistency of fparam and aparam usage across the codebase:

✅ Verification successful

Based on the search results, I can see comprehensive evidence of consistent handling of fparam and aparam dimensions across the codebase:

  1. The get_dim_fparam() and get_dim_aparam() methods are consistently implemented across different model classes and frameworks (JAX, PyTorch, TensorFlow)

  2. The shape handling in the review matches the expected tensor shapes used throughout the codebase:

  • fparam: (nf, dim_fparam)
  • aparam: (nf, nloc, dim_aparam)
  1. The null checks (if model.get_dim_fparam() and if model.get_dim_aparam()) are consistently used before accessing these parameters

  2. The dimensions are properly propagated through model compression, testing, and API layers

LGTM: The fparam and aparam shape handling is consistent with the codebase patterns.

The implementation correctly handles optional parameters with proper null checks and maintains consistent tensor shapes across all usage points.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify consistent handling of fparam and aparam dimensions
# Expected: All dimension getters should be used consistently with shape structs

# Check for any inconsistencies in fparam dimension handling
rg -A 2 "get_dim_fparam|dim_fparam"

# Check for any inconsistencies in aparam dimension handling
rg -A 2 "get_dim_aparam|dim_aparam"

Length of output: 127297

source/tests/consistent/io/test_io.py (3)

139-140: LGTM! Well-structured test enhancements.

The changes properly handle both supported and unsupported cases for fparam and aparam:

  • Correctly extracts array dimensions from input data
  • Creates parameters only when dimensions > 0
  • Uses appropriate array shapes for both parameter types

Also applies to: 150-157, 162-163


216-254: LGTM! Well-structured test class.

The test class properly sets up a test environment for validating fparam and aparam support:

  • Clear test scope through dedicated class
  • Proper initialization of test parameters
  • Good inheritance structure

256-257: LGTM! Proper test cleanup.

The tearDown method correctly inherits and calls the parent class cleanup logic.

deepmd/dpmodel/infer/deep_eval.py (2)

217-217: LGTM: Parameter forwarding is correct.

The fparam and aparam parameters are correctly forwarded to the _eval_model method.


339-344: Verify model interface compatibility.

The implementation looks correct. Let's verify that all model implementations support the fparam and aparam parameters in their call interface.

✅ Verification successful

Based on the search results, I can now generate a final response. The code shows that all model implementations in the codebase properly support fparam and aparam parameters through their base classes and interfaces:

Model interface compatibility is verified.

The codebase shows consistent support for fparam and aparam parameters across all model implementations:

  • All atomic models inherit from BaseAtomicModel which defines the interface with fparam and aparam support
  • The parameters are properly propagated through the model hierarchy:
    • Base models define abstract methods get_dim_fparam() and get_dim_aparam()
    • Concrete implementations like DPAtomicModel, LinearEnergyAtomicModel, etc. implement these methods
    • Models handle both parameters in their __call__ and forward methods with proper type checking and shape validation
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if all model implementations support fparam and aparam parameters

# Search for model class definitions and their call methods
ast-grep --pattern 'class $_ {
  $$$
  def __call__($_, fparam=$_, aparam=$_) {
    $$$
  }
  $$$
}'

Length of output: 102


Script:

#!/bin/bash
# Let's try a different approach to find model implementations and their call signatures

# First, let's find all Python files that might contain model implementations
echo "=== Files containing 'class' and '__call__' ==="
rg -l "class.*:.*def __call__" --type py

echo -e "\n=== Model class implementations with their call signatures ==="
rg "class.*:.*\n.*\n.*def __call__" -A 5 --type py

Length of output: 545


Script:

#!/bin/bash
# Let's try a simpler approach first to find model classes
echo "=== Python files with potential model implementations ==="
rg "class.*Model" --type py -l

echo -e "\n=== Model class implementations ==="
for file in $(rg "class.*Model" --type py -l); do
    echo "=== File: $file ==="
    # Print the class definition and a few lines after it
    rg "class.*Model" "$file" -A 10
done

# Also search specifically for __call__ methods with fparam/aparam
echo -e "\n=== Call methods with fparam/aparam ==="
rg "def __call__.*(?:fparam|aparam)" --type py -A 5

Length of output: 136305


Script:

#!/bin/bash
# Let's search for __call__ methods in model implementations with a more focused approach
echo "=== Searching for __call__ methods in model files ==="
rg "def __call__" --type py deepmd/dpmodel/model/ deepmd/dpmodel/atomic_model/ -A 5

echo -e "\n=== Searching specifically for fparam/aparam usage in model files ==="
rg "fparam|aparam" --type py deepmd/dpmodel/model/ deepmd/dpmodel/atomic_model/ -C 2

Length of output: 29695

deepmd/jax/infer/deep_eval.py (3)

227-227: LGTM: Parameters correctly passed to _eval_model

The addition of fparam and aparam parameters to the _eval_model call is consistent with the method signature and documentation.


317-318: LGTM: Method signature properly updated

The addition of fparam and aparam parameters with correct type hints (Optional[np.ndarray]) aligns with the parent method signature.


352-353: LGTM: Proper JAX array conversion and parameter passing

The parameters are correctly converted to JAX arrays and passed to the model with appropriate names.

deepmd/dpmodel/fitting/general_fitting.py (2)

391-392: LGTM! Clear and informative error message.

The error message is well-formatted using f-strings and provides clear information about the dimension mismatch by showing both the actual and expected values.


412-413: LGTM! Consistent error message style.

The error message maintains consistency with the fparam error message style, using f-strings and providing clear information about the dimension mismatch.

source/tests/consistent/io/test_io.py Show resolved Hide resolved
Copy link

codecov bot commented Oct 30, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.31%. Comparing base (d165fee) to head (78e9980).
Report is 7 commits behind head on devel.

Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4285      +/-   ##
==========================================
+ Coverage   84.29%   84.31%   +0.01%     
==========================================
  Files         553      553              
  Lines       51820    51828       +8     
  Branches     3052     3052              
==========================================
+ Hits        43683    43699      +16     
+ Misses       7177     7169       -8     
  Partials      960      960              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@njzjz njzjz mentioned this pull request Oct 31, 2024
@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Oct 31, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 31, 2024
@njzjz njzjz added this pull request to the merge queue Oct 31, 2024
Merged via the queue into deepmodeling:devel with commit ff04d8b Oct 31, 2024
60 checks passed
@njzjz njzjz deleted the fix-dpmodel-jax-fparam-aparam branch October 31, 2024 20:26
github-merge-queue bot pushed a commit that referenced this pull request Nov 2, 2024
For the frozen model, store two exported functions: one enables
do_atomic_virial and the other doesn't.

This PR is in conflict with #4285 (in `serialization.py`), and the
conflict must be resolved after one is merged.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced a new parameter for enhanced atomic virial data handling in
model evaluations.
- Added support for atomic virial calculations in multiple model
evaluation methods.
- Updated export functionality to dynamically include atomic virial data
based on user input.

- **Bug Fixes**
- Improved output structures across various backends to accommodate new
atomic virial data.

- **Tests**
- Enhanced test cases to verify the new atomic virial functionalities
and ensure compatibility with existing evaluations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@coderabbitai coderabbitai bot mentioned this pull request Dec 1, 2024
9 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants