-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax/array-api): dipole/polarizability fitting #4278
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
📝 WalkthroughWalkthroughThe pull request introduces significant modifications to the Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (6)
deepmd/jax/fitting/fitting.py (1)
70-82
: LGTM with a suggestion: Consider extracting polar-specific attribute handling
The implementation is correct and properly handles polar-specific attributes. However, to improve maintainability, consider extracting the polar-specific attribute handling into a helper function.
Consider refactoring like this:
+def setattr_for_polar_fitting(name: str, value: Any) -> Any:
+ if name in {"scale", "constant_matrix"}:
+ value = to_jax_array(value)
+ if value is not None:
+ value = ArrayAPIVariable(value)
+ return value
@BaseFitting.register("polar")
@flax_module
class PolarFittingNet(PolarFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
- if name in {
- "scale",
- "constant_matrix",
- }:
- value = to_jax_array(value)
- if value is not None:
- value = ArrayAPIVariable(value)
+ value = setattr_for_polar_fitting(name, value)
return super().__setattr__(name, value)
This change would:
- Make the code more maintainable by isolating polar-specific logic
- Make it easier to test the attribute handling separately
- Follow the same pattern as
setattr_for_general_fitting
source/tests/consistent/fitting/test_dipole.py (2)
167-175
: Consider adding type hints for return value.
The JAX evaluation implementation is clean and consistent with other backends.
Consider adding a return type hint for better type safety:
- def eval_jax(self, jax_obj: Any) -> Any:
+ def eval_jax(self, jax_obj: Any) -> np.ndarray:
177-185
: Consider adding type hints for return value.
The Array API Strict evaluation implementation is clean and consistent with other backends.
Consider adding a return type hint for better type safety:
- def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
+ def eval_array_api_strict(self, array_api_strict_obj: Any) -> np.ndarray:
doc/model/train-fitting-tensor.md (1)
Line range hint 1-999
: Document needs JAX-specific sections.
The document should be updated to include JAX-specific information in several sections:
- Add a JAX tab in the examples section showing paths to JAX input files
- Add JAX-specific configuration examples in the "The fitting Network" section
- Add JAX command in the "Train the Model" section
Example addition for the "Train the Model" section:
:::::{tab-set}
:::{tab-item} TensorFlow {{ tensorflow_icon }}
```bash
dp train input.json
:::
:::{tab-item} PyTorch {{ pytorch_icon }}
dp --pt train input.json
:::
+:::{tab-item} JAX {{ jax_icon }}
+
+bash +dp --jax train input.json +
+
+:::
+
::::
</details>
<details>
<summary>source/tests/consistent/fitting/test_polar.py (2)</summary>
`89-90`: **Consider shortening variable name `array_api_strict_class` for readability.**
The variable `array_api_strict_class` is descriptive but somewhat long. For improved readability, consider renaming it to align with naming conventions.
For example:
```diff
jax_class = PolarFittingJAX
-array_api_strict_class = PolarFittingArrayAPIStrict
+api_strict_class = PolarFittingArrayAPIStrict
This makes the variable name shorter while still conveying its purpose.
92-93
: Add docstrings for new properties skip_jax
and skip_array_api_strict
.
The properties skip_jax
and skip_array_api_strict
are introduced without docstrings. Adding brief docstrings would enhance code readability and maintainability.
For example:
skip_jax = not INSTALLED_JAX
+"""Indicates whether to skip JAX tests based on installation."""
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
+"""Indicates whether to skip Array API Strict tests based on installation."""
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (8)
- deepmd/dpmodel/fitting/dipole_fitting.py (2 hunks)
- deepmd/dpmodel/fitting/polarizability_fitting.py (6 hunks)
- deepmd/jax/fitting/init.py (1 hunks)
- deepmd/jax/fitting/fitting.py (2 hunks)
- doc/model/train-fitting-tensor.md (1 hunks)
- source/tests/array_api_strict/fitting/fitting.py (2 hunks)
- source/tests/consistent/fitting/test_dipole.py (4 hunks)
- source/tests/consistent/fitting/test_polar.py (4 hunks)
🔇 Additional comments (18)
deepmd/jax/fitting/__init__.py (1)
3-6
: LGTM! Clean and well-structured changes.
The additions follow the module's existing patterns and maintain proper organization.
Let's verify the existence of the newly imported classes:
Also applies to: 12-13
✅ Verification successful
Classes are properly implemented and exposed
The verification confirms that both DipoleFittingNet
and PolarFittingNet
classes are:
- Properly defined in
deepmd/jax/fitting/fitting.py
- Correctly registered with
@BaseFitting.register
decorators - Inherit from their respective DP model counterparts
- Follow the same pattern as existing fitting networks
- Properly exposed in the module's
__init__.py
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that the new fitting network classes exist in the source module
# Test: Check if the new classes are defined in the fitting module
ast-grep --pattern 'class DipoleFittingNet'
ast-grep --pattern 'class PolarFittingNet'
Length of output: 90
Script:
#!/bin/bash
# Let's try a broader search to locate these classes in the codebase
# Search for files containing DipoleFittingNet
echo "=== Files containing DipoleFittingNet ==="
rg "DipoleFittingNet" -l
# Search for files containing PolarFittingNet
echo -e "\n=== Files containing PolarFittingNet ==="
rg "PolarFittingNet" -l
# Get content of any Python files in the fitting directory
echo -e "\n=== Content of Python files in fitting directory ==="
fd "\.py$" "deepmd/jax/fitting" --exec cat {}
Length of output: 4356
source/tests/array_api_strict/fitting/fitting.py (2)
6-11
: LGTM: Import statements are well-organized
The new imports for dipole and polarizability fitting components follow the established pattern and are correctly placed.
52-55
: LGTM: DipoleFittingNet implementation is consistent
The implementation follows the established pattern of using setattr_for_general_fitting
for array backend compatibility, maintaining consistency with other fitting networks.
deepmd/jax/fitting/fitting.py (2)
6-6
: LGTM: Clean import additions
The new imports for dipole and polar fitting base classes are well-organized and properly aliased.
Also applies to: 9-11
62-67
: LGTM: Well-structured DipoleFittingNet implementation
The implementation follows the established pattern, properly integrating with the fitting framework through registration and maintaining consistency with other fitting networks.
source/tests/consistent/fitting/test_dipole.py (2)
15-16
: LGTM! Well-structured backend integration.
The new imports and class properties for JAX and Array API Strict backends are cleanly integrated, following the existing patterns for conditional imports and skip flags.
Also applies to: 37-51, 89-93
92-93
: Verify test coverage across all backends.
Let's ensure all test cases are being executed for each backend.
✅ Verification successful
Let me gather more information about the test coverage and backend handling.
Let me check the actual test implementations to understand how different backends are being tested.
Based on the comprehensive test file content, I can now provide the final response:
Test coverage across backends is properly implemented
The test implementation shows a robust framework for testing all backends:
- Tests are parameterized with different configurations (resnet_dt, precision, mixed_types)
- Each backend (TF, PT, JAX, Array API Strict) has dedicated evaluation methods
- Skip conditions are properly handled for optional backends (JAX, Array API Strict)
- Common test infrastructure ensures consistent testing across all available backends
- Appropriate tolerance levels are set based on precision
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check test coverage for dipole fitting across backends
# Test: Look for test execution patterns across different backends
echo "Checking test patterns across backends..."
rg -A 5 "def test_" source/tests/consistent/fitting/test_dipole.py
# Test: Verify skip conditions are properly handled
echo "Checking skip conditions..."
rg "skip_.*=.*not.*INSTALLED" source/tests/consistent/fitting/test_dipole.py
Length of output: 397
Script:
#!/bin/bash
# Check for test decorators and backend-specific test implementations
echo "Checking test decorators and backend handling..."
rg -B 2 -A 10 "@pytest" source/tests/consistent/fitting/test_dipole.py
# Look for backend-specific imports and configurations
echo "Checking backend-related imports and configurations..."
rg "INSTALLED_|import|backend" source/tests/consistent/fitting/test_dipole.py
# Check for any parametrized tests that might handle different backends
echo "Checking for parametrized tests..."
rg -B 2 -A 5 "@pytest.mark.parametrize" source/tests/consistent/fitting/test_dipole.py
Length of output: 1584
Script:
#!/bin/bash
# Get the full content of the test file to understand the test structure
echo "Checking full test implementation..."
cat source/tests/consistent/fitting/test_dipole.py
# Specifically look for test methods and their skip decorators
echo "Checking test methods and skip decorators..."
rg -B 1 -A 3 "(\s+def test_|@pytest.mark.skip)" source/tests/consistent/fitting/test_dipole.py
Length of output: 6222
doc/model/train-fitting-tensor.md (1)
1-1
: LGTM: JAX backend added to title.
The addition of the JAX icon to the title correctly reflects the new backend support.
source/tests/consistent/fitting/test_polar.py (3)
167-176
: Ensure proper conversion of inputs in eval_jax
method.
The eval_jax
method correctly converts inputs to JAX arrays using jnp.asarray
. This ensures compatibility with JAX operations.
177-186
: Ensure proper conversion of inputs in eval_array_api_strict
method.
The eval_array_api_strict
method uses array_api_strict.asarray
to convert inputs, which is appropriate for Array API Strict compatibility.
47-49
:
Verify relative import path for PolarFittingArrayAPIStrict
.
Please check if the relative import path is correct in the following import statement:
from ...array_api_strict.fitting.fitting import (
PolarFittingNet as PolarFittingArrayAPIStrict,
)
Using '...'
in relative imports moves up three levels in the package hierarchy. Ensure that this path correctly locates PolarFittingNet
.
To confirm the validity of the import path, you can run the following script:
✅ Verification successful
The relative import path is correct
The import path from ...array_api_strict.fitting.fitting import PolarFittingNet as PolarFittingArrayAPIStrict
is verified to be correct because:
- The importing file is at
source/tests/consistent/fitting/test_polar.py
- Going up three levels (
...
) from this location leads tosource/tests/
- The target file exists at
source/tests/array_api_strict/fitting/fitting.py
- The
PolarFittingNet
class is defined in the target file
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the existence of 'PolarFittingNet' in the specified path.
# Expected: The path leads to a file defining 'PolarFittingNet'.
# Search for 'PolarFittingNet' class definition in 'fitting.py' files.
fd --type f --name 'fitting.py' --exec grep -H 'class PolarFittingNet' {}
Length of output: 326
Script:
#!/bin/bash
# Let's try a different approach to locate the file and class definition
# First, find the current file location
fd "test_polar.py$"
# Then search for fitting.py files
fd "fitting.py$"
# Finally, let's use ast-grep to find the class definition
ast-grep --pattern 'class PolarFittingNet {
$$$
}'
Length of output: 1168
Script:
#!/bin/bash
# Let's examine the content of the relevant files to trace the import path
# First, check the content of test_polar.py to understand the context
cat source/tests/consistent/fitting/test_polar.py
# Then check the target file where PolarFittingNet should be defined
cat source/tests/array_api_strict/fitting/fitting.py
Length of output: 8046
deepmd/dpmodel/fitting/dipole_fitting.py (4)
9-9
: Addition of array_api_compat
import for backend compatibility
The import of array_api_compat
enhances the code's compatibility with different array backends, facilitating backend-agnostic operations.
211-211
: Ensure consistency of input arrays when obtaining array namespace
When using xp = array_api_compat.array_namespace(descriptor, atype)
, please verify that descriptor
and atype
are from the same array backend to prevent any potential inconsistencies during computations.
219-221
: Use of xp.reshape
for backend-agnostic reshaping
Replacing np.reshape
with xp.reshape
ensures that reshaping operations are compatible with the array backend in use.
223-225
: Confirm equivalence of matrix multiplication to original np.einsum
operation
The replacement of the commented out np.einsum
with out = out @ gr
followed by reshaping simplifies the code and leverages the matrix multiplication operator. Please confirm that out
and gr
have compatible shapes and that this operation yields the same results as the original einsum expression.
deepmd/dpmodel/fitting/polarizability_fitting.py (3)
9-9
: Import array_api_compat
for array backend compatibility
The addition of import array_api_compat
ensures compatibility with various array backends, which is appropriate for enhancing flexibility.
18-20
: Import to_numpy_array
for consistent serialization
Importing to_numpy_array
from deepmd.dpmodel.common
facilitates consistent serialization of arrays, which is good practice.
131-142
:
Fix missing handling in elif isinstance(scale, float):
There is no code under the elif isinstance(scale, float):
condition. This will result in a SyntaxError
or unintended behavior. You should handle this case by converting the float scale
into a list with length equal to ntypes
.
Apply this diff to fix the missing handling:
elif isinstance(scale, float):
+ scale = [scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
Likely invalid or redundant comment.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4278 +/- ##
==========================================
- Coverage 84.37% 84.30% -0.07%
==========================================
Files 551 553 +2
Lines 51585 51844 +259
Branches 3052 3052
==========================================
+ Hits 43524 43707 +183
- Misses 7100 7177 +77
+ Partials 961 960 -1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Summary by CodeRabbit
Release Notes
New Features
DipoleFittingNet
andPolarFittingNet
classes for enhanced fitting functionality.Bug Fixes
DipoleFitting
andPolarFitting
classes.Documentation
Tests