Skip to content

Commit

Permalink
Merge branch 'devel' into devel
Browse files Browse the repository at this point in the history
Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com>
  • Loading branch information
anyangml authored Feb 23, 2024
2 parents 344910a + 94f0ad1 commit 9fe7bbd
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def forward(
else:
assert self.filter_layers is not None
dmatrix = dmatrix.view(-1, self.nnei, 4)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
Expand Down Expand Up @@ -489,8 +490,8 @@ def forward(
result = result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([-1, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result,
rot_mat,
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
sw,
Expand Down
6 changes: 6 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def test_tf_consistent_with_ref(self):
np.testing.assert_allclose(
rr1.ravel(), rr2.ravel(), rtol=self.rtol, atol=self.atol
)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

def test_tf_self_consistent(self):
"""Test whether TF is self consistent."""
Expand All @@ -276,6 +277,7 @@ def test_tf_self_consistent(self):
np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

def test_dp_consistent_with_ref(self):
"""Test whether DP and reference are consistent."""
Expand All @@ -293,6 +295,7 @@ def test_dp_consistent_with_ref(self):
np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

def test_dp_self_consistent(self):
"""Test whether DP is self consistent."""
Expand All @@ -306,6 +309,7 @@ def test_dp_self_consistent(self):
for rr1, rr2 in zip(ret1, ret2):
if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"
else:
self.assertEqual(rr1, rr2)

Expand All @@ -330,6 +334,7 @@ def test_pt_consistent_with_ref(self):
np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

def test_pt_self_consistent(self):
"""Test whether PT is self consistent."""
Expand All @@ -343,6 +348,7 @@ def test_pt_self_consistent(self):
for rr1, rr2 in zip(ret1, ret2):
if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"
else:
self.assertEqual(rr1, rr2)

Expand Down
37 changes: 37 additions & 0 deletions source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
(True, False), # resnet_dt
(True, False), # type_one_side
([], [[0, 1]]), # excluded_types
("float32", "float64"), # precision
)
class TestSeA(CommonTest, DescriptorTest, unittest.TestCase):
@property
Expand All @@ -47,6 +48,7 @@ def data(self) -> dict:
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
return {
"sel": [10, 10],
Expand All @@ -57,6 +59,7 @@ def data(self) -> dict:
"resnet_dt": resnet_dt,
"type_one_side": type_one_side,
"exclude_types": excluded_types,
"precision": precision,
"seed": 1145141919810,
}

Expand All @@ -66,6 +69,7 @@ def skip_pt(self) -> bool:
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
return not type_one_side or CommonTest.skip_pt

Expand All @@ -75,6 +79,7 @@ def skip_dp(self) -> bool:
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
return not type_one_side or CommonTest.skip_dp

Expand Down Expand Up @@ -147,3 +152,35 @@ def eval_pt(self, pt_obj: Any) -> Any:

def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]:
return (ret[0],)

@property
def rtol(self) -> float:
"""Relative tolerance for comparing the return value."""
(
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
if precision == "float64":
return 1e-10
elif precision == "float32":
return 1e-4
else:
raise ValueError(f"Unknown precision: {precision}")

@property
def atol(self) -> float:
"""Absolute tolerance for comparing the return value."""
(
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
if precision == "float64":
return 1e-10
elif precision == "float32":
return 1e-4
else:
raise ValueError(f"Unknown precision: {precision}")

0 comments on commit 9fe7bbd

Please sign in to comment.