Skip to content

Commit

Permalink
feat: allow model arguments to be registered outside (deepmodeling#3995)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced new plugins for handling model arguments, including
`standard`, `frozen`, `pairtab`, `pairwise_dprc`, and `linear_ener`
models.
  - Added support for hybrid models that require another model as input.

- **Improvements**
- Enhanced argument-checking mechanisms to accommodate complex return
types and nested structures.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 1332d75 commit 5183372
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
List,
Optional,
Union,
)

from dargs import (
Expand Down Expand Up @@ -165,7 +166,10 @@ def __init__(self) -> None:

def register(
self, name: str, alias: Optional[List[str]] = None, doc: str = ""
) -> Callable[[], List[Argument]]:
) -> Callable[
[Union[Callable[[], Argument], Callable[[], List[Argument]]]],
Union[Callable[[], Argument], Callable[[], List[Argument]]],
]:
"""Register a descriptor argument plugin.
Parameters
Expand All @@ -177,8 +181,8 @@ def register(
Returns
-------
Callable[[], List[Argument]]
the registered descriptor argument method
Callable[[Union[Callable[[], Argument], Callable[[], List[Argument]]]], Union[Callable[[], Argument], Callable[[], List[Argument]]]]
decorator to return the registered descriptor argument method
Examples
--------
Expand Down Expand Up @@ -209,9 +213,17 @@ def get_all_argument(self, exclude_hybrid: bool = False) -> List[Argument]:
for (name, alias, doc), metd in self.__plugin.plugins.items():
if exclude_hybrid and name == "hybrid":
continue
arguments.append(
Argument(name=name, dtype=dict, sub_fields=metd(), alias=alias, doc=doc)
)
args = metd()
if isinstance(args, Argument):
arguments.append(args)
elif isinstance(args, list):
arguments.append(
Argument(
name=name, dtype=dict, sub_fields=metd(), alias=alias, doc=doc
)
)
else:
raise ValueError(f"Invalid return type {type(args)}")
return arguments


Expand Down Expand Up @@ -1517,6 +1529,11 @@ def model_compression_type_args():
)


model_args_plugin = ArgsPlugin()
# for models that require another model as input
hybrid_model_args_plugin = ArgsPlugin()


def model_args(exclude_hybrid=False):
doc_type_map = "A list of strings. Give the name to each type of atoms. It is noted that the number of atom type of training system must be less than 128 in a GPU environment. If not given, type.raw in each system should use the same type indexes, and type_map.raw will take no effect."
doc_data_stat_nbatch = "The model determines the normalization from the statistics of the data. This key specifies the number of `frames` in each `system` used for statistics."
Expand All @@ -1540,12 +1557,7 @@ def model_args(exclude_hybrid=False):

hybrid_models = []
if not exclude_hybrid:
hybrid_models.extend(
[
pairwise_dprc(),
linear_ener_model_args(),
]
)
hybrid_models.extend(hybrid_model_args_plugin.get_all_argument())
return Argument(
"model",
dict,
Expand Down Expand Up @@ -1644,9 +1656,7 @@ def model_args(exclude_hybrid=False):
Variant(
"type",
[
standard_model_args(),
frozen_model_args(),
pairtab_model_args(),
*model_args_plugin.get_all_argument(),
*hybrid_models,
],
optional=True,
Expand All @@ -1656,6 +1666,7 @@ def model_args(exclude_hybrid=False):
)


@model_args_plugin.register("standard")
def standard_model_args() -> Argument:
doc_descrpt = "The descriptor of atomic environment."
doc_fitting = "The fitting of physical properties."
Expand All @@ -1680,6 +1691,7 @@ def standard_model_args() -> Argument:
return ca


@hybrid_model_args_plugin.register("pairwise_dprc")
def pairwise_dprc() -> Argument:
qm_model_args = model_args(exclude_hybrid=True)
qm_model_args.name = "qm_model"
Expand All @@ -1699,6 +1711,7 @@ def pairwise_dprc() -> Argument:
return ca


@model_args_plugin.register("frozen")
def frozen_model_args() -> Argument:
doc_model_file = "Path to the frozen model file."
ca = Argument(
Expand All @@ -1711,6 +1724,7 @@ def frozen_model_args() -> Argument:
return ca


@model_args_plugin.register("pairtab")
def pairtab_model_args() -> Argument:
doc_tab_file = "Path to the tabulation file."
doc_rcut = "The cut-off radius."
Expand All @@ -1731,6 +1745,7 @@ def pairtab_model_args() -> Argument:
return ca


@hybrid_model_args_plugin.register("linear_ener")
def linear_ener_model_args() -> Argument:
doc_weights = (
"If the type is list of float, a list of weights for each model. "
Expand Down

0 comments on commit 5183372

Please sign in to comment.