Skip to content
18 changes: 18 additions & 0 deletions deepmd/pd/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,24 @@ def change_out_bias(
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)

def compute_fitting_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
pass

def _get_forward_wrapper_func(self) -> Callable[..., paddle.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

Expand Down
20 changes: 20 additions & 0 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,26 @@ def wrapped_sampler():
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def compute_fitting_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
self.fitting_net.compute_input_stats(
sample_merged, protection=self.data_stat_protect
)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.fitting_net.get_dim_fparam()
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pd/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def change_out_bias(
merged,
bias_adjust_mode=bias_adjust_mode,
)
if bias_adjust_mode == "set-by-statistic":
self.atomic_model.compute_fitting_stat(merged)

def forward_common_lower(
self,
Expand Down
18 changes: 18 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,24 @@ def change_out_bias(
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)

def compute_fitting_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
pass

def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

Expand Down
25 changes: 22 additions & 3 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Any,
Callable,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -328,12 +329,30 @@ def wrapped_sampler() -> list[dict]:
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
)
self.compute_fitting_stat(wrapped_sampler)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def compute_fitting_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
self.fitting_net.compute_input_stats(
sample_merged, protection=self.data_stat_protect
)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.fitting_net.get_dim_fparam()
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def change_out_bias(
merged,
bias_adjust_mode=bias_adjust_mode,
)
if bias_adjust_mode == "set-by-statistic":
self.atomic_model.compute_fitting_stat(merged)

def forward_common_lower(
self,
Expand Down
7 changes: 6 additions & 1 deletion source/tests/pd/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def test_dp_train(self) -> None:
state_dict_trained[state_key].numpy(),
state_dict_finetuned_empty[state_key].numpy(),
)
if "fitting_net" not in state_key:
if (
("fitting_net" not in state_key)
or ("fparam" in state_key)
or ("aparam" in state_key)
):
np.testing.assert_allclose(
state_dict_trained[state_key].numpy(),
state_dict_finetuned_random[state_key].numpy(),
Expand Down Expand Up @@ -190,6 +194,7 @@ def setUp(self) -> None:
self.config["training"]["save_freq"] = 1
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")
self.config["model"]["data_stat_nbatch"] = 100

def tearDown(self) -> None:
(self.set_path / "fparam.npy").unlink(missing_ok=True)
Expand Down
7 changes: 6 additions & 1 deletion source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def test_dp_train(self) -> None:
state_dict_trained[state_key],
state_dict_finetuned_empty[state_key],
)
if "fitting_net" not in state_key:
if (
("fitting_net" not in state_key)
or ("fparam" in state_key)
or ("aparam" in state_key)
):
torch.testing.assert_close(
state_dict_trained[state_key],
state_dict_finetuned_random[state_key],
Expand Down Expand Up @@ -256,6 +260,7 @@ def setUp(self) -> None:
self.config["training"]["save_freq"] = 1
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")
self.config["model"]["data_stat_nbatch"] = 100

def tearDown(self) -> None:
(self.set_path / "fparam.npy").unlink(missing_ok=True)
Expand Down