From dee7c76206cac6dc1aa44ad79b990e1816a47d84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 17 Nov 2024 07:03:16 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 175167edcb..1c5e3f1c52 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -548,8 +548,14 @@ def compute_output_stats_atomic( } # reshape outputs [nframes, nloc * ndim] --> reshape to [nframes * nloc, 1, ndim] for concatenation # reshape natoms [nframes, nloc] --> reshape to [nframes * nolc, 1] for concatenation - natoms = {k: [sys_v.reshape(-1,1) for sys_v in v] for k, v in natoms.items()} - outputs = {k: [sys.reshape(natoms[k][sys_idx].shape[0], 1, -1) for sys_idx, sys in enumerate(v)] for k, v in outputs.items()} + natoms = {k: [sys_v.reshape(-1, 1) for sys_v in v] for k, v in natoms.items()} + outputs = { + k: [ + sys.reshape(natoms[k][sys_idx].shape[0], 1, -1) + for sys_idx, sys in enumerate(v) + ] + for k, v in outputs.items() + } merged_output = { kk: to_numpy_array(torch.cat(outputs[kk]))