-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(np/pt): explicit dtype and device. #4241
Conversation
📝 WalkthroughWalkthroughThis pull request includes multiple changes across various files in the Changes
Assessment against linked issues
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Outside diff range and nitpick comments (27)
deepmd/utils/out_stat.py (1)
Line range hint
122-128
: Consider consistent dtype usage throughout statistical operationsWhile the array initializations now use explicit dtype, consider also specifying dtype in the statistical operations for complete consistency:
output_bias[type_i] = ( - output[mask].mean(axis=0) if output[mask].size > 0 else np.nan + output[mask].mean(axis=0, dtype=GLOBAL_NP_FLOAT_PRECISION) if output[mask].size > 0 else np.nan ) output_std[type_i] = ( - output[mask].std(axis=0) if output[mask].size > 0 else np.nan + output[mask].std(axis=0, dtype=GLOBAL_NP_FLOAT_PRECISION) if output[mask].size > 0 else np.nan )source/tests/common/test_auto_batch_size.py (1)
Line range hint
1-140
: Consider adding explicit dtypes to array operations.While the current changes are correct, to better align with the PR objectives of explicit dtype handling, consider adding explicit dtypes to array operations like
xp.zeros()
,xp.ones()
, andxp.zeros_like()
. This would prevent potential precision issues and make the tests more robust.Example improvement:
- dd1 = xp.zeros((10000, 2, 1)) + dd1 = xp.zeros((10000, 2, 1), dtype=GLOBAL_NP_FLOAT_PRECISION)This would need to be applied to all array operations in the test file. However, since this is a test file and the actual precision is controlled by the imported modules being tested, this is a minor enhancement rather than a critical change.
deepmd/pt/model/descriptor/descriptor.py (1)
148-157
: Fix indentation for consistency.The explicit dtype handling is good, but the indentation seems excessive. Consider reducing it to match the file's style.
- base_class.mean.copy_( - torch.tensor( - mean, device=env.DEVICE, dtype=base_class.mean.dtype - ) - ) + base_class.mean.copy_( + torch.tensor( + mean, device=env.DEVICE, dtype=base_class.mean.dtype + ) + )deepmd/pt/model/task/ener.py (1)
129-135
: Consider using GLOBAL_NP_FLOAT_PRECISION for consistency.While the explicit dtype specification is good, consider using
GLOBAL_NP_FLOAT_PRECISION
instead ofnp.float64
to maintain consistency with the global precision settings:- bias_atom_e = np.zeros([self.ntypes], dtype=np.float64) + bias_atom_e = np.zeros([self.ntypes], dtype=env.GLOBAL_NP_FLOAT_PRECISION)deepmd/utils/spin.py (1)
43-43
: Consider making type_dtype a class constant.Good practice using a single variable for consistent dtype. Consider making it a class constant since it's used across multiple methods:
class Spin: + TYPE_DTYPE = np.int32 def __init__( self, use_spin: list[bool], virtual_scale: Union[list[float], float], ) -> None: - type_dtype = np.int32 + type_dtype = self.TYPE_DTYPEdeepmd/dpmodel/utils/nlist.py (1)
293-295
: Consider extracting the dtype assignment for better readabilityThe code correctly uses explicit dtype from nbuff for coordinate indices. Consider extracting the dtype to a variable for better readability:
+idx_dtype = nbuff.dtype -xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1, dtype=nbuff.dtype) -yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1, dtype=nbuff.dtype) -zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1, dtype=nbuff.dtype) +xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1, dtype=idx_dtype) +yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1, dtype=idx_dtype) +zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1, dtype=idx_dtype)deepmd/utils/pair_tab.py (3)
39-39
: LGTM! Consider adding type annotation.The addition of explicit
data_type
attribute is good practice. Consider adding type annotation for better code clarity:- self.data_type = np.float64 + self.data_type: np.dtype = np.float64
172-179
: LGTM! Consider extracting common array initialization logic.Good addition of explicit dtypes. The code ensures type consistency by using
self.vdata.dtype
. However, there's duplicated array initialization logic between zero-padding and extrapolation cases.Consider extracting the common initialization logic:
+ def _init_pad_array(self, size: int) -> np.ndarray: + """Initialize padding array with proper dtype and grid points.""" + pad_array = np.zeros((size, self.ncol), dtype=self.vdata.dtype) + pad_array[:, 0] = np.linspace( + self.rmax + self.hh, + self.rmax + self.hh * size, + size, + dtype=self.vdata.dtype, + ) + return pad_array + def _check_table_upper_boundary(self) -> None: # ... existing code ... if np.all(upper_val == 0): if self.rcut > self.rmax: - pad_zero = np.zeros( - (rcut_idx - upper_idx, self.ncol), dtype=self.vdata.dtype - ) - pad_zero[:, 0] = np.linspace( - self.rmax + self.hh, - self.rmax + self.hh * (rcut_idx - upper_idx), - rcut_idx - upper_idx, - dtype=self.vdata.dtype, - ) + pad_zero = self._init_pad_array(rcut_idx - upper_idx) self.vdata = np.concatenate((self.vdata, pad_zero), axis=0) else: if self.rcut > self.rmax: - pad_extrapolation = np.zeros( - (rcut_idx - upper_idx, self.ncol), dtype=self.vdata.dtype - ) - pad_extrapolation[:, 0] = np.linspace( - self.rmax + self.hh, - self.rmax + self.hh * (rcut_idx - upper_idx), - rcut_idx - upper_idx, - dtype=self.vdata.dtype, - ) + pad_extrapolation = self._init_pad_array(rcut_idx - upper_idx)Also applies to: 193-201
Line range hint
262-274
: LGTM! Consider using pre-allocated arrays for performance.Good addition of explicit dtypes. The code ensures type consistency in the cubic spline interpolation coefficients.
For better performance, consider pre-allocating the arrays used in the loop:
def _make_data(self): data = np.zeros( [self.ntypes * self.ntypes * 4 * self.nspline], dtype=self.data_type ) stride = 4 * self.nspline idx_iter = 0 xx = self.vdata[:, 0] + dtmp = np.zeros(stride, dtype=self.data_type) # Pre-allocate outside loop + dd = np.zeros_like(xx, dtype=self.data_type) # Pre-allocate outside loop for t0 in range(self.ntypes): for t1 in range(t0, self.ntypes): vv = self.vdata[:, 1 + idx_iter] cs = CubicSpline(xx, vv, bc_type="clamped") - dd = cs(xx, 1) + np.copyto(dd, cs(xx, 1)) # Reuse pre-allocated array dd *= self.hh - dtmp = np.zeros(stride, dtype=self.data_type) + dtmp.fill(0) # Reset pre-allocated arraydeepmd/dpmodel/fitting/polarizability_fitting.py (3)
Line range hint
65-83
: Enhance scale parameter validationWhile the scale parameter validation has been improved, consider adding these additional checks for robustness:
- Validate that scale values are non-negative (since they're used as multiplicative factors)
- Consider adding a warning if any scale value is zero, as this would nullify the output
if self.scale is None: self.scale = [1.0 for _ in range(ntypes)] else: if isinstance(self.scale, list): assert ( len(self.scale) == ntypes ), "Scale should be a list of length ntypes." + assert all(isinstance(x, (int, float)) for x in self.scale), "All scale values must be numeric" + assert all(x >= 0 for x in self.scale), "Scale values must be non-negative" + if any(x == 0 for x in self.scale): + import warnings + warnings.warn("Some scale values are zero, which will nullify the output for those types") elif isinstance(self.scale, float): + assert self.scale >= 0, "Scale value must be non-negative" self.scale = [self.scale for _ in range(ntypes)] else: raise ValueError( "Scale must be a list of float of length ntypes or a float." )
Line range hint
306-311
: Optimize matrix operations for better performanceThe current implementation of identity matrix creation and bias application could be optimized:
- eye = np.eye(3, dtype=descriptor.dtype) - eye = np.tile(eye, (nframes, nloc, 1, 1)) - # (nframes, nloc, 3, 3) - bias = np.expand_dims(bias, axis=-1) * eye + # More efficient: create broadcasted identity matrix directly + eye = np.broadcast_to(np.eye(3, dtype=descriptor.dtype), (nframes, nloc, 3, 3)) + bias = bias[..., None, None] * eye # Use broadcasting instead of expand_dimsThis optimization:
- Reduces memory allocation by using broadcasting instead of tiling
- Simplifies the bias expansion using modern numpy broadcasting syntax
- Maintains the same numerical precision while improving performance
Line range hint
249-311
: Enhance documentation and add numerical stability checksThe matrix operations in the
call
method would benefit from:
- More detailed shape documentation in docstrings
- Numerical stability checks for the matrix operations
Consider adding these improvements:
- Update the docstring to include shape information for the output:
def call(self, ...): """ ... Returns ------- dict[str, np.ndarray] A dictionary containing: - 'polarizability': Polarizability tensor of shape (nframes, nloc, 3, 3) """
- Add numerical stability checks:
out = np.einsum("ij,ijk->ijk", out, gr) + # Check for numerical stability + if not np.all(np.isfinite(out)): + raise ValueError("Non-finite values detected in polarizability calculation")deepmd/dpmodel/atomic_model/pairtab_atomic_model.py (2)
307-308
: LGTM: Consider enhancing the comment for clarityThe explicit
dtype=np.int64
specification is correct. However, the comment could be more descriptive.Consider updating the comment to explain why int64 is required:
-# index type is int64 +# Use int64 for array indices to prevent integer overflow with large arrays
Line range hint
207-224
: Consider performance optimization opportunitiesThe current implementation in
forward_atomic
and_pair_tabulated_inter
methods creates multiple intermediate arrays and performs repeated indexing operations. Consider these optimizations:
- Pre-allocate arrays for frequently used shapes to reduce memory allocations
- Combine multiple array operations to reduce the number of temporary arrays
- Use vectorized operations where possible instead of element-wise operations
Would you like me to provide specific code examples for these optimizations?
deepmd/pt/utils/nlist.py (2)
459-467
: LGTM: Consistent dtype handling for range tensorsGood practice to explicitly match the dtype with
nbuff_cpu
. The tensors are appropriately created on CPU as they're small and will be transferred to GPU later.Consider using a consistent style for device specification:
- -nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device="cpu", dtype=nbuff_cpu.dtype + -nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device=torch.device("cpu"), dtype=nbuff_cpu.dtype
151-157
: Fix indentation for better readabilityThe indentation in this block is inconsistent with the surrounding code. Consider adjusting it to match the file's style:
- [ - rr, - torch.ones( - [batch_size, nloc, nsel - nnei], device=rr.device, dtype=rr.dtype - ) - + rcut, - ], + [ + rr, + torch.ones( + [batch_size, nloc, nsel - nnei], device=rr.device, dtype=rr.dtype + ) + + rcut, + ],deepmd/pt/model/descriptor/se_r.py (1)
278-283
: LGTM: Proper tensor dtype handling in compute_input_statsThe changes correctly maintain type consistency by explicitly using the destination tensor's dtype. However, consider adding error handling for potential shape mismatches between the computed statistics and the destination tensors.
Consider adding shape validation before the copy operations:
if not self.set_davg_zero: + if mean.shape != self.mean.shape: + raise ValueError(f"Shape mismatch: computed mean shape {mean.shape} != destination shape {self.mean.shape}") self.mean.copy_( torch.tensor(mean, device=env.DEVICE, dtype=self.mean.dtype) ) +if stddev.shape != self.stddev.shape: + raise ValueError(f"Shape mismatch: computed stddev shape {stddev.shape} != destination shape {self.stddev.shape}") self.stddev.copy_( torch.tensor(stddev, device=env.DEVICE, dtype=self.stddev.dtype) )deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)
272-276
: LGTM! Explicit device and dtype specifications improve robustness.The changes correctly specify the device and dtype for the index tensor, preventing potential device mismatches and ensuring compatibility with PyTorch's indexing operations. While the code could be slightly more concise using
arange(len(extended_atype))
, the current implementation is clear and explicit about the tensor's shape.Consider this alternative for a more concise implementation:
- torch.arange( - extended_atype.size(0), - device=extended_coord.device, - dtype=torch.int64, - )[:, None, None], + torch.arange(len(extended_atype), device=extended_coord.device, dtype=torch.int64)[:, None, None],deepmd/pt/utils/stat.py (1)
586-592
: LGTM! Good improvement in dtype handling.The addition of explicit dtype checks and matching improves type safety. The assertion ensures that bias and std arrays maintain consistent dtypes before padding, which prevents potential dtype-related issues during concatenation.
Consider adding a comment explaining why the dtype consistency is important, for better maintainability:
assert ( bias_atom_e[kk].dtype is std_atom_e[kk].dtype -), "bias and std should be of the same dtypes" +), "bias and std must have matching dtypes to ensure consistent numerical precision during concatenation"deepmd/pt/model/atomic_model/linear_atomic_model.py (2)
93-95
: Consider consistent integer dtype usageWhile explicitly specifying
dtype=torch.int32
is good, note that this tensor is later cast toint64
in_sort_rcuts_sels
. Consider usingint64
consistently to avoid unnecessary type conversion overhead.- self.get_model_nsels(), device=env.DEVICE, dtype=torch.int32 + self.get_model_nsels(), device=env.DEVICE, dtype=torch.int64
Line range hint
1-577
: Consider standardizing dtype specifications across all tensor initializationsWhile these changes add explicit dtypes to integer tensors, consider adopting a consistent pattern for all tensor initializations in the file. For example:
- Use
torch.int64
for indices and counts that might get large- Use
torch.int32
for small integers like atom types- Document dtype choices in docstrings for public methods that return tensors
This would make the code more maintainable and prevent potential dtype-related issues.
deepmd/pt/infer/deep_eval.py (1)
540-546
: Maintain consistency with _eval_model implementationThe dtype specification is more verbose here compared to the simpler approach used in
_eval_model
. Consider using the same style for consistency.Apply this change to match the style in
_eval_model
:- np.full( - np.abs(shape), - np.nan, - dtype=NP_PRECISION_DICT[ - RESERVED_PRECISON_DICT[GLOBAL_PT_FLOAT_PRECISION] - ], - ) + np.full(np.abs(shape), np.nan, dtype=prec)deepmd/pt/model/descriptor/se_a.py (1)
Line range hint
673-677
: Consider consistent dtype handling across all tensor creations.For consistency with the recent changes, consider making dtype explicit in other tensor creations:
- In
xyz_scatter
initialization:xyz_scatter = torch.zeros( [nfnl, 4, self.filter_neuron[-1]], - dtype=self.prec, + dtype=self.prec, # Consider using env.GLOBAL_PT_FLOAT_PRECISION for consistency device=extended_coord.device, )
- In the return statement, consider explicitly specifying the device:
return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE), None, None, sw, )Also applies to: 766-772
deepmd/utils/data.py (1)
Line range hint
269-679
: Consider using a constant for the index dtypeSince
np.int64
is used consistently across multiple index arrays, consider defining it as a constant at the module level for better maintainability.+ # At the top of the file with other imports + INDEX_DTYPE = np.int64 - idx = np.arange(self.iterator, iterator_1, dtype=np.int64) + idx = np.arange(self.iterator, iterator_1, dtype=INDEX_DTYPE) - idx = np.arange(ntests_, dtype=np.int64) + idx = np.arange(ntests_, dtype=INDEX_DTYPE) - idx = np.arange(natoms, dtype=np.int64) + idx = np.arange(natoms, dtype=INDEX_DTYPE) - natoms_vec = np.zeros(ntypes, dtype=np.int64) + natoms_vec = np.zeros(ntypes, dtype=INDEX_DTYPE) - idx = np.arange(nframes, dtype=np.int64) + idx = np.arange(nframes, dtype=INDEX_DTYPE) - idx = np.arange(natoms, dtype=np.int64) + idx = np.arange(natoms, dtype=INDEX_DTYPE)deepmd/pt/model/descriptor/se_t_tebd.py (1)
701-706
: LGTM! Consider a minor improvement for consistency.The explicit dtype handling is good and aligns with the PR objectives. However, for better consistency with the rest of the codebase, consider using
env.GLOBAL_PT_FLOAT_PRECISION
instead ofself.mean.dtype
since it's used elsewhere in the file.- torch.tensor(mean, device=env.DEVICE, dtype=self.mean.dtype) + torch.tensor(mean, device=env.DEVICE, dtype=env.GLOBAL_PT_FLOAT_PRECISION) - torch.tensor(stddev, device=env.DEVICE, dtype=self.stddev.dtype) + torch.tensor(stddev, device=env.DEVICE, dtype=env.GLOBAL_PT_FLOAT_PRECISION)deepmd/pt/model/descriptor/se_atten.py (2)
Line range hint
676-677
: Consider improving parameter documentation.The
smooth
andbias
parameters in the constructor could benefit from docstring documentation explaining their purpose and impact on the attention mechanism.Add parameter descriptions to the docstring:
def __init__( self, nnei: int, embed_dim: int, hidden_dim: int, num_heads: int = 1, dotr: bool = False, do_mask: bool = False, scaling_factor: float = 1.0, normalize: bool = True, temperature: Optional[float] = None, bias: bool = True, smooth: bool = True, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, - ): - """Construct a multi-head neighbor-wise attention net.""" + ): + """Construct a multi-head neighbor-wise attention net. + + Parameters + ---------- + bias : bool + Whether to include bias terms in the linear transformations + smooth : bool + Whether to apply smooth attention weights using the switch function + """Also applies to: 678-679
Line range hint
729-729
: Document the attnw_shift parameter.The
attnw_shift
parameter in the forward method lacks documentation explaining its purpose and optimal value selection.Add parameter description to the docstring:
attnw_shift: float = 20.0, ): """Compute the multi-head gated self-attention. Parameters ---------- + attnw_shift : float, default=20.0 + Shift value added to attention weights before softmax when using smooth attention. + Higher values create sharper attention distributions. """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (29)
- deepmd/dpmodel/atomic_model/base_atomic_model.py (2 hunks)
- deepmd/dpmodel/atomic_model/linear_atomic_model.py (2 hunks)
- deepmd/dpmodel/atomic_model/pairtab_atomic_model.py (2 hunks)
- deepmd/dpmodel/fitting/general_fitting.py (4 hunks)
- deepmd/dpmodel/fitting/polarizability_fitting.py (1 hunks)
- deepmd/dpmodel/infer/deep_eval.py (2 hunks)
- deepmd/dpmodel/utils/nlist.py (3 hunks)
- deepmd/infer/model_devi.py (2 hunks)
- deepmd/pt/infer/deep_eval.py (5 hunks)
- deepmd/pt/model/atomic_model/linear_atomic_model.py (2 hunks)
- deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1 hunks)
- deepmd/pt/model/descriptor/descriptor.py (1 hunks)
- deepmd/pt/model/descriptor/repformers.py (2 hunks)
- deepmd/pt/model/descriptor/se_a.py (1 hunks)
- deepmd/pt/model/descriptor/se_atten.py (1 hunks)
- deepmd/pt/model/descriptor/se_r.py (2 hunks)
- deepmd/pt/model/descriptor/se_t.py (1 hunks)
- deepmd/pt/model/descriptor/se_t_tebd.py (1 hunks)
- deepmd/pt/model/model/init.py (1 hunks)
- deepmd/pt/model/task/ener.py (1 hunks)
- deepmd/pt/train/training.py (1 hunks)
- deepmd/pt/utils/nlist.py (3 hunks)
- deepmd/pt/utils/stat.py (1 hunks)
- deepmd/utils/data.py (5 hunks)
- deepmd/utils/data_system.py (3 hunks)
- deepmd/utils/out_stat.py (2 hunks)
- deepmd/utils/pair_tab.py (6 hunks)
- deepmd/utils/spin.py (2 hunks)
- source/tests/common/test_auto_batch_size.py (2 hunks)
🧰 Additional context used
🔇 Additional comments (50)
deepmd/utils/out_stat.py (2)
10-12
: LGTM: Clean import additionThe import of
GLOBAL_NP_FLOAT_PRECISION
is properly placed and follows the project's import style.
122-123
: LGTM: Explicit dtype specification addedThe addition of explicit dtype using
GLOBAL_NP_FLOAT_PRECISION
for statistical arrays improves type safety and precision control.source/tests/common/test_auto_batch_size.py (2)
4-6
: LGTM: Clean import restructuring.The import restructuring improves code organization while maintaining the same functionality.
83-83
: LGTM: Consistent decorator usage.The decorator now uses the structured import, maintaining the same test behavior.
deepmd/pt/model/descriptor/descriptor.py (1)
Line range hint
1-266
: Well-structured code with good type safety.The code demonstrates excellent practices:
- Comprehensive type hints
- Clear documentation
- No pylint suppressions
- Good use of abstract methods and inheritance
deepmd/pt/model/task/ener.py (1)
Line range hint
1-266
: Implementation follows best practices for dtype handling.The changes effectively implement explicit dtype specifications using environment variables consistently throughout the file. The code:
- Uses environment variables for dtype/device specifications
- Properly handles precision conversion for tensor operations
- Aligns with the PR objectives of removing pylint disable comments
deepmd/utils/spin.py (3)
9-11
: LGTM: Import of global precision type.The addition of
GLOBAL_NP_FLOAT_PRECISION
import aligns with the PR's objective of making data types explicit.
51-61
: LGTM: Consistent dtype usage in array initializations.All array initializations now properly specify dtype=type_dtype, improving type safety while maintaining the original logic.
66-68
: Verify precision requirements for virtual_scale calculations.Good addition of explicit
GLOBAL_NP_FLOAT_PRECISION
dtype. Since this affects numerical calculations, let's verify the precision requirements across the codebase.✅ Verification successful
Based on the search results, I can now generate the final response since I have enough context about the usage of
GLOBAL_NP_FLOAT_PRECISION
across the codebase:Verified: Virtual scale precision is consistent with project standards
The addition of
GLOBAL_NP_FLOAT_PRECISION
dtype tovirtual_scale
initialization is correct and consistent with the project's precision requirements. The search results show that:
GLOBAL_NP_FLOAT_PRECISION
is the standard precision type used throughout the codebase for numpy arrays- It's consistently used for similar numerical array initializations in other parts of the project
- The change aligns with the project's precision handling patterns, particularly in the spin-related calculations
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for other virtual_scale usages and their precision handling # to ensure consistency across the codebase. echo "Checking for virtual_scale usage patterns:" rg -l "virtual_scale" | while read -r file; do echo "=== $file ===" rg "virtual_scale.*=.*zeros\(" "$file" -A 2 done echo -e "\nChecking for GLOBAL_NP_FLOAT_PRECISION usage patterns:" rg "GLOBAL_NP_FLOAT_PRECISION" -A 2Length of output: 58228
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
20-22
: LGTM! Clean import of GLOBAL_NP_FLOAT_PRECISION.The import is properly placed and aligns with the PR objective of making dtype specifications explicit.
60-65
: LGTM! Explicit dtype specification added.The addition of explicit dtype using GLOBAL_NP_FLOAT_PRECISION for both arrays is correct and aligns with the PR objectives.
Let's verify the consistency of precision across the model:
✅ Verification successful
Based on the search results, I can now generate the final response:
Precision usage is consistent across the codebase
The verification confirms that:
- GLOBAL_NP_FLOAT_PRECISION is consistently used across the codebase for numpy array initializations where precision matters
- In linear_atomic_model.py, the same precision is used for array initialization
- Other modules (descriptors, fitting, inference) also consistently use GLOBAL_NP_FLOAT_PRECISION
- Where precision is not specified (e.g., zeros_like, ones_like), it inherits from the input array's dtype, which maintains consistency
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for other numpy array initializations in the model to ensure consistent precision # Search for numpy array initializations rg -A 2 "np\.(zeros|ones|array|empty)" deepmd/dpmodel/atomic_model/ # Search for other uses of GLOBAL_NP_FLOAT_PRECISION rg "GLOBAL_NP_FLOAT_PRECISION" deepmd/dpmodel/Length of output: 4815
deepmd/pt/model/model/__init__.py (1)
103-103
: LGTM! Explicit dtype improves type safety.The addition of
dtype=bool
tonp.full()
is a good practice that:
- Prevents potential type coercion issues
- Makes the code intention clearer
- Aligns with the PR objective of adding explicit dtypes
deepmd/dpmodel/utils/nlist.py (3)
Line range hint
134-139
: LGTM: Type-safe array concatenationThe concatenation operations correctly inherit dtypes from their source arrays
rr
andnlist
, maintaining type consistency throughout the operation.
280-281
: LGTM: Explicit int64 dtype for index arraysGood practice using explicit
int64
dtype for index arrays. This is crucial for handling large molecular systems and prevents potential integer overflow issues.
Line range hint
1-320
: Verify removal of all pylint dtype disable commentsThe changes align well with the PR objective of making dtypes explicit. Let's verify that all pylint disable comments for dtype have been removed from this file.
✅ Verification successful
Let me broaden the search to check for any pylint disable comments in the file to ensure we haven't missed anything dtype-related that might be using a different format.
No dtype-related disable comments found in the file
The verification confirms that there are no pylint disable comments or noqa comments related to dtype in the file. The changes successfully make dtypes explicit without requiring any suppression of linter warnings.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Check for any remaining pylint disable comments related to dtype rg -i "pylint: disable=(no-explicit-dtype|no-explicit-device)" "deepmd/dpmodel/utils/nlist.py"Length of output: 94
Script:
#!/bin/bash # Check for any pylint disable comments in the file rg -i "pylint: disable" "deepmd/dpmodel/utils/nlist.py" # Also check with # noqa format which is sometimes used rg -i "# noqa" "deepmd/dpmodel/utils/nlist.py"Length of output: 102
deepmd/utils/pair_tab.py (1)
60-60
: LGTM! Explicit dtype in data loading.Good addition of explicit dtype in
loadtxt
. This ensures consistent precision throughout the computation pipeline.deepmd/dpmodel/infer/deep_eval.py (3)
27-29
: LGTM: Clean import additionThe addition of
GLOBAL_NP_FLOAT_PRECISION
import aligns with the PR objective of making dtype explicit throughout the codebase.
346-346
: LGTM: Added explicit dtype to np.fullGood addition of explicit dtype using GLOBAL_NP_FLOAT_PRECISION for consistent numerical precision handling.
351-352
: Consider revisiting the "hacky" implementation while maintaining explicit dtypeWhile the addition of explicit dtype is good, the comment suggests this is a temporary solution. Consider:
- Documenting why np.abs(shape) is necessary
- Creating a ticket to properly handle this edge case
- Adding tests to verify the behavior
The dtype change itself is approved, but the underlying implementation might need attention.
Let's check for similar patterns in the codebase:
deepmd/dpmodel/atomic_model/pairtab_atomic_model.py (1)
207-210
: LGTM: Explicit dtype for array indices enhances type safetyThe explicit specification of
dtype=np.int64
for array indices is a good practice that prevents potential integer overflow issues when dealing with large arrays.deepmd/dpmodel/atomic_model/linear_atomic_model.py (3)
15-17
: LGTM: Import of global precision constantThe addition of
GLOBAL_NP_FLOAT_PRECISION
import aligns with the PR objective of making dtype explicit and ensures consistent float precision across the codebase.
292-296
: LGTM: Explicit dtype specification for weightsThe weights are now correctly initialized with explicit dtype using
GLOBAL_NP_FLOAT_PRECISION
, ensuring consistent numerical precision across the interface. The implementation efficiently uses broadcasting and numpy operations.
Line range hint
379-379
: Verify version compatibility handlingThe version check has been updated from 1 to 2, which implies a breaking change in the serialization format. This could affect backward compatibility with existing serialized models.
Let's check for any existing serialized models and documentation about this version change:
✅ Verification successful
Let me analyze the results. I notice that:
- The version check update from 1 to 2 is consistently applied across many files in the codebase
- Most files use
check_version_compatibility(data.pop("@version", 1), 2, 1)
, but atomic models specifically usecheck_version_compatibility(data.pop("@version", 1), 2, 2)
- There are compatibility test inputs in
source/tests/tf/compat_inputs/
with different versionsLet me check the compatibility test files and version handling implementation:
Let me check the implementation of version compatibility and any related documentation:
Based on the comprehensive analysis of the results, I can now provide a final response:
Version compatibility change is part of a coordinated update across the codebase
The version change from 1 to 2 in
linear_atomic_model.py
is part of a systematic version update across the entire codebase. The evidence shows:
- The version check has been consistently updated from 1 to 2 across multiple components, with atomic models specifically using
check_version_compatibility(data.pop("@version", 1), 2, 2)
for stricter compatibility.- The version compatibility implementation in
deepmd/utils/version.py
properly handles version ranges and provides clear error messages.- The codebase maintains backward compatibility by:
- Using default version 1 when version is not specified (
@version, 1
)- Setting minimum supported version to 1 in most components
- Preserving compatibility test files for different versions in
source/tests/tf/compat_inputs/
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for documentation about version changes echo "Checking for version documentation..." rg -l "version.*2.*breaking.*change|migration.*guide" --type md # Search for potential serialized model files echo "Checking for potential serialized model files..." fd -e pkl -e json -e yaml -e h5 -e pb # Search for other version checks in the codebase echo "Checking for other version checks..." rg "check_version_compatibility.*1.*2"Length of output: 7372
Script:
#!/bin/bash # Check compatibility test files content echo "Checking compatibility test files..." rg -A 5 "@version" source/tests/tf/compat_inputs/ # Check version compatibility implementation echo "Checking version compatibility implementation..." rg -A 5 "def check_version_compatibility"Length of output: 644
Script:
#!/bin/bash # Check the version compatibility implementation echo "Checking version compatibility implementation..." cat deepmd/utils/version.py # Check for version-related documentation echo "Checking for version documentation..." fd README.md | xargs rg -A 5 "version|compatibility" # Check commit history for version changes echo "Checking commit history..." git log -p --all | rg -A 5 "check_version_compatibility"Length of output: 61586
deepmd/infer/model_devi.py (2)
331-331
: LGTM! Explicit dtype specification improves code clarity.The explicit
dtype=np.int64
specification for the frame indices array ensures consistent behavior across different numpy versions and platforms.
505-505
: LGTM! Consistent dtype specification.The explicit
dtype=np.int64
specification here maintains consistency with the similar change incalc_model_devi
, ensuring uniform integer type usage for frame indices throughout the codebase.deepmd/pt/utils/nlist.py (2)
437-440
: LGTM: Explicit int64 dtype for index operationsThe explicit specification of
dtype=torch.int64
for index operations is a good practice that ensures type safety and prevents potential overflow issues with large arrays.
Line range hint
1-495
: Verify removal of implicit dtype usageLet's verify that we haven't missed any implicit dtype usage in the file.
✅ Verification successful
Let me verify the remaining instances with a more focused search to ensure we catch all potential cases.
All tensor creation operations have explicit dtype specifications
After reviewing the search results, I can confirm that all tensor creation operations (
torch.zeros
,torch.ones
,torch.arange
) in the file have explicit dtype specifications. Each instance properly inherits dtype from input tensors or explicitly specifies it. For example:
torch.zeros(1, dtype=coord.dtype, ...)
torch.ones([...], dtype=rr.dtype, ...)
torch.arange(..., dtype=nbuff_cpu.dtype)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for potential implicit dtype usage in tensor creation # Test: Search for tensor creation without explicit dtype. # Expect: No matches for tensor creation without dtype specification. # Look for torch.tensor, torch.zeros, torch.ones, torch.full without dtype rg "torch\.(tensor|zeros|ones|full)\([^)]*\)" -A 1 "deepmd/pt/utils/nlist.py" | \ rg -v "dtype=" # Look for arange without dtype rg "torch\.arange\([^)]*\)" -A 1 "deepmd/pt/utils/nlist.py" | \ rg -v "dtype="Length of output: 217
Script:
#!/bin/bash # Get more context around the potential implicit dtype usage rg "torch\.(zeros|ones|full|tensor|arange)" -B 2 -A 2 "deepmd/pt/utils/nlist.py"Length of output: 1531
deepmd/pt/model/descriptor/se_r.py (1)
210-219
: LGTM: Proper tensor dtype handling in share_paramsThe changes correctly maintain type consistency by explicitly using the source tensor's dtype during tensor copying operations. This aligns with the PR objectives of making dtype specifications explicit.
deepmd/dpmodel/fitting/general_fitting.py (5)
17-17
: LGTM: Import added for explicit dtype handling.The addition of
PRECISION_DICT
import aligns with the PR objective of making data types explicit.
137-137
: LGTM: Precision initialization added.The precision value is correctly retrieved from
PRECISION_DICT
using a case-insensitive lookup, which will be used for array initializations.
151-154
: LGTM: Explicit dtype added to bias_atom_e initialization.The array initialization now uses explicit dtype from the precision value, maintaining type safety while preserving the existing shape checks.
156-157
: LGTM: Explicit dtype added to parameter initializations.The fparam and aparam arrays are now initialized with explicit dtype from the precision value, ensuring type consistency throughout the model.
Also applies to: 161-162
415-415
: Verify type consistency between descriptor and model precision.While the output array correctly inherits its dtype from the input descriptor, we should verify that this doesn't create type mismatches with the model's precision.
✅ Verification successful
Let me gather more specific information about the dtype handling in the relevant files.
This will help us understand:
- How dtype is handled in the fitting code
- How dtype is managed in the descriptor base class
- How model precision is handled during model creation
Based on the gathered information, I can now provide the final response:
No type consistency issues with descriptor dtype inheritance
The code is correctly handling dtype consistency:
- The model has a well-defined precision handling system using
PRECISION_DICT
and global precision settings- Input data is automatically cast to the model's precision in
make_model.py
- The descriptor dtype is propagated through the fitting network, maintaining consistency with the model's precision settings
- Output data is properly converted back to the input precision when needed
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for potential dtype mismatches between descriptor and model precision # Look for places where descriptor dtype is set or modified # Search for descriptor dtype assignments echo "Searching for descriptor dtype assignments..." rg -l "dtype.*descriptor" # Search for precision conversions echo "Searching for precision/dtype conversions..." rg -l "astype.*prec|to.*dtype"Length of output: 4922
Script:
#!/bin/bash # Check dtype handling in general_fitting.py and related descriptor code # Check general_fitting.py for dtype handling echo "Checking dtype handling in general_fitting.py:" rg -A 5 -B 5 "dtype.*descriptor" deepmd/dpmodel/fitting/general_fitting.py # Check descriptor base class for dtype handling echo -e "\nChecking descriptor base class for dtype handling:" rg -A 5 -B 5 "dtype" deepmd/tf/descriptor/descriptor.py # Check model precision handling echo -e "\nChecking model precision handling:" rg -A 5 -B 5 "precision|dtype" deepmd/dpmodel/model/make_model.pyLength of output: 5254
deepmd/pt/model/atomic_model/linear_atomic_model.py (1)
304-308
: LGTM! Good explicit dtype specification with clear justificationThe use of
torch.int32
is well-justified by the comment and appropriate for mapping atom types, as the number of atom types in molecular systems is typically small.deepmd/pt/model/descriptor/repformers.py (2)
469-478
: LGTM: Explicit dtype and device specifications for border_op parameters.The changes correctly specify
torch.int32
dtype and device for size parameters passed to border_op, which aligns with C++ integer expectations.
541-546
: LGTM: Proper device and dtype specifications for statistical tensors.The changes correctly specify the device and maintain dtype consistency with existing tensors when copying mean and stddev values.
deepmd/pt/infer/deep_eval.py (3)
397-397
: LGTM: Centralized precision handlingGood practice to define precision once and reuse it consistently throughout the method.
Line range hint
407-421
: LGTM: Explicit dtype and device specificationGood improvements in type safety by making dtype and device explicit in tensor creation.
460-460
:⚠️ Potential issueVerify np.nan compatibility with all possible dtypes
Using
np.nan
withprec
dtype might cause issues ifprec
is an integer dtype. Consider adding a type check or using a type-specific sentinel value.Consider this safer alternative:
-np.full(np.abs(shape), np.nan, dtype=prec) +np.full(np.abs(shape), np.nan if np.issubdtype(prec, np.floating) else -1, dtype=prec)deepmd/pt/model/descriptor/se_a.py (2)
559-564
: LGTM! Explicit dtype handling improves code quality.The changes properly specify the device and dtype when creating tensors, which aligns with the PR objectives of making data types explicit. The implementation follows PyTorch best practices for tensor creation.
559-564
: Verify precision consistency across tensor operations.The changes maintain numerical precision by using the same dtype as the target tensors (
self.mean.dtype
andself.stddev.dtype
). Let's verify that all tensor operations maintain consistent precision:✅ Verification successful
Based on the search results, I can now generate the final response as I have sufficient information about precision handling across the codebase.
Precision consistency is maintained across tensor operations
The code changes maintain numerical precision consistency because:
- All tensor initializations consistently use either:
- The precision from the instance (
self.prec
)- The target tensor's dtype (
self.mean.dtype
,self.stddev.dtype
)- The global precision setting (
env.GLOBAL_PT_FLOAT_PRECISION
)- Explicit dtype conversions via
.to(dtype=...)
are used when needed to ensure precision matching- The pattern of precision handling is consistent across all descriptor implementations (se_a, se_r, se_t, etc.)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for potential precision mismatches in tensor operations # Test: Look for tensor creation and conversion patterns rg -A 2 "torch\.(tensor|zeros|ones|empty)" deepmd/pt/model/descriptor/ rg -A 2 "\.to\(dtype=" deepmd/pt/model/descriptor/Length of output: 13016
deepmd/pt/model/descriptor/se_t.py (1)
609-614
: LGTM! Explicit dtype handling improves type safety.The changes properly specify the dtype when copying tensor data, which:
- Ensures type consistency between source and destination tensors
- Prevents potential silent type conversions
- Aligns with the PR objective of making dtypes explicit
deepmd/utils/data.py (6)
269-269
: LGTM: Explicit dtype specification for batch indicesThe change to use
np.int64
for batch indices is appropriate and aligns with the PR objectives.
293-293
: LGTM: Explicit dtype specification for test indicesThe change to use
np.int64
for test indices is appropriate and aligns with the PR objectives.
381-381
: LGTM: Explicit dtype specification for atom indicesThe change to use
np.int64
for atom indices in the selection map is appropriate and aligns with the PR objectives.
388-388
: LGTM: Explicit dtype specification for atom count arrayThe change to use
np.int64
for the atom count array is appropriate and aligns with the PR objectives.
438-438
: LGTM: Explicit dtype specification for shuffle indicesThe change to use
np.int64
for shuffle indices is appropriate and aligns with the PR objectives.
679-679
: LGTM: Explicit dtype specification for atom mapping indicesThe change to use
np.int64
for atom mapping indices is appropriate and aligns with the PR objectives.deepmd/pt/train/training.py (1)
944-946
: LGTM! Explicit dtype specification improves type safety.The addition of
dtype=np.int32
for the model index array is a good practice as it:
- Makes the data type explicit, improving code clarity
- Uses an appropriate integer size for model indices
- Aligns with the PR's objective of cleaning up implicit data types
deepmd/utils/data_system.py (1)
462-464
: Consider usingnp.intp
for array indices to ensure platform compatibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (3)
source/tests/common/test_common.py (1)
16-17
: Add docstring and use more descriptive test method name.The test class and method would benefit from documentation explaining their purpose and the scenarios being tested.
class TestGetXPPrecision(unittest.TestCase): - def test(self): + """Test suite for get_xp_precision function.""" + def test_precision_type_mapping(self): + """Test mapping of various precision string inputs to their corresponding types."""deepmd/dpmodel/common.py (1)
53-58
: Enhance function documentation with type hints and return value.The docstring should include parameter descriptions and return type information for better clarity.
def get_xp_precision( xp: Any, precision: str, ): - """Get the precision from the API compatible namespace.""" + """Get the precision from the API compatible namespace. + + Parameters + ---------- + xp : Any + The array API compatible namespace (e.g., numpy, torch, jax.numpy) + precision : str + The precision string identifier + + Returns + ------- + Any + The corresponding precision type from the xp namespace + + Raises + ------ + ValueError + If the precision string is not supported for the given namespace + """deepmd/dpmodel/fitting/general_fitting.py (1)
155-157
: Use consistent precision for 'bias_atom_e'In lines 155-157 and 160,
GLOBAL_NP_FLOAT_PRECISION
is used as thedtype
forbias_atom_e
. To maintain consistency with the specified precision of the class, consider usingself.prec
instead.Apply this diff to use
self.prec
:if bias_atom_e is None: self.bias_atom_e = np.zeros( - [self.ntypes, net_dim_out], dtype=GLOBAL_NP_FLOAT_PRECISION + [self.ntypes, net_dim_out], dtype=self.prec ) else: assert bias_atom_e.shape == (self.ntypes, net_dim_out) - self.bias_atom_e = bias_atom_e.astype(GLOBAL_NP_FLOAT_PRECISION) + self.bias_atom_e = bias_atom_e.astype(self.prec)Also applies to: 160-160
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (6)
- deepmd/dpmodel/common.py (1 hunks)
- deepmd/dpmodel/fitting/general_fitting.py (6 hunks)
- deepmd/dpmodel/utils/nlist.py (3 hunks)
- deepmd/pt/model/task/fitting.py (1 hunks)
- deepmd/pt/utils/nlist.py (3 hunks)
- source/tests/common/test_common.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/dpmodel/utils/nlist.py
🧰 Additional context used
🔇 Additional comments (11)
source/tests/common/test_common.py (2)
1-14
: LGTM! Well-organized imports following best practices.The imports are logically grouped and include all necessary dependencies for the test implementation.
8-10
: Verify get_xp_precision implementation.Let's verify the implementation to ensure our test recommendations align with the actual function behavior.
✅ Verification successful
Based on the implementation found, I can now generate the final response since I have sufficient information about the
get_xp_precision
function.Function implementation verified - maps precision strings to data types
The
get_xp_precision
function is correctly implemented indeepmd/dpmodel/common.py
. It:
- Takes two parameters:
xp
(API namespace) andprecision
(string)- Maps precision string identifiers to their corresponding data types
- Supports multiple precision types including:
- Floating point: float16/half, float32/single, float64/double, bfloat16
- Integer: int32, int64
- Boolean: bool
- Special cases: 'default' and 'global' that use precision dictionaries
- Throws ValueError for unsupported precision values
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for get_xp_precision implementation ast-grep --pattern 'def get_xp_precision($_, $_): $$$ 'Length of output: 1677
deepmd/dpmodel/common.py (2)
70-73
: Consider optimizing recursive calls for default and global precision.The recursive calls for "default" and "global" cases could be simplified by directly accessing the final precision type. However, the current implementation maintains better code readability and flexibility for future changes.
53-78
: Implementation looks good!The function effectively maps precision strings to their corresponding types while:
- Handling all precision types defined in PRECISION_DICT
- Providing proper error handling for unsupported precisions
- Maintaining special handling for bfloat16
- Integrating well with existing constants and assertions
Let's verify the usage of this new function across the codebase:
✅ Verification successful
Let me analyze the results we have. The
rg
search shows thatget_xp_precision
is:
- Properly tested in
source/tests/common/test_common.py
with all supported precision types- Used in
deepmd/dpmodel/fitting/general_fitting.py
for array creation and type conversion- Has recursive calls within itself for handling "default" and "global" precisions
Let's verify if there are any other potential places where dtype is specified directly that should be using this function instead:
This will help us ensure the function is being used consistently across the codebase where precision types are needed.
Based on the search results, I can now generate the final response:
The
get_xp_precision
function implementation is correct and properly usedThe function is used in two key places:
- In tests (
source/tests/common/test_common.py
) where all precision types are properly verified- In the fitting module (
deepmd/dpmodel/fitting/general_fitting.py
) where it's used for array creation and type conversionThe search for direct dtype specifications shows that most direct dtype usage is in test files or for specific cases like:
- Array initialization in tests
- Integer array creation for indices/masks
- Type conversion for numpy arrays in data loading/processing
The function is used where precision control is needed for model operations, while direct dtype specifications are appropriately used for infrastructure code (tests, data loading, etc.) where specific types are required.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check the usage of get_xp_precision function # and ensure it's being used consistently with dtype specifications # Look for direct function calls echo "Direct function calls:" rg "get_xp_precision\(" -A 2 # Look for related dtype specifications that might need to use this function echo "\nPotential places where the function could be used:" rg "dtype\s*=\s*['\"](?:float|int|bool)" --type pythonLength of output: 2091
Script:
#!/bin/bash # Search for direct dtype specifications in python files echo "Direct dtype specifications:" rg "dtype\s*=\s*(?:np\.|xp\.)?(float|int|bool)" --type-add 'python:*.py' --type python -A 2Length of output: 79376
deepmd/pt/utils/nlist.py (3)
151-157
: LGTM: Explicit dtype in tensor concatenationThe explicit dtype specification in tensor creation ensures type consistency during concatenation operations.
437-440
: LGTM: Appropriate use of int64 for index tensorsUsing
torch.int64
for index tensors is the correct choice as it ensures safe handling of large arrays and follows PyTorch best practices.
455-467
: Consider adding error handling for device transfer operationsWhile the explicit dtype handling is correct, the device transfer operation at line 472 (
xyz.to(device=device, non_blocking=True)
) could benefit from error handling to gracefully handle CUDA out-of-memory or device transfer failures.Consider wrapping the device transfer in a try-catch block:
- xyz = xyz.to(device=device, non_blocking=True) + try: + xyz = xyz.to(device=device, non_blocking=True) + except RuntimeError as e: + if "out of memory" in str(e): + raise RuntimeError( + f"Failed to transfer coordinate indices to device. " + f"Consider reducing the cutoff radius (current: {rcut}) " + f"or using a smaller system size." + ) from e + raiseLet's verify the current error handling practices in the codebase:
deepmd/pt/model/task/fitting.py (1)
180-182
: LGTM! Proper use of explicit dtype and device.The initialization of
bias_atom_e
tensor now correctly usesenv.GLOBAL_PT_FLOAT_PRECISION
for explicit dtype and properly specifies the device. This change aligns with the PR objectives and follows PyTorch best practices.deepmd/dpmodel/fitting/general_fitting.py (3)
162-163
: Consistent precision usage for parameter initializationGreat job using
self.prec
as thedtype
for initializingfparam_avg
,fparam_inv_std
,aparam_avg
, andaparam_inv_std
. This ensures consistency with the specified precision throughout the class.Also applies to: 167-168
421-423
: Verify the handling of precision in 'get_xp_precision'In line 422,
get_xp_precision(xp, self.precision)
is used to determine thedtype
for initializingouts
. Ensure thatget_xp_precision
correctly handles the value ofself.precision
, especially if custom or unexpected precision strings are provided.Run the following script to check the accepted precision values in
get_xp_precision
:✅ Verification successful
Let me gather more information about how
self.precision
is initialized and used in the codebase to ensure it matches the supported values.
Let me search for the definition of DEFAULT_PRECISION to complete the verification.
Based on the gathered information, I can now provide the final response:
The precision handling in
get_xp_precision
is properly implemented and safeThe implementation in
get_xp_precision
correctly handles all supported precision values:
- "float16"/"half" → xp.float16
- "float32"/"single" → xp.float32
- "float64"/"double" → xp.float64
The
precision
parameter is consistently initialized across the codebase with a default value of "float64" (DEFAULT_PRECISION), and the values passed to it are strictly controlled through type hints (str) and default parameters. This ensures that only valid precision strings reach theget_xp_precision
function.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that 'get_xp_precision' supports the provided precision strings # Test: Search for the implementation of 'get_xp_precision' and check how it handles 'self.precision' rg -A 10 'def get_xp_precision' deepmd/dpmodel/common.pyLength of output: 427
Script:
#!/bin/bash # Search for where precision is set as a class attribute ast-grep --pattern 'self.precision = $_' # Search for where precision is passed as an argument rg -B 3 'precision: str'Length of output: 28401
Script:
#!/bin/bash # Search for DEFAULT_PRECISION definition rg "DEFAULT_PRECISION ="Length of output: 134
449-449
: Confirm 'global' is a valid argument for 'get_xp_precision'In line 449,
get_xp_precision(xp, "global")
is used to castouts
. Verify that"global"
is an accepted value for theprecision
parameter inget_xp_precision
and that it returns the correct global precision.Run the following script to confirm support for the
"global"
precision argument:✅ Verification successful
Let me search for any other references to "global" precision in the codebase to ensure we have a complete understanding.
Let me search for the definition of
get_xp_precision
to verify how it handles the "global" argument.
Let me search for more context around the
get_xp_precision
function to see the complete implementation.
Based on the code analysis, I can now provide a definitive response about the use of "global" as an argument to
get_xp_precision
:"global" is a valid argument for get_xp_precision
The implementation shows that when
precision == "global"
, the function returns the precision corresponding toGLOBAL_NP_FLOAT_PRECISION
by recursively calling itself with the resolved precision type. This means "global" is a valid argument that will return the globally configured precision for the system.The code at line 449 is correct and safe to use. When
"global"
is passed toget_xp_precision
, it will properly resolve to the system's global precision setting throughGLOBAL_NP_FLOAT_PRECISION
.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check if 'get_xp_precision' supports 'global' as a precision argument # Test: Inspect the definition of 'get_xp_precision' to see if 'global' is handled appropriately rg -A 10 'def get_xp_precision' deepmd/dpmodel/common.pyLength of output: 427
Script:
#!/bin/bash # Search for any references to "global" precision in the codebase rg -B 2 -A 2 '"global".*precision' # Also search for any configuration or constant definitions that might define valid precision values rg -B 2 -A 2 'PRECISION|precision.*='Length of output: 477662
Script:
#!/bin/bash # Search for get_xp_precision definition rg -A 10 'def get_xp_precision'Length of output: 677
Script:
#!/bin/bash # Search for more context around get_xp_precision function rg -B 5 -A 20 'def get_xp_precision'Length of output: 1564
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
- deepmd/dpmodel/fitting/general_fitting.py (6 hunks)
- source/tests/common/test_common.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/common/test_common.py
🧰 Additional context used
🔇 Additional comments (8)
deepmd/dpmodel/fitting/general_fitting.py (8)
17-17
: Appropriate import ofPRECISION_DICT
The
PRECISION_DICT
is now imported, which is necessary for precision handling.
21-21
: Importedget_xp_precision
for precision managementThe function
get_xp_precision
is imported to facilitate consistent precision handling across computations.
32-34
: ImportedGLOBAL_NP_FLOAT_PRECISION
Importing
GLOBAL_NP_FLOAT_PRECISION
fromdeepmd.env
ensures access to the global numpy float precision setting.
141-145
: Added validation forprecision
parameterThe code now validates the
precision
parameter againstPRECISION_DICT
, raising aValueError
if an unsupported precision is provided. This prevents potentialKeyError
exceptions and improves robustness.
166-167
: Initializedfparam_avg
andfparam_inv_std
with specified precisionArrays
fparam_avg
andfparam_inv_std
are now initialized withdtype=self.prec
, ensuring consistency with the specified precision.
171-172
: Initializedaparam_avg
andaparam_inv_std
with specified precisionArrays
aparam_avg
andaparam_inv_std
are initialized withdtype=self.prec
, maintaining precision consistency.
425-427
: Initializedouts
array with specified precisionThe
outs
array is initialized with the appropriate precision usingget_xp_precision(xp, self.precision)
, ensuring consistent precision in computations.
453-453
: Casting outputs to global precisionThe outputs are cast to global precision before returning, standardizing the output precision regardless of internal computation precision.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4241 +/- ##
==========================================
+ Coverage 84.55% 84.57% +0.02%
==========================================
Files 537 547 +10
Lines 51238 51360 +122
Branches 3047 3047
==========================================
+ Hits 43323 43440 +117
- Misses 6969 6973 +4
- Partials 946 947 +1 ☔ View full report in Codecov by Sentry. |
fix: #4016
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Documentation
Refactor
Tests