diff --git a/deepmd/pd/model/atomic_model/base_atomic_model.py b/deepmd/pd/model/atomic_model/base_atomic_model.py index 4f40117fb7..a66b075498 100644 --- a/deepmd/pd/model/atomic_model/base_atomic_model.py +++ b/deepmd/pd/model/atomic_model/base_atomic_model.py @@ -515,6 +515,24 @@ def change_out_bias( else: raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) + def compute_fitting_stat( + self, + sample_merged: Union[Callable[[], list[dict]], list[dict]], + ) -> None: + """Compute the input statistics (e.g. mean and stddev) for the fittings from packed data. + + Parameters + ---------- + sample_merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + """ + pass + def _get_forward_wrapper_func(self) -> Callable[..., paddle.Tensor]: """Get a forward wrapper of the atomic model for output bias calculation.""" diff --git a/deepmd/pd/model/atomic_model/dp_atomic_model.py b/deepmd/pd/model/atomic_model/dp_atomic_model.py index 816245c28a..c09abd2221 100644 --- a/deepmd/pd/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pd/model/atomic_model/dp_atomic_model.py @@ -403,6 +403,26 @@ def wrapped_sampler(): if compute_or_load_out_stat: self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + def compute_fitting_stat( + self, + sample_merged: Union[Callable[[], list[dict]], list[dict]], + ) -> None: + """Compute the input statistics (e.g. mean and stddev) for the fittings from packed data. + + Parameters + ---------- + sample_merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + """ + self.fitting_net.compute_input_stats( + sample_merged, protection=self.data_stat_protect + ) + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.fitting_net.get_dim_fparam() diff --git a/deepmd/pd/model/model/make_model.py b/deepmd/pd/model/model/make_model.py index 42c406f8d7..bc57113f4d 100644 --- a/deepmd/pd/model/model/make_model.py +++ b/deepmd/pd/model/model/make_model.py @@ -228,6 +228,8 @@ def change_out_bias( merged, bias_adjust_mode=bias_adjust_mode, ) + if bias_adjust_mode == "set-by-statistic": + self.atomic_model.compute_fitting_stat(merged) def forward_common_lower( self, diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index b8ba0a1981..6377a1c3db 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -493,6 +493,24 @@ def change_out_bias( else: raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) + def compute_fitting_stat( + self, + sample_merged: Union[Callable[[], list[dict]], list[dict]], + ) -> None: + """Compute the input statistics (e.g. mean and stddev) for the fittings from packed data. + + Parameters + ---------- + sample_merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + """ + pass + def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]: """Get a forward wrapper of the atomic model for output bias calculation.""" diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 5b7d96560f..fea7779d91 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -5,6 +5,7 @@ Any, Callable, Optional, + Union, ) import torch @@ -328,12 +329,30 @@ def wrapped_sampler() -> list[dict]: return sampled self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) - self.fitting_net.compute_input_stats( - wrapped_sampler, protection=self.data_stat_protect - ) + self.compute_fitting_stat(wrapped_sampler) if compute_or_load_out_stat: self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + def compute_fitting_stat( + self, + sample_merged: Union[Callable[[], list[dict]], list[dict]], + ) -> None: + """Compute the input statistics (e.g. mean and stddev) for the fittings from packed data. + + Parameters + ---------- + sample_merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + """ + self.fitting_net.compute_input_stats( + sample_merged, protection=self.data_stat_protect + ) + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.fitting_net.get_dim_fparam() diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 53d32977b0..e18f5e90bf 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -232,6 +232,8 @@ def change_out_bias( merged, bias_adjust_mode=bias_adjust_mode, ) + if bias_adjust_mode == "set-by-statistic": + self.atomic_model.compute_fitting_stat(merged) def forward_common_lower( self, diff --git a/source/tests/pd/test_training.py b/source/tests/pd/test_training.py index 0dc36fa314..f3d7860881 100644 --- a/source/tests/pd/test_training.py +++ b/source/tests/pd/test_training.py @@ -89,7 +89,11 @@ def test_dp_train(self) -> None: state_dict_trained[state_key].numpy(), state_dict_finetuned_empty[state_key].numpy(), ) - if "fitting_net" not in state_key: + if ( + ("fitting_net" not in state_key) + or ("fparam" in state_key) + or ("aparam" in state_key) + ): np.testing.assert_allclose( state_dict_trained[state_key].numpy(), state_dict_finetuned_random[state_key].numpy(), @@ -190,6 +194,7 @@ def setUp(self) -> None: self.config["training"]["save_freq"] = 1 self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000" shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy") + self.config["model"]["data_stat_nbatch"] = 100 def tearDown(self) -> None: (self.set_path / "fparam.npy").unlink(missing_ok=True) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index da239212b0..ff4f00f912 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -92,7 +92,11 @@ def test_dp_train(self) -> None: state_dict_trained[state_key], state_dict_finetuned_empty[state_key], ) - if "fitting_net" not in state_key: + if ( + ("fitting_net" not in state_key) + or ("fparam" in state_key) + or ("aparam" in state_key) + ): torch.testing.assert_close( state_dict_trained[state_key], state_dict_finetuned_random[state_key], @@ -256,6 +260,7 @@ def setUp(self) -> None: self.config["training"]["save_freq"] = 1 self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000" shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy") + self.config["model"]["data_stat_nbatch"] = 100 def tearDown(self) -> None: (self.set_path / "fparam.npy").unlink(missing_ok=True)