From 7f9300da7d99b69a062c0b7ea28bac2159c9e7e7 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:09:29 +0800 Subject: [PATCH] fix(pt): optimize graph memory usage (#4006) - Remove atomic virial graph. - Remove force graph during inference. After this, the lammps memory saves **50% for dpa1** (attn_layer=0) and **80% for dpa2** (layer=12). ## Summary by CodeRabbit - **New Features** - Introduced a new `inference` parameter to key model functions, enhancing flexibility for inference scenarios during model execution. - Added functionality to output a mapping array to a CSV file, improving data handling capabilities. - **Bug Fixes** - Improved the behavior of the model during inference versus training, potentially impacting downstream processing based on the output. --- deepmd/pt/entrypoints/main.py | 4 +++- deepmd/pt/model/model/make_model.py | 1 + deepmd/pt/model/model/transform_output.py | 29 +++++++++++++++++++---- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index b0edf66878..d43e9afdd2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -335,7 +335,9 @@ def train(FLAGS): def freeze(FLAGS): - model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model) + model = inference.Tester(FLAGS.model, head=FLAGS.head).model + model.eval() + model = torch.jit.script(model) extra_files = {} torch.jit.save( model, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 32432725d3..d7c75a4c6e 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -267,6 +267,7 @@ def forward_common_lower( self.atomic_output_def(), cc_ext, do_atomic_virial=do_atomic_virial, + create_graph=self.training, ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 9fde6a1589..42ea926d5c 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -33,15 +33,27 @@ def atomic_virial_corr( faked_grad = torch.ones_like(sumce0) lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) extended_virial_corr0 = torch.autograd.grad( - [sumce0], [extended_coord], grad_outputs=lst, create_graph=True + [sumce0], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, )[0] assert extended_virial_corr0 is not None extended_virial_corr1 = torch.autograd.grad( - [sumce1], [extended_coord], grad_outputs=lst, create_graph=True + [sumce1], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, )[0] assert extended_virial_corr1 is not None extended_virial_corr2 = torch.autograd.grad( - [sumce2], [extended_coord], grad_outputs=lst, create_graph=True + [sumce2], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, )[0] assert extended_virial_corr2 is not None extended_virial_corr = torch.concat( @@ -61,11 +73,16 @@ def task_deriv_one( extended_coord: torch.Tensor, do_virial: bool = True, do_atomic_virial: bool = False, + create_graph: bool = True, ): faked_grad = torch.ones_like(energy) lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) extended_force = torch.autograd.grad( - [energy], [extended_coord], grad_outputs=lst, create_graph=True + [energy], + [extended_coord], + grad_outputs=lst, + create_graph=create_graph, + retain_graph=True, )[0] assert extended_force is not None extended_force = -extended_force @@ -106,6 +123,7 @@ def take_deriv( coord_ext: torch.Tensor, do_virial: bool = False, do_atomic_virial: bool = False, + create_graph: bool = True, ): size = 1 for ii in vdef.shape: @@ -123,6 +141,7 @@ def take_deriv( coord_ext, do_virial=do_virial, do_atomic_virial=do_atomic_virial, + create_graph=create_graph, ) # nf x nloc x 1 x 3, nf x nloc x 1 x 9 ffi = ffi.unsqueeze(-2) @@ -146,6 +165,7 @@ def fit_output_to_model_output( fit_output_def: FittingOutputDef, coord_ext: torch.Tensor, do_atomic_virial: bool = False, + create_graph: bool = True, ) -> Dict[str, torch.Tensor]: """Transform the output of the fitting network to the model output. @@ -169,6 +189,7 @@ def fit_output_to_model_output( coord_ext, do_virial=vdef.c_differentiable, do_atomic_virial=do_atomic_virial, + create_graph=create_graph, ) model_ret[kk_derv_r] = dr if vdef.c_differentiable: