diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 175167edcb..1c5e3f1c52 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -548,8 +548,14 @@ def compute_output_stats_atomic( } # reshape outputs [nframes, nloc * ndim] --> reshape to [nframes * nloc, 1, ndim] for concatenation # reshape natoms [nframes, nloc] --> reshape to [nframes * nolc, 1] for concatenation - natoms = {k: [sys_v.reshape(-1,1) for sys_v in v] for k, v in natoms.items()} - outputs = {k: [sys.reshape(natoms[k][sys_idx].shape[0], 1, -1) for sys_idx, sys in enumerate(v)] for k, v in outputs.items()} + natoms = {k: [sys_v.reshape(-1, 1) for sys_v in v] for k, v in natoms.items()} + outputs = { + k: [ + sys.reshape(natoms[k][sys_idx].shape[0], 1, -1) + for sys_idx, sys in enumerate(v) + ] + for k, v in outputs.items() + } merged_output = { kk: to_numpy_array(torch.cat(outputs[kk]))