From ca83b530cfae1f03fac69decc08224389f912378 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 28 Jun 2024 17:50:37 +0800 Subject: [PATCH 1/3] feat(pt): support multitask argcheck --- deepmd/pt/entrypoints/main.py | 5 +- deepmd/utils/argcheck.py | 82 +++++++++++++++---- .../pytorch_example/input_torch.json | 1 - 3 files changed, 69 insertions(+), 19 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 8e8aab939e..2cd51c83d2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -241,9 +241,8 @@ def train(FLAGS): ) # argcheck - if not multi_task: - config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") - config = normalize(config) + config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") + config = normalize(config, multi_task=multi_task) # do neighbor stat min_nbor_dist = None diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index d34726e7b1..0d0ca4eabd 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2325,7 +2325,9 @@ def mixed_precision_args(): # ! added by Denghui. ) -def training_args(): # ! modified by Ziyao: data configuration isolated. +def training_args( + multi_task=False, +): # ! modified by Ziyao: data configuration isolated. doc_numb_steps = "Number of training batch. Each training uses one batch of data." doc_seed = "The random seed for getting frames from the training data set." doc_disp_file = "The file for printing learning curve." @@ -2364,14 +2366,30 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. ) doc_opt_type = "The type of optimizer to use." doc_kf_blocksize = "The blocksize for the Kalman filter." + doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode." + doc_data_dict = "The multiple definition of the data, used in the multi-task mode." arg_training_data = training_data_args() arg_validation_data = validation_data_args() mixed_precision_data = mixed_precision_args() - args = [ + data_args = [ arg_training_data, arg_validation_data, + Argument( + "stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file + ), + ] + args = ( + data_args + if not multi_task + else [ + Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob), + Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict), + ] + ) + + args += [ mixed_precision_data, Argument( "numb_steps", int, optional=False, doc=doc_numb_steps, alias=["stop_batch"] @@ -2438,9 +2456,6 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. optional=True, doc=doc_only_pt_supported + doc_gradient_max_norm, ), - Argument( - "stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file - ), ] variants = [ Variant( @@ -2472,6 +2487,34 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. return Argument("training", dict, args, variants, doc=doc_training) +def multi_model_args(): + model_dict = model_args() + model_dict.name = "model_dict" + model_dict.repeat = True + model_dict.doc = ( + "The multiple definition of the model, used in the multi-task mode." + ) + doc_shared_dict = "The definition of the shared parameters used in the `model_dict` within multi-task mode." + return Argument( + "model", + dict, + [ + model_dict, + Argument( + "shared_dict", dict, optional=True, default={}, doc=doc_shared_dict + ), + ], + ) + + +def multi_loss_args(): + loss_dict = loss_args() + loss_dict.name = "loss_dict" + loss_dict.repeat = True + loss_dict.doc = "The multiple definition of the loss, used in the multi-task mode." + return loss_dict + + def make_index(keys): ret = [] for ii in keys: @@ -2502,14 +2545,23 @@ def gen_json(**kwargs): ) -def gen_args(**kwargs) -> List[Argument]: - return [ - model_args(), - learning_rate_args(), - loss_args(), - training_args(), - nvnmd_args(), - ] +def gen_args(multi_task=False) -> List[Argument]: + if not multi_task: + return [ + model_args(), + learning_rate_args(), + loss_args(), + training_args(multi_task=multi_task), + nvnmd_args(), + ] + else: + return [ + multi_model_args(), + learning_rate_args(), + multi_loss_args(), + training_args(multi_task=multi_task), + nvnmd_args(), + ] def gen_json_schema() -> str: @@ -2524,8 +2576,8 @@ def gen_json_schema() -> str: return json.dumps(generate_json_schema(arg)) -def normalize(data): - base = Argument("base", dict, gen_args()) +def normalize(data, multi_task=False): + base = Argument("base", dict, gen_args(multi_task=multi_task)) data = base.normalize_value(data, trim_pattern="_*") base.check_value(data, strict=True) diff --git a/examples/water_multi_task/pytorch_example/input_torch.json b/examples/water_multi_task/pytorch_example/input_torch.json index 801848f077..04d848538d 100644 --- a/examples/water_multi_task/pytorch_example/input_torch.json +++ b/examples/water_multi_task/pytorch_example/input_torch.json @@ -67,7 +67,6 @@ "_comment": "that's all" }, "loss_dict": { - "_comment": " that's all", "water_1": { "type": "ener", "start_pref_e": 0.02, From ad16a70ddb3f39e819c73a98a8d5b81386ab9a8d Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 28 Jun 2024 23:05:00 +0800 Subject: [PATCH 2/3] Update test_examples.py --- source/tests/common/test_examples.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index f7f1593f6f..6498d7beb2 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -15,6 +15,10 @@ normalize, ) +from ..pt.test_multitask import ( + preprocess_shared_params, +) + p_examples = Path(__file__).parent.parent.parent.parent / "examples" input_files = ( @@ -51,11 +55,18 @@ p_examples / "water" / "dpa2" / "input_torch.json", ) +input_files_multi = ( + p_examples / "water_multi_task" / "pytorch_example" / "input_torch.json", +) + class TestExamples(unittest.TestCase): def test_arguments(self): - for fn in input_files: + for fn in input_files + input_files_multi: + multi_task = fn in input_files_multi fn = str(fn) with self.subTest(fn=fn): jdata = j_loader(fn) - normalize(jdata) + if multi_task: + jdata["model"], _ = preprocess_shared_params(jdata["model"]) + normalize(jdata, multi_task=multi_task) From 24ab43d26c551ef940126184afb5e6a5230d0001 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 2 Jul 2024 12:24:02 +0800 Subject: [PATCH 3/3] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d9cbeb44e4..ea306f0d5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ 'numpy', 'scipy', 'pyyaml', - 'dargs >= 0.4.6', + 'dargs >= 0.4.7', 'typing_extensions; python_version < "3.8"', 'importlib_metadata>=1.4; python_version < "3.8"', 'h5py',