File tree Expand file tree Collapse file tree 3 files changed +23
-3
lines changed Expand file tree Collapse file tree 3 files changed +23
-3
lines changed Original file line number Diff line number Diff line change @@ -487,6 +487,26 @@ def change_out_bias(
487487 else :
488488 raise RuntimeError ("Unknown bias_adjust_mode mode: " + bias_adjust_mode )
489489
490+ def compute_fitting_stat (
491+ self ,
492+ sample_merged
493+ ) -> None :
494+ """Compute the input statistics (e.g. mean and stddev) for the fittings from packed data..
495+
496+ Parameters
497+ ----------
498+ sample_merged : Union[Callable[[], list[dict]], list[dict]]
499+ - list[dict]: A list of data samples from various data systems.
500+ Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
501+ originating from the `i`-th data system.
502+ - Callable[[], list[dict]]: A lazy function that returns data samples in the above format
503+ only when needed. Since the sampling process can be slow and memory-intensive,
504+ the lazy function helps by only sampling once.
505+ """
506+ self .fitting_net .compute_input_stats (
507+ sample_merged , protection = self .data_stat_protect
508+ )
509+
490510 def _get_forward_wrapper_func (self ) -> Callable [..., torch .Tensor ]:
491511 """Get a forward wrapper of the atomic model for output bias calculation."""
492512
Original file line number Diff line number Diff line change @@ -324,9 +324,7 @@ def wrapped_sampler():
324324 return sampled
325325
326326 self .descriptor .compute_input_stats (wrapped_sampler , stat_file_path )
327- self .fitting_net .compute_input_stats (
328- wrapped_sampler , protection = self .data_stat_protect
329- )
327+ self .compute_fitting_stat (wrapped_sampler )
330328 if compute_or_load_out_stat :
331329 self .compute_or_load_out_stat (wrapped_sampler , stat_file_path )
332330
Original file line number Diff line number Diff line change @@ -230,6 +230,8 @@ def change_out_bias(
230230 merged ,
231231 bias_adjust_mode = bias_adjust_mode ,
232232 )
233+ if bias_adjust_mode == "set-by-statistic" :
234+ self .atomic_model .compute_fitting_stat (merged )
233235
234236 def forward_common_lower (
235237 self ,
You can’t perform that action at this time.
0 commit comments