Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact: compute_output_stats and change_out_bias #3639

Merged
merged 4 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
"""
return self.fitting_output_def()

def get_output_keys(self) -> List[str]:
return list(self.atomic_output_def().keys())

Check warning on line 55 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L55

Added line #L55 was not covered by tests

@abstractmethod
def get_rcut(self) -> float:
"""Get the cut-off radius."""
Expand Down
123 changes: 50 additions & 73 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
List,
Optional,
Tuple,
Union,
)

import numpy as np
import torch

from deepmd.dpmodel.atomic_model import (
Expand All @@ -30,9 +30,6 @@
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -190,115 +187,95 @@
"pair_exclude_types": self.pair_exclude_types,
}

def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward

def compute_or_load_stat(
self,
sampled_func,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
sampled_func
The sampled data frames from different data systems.
stat_file_path
The path to the statistics files.
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
raise NotImplementedError

def change_out_bias(
self,
merged,
origin_type_map,
full_type_map,
sample_merged,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias according to the input data and the pretrained model.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
sample_merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the output bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
and do least square on the errors to obtain the target shift as bias.
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
sorter = np.argsort(full_type_map)
missing_types = [t for t in origin_type_map if t not in full_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
]
original_bias = self.get_out_bias()
if bias_adjust_mode == "change-by-statistic":
delta_bias = compute_output_stats(
merged,
sample_merged,
self.get_ntypes(),
keys=["energy"],
model_forward=self.get_forward_wrapper_func(),
keys=self.get_output_keys(),
model_forward=self._get_forward_wrapper_func(),
)["energy"]
self.set_out_bias(delta_bias, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
merged,
sample_merged,
self.get_ntypes(),
keys=["energy"],
keys=self.get_output_keys(),
)["energy"]
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
bias_atom = self.get_out_bias()
log.info(
f"Change output bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom[idx_type_map]).reshape(-1)!s}."
)

def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:

Check warning on line 253 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L253

Added line #L253 was not covered by tests
"""Get a forward wrapper of the atomic model for output bias calculation."""

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(

Check warning on line 258 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L256-L258

Added lines #L256 - L258 were not covered by tests
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(

Check warning on line 271 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L271

Added line #L271 was not covered by tests
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

Check warning on line 279 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L279

Added line #L279 was not covered by tests

return model_forward

Check warning on line 281 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L281

Added line #L281 was not covered by tests
11 changes: 3 additions & 8 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,12 @@
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()

Check warning on line 176 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L175-L176

Added lines #L175 - L176 were not covered by tests

def change_out_bias(
self,
merged,
origin_type_map,
full_type_map,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias of atomic model according to the input data and the pretrained model.
Expand All @@ -190,10 +191,6 @@
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the output bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
Expand All @@ -202,8 +199,6 @@
"""
self.atomic_model.change_out_bias(
merged,
origin_type_map,
full_type_map,
bias_adjust_mode=bias_adjust_mode,
)

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def compute_output_stats(
bias_atom_e = compute_output_stats(
merged,
self.ntypes,
keys=["energy"],
keys=[self.var_name],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
)[self.var_name]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
Expand Down
42 changes: 33 additions & 9 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,21 +570,17 @@
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
_model = _model_change_out_bias(

Check warning on line 573 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L573

Added line #L573 was not covered by tests
Fixed Show fixed Hide fixed
_model, new_type_map, _sample_func, _model_params
)
else:
# need to updated
pass
return _model

Check warning on line 579 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L579

Added line #L579 was not covered by tests

# finetune
if not self.multi_task:
single_model_finetune(
self.model = single_model_finetune(

Check warning on line 583 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L583

Added line #L583 was not covered by tests
self.model, model_params, self.get_sample_func
)
else:
Expand All @@ -593,7 +589,7 @@
log.info(
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
single_model_finetune(
self.model[model_key] = single_model_finetune(

Check warning on line 592 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L592

Added line #L592 was not covered by tests
self.model[model_key],
model_params["model_dict"][model_key],
self.get_sample_func[model_key],
Expand Down Expand Up @@ -1148,3 +1144,31 @@
print_str += " %8.1e\n" % cur_lr
fout.write(print_str)
fout.flush()


def _model_change_out_bias(

Check warning on line 1149 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1149

Added line #L1149 was not covered by tests
_model,
new_type_map,
_sample_func,
_model_params,
):
old_bias = _model.get_out_bias()
_model.change_out_bias(

Check warning on line 1156 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1155-L1156

Added lines #L1155 - L1156 were not covered by tests
_sample_func,
bias_adjust_mode=_model_params.get("bias_adjust_mode", "change-by-statistic"),
)
new_bias = _model.get_out_bias()

Check warning on line 1160 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1160

Added line #L1160 was not covered by tests

model_type_map = _model.get_type_map()
sorter = np.argsort(model_type_map)
missing_types = [t for t in new_type_map if t not in model_type_map]
assert (

Check warning on line 1165 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1162-L1165

Added lines #L1162 - L1165 were not covered by tests
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[np.searchsorted(model_type_map, new_type_map, sorter=sorter)]
log.info(

Check warning on line 1169 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1168-L1169

Added lines #L1168 - L1169 were not covered by tests
f"Change output bias of {new_type_map!s} "
f"from {to_numpy_array(old_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(new_bias[idx_type_map]).reshape(-1)!s}."
)
return _model

Check warning on line 1174 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1174

Added line #L1174 was not covered by tests
Loading