From c07a56f67dbe3610b1d440cbd5118b381b5ae47a Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Jul 2024 16:28:49 +0800 Subject: [PATCH] fix(2024Q1): optimize graph memory (copy #4006) --- deepmd/pt/entrypoints/main.py | 4 +++- deepmd/pt/model/model/make_model.py | 1 + deepmd/pt/model/model/transform_output.py | 13 +++++++++---- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 7b1463a3b2..99d6aab97d 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -282,7 +282,9 @@ def train(FLAGS): def freeze(FLAGS): - model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model) + model = inference.Tester(FLAGS.model, head=FLAGS.head).model + model.eval() + model = torch.jit.script(model) torch.jit.save( model, FLAGS.output, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 0e89c05b79..595e5fba74 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -265,6 +265,7 @@ def forward_common_lower( self.atomic_output_def(), cc_ext, do_atomic_virial=do_atomic_virial, + create_graph=self.training, ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 730e6b29d0..e8a99fb5fc 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -33,15 +33,15 @@ def atomic_virial_corr( faked_grad = torch.ones_like(sumce0) lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) extended_virial_corr0 = torch.autograd.grad( - [sumce0], [extended_coord], grad_outputs=lst, create_graph=True + [sumce0], [extended_coord], grad_outputs=lst, create_graph=False, retain_graph=True, )[0] assert extended_virial_corr0 is not None extended_virial_corr1 = torch.autograd.grad( - [sumce1], [extended_coord], grad_outputs=lst, create_graph=True + [sumce1], [extended_coord], grad_outputs=lst, create_graph=False, retain_graph=True, )[0] assert extended_virial_corr1 is not None extended_virial_corr2 = torch.autograd.grad( - [sumce2], [extended_coord], grad_outputs=lst, create_graph=True + [sumce2], [extended_coord], grad_outputs=lst, create_graph=False, retain_graph=True, )[0] assert extended_virial_corr2 is not None extended_virial_corr = torch.concat( @@ -61,11 +61,12 @@ def task_deriv_one( extended_coord: torch.Tensor, do_virial: bool = True, do_atomic_virial: bool = False, + create_graph: bool = True, ): faked_grad = torch.ones_like(energy) lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) extended_force = torch.autograd.grad( - [energy], [extended_coord], grad_outputs=lst, create_graph=True + [energy], [extended_coord], grad_outputs=lst, create_graph=create_graph, retain_graph=True, )[0] assert extended_force is not None extended_force = -extended_force @@ -106,6 +107,7 @@ def take_deriv( coord_ext: torch.Tensor, do_virial: bool = False, do_atomic_virial: bool = False, + create_graph: bool = True, ): size = 1 for ii in vdef.shape: @@ -123,6 +125,7 @@ def take_deriv( coord_ext, do_virial=do_virial, do_atomic_virial=do_atomic_virial, + create_graph=create_graph, ) # nf x nloc x 1 x 3, nf x nloc x 1 x 9 ffi = ffi.unsqueeze(-2) @@ -146,6 +149,7 @@ def fit_output_to_model_output( fit_output_def: FittingOutputDef, coord_ext: torch.Tensor, do_atomic_virial: bool = False, + create_graph: bool = True, ) -> Dict[str, torch.Tensor]: """Transform the output of the fitting network to the model output. @@ -169,6 +173,7 @@ def fit_output_to_model_output( coord_ext, do_virial=vdef.c_differentiable, do_atomic_virial=do_atomic_virial, + create_graph=create_graph, ) model_ret[kk_derv_r] = dr if vdef.c_differentiable: