Skip to content

Commit

Permalink
fix(pt): optimize graph memory usage (#4006)
Browse files Browse the repository at this point in the history
- Remove atomic virial graph.
- Remove force graph during inference.

After this, the lammps memory saves **50% for dpa1** (attn_layer=0) and
**80% for dpa2** (layer=12).

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new `inference` parameter to key model functions,
enhancing flexibility for inference scenarios during model execution.
- Added functionality to output a mapping array to a CSV file, improving
data handling capabilities.

- **Bug Fixes**
- Improved the behavior of the model during inference versus training,
potentially impacting downstream processing based on the output.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Jul 24, 2024
1 parent 269ed3e commit 7f9300d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
4 changes: 3 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ def train(FLAGS):


def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
model = inference.Tester(FLAGS.model, head=FLAGS.head).model
model.eval()
model = torch.jit.script(model)
extra_files = {}
torch.jit.save(
model,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def forward_common_lower(
self.atomic_output_def(),
cc_ext,
do_atomic_virial=do_atomic_virial,
create_graph=self.training,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict
Expand Down
29 changes: 25 additions & 4 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,27 @@ def atomic_virial_corr(
faked_grad = torch.ones_like(sumce0)
lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad])
extended_virial_corr0 = torch.autograd.grad(
[sumce0], [extended_coord], grad_outputs=lst, create_graph=True
[sumce0],
[extended_coord],
grad_outputs=lst,
create_graph=False,
retain_graph=True,
)[0]
assert extended_virial_corr0 is not None
extended_virial_corr1 = torch.autograd.grad(
[sumce1], [extended_coord], grad_outputs=lst, create_graph=True
[sumce1],
[extended_coord],
grad_outputs=lst,
create_graph=False,
retain_graph=True,
)[0]
assert extended_virial_corr1 is not None
extended_virial_corr2 = torch.autograd.grad(
[sumce2], [extended_coord], grad_outputs=lst, create_graph=True
[sumce2],
[extended_coord],
grad_outputs=lst,
create_graph=False,
retain_graph=True,
)[0]
assert extended_virial_corr2 is not None
extended_virial_corr = torch.concat(
Expand All @@ -61,11 +73,16 @@ def task_deriv_one(
extended_coord: torch.Tensor,
do_virial: bool = True,
do_atomic_virial: bool = False,
create_graph: bool = True,
):
faked_grad = torch.ones_like(energy)
lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad])
extended_force = torch.autograd.grad(
[energy], [extended_coord], grad_outputs=lst, create_graph=True
[energy],
[extended_coord],
grad_outputs=lst,
create_graph=create_graph,
retain_graph=True,
)[0]
assert extended_force is not None
extended_force = -extended_force
Expand Down Expand Up @@ -106,6 +123,7 @@ def take_deriv(
coord_ext: torch.Tensor,
do_virial: bool = False,
do_atomic_virial: bool = False,
create_graph: bool = True,
):
size = 1
for ii in vdef.shape:
Expand All @@ -123,6 +141,7 @@ def take_deriv(
coord_ext,
do_virial=do_virial,
do_atomic_virial=do_atomic_virial,
create_graph=create_graph,
)
# nf x nloc x 1 x 3, nf x nloc x 1 x 9
ffi = ffi.unsqueeze(-2)
Expand All @@ -146,6 +165,7 @@ def fit_output_to_model_output(
fit_output_def: FittingOutputDef,
coord_ext: torch.Tensor,
do_atomic_virial: bool = False,
create_graph: bool = True,
) -> Dict[str, torch.Tensor]:
"""Transform the output of the fitting network to
the model output.
Expand All @@ -169,6 +189,7 @@ def fit_output_to_model_output(
coord_ext,
do_virial=vdef.c_differentiable,
do_atomic_virial=do_atomic_virial,
create_graph=create_graph,
)
model_ret[kk_derv_r] = dr
if vdef.c_differentiable:
Expand Down

0 comments on commit 7f9300d

Please sign in to comment.