From e4c6f8ff8ac936f1f8ea1e5d7ae7678a2c4cc9c9 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Jul 2024 15:39:58 +0800 Subject: [PATCH 1/3] fix(pt): optimize graph memory usage --- deepmd/pt/model/model/dipole_model.py | 1 + deepmd/pt/model/model/dos_model.py | 1 + deepmd/pt/model/model/dp_zbl_model.py | 1 + deepmd/pt/model/model/ener_model.py | 1 + deepmd/pt/model/model/make_model.py | 4 ++++ deepmd/pt/model/model/polar_model.py | 1 + deepmd/pt/model/model/spin_model.py | 3 +++ deepmd/pt/model/model/transform_output.py | 29 +++++++++++++++++++---- 8 files changed, 37 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 2d732c2800..703549d48e 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -111,6 +111,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + inference=True, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py index 0dd6af7b80..496f008125 100644 --- a/deepmd/pt/model/model/dos_model.py +++ b/deepmd/pt/model/model/dos_model.py @@ -101,6 +101,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + inference=True, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index 4d7c16eb7d..dd07f6cc55 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -112,6 +112,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + inference=True, ) model_predict = {} diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index fb16478bc0..5b8af3bdf8 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -115,6 +115,7 @@ def forward_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, comm_dict=comm_dict, + inference=True, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 32432725d3..641168acdb 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -215,6 +215,7 @@ def forward_common_lower( aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, comm_dict: Optional[Dict[str, torch.Tensor]] = None, + inference: bool = False, ): """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -239,6 +240,8 @@ def forward_common_lower( whether calculate atomic virial. comm_dict The data needed for communication for parallel inference. + inference + Whether only perform inference rather than undergoing training. Returns ------- @@ -267,6 +270,7 @@ def forward_common_lower( self.atomic_output_def(), cc_ext, do_atomic_virial=do_atomic_virial, + inference=inference, ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index 449fdbe700..b31dbf906f 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -95,6 +95,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + inference=True, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index 72e6797ea8..033cd24b11 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -459,6 +459,7 @@ def forward_common_lower( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + inference: bool = False, ): nframes, nloc = nlist.shape[:2] ( @@ -479,6 +480,7 @@ def forward_common_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + inference=inference, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: @@ -603,6 +605,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + inference=True, ) model_predict = {} model_predict["atom_energy"] = model_ret["energy"] diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 9fde6a1589..56b0ba7850 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -33,15 +33,27 @@ def atomic_virial_corr( faked_grad = torch.ones_like(sumce0) lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) extended_virial_corr0 = torch.autograd.grad( - [sumce0], [extended_coord], grad_outputs=lst, create_graph=True + [sumce0], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, )[0] assert extended_virial_corr0 is not None extended_virial_corr1 = torch.autograd.grad( - [sumce1], [extended_coord], grad_outputs=lst, create_graph=True + [sumce1], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, )[0] assert extended_virial_corr1 is not None extended_virial_corr2 = torch.autograd.grad( - [sumce2], [extended_coord], grad_outputs=lst, create_graph=True + [sumce2], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, )[0] assert extended_virial_corr2 is not None extended_virial_corr = torch.concat( @@ -61,11 +73,16 @@ def task_deriv_one( extended_coord: torch.Tensor, do_virial: bool = True, do_atomic_virial: bool = False, + inference: bool = False, ): faked_grad = torch.ones_like(energy) lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) extended_force = torch.autograd.grad( - [energy], [extended_coord], grad_outputs=lst, create_graph=True + [energy], + [extended_coord], + grad_outputs=lst, + create_graph=not inference, + retain_graph=True, )[0] assert extended_force is not None extended_force = -extended_force @@ -106,6 +123,7 @@ def take_deriv( coord_ext: torch.Tensor, do_virial: bool = False, do_atomic_virial: bool = False, + inference: bool = False, ): size = 1 for ii in vdef.shape: @@ -123,6 +141,7 @@ def take_deriv( coord_ext, do_virial=do_virial, do_atomic_virial=do_atomic_virial, + inference=inference, ) # nf x nloc x 1 x 3, nf x nloc x 1 x 9 ffi = ffi.unsqueeze(-2) @@ -146,6 +165,7 @@ def fit_output_to_model_output( fit_output_def: FittingOutputDef, coord_ext: torch.Tensor, do_atomic_virial: bool = False, + inference: bool = False, ) -> Dict[str, torch.Tensor]: """Transform the output of the fitting network to the model output. @@ -169,6 +189,7 @@ def fit_output_to_model_output( coord_ext, do_virial=vdef.c_differentiable, do_atomic_virial=do_atomic_virial, + inference=inference, ) model_ret[kk_derv_r] = dr if vdef.c_differentiable: From 994258314b97dd1e7c237b1b78070f13891ee513 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:27:03 +0800 Subject: [PATCH 2/3] use self.training --- deepmd/pt/entrypoints/main.py | 4 +++- deepmd/pt/model/model/dipole_model.py | 1 - deepmd/pt/model/model/dos_model.py | 1 - deepmd/pt/model/model/dp_zbl_model.py | 1 - deepmd/pt/model/model/ener_model.py | 1 - deepmd/pt/model/model/make_model.py | 5 +---- deepmd/pt/model/model/polar_model.py | 1 - deepmd/pt/model/model/spin_model.py | 3 --- deepmd/pt/model/model/transform_output.py | 12 ++++++------ source/lmp/pair_deepmd.cpp | 18 ++++++++++++++++++ 10 files changed, 28 insertions(+), 19 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index b0edf66878..d43e9afdd2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -335,7 +335,9 @@ def train(FLAGS): def freeze(FLAGS): - model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model) + model = inference.Tester(FLAGS.model, head=FLAGS.head).model + model.eval() + model = torch.jit.script(model) extra_files = {} torch.jit.save( model, diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 703549d48e..2d732c2800 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -111,7 +111,6 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, - inference=True, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py index 496f008125..0dd6af7b80 100644 --- a/deepmd/pt/model/model/dos_model.py +++ b/deepmd/pt/model/model/dos_model.py @@ -101,7 +101,6 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, - inference=True, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index dd07f6cc55..4d7c16eb7d 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -112,7 +112,6 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, - inference=True, ) model_predict = {} diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 5b8af3bdf8..fb16478bc0 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -115,7 +115,6 @@ def forward_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, comm_dict=comm_dict, - inference=True, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 641168acdb..d7c75a4c6e 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -215,7 +215,6 @@ def forward_common_lower( aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, comm_dict: Optional[Dict[str, torch.Tensor]] = None, - inference: bool = False, ): """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -240,8 +239,6 @@ def forward_common_lower( whether calculate atomic virial. comm_dict The data needed for communication for parallel inference. - inference - Whether only perform inference rather than undergoing training. Returns ------- @@ -270,7 +267,7 @@ def forward_common_lower( self.atomic_output_def(), cc_ext, do_atomic_virial=do_atomic_virial, - inference=inference, + create_graph=self.training, ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index b31dbf906f..449fdbe700 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -95,7 +95,6 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, - inference=True, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index 74d0a25a79..551c0b86b2 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -467,7 +467,6 @@ def forward_common_lower( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, - inference: bool = False, ): nframes, nloc = nlist.shape[:2] ( @@ -488,7 +487,6 @@ def forward_common_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, - inference=inference, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: @@ -613,7 +611,6 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, - inference=True, ) model_predict = {} model_predict["atom_energy"] = model_ret["energy"] diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 56b0ba7850..42ea926d5c 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -73,7 +73,7 @@ def task_deriv_one( extended_coord: torch.Tensor, do_virial: bool = True, do_atomic_virial: bool = False, - inference: bool = False, + create_graph: bool = True, ): faked_grad = torch.ones_like(energy) lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) @@ -81,7 +81,7 @@ def task_deriv_one( [energy], [extended_coord], grad_outputs=lst, - create_graph=not inference, + create_graph=create_graph, retain_graph=True, )[0] assert extended_force is not None @@ -123,7 +123,7 @@ def take_deriv( coord_ext: torch.Tensor, do_virial: bool = False, do_atomic_virial: bool = False, - inference: bool = False, + create_graph: bool = True, ): size = 1 for ii in vdef.shape: @@ -141,7 +141,7 @@ def take_deriv( coord_ext, do_virial=do_virial, do_atomic_virial=do_atomic_virial, - inference=inference, + create_graph=create_graph, ) # nf x nloc x 1 x 3, nf x nloc x 1 x 9 ffi = ffi.unsqueeze(-2) @@ -165,7 +165,7 @@ def fit_output_to_model_output( fit_output_def: FittingOutputDef, coord_ext: torch.Tensor, do_atomic_virial: bool = False, - inference: bool = False, + create_graph: bool = True, ) -> Dict[str, torch.Tensor]: """Transform the output of the fitting network to the model output. @@ -189,7 +189,7 @@ def fit_output_to_model_output( coord_ext, do_virial=vdef.c_differentiable, do_atomic_virial=do_atomic_virial, - inference=inference, + create_graph=create_graph, ) model_ret[kk_derv_r] = dr if vdef.c_differentiable: diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index d4fbdd3363..8b31455750 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -473,6 +473,7 @@ void PairDeepMD::compute(int eflag, int vflag) { double **x = atom->x; double **f = atom->f; int *type = atom->type; + int *tag_array = atom->tag; int nlocal = atom->nlocal; int nghost = 0; if (do_ghost) { @@ -494,6 +495,21 @@ void PairDeepMD::compute(int eflag, int vflag) { } } } + // make mapping array + int *mapping = new int[nall]; + for (int i = 0; i < nall; ++i) { + mapping[i] = atom->map(tag_array[i]); + } + // write mapping + std::cout << "mapping.size(): " << nall << std::endl; + std::ofstream outfile_mapping("mapping_data.csv"); + if (!outfile_mapping.is_open()) { + std::cerr << "Error opening file for writing" << std::endl; + } + for (size_t i = 0; i < nall; i += 1) { + outfile_mapping << mapping[i] << "\n"; + } + outfile_mapping.close(); vector dtype(nall); for (int ii = 0; ii < nall; ++ii) { @@ -1293,6 +1309,8 @@ void PairDeepMD::coeff(int narg, char **arg) { void PairDeepMD::init_style() { #if LAMMPS_VERSION_NUMBER >= 20220324 neighbor->add_request(this, NeighConst::REQ_FULL); + atom->map_user = 2; + atom->map_init(1); #else int irequest = neighbor->request(this, instance_me); neighbor->requests[irequest]->half = 0; From 0f44a60692656948c2ce1c83ccd890f4e6c1628b Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:31:21 +0800 Subject: [PATCH 3/3] Update pair_deepmd.cpp --- source/lmp/pair_deepmd.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index 8b31455750..d4fbdd3363 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -473,7 +473,6 @@ void PairDeepMD::compute(int eflag, int vflag) { double **x = atom->x; double **f = atom->f; int *type = atom->type; - int *tag_array = atom->tag; int nlocal = atom->nlocal; int nghost = 0; if (do_ghost) { @@ -495,21 +494,6 @@ void PairDeepMD::compute(int eflag, int vflag) { } } } - // make mapping array - int *mapping = new int[nall]; - for (int i = 0; i < nall; ++i) { - mapping[i] = atom->map(tag_array[i]); - } - // write mapping - std::cout << "mapping.size(): " << nall << std::endl; - std::ofstream outfile_mapping("mapping_data.csv"); - if (!outfile_mapping.is_open()) { - std::cerr << "Error opening file for writing" << std::endl; - } - for (size_t i = 0; i < nall; i += 1) { - outfile_mapping << mapping[i] << "\n"; - } - outfile_mapping.close(); vector dtype(nall); for (int ii = 0; ii < nall; ++ii) { @@ -1309,8 +1293,6 @@ void PairDeepMD::coeff(int narg, char **arg) { void PairDeepMD::init_style() { #if LAMMPS_VERSION_NUMBER >= 20220324 neighbor->add_request(this, NeighConst::REQ_FULL); - atom->map_user = 2; - atom->map_init(1); #else int irequest = neighbor->request(this, instance_me); neighbor->requests[irequest]->half = 0;