Skip to content

Commit 8215320

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 048612e commit 8215320

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

deepmd/tf/nvnmd/entrypoints/mapt.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,9 @@ def run_u2s(self):
457457
keys = list(dic_ph.keys())
458458
vals = list(dic_ph.values())
459459

460-
u = N2 * np.reshape(np.arange(0, N + 1, dtype=GLOBAL_NP_FLOAT_PRECISION) / N, [-1, 1])
460+
u = N2 * np.reshape(
461+
np.arange(0, N + 1, dtype=GLOBAL_NP_FLOAT_PRECISION) / N, [-1, 1]
462+
)
461463
res_lst = run_sess(sess, vals, feed_dict={dic_ph["u"]: u})
462464
res_dic = dict(zip(keys, res_lst))
463465

@@ -563,7 +565,13 @@ def run_s2g(self):
563565
keys = list(dic_ph.keys())
564566
vals = list(dic_ph.values())
565567

566-
s = N2 * np.reshape(np.arange(0, N + 1, dtype=GLOBAL_NP_FLOAT_PRECISION) / N, [-1, 1]) + smin_
568+
s = (
569+
N2
570+
* np.reshape(
571+
np.arange(0, N + 1, dtype=GLOBAL_NP_FLOAT_PRECISION) / N, [-1, 1]
572+
)
573+
+ smin_
574+
)
567575
res_lst = run_sess(sess, vals, feed_dict={dic_ph["s"]: s})
568576
res_dic = dict(zip(keys, res_lst))
569577

@@ -602,7 +610,7 @@ def build_t2g(self):
602610
# type_embedding of i, j atoms -> two_side_type_embedding
603611
type_embedding = dic_ph["t_ebd"]
604612
padding_ntypes = type_embedding.shape[0]
605-
type_embedding_nei = tf.tile( # pylint: disable=no-explicit-dtype
613+
type_embedding_nei = tf.tile( # pylint: disable=no-explicit-dtype
606614
tf.reshape(type_embedding, [1, padding_ntypes, -1]),
607615
[padding_ntypes, 1, 1],
608616
) # (ntypes) * ntypes * Y

source/tests/tf/test_nvnmd_entrypoints.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,15 @@ class TestNvnmdEntrypointsV0(tf.test.TestCase):
5151
def test_mapt_cnn_v0(self) -> None:
5252
config_file = str(tests_path / "nvnmd" / "ref" / "config_v0_cnn.npy")
5353
weight_file = str(tests_path / "nvnmd" / "ref" / "weight_v0_cnn.npy")
54-
output_filename = f'{tests_path}/nvnmd/out/map_v0_cnn.npy'
55-
parts = [f'{tests_path}/nvnmd/out/map_v0_cnn_part_1.npy', f'{tests_path}/nvnmd/out/map_v0_cnn_part_2.npy', f'{tests_path}/nvnmd/out/map_v0_cnn_part_3.npy']
56-
with open(output_filename, 'wb') as output_file:
54+
output_filename = f"{tests_path}/nvnmd/out/map_v0_cnn.npy"
55+
parts = [
56+
f"{tests_path}/nvnmd/out/map_v0_cnn_part_1.npy",
57+
f"{tests_path}/nvnmd/out/map_v0_cnn_part_2.npy",
58+
f"{tests_path}/nvnmd/out/map_v0_cnn_part_3.npy",
59+
]
60+
with open(output_filename, "wb") as output_file:
5761
for part_filename in parts:
58-
with open(part_filename, 'rb') as part_file:
62+
with open(part_filename, "rb") as part_file:
5963
output_file.write(part_file.read())
6064
map_file = str(tests_path / "nvnmd" / "out" / "map_v0_cnn.npy")
6165
# mapt
@@ -528,11 +532,14 @@ class TestNvnmdEntrypointsV1(tf.test.TestCase):
528532
def test_mapt_cnn_v1(self) -> None:
529533
config_file = str(tests_path / "nvnmd" / "ref" / "config_v1_cnn.npy")
530534
weight_file = str(tests_path / "nvnmd" / "ref" / "weight_v1_cnn.npy")
531-
output_filename = f'{tests_path}/nvnmd/out/map_v1_cnn.npy'
532-
parts = [f'{tests_path}/nvnmd/out/map_v1_cnn_part_1.npy', f'{tests_path}/nvnmd/out/map_v1_cnn_part_2.npy']
533-
with open(output_filename, 'wb') as output_file:
535+
output_filename = f"{tests_path}/nvnmd/out/map_v1_cnn.npy"
536+
parts = [
537+
f"{tests_path}/nvnmd/out/map_v1_cnn_part_1.npy",
538+
f"{tests_path}/nvnmd/out/map_v1_cnn_part_2.npy",
539+
]
540+
with open(output_filename, "wb") as output_file:
534541
for part_filename in parts:
535-
with open(part_filename, 'rb') as part_file:
542+
with open(part_filename, "rb") as part_file:
536543
output_file.write(part_file.read())
537544
map_file = str(tests_path / "nvnmd" / "out" / "map_v1_cnn.npy")
538545
# mapt

0 commit comments

Comments
 (0)