-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chore: refactor dpmodel model #4296
Chore: refactor dpmodel model #4296
Conversation
for more information, see https://pre-commit.ci
📝 Walkthrough📝 Walkthrough📝 WalkthroughWalkthroughThe pull request introduces several new atomic models to the Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (8)
deepmd/dpmodel/atomic_model/dos_atomic_model.py (1)
11-14
: Add documentation and improve type safety.The class implementation looks functionally correct, but could benefit from better documentation and type safety:
- Add a class docstring explaining the purpose and parameters
- Add type hints for better code maintainability
- Improve the assertion error message
Here's the suggested implementation:
class DPDOSAtomicModel(DPAtomicModel): + """Deep Potential Density of States (DOS) Atomic Model. + + Args: + descriptor: The descriptor network + fitting (DOSFittingNet): The fitting network for DOS calculations + type_map: Mapping of atom types + **kwargs: Additional arguments passed to parent class + """ - def __init__(self, descriptor, fitting, type_map, **kwargs): + def __init__( + self, + descriptor: Any, + fitting: DOSFittingNet, + type_map: List[str], + **kwargs + ) -> None: - assert isinstance(fitting, DOSFittingNet) + assert isinstance(fitting, DOSFittingNet), ( + f"fitting must be an instance of DOSFittingNet, got {type(fitting)}" + ) super().__init__(descriptor, fitting, type_map, **kwargs)deepmd/dpmodel/atomic_model/energy_atomic_model.py (1)
14-16
: Optimize isinstance checks for better readability.The multiple
isinstance
checks can be merged into a single call for better readability.- assert isinstance(fitting, EnergyFittingNet) or isinstance( - fitting, InvarFitting - ) + assert isinstance(fitting, (EnergyFittingNet, InvarFitting))🧰 Tools
🪛 Ruff
14-16: Multiple
isinstance
calls forfitting
, merge into a single callMerge
isinstance
calls forfitting
(SIM101)
deepmd/dpmodel/model/ener_model.py (1)
2-3
: LGTM! Good architectural improvement.The switch from
DPAtomicModel
toDPEnergyAtomicModel
represents a positive architectural change that improves separation of concerns by using a specialized model for energy calculations.This change:
- Enhances type safety through specialized model types
- Improves code maintainability by making the energy-specific requirements explicit
- Makes the codebase more modular by separating different atomic model concerns
Also applies to: 16-16
deepmd/dpmodel/atomic_model/dipole_atomic_model.py (2)
13-16
: Add type hints and improve assertion message.While the implementation is correct, consider adding type hints and a more descriptive assertion message for better maintainability and debugging.
- def __init__(self, descriptor, fitting, type_map, **kwargs): - assert isinstance(fitting, DipoleFitting) + def __init__( + self, + descriptor: "Descriptor", + fitting: DipoleFitting, + type_map: list[str], + **kwargs + ) -> None: + assert isinstance(fitting, DipoleFitting), f"Expected DipoleFitting instance but got {type(fitting)}"
18-24
: Add detailed docstring for apply_out_stat method.While the implementation is correct, consider adding a more detailed docstring explaining why dipole calculations don't require bias application. This would help future maintainers understand the design decision.
def apply_out_stat( self, ret: dict[str, np.ndarray], atype: np.ndarray, - ): - # dipole not applying bias + ) -> dict[str, np.ndarray]: + """Apply output statistics for dipole calculations. + + Unlike other atomic models, dipole calculations do not require bias application + as the dipole values are already in the correct physical units and scale. + + Args: + ret: Dictionary containing the model outputs + atype: Array of atom types + + Returns: + The unmodified input dictionary + """deepmd/dpmodel/atomic_model/polar_atomic_model.py (3)
15-17
: Add type annotations to the constructor parametersAdding type hints enhances code readability and assists in static analysis. This is especially helpful for complex codebases and when using tools like MyPy for type checking.
Apply this diff to include type annotations:
from .dp_atomic_model import ( DPAtomicModel, ) class DPPolarAtomicModel(DPAtomicModel): - def __init__(self, descriptor, fitting, type_map, **kwargs): + def __init__( + self, + descriptor: DescriptorType, + fitting: PolarFitting, + type_map: List[int], + **kwargs, + ): if not isinstance(fitting, PolarFitting): raise TypeError("fitting must be an instance of PolarFitting") super().__init__(descriptor, fitting, type_map, **kwargs)Remember to import
List
from thetyping
module and define or importDescriptorType
accordingly:from typing import List # from .descriptor_module import DescriptorType # Adjust the import as needed
20-33
: Reformat the docstring for better clarity and compliance with PEP 257The docstring for
apply_out_stat
should follow standard conventions to ensure proper rendering and readability. Specifically, the description should start on the first line after the opening quotes, and there should be a blank line before the parameters section.Apply this diff to reformat the docstring:
def apply_out_stat( self, ret: dict[str, np.ndarray], atype: np.ndarray, ): - """Apply the stat to each atomic output. - - Parameters + """ + Apply statistical modifications to atomic outputs based on provided bias keys. + + Parameters ---------- ret The returned dict by the forward_atomic method atype The atom types. nf x nloc - - """ + """
34-35
: Handle unused variableout_std
The variable
out_std
is retrieved but not used in the subsequent code. If it's not needed, consider omitting it to keep the code clean.Apply this diff to ignore the unused variable:
""" out_bias, out_std = self._fetch_out_stat(self.bias_keys) + # out_std is not used; consider removing if unnecessary
Or unpack only the needed variable:
- out_bias, out_std = self._fetch_out_stat(self.bias_keys) + out_bias, _ = self._fetch_out_stat(self.bias_keys)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (7)
deepmd/dpmodel/atomic_model/__init__.py
(2 hunks)deepmd/dpmodel/atomic_model/dipole_atomic_model.py
(1 hunks)deepmd/dpmodel/atomic_model/dos_atomic_model.py
(1 hunks)deepmd/dpmodel/atomic_model/energy_atomic_model.py
(1 hunks)deepmd/dpmodel/atomic_model/polar_atomic_model.py
(1 hunks)deepmd/dpmodel/model/dos_model.py
(1 hunks)deepmd/dpmodel/model/ener_model.py
(2 hunks)
✅ Files skipped from review due to trivial changes (1)
- deepmd/dpmodel/model/dos_model.py
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/atomic_model/energy_atomic_model.py
14-16: Multiple isinstance
calls for fitting
, merge into a single call
Merge isinstance
calls for fitting
(SIM101)
🔇 Additional comments (10)
deepmd/dpmodel/atomic_model/dos_atomic_model.py (2)
1-9
: LGTM! Well-organized imports and proper license header.
The imports are properly organized with specific imports rather than wildcards, and the license header is correctly included.
11-14
: Verify test coverage and module registration.
Since this is a new atomic model class, we should ensure:
- The class is properly registered in the module's init.py
- Adequate test coverage exists
Let's verify these aspects:
✅ Verification successful
Class is properly registered and tested
The verification shows that:
- The
DPDOSAtomicModel
is properly registered indeepmd/dpmodel/atomic_model/__init__.py
, both in imports and__all__
- There are multiple test files covering DOS functionality:
source/tests/consistent/fitting/test_dos.py
source/tests/tf/test_deepdos.py
source/tests/tf/test_fitting_dos.py
source/tests/tf/test_model_dos.py
- The model is properly used in the codebase, including in the PyTorch implementation (
deepmd/pt/model/model/dos_model.py
)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check if the class is properly registered in __init__.py
echo "Checking module registration:"
rg "DPDOSAtomicModel" "deepmd/dpmodel/atomic_model/__init__.py"
# Check for test coverage
echo -e "\nChecking test coverage:"
fd -e py -g "*test*dos*" "tests/"
Length of output: 367
Script:
#!/bin/bash
# Let's first check if tests directory exists at a different location
echo "Looking for test directories:"
fd -t d -g "*test*"
# Let's also verify the content of __init__.py to understand the registration
echo -e "\nChecking __init__.py content:"
cat "deepmd/dpmodel/atomic_model/__init__.py"
# Let's also check for any test files containing 'dos' in their name anywhere in the repo
echo -e "\nLooking for DOS-related test files:"
fd -t f -e py -g "*dos*test*"
fd -t f -e py -g "*test*dos*"
# Let's also check for usage of this model in the codebase
echo -e "\nChecking usage of DPDOSAtomicModel:"
rg "DPDOSAtomicModel" --type py
Length of output: 2749
deepmd/dpmodel/atomic_model/energy_atomic_model.py (2)
1-10
: LGTM! Well-organized imports.
The imports are properly organized, specific, and follow good practices with clear separation between external and local imports.
12-17
: LGTM! Well-implemented type-safe model class.
The class properly extends DPAtomicModel and ensures type safety for the fitting parameter while maintaining extensibility through kwargs.
🧰 Tools
🪛 Ruff
14-16: Multiple isinstance
calls for fitting
, merge into a single call
Merge isinstance
calls for fitting
(SIM101)
deepmd/dpmodel/model/ener_model.py (1)
2-3
: Verify consistent usage of DPEnergyAtomicModel.
Let's ensure this architectural change is consistently applied across the codebase.
Also applies to: 16-16
✅ Verification successful
Let me generate additional verification scripts to check the inheritance hierarchy and usage patterns more thoroughly.
This will help us verify:
- All classes inheriting from DPAtomicModel
- How DPEnergyModel is being instantiated
- Any direct usage of DPAtomicModel in energy-related contexts
Let me do one final verification to check the complete implementation of the energy-related models.
This will help us see:
- The complete implementation of DPEnergyAtomicModel
- All usages of DPEnergyModel_ to ensure proper architectural consistency
Architectural change is properly implemented and consistent
Based on the verification results:
DPEnergyAtomicModel
correctly inherits fromDPAtomicModel
and enforces proper fitting types (EnergyFittingNet
orInvarFitting
)- The model is consistently used through
DPEnergyModel_
factory in both PyTorch and core implementations - No direct instantiations of
DPAtomicModel
found in energy contexts except in tests - Clean inheritance hierarchy and proper type checking in place
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any remaining references to DPAtomicModel in energy-related contexts
# and verify proper usage of DPEnergyAtomicModel
# Search for any remaining references to DPAtomicModel in energy-related files
echo "Checking for old DPAtomicModel references in energy-related files..."
rg -l "DPAtomicModel" | rg -i "energy|ener"
# Verify proper usage of DPEnergyAtomicModel
echo "Verifying DPEnergyAtomicModel usage..."
rg -A 3 "DPEnergyAtomicModel"
Length of output: 3234
Script:
#!/bin/bash
# Check the inheritance hierarchy and implementation details
echo "Checking DPAtomicModel base class usage..."
ast-grep --pattern 'class $_ (DPAtomicModel)'
echo "Checking energy model instantiations..."
rg -B2 -A2 "DPEnergyModel_\("
echo "Checking if there are any direct DPAtomicModel instantiations in energy contexts..."
rg -B2 -A2 "DPAtomicModel\(" | rg -i "energy|ener"
Length of output: 834
Script:
#!/bin/bash
# Check the complete implementation of energy models
echo "Checking DPEnergyAtomicModel implementation..."
rg -l "class DPEnergyAtomicModel" | xargs cat
echo "Checking DPEnergyModel_ usage..."
rg -l "DPEnergyModel_" | xargs cat
Length of output: 6350
deepmd/dpmodel/atomic_model/dipole_atomic_model.py (2)
1-10
: LGTM! Well-organized imports with proper licensing.
The imports are properly organized with clear grouping and specific imports. The LGPL-3.0 license header is correctly included.
1-24
: Verify integration with DipoleFitting and module registration.
Let's verify the proper integration of this new atomic model with the codebase.
✅ Verification successful
Let me gather more information about the DipoleFitting implementation since the previous search didn't show its contents.
Based on the gathered information, I can now provide the final response:
Integration with DipoleFitting is properly implemented and registered
- DipoleFitting class is properly implemented in
deepmd/dpmodel/fitting/dipole_fitting.py
with all necessary methods and proper inheritance from GeneralFitting - DPDipoleAtomicModel is correctly registered in
__init__.py
- The implementation includes proper handling of rotational equivariance and all required parameters
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify DipoleFitting implementation and model registration
# Test 1: Check if DipoleFitting class exists and has expected methods
echo "Checking DipoleFitting implementation..."
ast-grep --pattern 'class DipoleFitting {
$$$
}'
# Test 2: Verify model is properly registered in __init__.py
echo "Checking model registration..."
rg "DPDipoleAtomicModel" "deepmd/dpmodel/atomic_model/__init__.py"
Length of output: 331
Script:
#!/bin/bash
# Let's check the actual implementation of DipoleFitting
echo "Checking DipoleFitting implementation details..."
rg -A 10 "class DipoleFitting"
# Also verify the import path is correct
echo "Verifying import path..."
fd "dipole_fitting.py" --exec cat {}
Length of output: 25527
deepmd/dpmodel/atomic_model/__init__.py (2)
Line range hint 20-44
: LGTM! Well-structured imports.
The new atomic model imports are well-organized, following the module's existing import pattern and maintaining alphabetical order.
50-56
: LGTM! Proper public API exposure.
The new atomic models are correctly exposed in the __all__
list, maintaining alphabetical ordering and consistent with the module's public API design.
Let's verify that all the new atomic models properly extend the base DPAtomicModel:
✅ Verification successful
The previous script had incorrect Python syntax in the pattern. Let me try with the correct Python class inheritance syntax:
This should help us verify:
- The actual class definitions of the atomic models
- Whether they properly import and inherit from DPAtomicModel
Let's try a different approach by searching for the model files and their content:
This should help us locate and examine the atomic model implementations to verify their inheritance structure.
Let me search for the atomic model class definitions to verify their inheritance:
Let me search for the actual class implementations to verify their inheritance structure:
Based on the code analysis, I can now provide a definitive response about the atomic models' inheritance structure:
All atomic models properly extend DPAtomicModel and implement required interfaces
The code shows that:
- All specialized atomic models (DPEnergyAtomicModel, DPDOSAtomicModel, DPDipoleAtomicModel, DPPolarAtomicModel, DPPropertyAtomicModel) correctly inherit from DPAtomicModel
- DPAtomicModel itself inherits from BaseAtomicModel and is registered with the "@register" decorator as the "standard" type
- Each specialized model properly asserts the correct fitting net type in its constructor
- The inheritance chain is maintained across all implementations (numpy, torch, jax)
- All models implement the required interfaces defined in BaseAtomicModel
The new atomic models in __all__
are correctly exposed and their implementation follows the proper inheritance pattern.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that new atomic models properly extend DPAtomicModel
# Expected: Each new model should show inheritance from DPAtomicModel
echo "Checking inheritance patterns for new atomic models..."
ast-grep --pattern 'class DP$_AtomicModel($$$) extends DPAtomicModel:
$$$'
Length of output: 197
Script:
#!/bin/bash
# Search for class definitions of atomic models
ast-grep --pattern 'class DP$_AtomicModel($$$):
$$$'
# Also search for base class imports
rg "from.*DPAtomicModel.*import" -A 1
Length of output: 96
Script:
#!/bin/bash
# Find all potential atomic model files
fd ".*atomic.*model.*\.py$" --type f
# Then let's examine their content
echo "=== Checking content of atomic model files ==="
fd ".*atomic.*model.*\.py$" --type f --exec cat {}
Length of output: 260328
Script:
#!/bin/bash
# Search for atomic model class definitions
rg "class DP.*AtomicModel\(" -A 2
# Also search for base class imports
rg "from .* import.*AtomicModel" -A 1
Length of output: 5411
Script:
#!/bin/bash
# Check the content of atomic model files
fd ".*atomic_model.*\.py$" --type f -e py -E "test_*" -x cat {}
Length of output: 135733
deepmd/dpmodel/atomic_model/polar_atomic_model.py (1)
58-58
: Verify shape compatibility before addition
When adding modified_bias
to ret[kk]
, ensure that their shapes are compatible to prevent broadcasting errors. Mismatched shapes could result in runtime exceptions.
Run the following script to check the shapes of ret[kk]
and modified_bias
across all keys:
Please double-check the shapes in the code to confirm compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (6)
deepmd/dpmodel/model/polar_model.py (2)
17-17
: Consider adding type annotation for better code clarity.The dynamically created model could benefit from type annotation:
-DPPolarModel_ = make_model(DPPolarAtomicModel) +DPPolarModel_ = make_model(DPPolarAtomicModel) # type: type[DPPolarAtomicModel]
20-30
: Add class documentation for better maintainability.Consider adding a docstring to explain:
- The purpose and use cases of the PolarModel
- Expected parameters for initialization
- Any specific behaviors or requirements
@BaseModel.register("polar") class PolarModel(DPModelCommon, DPPolarModel_): + """Implements a polar model for atomic systems. + + This model combines functionality from DPModelCommon and DPPolarAtomicModel + to handle polar properties in atomic simulations. + + Args: + *args: Variable length argument list passed to DPPolarModel_ + **kwargs: Arbitrary keyword arguments passed to DPPolarModel_ + """ model_type = "polar"deepmd/dpmodel/model/dipole_model.py (2)
18-18
: Consider adding type annotation for better code clarity.The dynamically created model could benefit from a type annotation to make its type more explicit.
-DPDipoleModel_ = make_model(DPDipoleAtomicModel) +DPDipoleModel_ = make_model(DPDipoleAtomicModel) # type: type[DPDipoleAtomicModel]
21-31
: Add docstring with parameter documentation.The class would benefit from comprehensive documentation explaining its purpose, parameters, and usage examples.
@BaseModel.register("dipole") class DipoleModel(DPModelCommon, DPDipoleModel_): model_type = "dipole" + """Dipole model implementation. + + This class combines DPModelCommon functionality with dipole-specific behavior + from DPDipoleAtomicModel. + + Args: + *args: Positional arguments passed to DPDipoleModel_ + **kwargs: Keyword arguments passed to DPDipoleModel_ + + Example: + >>> model = DipoleModel(descriptor=desc, fitting=fit, type_map=type_map) + """ + def __init__( self, *args, **kwargs, ):deepmd/pt/model/model/polar_model.py (1)
25-25
: LGTM! Consider adding docstring.The renaming to
DPPolarModel_
better reflects the model's purpose. However, adding a docstring to thePolarModel
class would improve code documentation.Add a docstring explaining the purpose and usage of the
PolarModel
class:@BaseModel.register("polar") class PolarModel(DPModelCommon, DPPolarModel_): + """Implements a model for calculating atomic and global polarizability. + + This model combines DPModelCommon functionality with polarizability-specific + calculations provided by DPPolarAtomicModel. + """ model_type = "polar"Also applies to: 29-29
deepmd/pt/model/model/dipole_model.py (1)
Line range hint
29-124
: Consider enhancing documentation and type hints.The model implementation contains complex tensor operations and conditional logic. Consider:
- Adding docstrings to methods explaining the tensor shapes and operations
- Improving type hints for tensor parameters (e.g., specify tensor dimensions)
- Adding examples in docstrings showing typical usage patterns
Example docstring format:
def forward( self, coord: torch.Tensor, # shape: [batch_size, n_atoms, 3] atype: torch.Tensor, # shape: [batch_size, n_atoms] box: Optional[torch.Tensor] = None, # shape: [batch_size, 3, 3] fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: """Calculate dipole and its derivatives. Args: coord: Atomic coordinates atype: Atomic types box: Simulation box fparam: Frame parameters aparam: Atomic parameters do_atomic_virial: If True, compute atomic virial Returns: Dictionary containing: - dipole: Atomic dipoles [batch_size, n_atoms, 3] - global_dipole: System dipole [batch_size, 3] - force: Forces if grad_r enabled [batch_size, n_atoms, 3] - virial: Virial if grad_c enabled [batch_size, 9] - atom_virial: Atomic virial if requested [batch_size, n_atoms, 9] """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
deepmd/dpmodel/model/dipole_model.py
(1 hunks)deepmd/dpmodel/model/dos_model.py
(1 hunks)deepmd/dpmodel/model/polar_model.py
(1 hunks)deepmd/pt/model/model/dipole_model.py
(2 hunks)deepmd/pt/model/model/polar_model.py
(2 hunks)
🔇 Additional comments (9)
deepmd/dpmodel/model/dos_model.py (3)
1-15
: LGTM: Clean imports organization
The imports are well-organized and properly separated by source, with clear grouping between external and relative imports.
20-22
: Verify inheritance order and model registration
The class structure looks good, but let's verify the inheritance order is correct for proper method resolution.
#!/bin/bash
# Description: Check other model implementations for consistent inheritance patterns
# Expected: Should find similar inheritance patterns in other model implementations
# Check other model class definitions
ast-grep --pattern 'class $_Model(DPModelCommon, $_) {
$$$
}'
# Check model registrations
rg "@BaseModel.register" -A 2
17-17
: Verify DOS model factory implementation
The use of factory pattern is good. Let's verify the model factory implementation handles DOS-specific requirements correctly.
✅ Verification successful
Let me gather more information about the DOS model implementation and factory pattern.
Let me check one more aspect to understand the DOS atomic model implementation.
Factory pattern correctly handles DOS model requirements
Based on the verification results:
- The
make_model
factory function is properly implemented to create models from atomic models, providing consistent interfaces for both extended and local region predictions - The DOS model implementation follows the same pattern as other models (energy, dipole, etc.) with proper registration via
@BaseModel.register("dos")
DPDOSAtomicModel
correctly enforces DOS-specific requirements by validating the fitting network type withassert isinstance(fitting, DOSFittingNet)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check make_model implementation and its usage with DOS models
# Expected: Should find make_model implementation and other DOS model usages
# Check make_model implementation
rg -A 10 "def make_model"
# Check other DOS model usages for consistency
ast-grep --pattern 'make_model($DOS)'
Length of output: 3263
Script:
#!/bin/bash
# Check DOS model implementation and factory pattern
rg -A 10 "class DPDOSModel_"
rg -A 10 "def make_model"
# Check if there are any DOS-specific requirements in the model
rg -A 5 "class DOSModel" deepmd/dpmodel/model/dos_model.py deepmd/pt/model/model/dos_model.py
Length of output: 3137
Script:
#!/bin/bash
# Check DOS atomic model implementation
rg -A 10 "class DPDOSAtomicModel"
# Check if there are any DOS-specific requirements in the atomic model
ast-grep --pattern 'class $DOS extends BaseAtomicModel {
$$$
}'
Length of output: 934
deepmd/dpmodel/model/polar_model.py (1)
1-15
: LGTM! Well-organized imports.
The imports are logically grouped and follow good practices.
deepmd/dpmodel/model/dipole_model.py (1)
4-16
: LGTM! Well-organized imports.
The imports are clearly organized, specific, and follow good practices by importing only the necessary components.
deepmd/pt/model/model/polar_model.py (2)
Line range hint 31-38
: LGTM! Constructor implementation is correct.
The constructor properly initializes both parent classes and maintains proper type hints.
25-29
: Verify dependent code for model name changes.
The change from DPDOSModel_
to DPPolarModel_
might affect dependent code. Let's verify there are no remaining references to the old model name.
✅ Verification successful
No issues found with model name changes
The search results show that DPDOSModel
references are properly contained within their respective dos_model.py
files, and the new polar_model.py
is correctly imported where needed. The model structure follows a consistent pattern across different implementations (dpmodel, pt, tf), with proper separation of concerns. There's no evidence of any lingering cross-dependencies or incorrect model name usage.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any remaining references to DPDOSModel
rg "DPDOSModel" --type python
# Search for imports or usage of the polar model
rg "from.*polar_model import|import.*polar_model" --type python
Length of output: 157
Script:
#!/bin/bash
# Search for any remaining references to DPDOSModel without file type restriction
rg "DPDOSModel"
# Search for imports or usage of the polar model without file type restriction
rg "from.*polar_model import|import.*polar_model"
# Search for any files containing 'model' in their name to understand the model structure
fd "model.*\.py$"
Length of output: 5479
deepmd/pt/model/model/dipole_model.py (2)
38-38
: LGTM! Constructor initialization is correct.
The constructor properly initializes both parent classes in the correct order.
25-25
: LGTM! Verify model behavior after renaming.
The renaming from DPDOSModel_
to DPDipoleModel_
better reflects the model's purpose. The inheritance and registration are correctly implemented.
Let's verify that this is the only instance of DPDOSModel_
that needs to be updated:
Also applies to: 29-29
✅ Verification successful
No issues found - DPDOSModel_
is correctly used in DOS-specific models
The search results show that DPDOSModel_
is only used in the DOS (Density of States) model files (dos_model.py
), which is the correct location for this class. The renaming to DPDipoleModel_
in the dipole model file is appropriate and doesn't conflict with the DOS model implementation.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any remaining references to DPDOSModel_
rg "DPDOSModel_"
Length of output: 501
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (7)
deepmd/dpmodel/atomic_model/energy_atomic_model.py (1)
14-16
: Merge multiple isinstance checks for better readability.The multiple
isinstance
checks can be simplified into a single call.- if not ( - isinstance(fitting, EnergyFittingNet) or isinstance(fitting, InvarFitting) - ): + if not isinstance(fitting, (EnergyFittingNet, InvarFitting)):🧰 Tools
🪛 Ruff
15-15: Multiple
isinstance
calls forfitting
, merge into a single callMerge
isinstance
calls forfitting
(SIM101)
deepmd/pt/model/atomic_model/energy_atomic_model.py (1)
15-22
: Good improvement in error handling!The replacement of assertions with proper type checking and error raising is a good practice, especially for a public API. The error message is clear and informative.
Consider optimizing the isinstance checks as suggested by the static analysis:
- if not ( - isinstance(fitting, EnergyFittingNet) - or isinstance(fitting, EnergyFittingNetDirect) - or isinstance(fitting, InvarFitting) - ): + if not isinstance(fitting, (EnergyFittingNet, EnergyFittingNetDirect, InvarFitting)):This change maintains the same functionality while making the code more concise and potentially more efficient.
🧰 Tools
🪛 Ruff
16-18: Multiple
isinstance
calls forfitting
, merge into a single callMerge
isinstance
calls forfitting
(SIM101)
deepmd/dpmodel/atomic_model/dipole_atomic_model.py (3)
4-6
: Add type hints to imports for better code maintainability.-from deepmd.dpmodel.fitting.dipole_fitting import ( - DipoleFitting, -) +from deepmd.dpmodel.fitting.dipole_fitting import ( + DipoleFitting, # type: DipoleFittingType +)
13-20
: Add class docstring and type hints to improve code documentation.The class lacks a docstring explaining its purpose and usage. Additionally, the
__init__
method could benefit from type hints.class DPDipoleAtomicModel(DPAtomicModel): + """Atomic model for dipole calculations. + + This model extends DPAtomicModel to handle dipole-specific calculations + without applying bias in the output statistics. + """ - def __init__(self, descriptor, fitting, type_map, **kwargs): + def __init__( + self, + descriptor: Any, + fitting: DipoleFitting, + type_map: list[str], + **kwargs: Any + ) -> None:
21-27
: Enhance method documentation with a proper docstring.The current comment is minimal. A proper docstring would better explain the method's purpose and behavior.
def apply_out_stat( self, ret: dict[str, np.ndarray], atype: np.ndarray, - ): - # dipole not applying bias + ) -> dict[str, np.ndarray]: + """Process output statistics for dipole calculations. + + For dipole calculations, no bias is applied to the output, + so the input dictionary is returned as-is. + + Args: + ret: Dictionary containing output arrays + atype: Array of atomic types + + Returns: + The unmodified input dictionary + """ return retdeepmd/pt/model/atomic_model/property_atomic_model.py (1)
16-19
: LGTM! Good improvement in error handling.The change from assertion to explicit TypeError is a good improvement that:
- Provides better error messages
- Follows Python's error handling best practices
- Cannot be disabled by the -O flag (unlike assertions)
Consider making the error message even more helpful by including the actual type received:
- "fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel" + f"fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel, got {type(fitting).__name__}"deepmd/pt/model/atomic_model/polar_atomic_model.py (1)
Line range hint
1-67
: Consider enhancing code maintainability with type hints and documentation.The implementation could benefit from:
- Adding type hints to method parameters and return values
- Documenting the mathematical operations in
apply_out_stat
- Extracting the shift_diag logic to a separate method
Here's a suggested improvement for the class structure:
class DPPolarAtomicModel(DPAtomicModel): - def __init__(self, descriptor, fitting, type_map, **kwargs): + def __init__( + self, + descriptor: torch.nn.Module, + fitting: PolarFittingNet, + type_map: list[str], + **kwargs + ) -> None: if not isinstance(fitting, PolarFittingNet): raise TypeError( "fitting must be an instance of PolarFittingNet for DPPolarAtomicModel" ) super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat( self, ret: dict[str, torch.Tensor], atype: torch.Tensor, ) -> dict[str, torch.Tensor]: """Apply the stat to each atomic output. Parameters ---------- ret : dict[str, torch.Tensor] The returned dict by the forward_atomic method. Each tensor has shape (nframes, nloc, 9) representing the 3x3 polarizability matrix. atype : torch.Tensor The atom types tensor of shape (nframes, nloc). Returns ------- dict[str, torch.Tensor] Modified tensors with applied statistics and diagonal shifts. """ out_bias, out_std = self._fetch_out_stat(self.bias_keys) if self.fitting_net.shift_diag: ret = self._apply_diagonal_shift(ret, atype, out_bias) return ret + def _apply_diagonal_shift( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + out_bias: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """Apply diagonal shift to the polarizability matrices. + + This method handles the diagonal shift operation for polarizability matrices + when shift_diag is enabled. It computes the mean of diagonal elements and + applies the shift using the fitting network's scale. + """ + nframes, nloc = atype.shape + device = out_bias[self.bias_keys[0]].device + dtype = out_bias[self.bias_keys[0]].dtype + + for kk in self.bias_keys: + ntypes = out_bias[kk].shape[0] + temp = torch.zeros(ntypes, dtype=dtype, device=device) + for i in range(ntypes): + temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3))) + modified_bias = temp[atype] + + modified_bias = ( + modified_bias.unsqueeze(-1) + * (self.fitting_net.scale.to(atype.device))[atype] + ) + + eye = torch.eye(3, dtype=dtype, device=device) + eye = eye.repeat(nframes, nloc, 1, 1) + modified_bias = modified_bias.unsqueeze(-1) * eye + + ret[kk] = ret[kk] + modified_bias + return ret
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (10)
deepmd/dpmodel/atomic_model/dipole_atomic_model.py
(1 hunks)deepmd/dpmodel/atomic_model/dos_atomic_model.py
(1 hunks)deepmd/dpmodel/atomic_model/energy_atomic_model.py
(1 hunks)deepmd/dpmodel/atomic_model/polar_atomic_model.py
(1 hunks)deepmd/dpmodel/atomic_model/property_atomic_model.py
(1 hunks)deepmd/pt/model/atomic_model/dipole_atomic_model.py
(1 hunks)deepmd/pt/model/atomic_model/dos_atomic_model.py
(1 hunks)deepmd/pt/model/atomic_model/energy_atomic_model.py
(1 hunks)deepmd/pt/model/atomic_model/polar_atomic_model.py
(1 hunks)deepmd/pt/model/atomic_model/property_atomic_model.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/dpmodel/atomic_model/dos_atomic_model.py
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/atomic_model/energy_atomic_model.py
15-15: Multiple isinstance
calls for fitting
, merge into a single call
Merge isinstance
calls for fitting
(SIM101)
deepmd/pt/model/atomic_model/energy_atomic_model.py
16-18: Multiple isinstance
calls for fitting
, merge into a single call
Merge isinstance
calls for fitting
(SIM101)
🔇 Additional comments (11)
deepmd/pt/model/atomic_model/dos_atomic_model.py (2)
13-16
: Improved error handling with explicit TypeError
The change from assertion to explicit TypeError is a good improvement as it:
- Provides better error handling that won't be disabled by Python's -O flag
- Gives a more descriptive error message to users
- Uses the appropriate exception type for type checking
Line range hint 1-17
: LGTM: Clean and well-structured implementation
The overall implementation follows good practices:
- Clear inheritance hierarchy
- Proper type checking before parent initialization
- Minimal and focused imports
- Appropriate license header
deepmd/dpmodel/atomic_model/property_atomic_model.py (1)
13-16
: LGTM! Improved error handling implementation.
The change from assertions to explicit type checking is a good improvement that:
- Provides clearer error messages to users
- Maintains strong type safety at runtime
- Follows Python best practices for type checking
deepmd/dpmodel/atomic_model/energy_atomic_model.py (1)
12-20
: LGTM! Clean and well-structured implementation.
The class is well-implemented with:
- Proper inheritance from
DPAtomicModel
- Type validation for the
fitting
parameter - Correct delegation to parent constructor
🧰 Tools
🪛 Ruff
15-15: Multiple isinstance
calls for fitting
, merge into a single call
Merge isinstance
calls for fitting
(SIM101)
deepmd/pt/model/atomic_model/dipole_atomic_model.py (1)
16-19
: LGTM! Improved error handling with explicit type checking.
The replacement of assertion with explicit type checking and TypeError
is a good improvement because it:
- Provides a more informative error message
- Follows Python's error handling best practices
- Maintains strong type safety at runtime
deepmd/pt/model/atomic_model/energy_atomic_model.py (1)
15-22
: Verify consistent error handling across atomic models
Let's ensure this error handling pattern is consistently implemented across all atomic models.
✅ Verification successful
Error handling is consistent across all atomic models
The verification shows that all atomic model classes in both deepmd/dpmodel/atomic_model/
and deepmd/pt/model/atomic_model/
consistently implement type checking for their respective fitting networks in __init__
:
- DPDipoleAtomicModel checks for DipoleFittingNet
- DPDOSAtomicModel checks for DOSFittingNet
- DPPolarAtomicModel checks for PolarFittingNet
- DPPropertyAtomicModel checks for PropertyFittingNet
- DPEnergyAtomicModel checks for EnergyFittingNet/EnergyFittingNetDirect/InvarFitting
Each model raises a TypeError with a clear message when the fitting parameter is of incorrect type.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check error handling implementation in other atomic models
# Expected: Similar TypeError raising pattern in other atomic model classes
# Search for similar type checking patterns in other atomic models
rg -l "DPAtomicModel" | xargs rg -A 5 "def __init__.*fitting"
Length of output: 6240
🧰 Tools
🪛 Ruff
16-18: Multiple isinstance
calls for fitting
, merge into a single call
Merge isinstance
calls for fitting
(SIM101)
deepmd/dpmodel/atomic_model/dipole_atomic_model.py (1)
1-27
: Implementation looks good!
The DPDipoleAtomicModel
is well-implemented with proper type checking and a clean structure. The suggested improvements are mainly for documentation and don't affect the functionality.
deepmd/pt/model/atomic_model/property_atomic_model.py (1)
16-19
: Consistent error handling approach across atomic models.
This change aligns well with similar updates in other atomic model classes (DPDipoleAtomicModel
, DPDOSAtomicModel
, DPEnergyAtomicModel
, and DPPolarAtomicModel
), establishing a consistent pattern for type validation across the codebase.
Let's verify the consistency of error handling across other atomic models:
✅ Verification successful
Type validation is consistently implemented across all atomic models
The verification confirms that all atomic model classes (DPPropertyAtomicModel
, DPPolarAtomicModel
, DPDOSAtomicModel
, DPEnergyAtomicModel
, and DPDipoleAtomicModel
) follow the same pattern for type validation:
- They all perform
isinstance
checks on thefitting
parameter - They raise
TypeError
with clear, model-specific error messages - The error messages consistently follow the format "fitting must be an instance of [ExpectedType] for [ModelName]"
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check error handling patterns in other atomic models
# Expected: Similar TypeError raising pattern in other atomic model files
# Search for type checking patterns in atomic models
rg -A 3 "isinstance.*fitting" "deepmd/pt/model/atomic_model/"
Length of output: 2279
deepmd/pt/model/atomic_model/polar_atomic_model.py (1)
16-19
: LGTM! Improved error handling in constructor.
The replacement of assertion with explicit type checking and TypeError
is a good improvement that provides clearer error messages and follows Python's best practices.
deepmd/dpmodel/atomic_model/polar_atomic_model.py (2)
15-20
: Proper type checking in the constructor
Replacing the assert
statement with an explicit type check and raising a TypeError
enhances robustness by ensuring the type validation is not bypassed in optimized modes.
46-46
: Ensure safe reshaping and diagonal computation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4296 +/- ##
==========================================
- Coverage 84.50% 84.46% -0.04%
==========================================
Files 596 603 +7
Lines 56665 56789 +124
Branches 3460 3461 +1
==========================================
+ Hits 47883 47967 +84
- Misses 7654 7696 +42
+ Partials 1128 1126 -2 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (5)
source/tests/consistent/model/test_dos.py (5)
40-49
: Add documentation for parameterized test cases.The test parameterization would benefit from documentation explaining what each test case represents and what scenarios they cover.
Consider adding a docstring explaining:
- The purpose of each parameter combination
- What
([], [[0, 1]])
and([], [1])
test cases validate
54-76
: Add documentation for the model configuration.The data dictionary contains complex configuration for the DOS model. Consider:
- Adding type hints for the dictionary structure
- Documenting the purpose and constraints of each configuration section
- Adding validation for critical parameters
Example structure:
from typing import TypedDict, List class DescriptorConfig(TypedDict): type: str sel: List[int] rcut_smth: float # ... other fields class FittingNetConfig(TypedDict): type: str numb_dos: int # ... other fields class ModelConfig(TypedDict): type_map: List[str] descriptor: DescriptorConfig fitting_net: FittingNetConfig
103-105
: Document why JAX backend is skipped.The
skip_jax
property is hardcoded toTrue
. Consider adding a TODO comment or documentation explaining if JAX support is planned for future implementation.
122-144
: Document test data initialization.The coordinates array contains hardcoded values without explanation. Consider:
- Adding comments explaining what this test structure represents
- Using named constants or helper methods to make the structure more clear
- Documenting the expected behavior with this test structure
213-219
: Use named constants for TensorFlow return values.The TensorFlow backend uses magic numbers for array indexing. Consider defining constants to make the code more maintainable:
class TFReturnIndices: ENERGY = 0 ATOM_ENERGY = 1 FORCE = 2 VIRIAL = 3 ATOM_VIRIAL = 4 # Usage in extract_ret: return ( ret[TFReturnIndices.ENERGY].ravel(), ret[TFReturnIndices.ATOM_ENERGY].ravel(), ret[TFReturnIndices.FORCE].ravel(), ret[TFReturnIndices.VIRIAL].ravel(), ret[TFReturnIndices.ATOM_VIRIAL].ravel(), )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dipole_atomic_model.py
(1 hunks)source/tests/consistent/model/test_dos.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/atomic_model/dipole_atomic_model.py
🧰 Additional context used
🪛 GitHub Check: CodeQL
source/tests/consistent/model/test_dos.py
[notice] 53-53: Unused local variable
Variable pair_exclude_types is not used.
[notice] 53-53: Unused local variable
Variable atom_exclude_types is not used.
🔇 Additional comments (1)
source/tests/consistent/model/test_dos.py (1)
114-114
: Document the PyTorch-specific initialization.
The line model.atomic_model.out_bias.uniform_()
needs explanation for why this initialization is necessary only for the PyTorch backend.
Let's verify if this initialization is consistent with other test files:
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
source/tests/consistent/model/common.py (1)
56-70
: Consider improving error handling and documentation.While the conditional logic is functional, there are several areas for improvement:
- The error message should indicate supported keys
- Consider validating that required keys exist in
ret
dictionary- Consider splitting this into separate methods for better separation of concerns
Here's a suggested improvement:
if ret_key == "energy": + required_keys = ["energy", "atom_ener", "force", "virial", "atom_virial"] + missing_keys = [k for k in required_keys if k not in ret] + if missing_keys: + raise KeyError(f"Missing required keys for energy model: {missing_keys}") ret_list = [ ret["energy"], ret["atom_ener"], ret["force"], ret["virial"], ret["atom_virial"], ] elif ret_key == "dos": + required_keys = ["dos", "atom_dos"] + missing_keys = [k for k in required_keys if k not in ret] + if missing_keys: + raise KeyError(f"Missing required keys for DOS model: {missing_keys}") ret_list = [ ret["dos"], ret["atom_dos"], ] else: - raise NotImplementedError + raise ValueError(f"Unsupported ret_key '{ret_key}'. Supported values are: 'energy', 'dos'")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/dpmodel/model/model.py
(2 hunks)source/tests/consistent/model/common.py
(2 hunks)source/tests/consistent/model/test_dos.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/dpmodel/model/model.py
- source/tests/consistent/model/test_dos.py
🔇 Additional comments (2)
source/tests/consistent/model/common.py (2)
37-39
: LGTM! Well-structured method signature change.
The addition of the ret_key
parameter with type hint and backward-compatible default value is clean and follows good practices.
Line range hint 37-71
: Verify test coverage for the new functionality.
The changes introduce new behavior that should be tested. Let's verify the test coverage:
✅ Verification successful
Test coverage for DOS functionality is adequate
The code is well-tested through:
source/tests/consistent/model/test_dos.py
explicitly tests the DOS model functionality usingbuild_tf_model
withret_key="dos"
- The test class
TestDOS
includes comprehensive test cases covering both DOS and atom_dos outputs - Additional DOS-related tests exist in
source/tests/consistent/fitting/test_dos.py
that verify the DOS fitting functionality
The new ret_key
parameter and its DOS-related functionality have proper test coverage.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for test cases covering the new ret_key parameter
# Look for test files that might test this functionality
rg -l "build_tf_model.*ret_key" source/tests/
# Look for test cases using DOS models
rg -l "test_dos" source/tests/
Length of output: 78
Script:
#!/bin/bash
# Let's try a broader search to understand the test coverage
# Look for test files that use build_tf_model
rg -l "build_tf_model" source/tests/
# Look for DOS-related tests with broader patterns
rg -l -e "dos" -e "DOS" source/tests/
# Check the content of the test file containing this code
cat source/tests/consistent/model/common.py
# Look for any test classes or functions in this file
rg "class.*Test|def test_" source/tests/consistent/model/common.py
Length of output: 4750
Script:
#!/bin/bash
# Let's check the test files that actually use this utility class and verify DOS testing
# Check test_dos.py content for DOS-related test coverage
cat source/tests/consistent/model/test_dos.py
# Check if test_dos.py uses build_tf_model with ret_key parameter
rg -A 5 "build_tf_model" source/tests/consistent/model/test_dos.py
# Check fitting DOS tests as well
cat source/tests/consistent/fitting/test_dos.py
Length of output: 13136
TODO: to add dipole, polar, property consistent UT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (3)
source/tests/consistent/model/common.py (1)
37-39
: LGTM! Consider adding docstring.The signature change is well-structured with proper type hints and a sensible default value. However, adding a docstring would help document the purpose of the
ret_key
parameter and its allowed values.def build_tf_model( self, obj, natoms, coords, atype, box, suffix, ret_key: str = "energy" ): + """Build TensorFlow model for testing. + + Args: + obj: Model object to build + natoms: Number of atoms + coords: Coordinates + atype: Atom types + box: Simulation box + suffix: Model suffix + ret_key: Type of return values to fetch. Can be "energy", "dos", "dipole", or "polar" + + Returns: + Tuple of (return_list, feed_dict) where return_list contains model outputs based on ret_key + """deepmd/dpmodel/model/model.py (2)
Line range hint
70-109
: Improve error handling and maintainabilitySeveral improvements could make this code more maintainable:
- The change from ValueError to RuntimeError for unknown fitting types needs justification, as ValueError seems more appropriate for invalid input.
- Consider using a mapping dictionary for model types to improve maintainability.
- Update the docstring to reflect that the function can return various model types, not just EnergyModel.
Here's a suggested improvement:
-def get_standard_model(data: dict) -> EnergyModel: +def get_standard_model(data: dict) -> BaseModel: """Get a model from a dictionary. Parameters ---------- data : dict The data to construct the model. + + Returns + ------- + BaseModel + One of: DipoleModel, PolarModel, DOSModel, EnergyModel, or PropertyModel + based on the fitting type. + + Raises + ------ + ValueError + If type_embedding is present or if fitting type is unknown. """ if "type_embedding" in data: raise ValueError( "In the DP backend, type_embedding is not at the model level, but within the descriptor. See type embedding documentation for details." ) data = copy.deepcopy(data) ntypes = len(data["type_map"]) descriptor, fitting, fitting_net_type = _get_standard_model_components(data, ntypes) atom_exclude_types = data.get("atom_exclude_types", []) pair_exclude_types = data.get("pair_exclude_types", []) - if fitting_net_type == "dipole": - modelcls = DipoleModel - elif fitting_net_type == "polar": - modelcls = PolarModel - elif fitting_net_type == "dos": - modelcls = DOSModel - elif fitting_net_type in ["ener", "direct_force_ener"]: - modelcls = EnergyModel - elif fitting_net_type == "property": - modelcls = PropertyModel - else: - raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") + MODEL_TYPES = { + "dipole": DipoleModel, + "polar": PolarModel, + "dos": DOSModel, + "ener": EnergyModel, + "direct_force_ener": EnergyModel, + "property": PropertyModel, + } + + modelcls = MODEL_TYPES.get(fitting_net_type) + if modelcls is None: + raise ValueError(f"Unknown fitting type: {fitting_net_type}") model = modelcls( descriptor=descriptor, fitting=fitting, type_map=data["type_map"], atom_exclude_types=atom_exclude_types, pair_exclude_types=pair_exclude_types, ) return model
89-95
: Add unit tests for new model typesThe code adds support for dipole, polar, and DOS models, but as noted in the PR objectives, unit tests are still needed to ensure consistent behavior.
Would you like me to help create unit tests for these new model types? I can:
- Generate test cases for dipole models
- Generate test cases for polar models
- Generate test cases for DOS models
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
deepmd/dpmodel/model/dipole_model.py
(1 hunks)deepmd/dpmodel/model/dos_model.py
(1 hunks)deepmd/dpmodel/model/model.py
(4 hunks)deepmd/dpmodel/model/polar_model.py
(1 hunks)source/tests/consistent/model/common.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/dpmodel/model/dipole_model.py
- deepmd/dpmodel/model/dos_model.py
- deepmd/dpmodel/model/polar_model.py
🔇 Additional comments (2)
source/tests/consistent/model/common.py (1)
56-80
: Verify test coverage for different model types.
The code handles multiple model types (energy, dos, dipole, polar). Let's verify that we have corresponding test cases.
✅ Verification successful
Test coverage for model types is comprehensive
Based on the verification results, we can confirm that all model types (energy, dos, dipole, polar) have corresponding test implementations:
- Energy models: Tested in
test_ener.py
with comprehensive test cases for energy, forces, and virial calculations - DOS models: Tested in
test_dos.py
with tests for DOS and atomic DOS calculations - Dipole models: Multiple test files including
test_dipole.py
,test_deepdipole.py
, and framework-specific tests in tf/pt directories - Polar models: Multiple test files including
test_polar.py
,test_deeppolar.py
, and framework-specific tests in tf/pt directories
The test coverage appears to be thorough across different frameworks (TensorFlow, PyTorch, JAX) and includes both high-level and low-level interface testing.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check test coverage for different model types
# Look for test files and methods testing different model types
echo "Searching for test files..."
rg -l "ModelTest" source/tests/
echo "Checking test methods for different model types..."
rg "def test.*(?i)(energy|dos|dipole|polar)" source/tests/
echo "Checking actual usage of build_tf_model with different ret_keys..."
rg "build_tf_model.*ret_key\s*=" source/tests/
Length of output: 1826
Script:
#!/bin/bash
# Let's search more specifically for test files related to each model type
echo "Checking test files for energy models..."
rg -l "test.*ener" source/tests/consistent/model/
rg -C 3 "ret_key.*=.*energy" source/tests/consistent/model/
echo -e "\nChecking test files for DOS models..."
rg -l "test.*dos" source/tests/consistent/model/
rg -C 3 "ret_key.*=.*dos" source/tests/consistent/model/
echo -e "\nChecking test files for dipole models..."
rg -l "test.*dipole" source/tests/consistent/model/
rg -C 3 "ret_key.*=.*dipole" source/tests/consistent/model/
echo -e "\nChecking test files for polar models..."
rg -l "test.*polar" source/tests/consistent/model/
rg -C 3 "ret_key.*=.*polar" source/tests/consistent/model/
Length of output: 3820
Script:
#!/bin/bash
# Let's check the actual test implementations for each model type
echo "Checking test_dos.py implementation..."
cat source/tests/consistent/model/test_dos.py
echo -e "\nChecking test_ener.py implementation..."
cat source/tests/consistent/model/test_ener.py
echo -e "\nSearching for dipole and polar test implementations..."
fd -e py -x grep -l "class.*Test.*Dipole" {} \; source/tests/
fd -e py -x grep -l "class.*Test.*Polar" {} \; source/tests/
Length of output: 19516
Script:
#!/bin/bash
# Let's fix the search for dipole and polar tests with correct fd syntax
echo "Searching for dipole test implementations..."
fd -e py . source/tests/ -x grep -l "class.*Test.*Dipole" {}
echo -e "\nSearching for polar test implementations..."
fd -e py . source/tests/ -x grep -l "class.*Test.*Polar" {}
# Let's also check if there are any test files with dipole or polar in their names
echo -e "\nSearching for test files with dipole/polar in names..."
fd -e py ".*dipole.*|.*polar.*" source/tests/
Length of output: 1817
deepmd/dpmodel/model/model.py (1)
Line range hint 2-39
: LGTM: Well-organized imports
The new imports are logically grouped and necessary for supporting the new model types.
Summary by CodeRabbit
New Features
Bug Fixes
Documentation