Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dp&pt: let DPAtomicModel fetch attributes from Fitting #3292

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return 0
return self.fitting.get_dim_fparam()

Check warning on line 149 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L149

Added line #L149 was not covered by tests

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return 0
return self.fitting.get_dim_aparam()

Check warning on line 153 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L153

Added line #L153 was not covered by tests

def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Expand All @@ -159,7 +159,7 @@
to the result of the model.
If returning an empty list, all atom types are selected.
"""
return []
return self.fitting.get_sel_type()

Check warning on line 162 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L162

Added line #L162 was not covered by tests

def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
Expand Down
9 changes: 3 additions & 6 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,12 @@ def compute_or_load_stat(
@torch.jit.export
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
# TODO: self.fitting_net.get_dim_fparam()
return 0
return self.fitting_net.get_dim_fparam()

@torch.jit.export
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
# TODO: self.fitting_net.get_dim_aparam()
return 0
return self.fitting_net.get_dim_aparam()

@torch.jit.export
def get_sel_type(self) -> List[int]:
Expand All @@ -211,8 +209,7 @@ def get_sel_type(self) -> List[int]:
to the result of the model.
If returning an empty list, all atom types are selected.
"""
# TODO: self.fitting_net.get_sel_type()
return []
return self.fitting_net.get_sel_type()

@torch.jit.export
def is_aparam_nall(self) -> bool:
Expand Down