diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 2fdda0aadd..8ee0a480a7 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -6,6 +6,7 @@ Callable, List, Optional, + Union, ) from dargs import ( @@ -165,7 +166,10 @@ def __init__(self) -> None: def register( self, name: str, alias: Optional[List[str]] = None, doc: str = "" - ) -> Callable[[], List[Argument]]: + ) -> Callable[ + [Union[Callable[[], Argument], Callable[[], List[Argument]]]], + Union[Callable[[], Argument], Callable[[], List[Argument]]], + ]: """Register a descriptor argument plugin. Parameters @@ -177,8 +181,8 @@ def register( Returns ------- - Callable[[], List[Argument]] - the registered descriptor argument method + Callable[[Union[Callable[[], Argument], Callable[[], List[Argument]]]], Union[Callable[[], Argument], Callable[[], List[Argument]]]] + decorator to return the registered descriptor argument method Examples -------- @@ -209,9 +213,17 @@ def get_all_argument(self, exclude_hybrid: bool = False) -> List[Argument]: for (name, alias, doc), metd in self.__plugin.plugins.items(): if exclude_hybrid and name == "hybrid": continue - arguments.append( - Argument(name=name, dtype=dict, sub_fields=metd(), alias=alias, doc=doc) - ) + args = metd() + if isinstance(args, Argument): + arguments.append(args) + elif isinstance(args, list): + arguments.append( + Argument( + name=name, dtype=dict, sub_fields=metd(), alias=alias, doc=doc + ) + ) + else: + raise ValueError(f"Invalid return type {type(args)}") return arguments @@ -1517,6 +1529,11 @@ def model_compression_type_args(): ) +model_args_plugin = ArgsPlugin() +# for models that require another model as input +hybrid_model_args_plugin = ArgsPlugin() + + def model_args(exclude_hybrid=False): doc_type_map = "A list of strings. Give the name to each type of atoms. It is noted that the number of atom type of training system must be less than 128 in a GPU environment. If not given, type.raw in each system should use the same type indexes, and type_map.raw will take no effect." doc_data_stat_nbatch = "The model determines the normalization from the statistics of the data. This key specifies the number of `frames` in each `system` used for statistics." @@ -1540,12 +1557,7 @@ def model_args(exclude_hybrid=False): hybrid_models = [] if not exclude_hybrid: - hybrid_models.extend( - [ - pairwise_dprc(), - linear_ener_model_args(), - ] - ) + hybrid_models.extend(hybrid_model_args_plugin.get_all_argument()) return Argument( "model", dict, @@ -1644,9 +1656,7 @@ def model_args(exclude_hybrid=False): Variant( "type", [ - standard_model_args(), - frozen_model_args(), - pairtab_model_args(), + *model_args_plugin.get_all_argument(), *hybrid_models, ], optional=True, @@ -1656,6 +1666,7 @@ def model_args(exclude_hybrid=False): ) +@model_args_plugin.register("standard") def standard_model_args() -> Argument: doc_descrpt = "The descriptor of atomic environment." doc_fitting = "The fitting of physical properties." @@ -1680,6 +1691,7 @@ def standard_model_args() -> Argument: return ca +@hybrid_model_args_plugin.register("pairwise_dprc") def pairwise_dprc() -> Argument: qm_model_args = model_args(exclude_hybrid=True) qm_model_args.name = "qm_model" @@ -1699,6 +1711,7 @@ def pairwise_dprc() -> Argument: return ca +@model_args_plugin.register("frozen") def frozen_model_args() -> Argument: doc_model_file = "Path to the frozen model file." ca = Argument( @@ -1711,6 +1724,7 @@ def frozen_model_args() -> Argument: return ca +@model_args_plugin.register("pairtab") def pairtab_model_args() -> Argument: doc_tab_file = "Path to the tabulation file." doc_rcut = "The cut-off radius." @@ -1731,6 +1745,7 @@ def pairtab_model_args() -> Argument: return ca +@hybrid_model_args_plugin.register("linear_ener") def linear_ener_model_args() -> Argument: doc_weights = ( "If the type is list of float, a list of weights for each model. "