Skip to content

Commit 0f91fe0

Browse files
Init branch
1 parent 3f0087b commit 0f91fe0

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,26 @@ def change_out_bias(
487487
else:
488488
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
489489

490+
def compute_fitting_stat(
491+
self,
492+
sample_merged
493+
) -> None:
494+
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data..
495+
496+
Parameters
497+
----------
498+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
499+
- list[dict]: A list of data samples from various data systems.
500+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
501+
originating from the `i`-th data system.
502+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
503+
only when needed. Since the sampling process can be slow and memory-intensive,
504+
the lazy function helps by only sampling once.
505+
"""
506+
self.fitting_net.compute_input_stats(
507+
sample_merged, protection=self.data_stat_protect
508+
)
509+
490510
def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
491511
"""Get a forward wrapper of the atomic model for output bias calculation."""
492512

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,7 @@ def wrapped_sampler():
324324
return sampled
325325

326326
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
327-
self.fitting_net.compute_input_stats(
328-
wrapped_sampler, protection=self.data_stat_protect
329-
)
327+
self.compute_fitting_stat(wrapped_sampler)
330328
if compute_or_load_out_stat:
331329
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
332330

deepmd/pt/model/model/make_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def change_out_bias(
230230
merged,
231231
bias_adjust_mode=bias_adjust_mode,
232232
)
233+
if bias_adjust_mode == "set-by-statistic":
234+
self.atomic_model.compute_fitting_stat(merged)
233235

234236
def forward_common_lower(
235237
self,

0 commit comments

Comments
 (0)