Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): support multitask argcheck #3925

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,8 @@ def train(FLAGS):
)

# argcheck
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)

# do neighbor stat
min_nbor_dist = None
Expand Down
82 changes: 67 additions & 15 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2325,7 +2325,9 @@
)


def training_args(): # ! modified by Ziyao: data configuration isolated.
def training_args(
multi_task=False,
): # ! modified by Ziyao: data configuration isolated.
doc_numb_steps = "Number of training batch. Each training uses one batch of data."
doc_seed = "The random seed for getting frames from the training data set."
doc_disp_file = "The file for printing learning curve."
Expand Down Expand Up @@ -2364,14 +2366,30 @@
)
doc_opt_type = "The type of optimizer to use."
doc_kf_blocksize = "The blocksize for the Kalman filter."
doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode."
doc_data_dict = "The multiple definition of the data, used in the multi-task mode."

arg_training_data = training_data_args()
arg_validation_data = validation_data_args()
mixed_precision_data = mixed_precision_args()

args = [
data_args = [
arg_training_data,
arg_validation_data,
Argument(
"stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file
),
]
args = (
data_args
if not multi_task
else [
Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob),
Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict),
]
)

args += [
mixed_precision_data,
Argument(
"numb_steps", int, optional=False, doc=doc_numb_steps, alias=["stop_batch"]
Expand Down Expand Up @@ -2438,9 +2456,6 @@
optional=True,
doc=doc_only_pt_supported + doc_gradient_max_norm,
),
Argument(
"stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file
),
]
variants = [
Variant(
Expand Down Expand Up @@ -2472,6 +2487,34 @@
return Argument("training", dict, args, variants, doc=doc_training)


def multi_model_args():
model_dict = model_args()
model_dict.name = "model_dict"
model_dict.repeat = True
model_dict.doc = (

Check warning on line 2494 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L2491-L2494

Added lines #L2491 - L2494 were not covered by tests
"The multiple definition of the model, used in the multi-task mode."
)
doc_shared_dict = "The definition of the shared parameters used in the `model_dict` within multi-task mode."
return Argument(

Check warning on line 2498 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L2497-L2498

Added lines #L2497 - L2498 were not covered by tests
"model",
dict,
[
model_dict,
Argument(
"shared_dict", dict, optional=True, default={}, doc=doc_shared_dict
),
],
)


def multi_loss_args():
loss_dict = loss_args()
loss_dict.name = "loss_dict"
loss_dict.repeat = True
loss_dict.doc = "The multiple definition of the loss, used in the multi-task mode."
return loss_dict

Check warning on line 2515 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L2511-L2515

Added lines #L2511 - L2515 were not covered by tests


def make_index(keys):
ret = []
for ii in keys:
Expand Down Expand Up @@ -2502,14 +2545,23 @@
)


def gen_args(**kwargs) -> List[Argument]:
return [
model_args(),
learning_rate_args(),
loss_args(),
training_args(),
nvnmd_args(),
]
def gen_args(multi_task=False) -> List[Argument]:
if not multi_task:
return [
model_args(),
learning_rate_args(),
loss_args(),
training_args(multi_task=multi_task),
nvnmd_args(),
]
else:
return [

Check warning on line 2558 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L2558

Added line #L2558 was not covered by tests
multi_model_args(),
learning_rate_args(),
multi_loss_args(),
training_args(multi_task=multi_task),
nvnmd_args(),
]


def gen_json_schema() -> str:
Expand All @@ -2524,8 +2576,8 @@
return json.dumps(generate_json_schema(arg))


def normalize(data):
base = Argument("base", dict, gen_args())
def normalize(data, multi_task=False):
base = Argument("base", dict, gen_args(multi_task=multi_task))
data = base.normalize_value(data, trim_pattern="_*")
base.check_value(data, strict=True)

Expand Down
1 change: 0 additions & 1 deletion examples/water_multi_task/pytorch_example/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
"_comment": "that's all"
},
"loss_dict": {
"_comment": " that's all",
"water_1": {
"type": "ener",
"start_pref_e": 0.02,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
'numpy',
'scipy',
'pyyaml',
'dargs >= 0.4.6',
'dargs >= 0.4.7',
'typing_extensions; python_version < "3.8"',
'importlib_metadata>=1.4; python_version < "3.8"',
'h5py',
Expand Down
15 changes: 13 additions & 2 deletions source/tests/common/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
normalize,
)

from ..pt.test_multitask import (
preprocess_shared_params,
)

p_examples = Path(__file__).parent.parent.parent.parent / "examples"

input_files = (
Expand Down Expand Up @@ -51,11 +55,18 @@
p_examples / "water" / "dpa2" / "input_torch.json",
)

input_files_multi = (
p_examples / "water_multi_task" / "pytorch_example" / "input_torch.json",
)


class TestExamples(unittest.TestCase):
def test_arguments(self):
for fn in input_files:
for fn in input_files + input_files_multi:
multi_task = fn in input_files_multi
fn = str(fn)
with self.subTest(fn=fn):
jdata = j_loader(fn)
normalize(jdata)
if multi_task:
jdata["model"], _ = preprocess_shared_params(jdata["model"])
normalize(jdata, multi_task=multi_task)