From 1b54b4acbe7376d37fcc392e8762b4ccaec021b6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 15 Apr 2023 20:31:29 -0400 Subject: [PATCH 1/2] fix batch size for simplify #803 changed the behavior of sys_idx in the fp step and caused there to be lots of systems. However, it failed to try to get the batch size of these systems. --- dpgen/generator/run.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 21d516508..0001df56d 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -393,8 +393,11 @@ def make_train(iter_index, jdata, mdata): init_data_sys.append( os.path.normpath(os.path.join("..", "data.iters", sys_single)) ) + batch_size = sys_batch_size[sys_idx] if sys_idx < len( + sys_batch_size + ) else "auto" init_batch_size.append( - detect_batch_size(sys_batch_size[sys_idx], sys_single) + detect_batch_size(batch_size, sys_single) ) # establish tasks jinput = jdata["default_training_param"] From d585995a101560763eb4c579f977319a738f50a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 16 Apr 2023 00:32:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpgen/generator/run.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 0001df56d..293424a4d 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -393,12 +393,12 @@ def make_train(iter_index, jdata, mdata): init_data_sys.append( os.path.normpath(os.path.join("..", "data.iters", sys_single)) ) - batch_size = sys_batch_size[sys_idx] if sys_idx < len( - sys_batch_size - ) else "auto" - init_batch_size.append( - detect_batch_size(batch_size, sys_single) + batch_size = ( + sys_batch_size[sys_idx] + if sys_idx < len(sys_batch_size) + else "auto" ) + init_batch_size.append(detect_batch_size(batch_size, sys_single)) # establish tasks jinput = jdata["default_training_param"] try: