Skip to content

Commit

Permalink
do finetune in iter-0 instead of additional finetune step
Browse files Browse the repository at this point in the history
Signed-off-by: zjgemi <liuxin_zijian@163.com>
  • Loading branch information
zjgemi committed Jun 19, 2024
1 parent 43c3b90 commit 3a45d4c
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 203 deletions.
85 changes: 6 additions & 79 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,53 +412,6 @@ def make_optional_parameter(
return {"data_mixed_type": mixed_type, "finetune_mode": finetune_mode}


def make_finetune_step(
config,
prep_train_config,
run_train_config,
upload_python_packages,
numb_models,
template_script,
train_config,
init_models,
init_data,
iter_data,
valid_data=None,
):
finetune_optional_parameter = {
"mixed_type": config["inputs"]["mixed_type"],
"finetune_mode": "finetune",
}

finetune_op = PrepRunDPTrain(
"finetune",
PrepDPTrain,
RunDPTrain,
prep_config=prep_train_config,
run_config=run_train_config,
upload_python_packages=upload_python_packages,
finetune=True,
valid_data=valid_data,
)
finetune_step = Step(
"finetune-step",
template=finetune_op,
parameters={
"block_id": "finetune",
"numb_models": numb_models,
"template_script": template_script,
"train_config": train_config,
"run_optional_parameter": finetune_optional_parameter,
},
artifacts={
"init_models": init_models,
"init_data": init_data,
"iter_data": iter_data,
},
)
return finetune_step


def get_systems_from_data(data, data_prefix=None):
data = [data] if isinstance(data, str) else data
assert isinstance(data, list)
Expand Down Expand Up @@ -610,32 +563,17 @@ def workflow_concurrent_learning(
else:
init_models = None

finetune_step = None
optional_parameter = make_optional_parameter(
config["inputs"]["mixed_type"],
)

if config["inputs"].get("do_finetune", False):
finetune_step = make_finetune_step(
config,
prep_train_config,
run_train_config,
upload_python_packages,
numb_models,
template_script,
train_config,
init_models,
init_data,
iter_data,
valid_data=valid_data,
)

init_models = finetune_step.outputs.artifacts["models"]
template_script = finetune_step.outputs.parameters["template_script"]

if train_config["init_model_policy"] != "yes":
logging.warning("In finetune mode, init_model_policy is forced to be 'yes'")
train_config["init_model_policy"] = "yes"
optional_parameter = make_optional_parameter(
config["inputs"]["mixed_type"],
finetune_mode="train-init",
finetune_mode="finetune",
)

# here the scheduler is passed as input parameter to the concurrent_learning_op
Expand All @@ -658,7 +596,7 @@ def workflow_concurrent_learning(
"iter_data": iter_data,
},
)
return dpgen_step, finetune_step
return dpgen_step

Check failure on line 599 in dpgen2/entrypoint/submit.py

View workflow job for this annotation

GitHub Actions / pyright

Expression of type "Step" cannot be assigned to return type "Tuple[Step, Step | None]"   "Step" is incompatible with "Tuple[Step, Step | None]" (reportGeneralTypeIssues)


def get_scheduler_ids(
Expand Down Expand Up @@ -743,9 +681,7 @@ def submit_concurrent_learning(

global_config_workflow(wf_config)

dpgen_step, finetune_step = workflow_concurrent_learning(
wf_config,
)
dpgen_step = workflow_concurrent_learning(wf_config)

if reuse_step is not None and replace_scheduler:
scheduler_new = copy.deepcopy(
Expand Down Expand Up @@ -781,17 +717,9 @@ def submit_concurrent_learning(
"conf_selector",
selector,
)
# the modify-train-script step will be added as reuse step.
# the following hack is not needed anymore.
# wf_config["inputs"]["do_finetune"] = False
# finetune will not be done again if the old process is reused.

wf = Workflow(name=wf_config["name"], parallelism=wf_config["parallelism"])

if wf_config["inputs"].get("do_finetune", False):
assert finetune_step is not None
wf.add(finetune_step)

wf.add(dpgen_step)

Check failure on line 723 in dpgen2/entrypoint/submit.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "Tuple[Step, Step | None]" cannot be assigned to parameter "step" of type "Step | List[Step] | Task | List[Task]" in function "add"   Type "Tuple[Step, Step | None]" cannot be assigned to type "Step | List[Step] | Task | List[Task]"     "Tuple[Step, Step | None]" is incompatible with "Step"     "Tuple[Step, Step | None]" is incompatible with "List[Step]"     "Tuple[Step, Step | None]" is incompatible with "Task"     "Tuple[Step, Step | None]" is incompatible with "List[Task]" (reportGeneralTypeIssues)

# for debug purpose, we may not really submit the wf
Expand Down Expand Up @@ -885,7 +813,6 @@ def get_resubmit_keys(
"prep-run-train",
"prep-train",
"run-train",
"modify-train-script",
"prep-caly-input",
"prep-caly-model-devi",
"run-caly-model-devi",
Expand Down
10 changes: 9 additions & 1 deletion dpgen2/flow/dpgen_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def make_block_optional_parameter(cl_optional_parameter):
}


def make_next_optional_parameter(optional_parameter):
return {
"data_mixed_type": optional_parameter["data_mixed_type"],
"finetune_mode": "no", # not to do finetune for `next` loop
}


class SchedulerWrapper(OP):
@classmethod
def get_input_sign(cls):
Expand Down Expand Up @@ -426,7 +433,8 @@ def _loop(
"exploration_scheduler": scheduler_step.outputs.parameters[
"exploration_scheduler"
],
"optional_parameter": steps.inputs.parameters["optional_parameter"],
"optional_parameter": make_next_optional_parameter(
steps.inputs.parameters["optional_parameter"]),
"expl_task_grp": scheduler_step.outputs.parameters["expl_task_grp"],
}
next_step = Step(
Expand Down
11 changes: 4 additions & 7 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ def _make_train_command(
return command
# case of init model and finetune
assert checkpoint is None
do_init_model_or_train_init = do_init_model or finetune_mode == "train-init"
case_init_model = do_init_model_or_train_init and (not init_model_with_finetune)
case_init_model = do_init_model and (not init_model_with_finetune)
case_finetune = finetune_mode == "finetune" or (
do_init_model_or_train_init and init_model_with_finetune
do_init_model and init_model_with_finetune
)
if case_init_model:
init_flag = "--init-frz-model" if impl == "tensorflow" else "--init-model"
Expand Down Expand Up @@ -128,9 +127,7 @@ def _make_train_command_old(
checkpoint,
train_script_name,
]
elif (
do_init_model or finetune_mode == "train-init"
) and not init_model_with_finetune:
elif do_init_model and not init_model_with_finetune:
if impl == "pytorch":
command = dp_command + [
"train",
Expand All @@ -146,7 +143,7 @@ def _make_train_command_old(
train_script_name,
]
elif finetune_mode == "finetune" or (
(do_init_model or finetune_mode == "train-init") and init_model_with_finetune
do_init_model and init_model_with_finetune
):
command = (
dp_command
Expand Down
110 changes: 3 additions & 107 deletions dpgen2/superop/prep_run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,86 +54,15 @@
from dpgen2.utils.step_config import normalize as normalize_step_dict


class ModifyTrainScript(OP):
r"""Modify the training scripts to prepare them for training
tasks in dpgen step.
Read the training scripts modified by finetune, and replace
the original template scripts to be compatible with pre-trained models.
New templates are returned as `op["template_script"]`.
"""

@classmethod
def get_input_sign(cls):
return OPIOSign(
{
"numb_models": int,
"scripts": Artifact(Path),
}
)

@classmethod
def get_output_sign(cls):
return OPIOSign(
{
"template_script": BigParameter(List[dict]),
}
)

@OP.exec_sign_check
def execute(
self,
ip: OPIO,
) -> OPIO:
r"""Execute the OP.
Parameters
----------
ip : dict
Input dict with components:
- `scripts`: (`Artifact(Path)`) Training scripts from finetune.
- `numb_models`: (`int`) Number of DP models to train.
Returns
-------
op : dict
Output dict with components:
- `template_script`: (`List[dict]`) One template from one finetuning task. The length of the list should be the same as `numb_models`.
"""
scripts = ip["scripts"]
new_template_script = []
numb_models = ip["numb_models"]

for ii in range(numb_models):
subdir = Path(train_task_pattern % ii)
train_script = Path(scripts) / subdir / train_script_name
with open(train_script, "r") as fp:
train_dict = json.load(fp)
new_template_script.append(train_dict)

op = OPIO(
{
"template_script": new_template_script,
}
)
return op


class PrepRunDPTrain(Steps):
def __init__(
self,
name: str,
prep_train_op: Type[OP],
run_train_op: Type[RunDPTrain],
modify_train_script_op: Type[ModifyTrainScript] = ModifyTrainScript,
prep_config: dict = normalize_step_dict({}),
run_config: dict = normalize_step_dict({}),
upload_python_packages: Optional[List[os.PathLike]] = None,
finetune: bool = False,
valid_data: Optional[S3Artifact] = None,
):
self._input_parameters = {
Expand Down Expand Up @@ -173,28 +102,22 @@ def __init__(
)

self._keys = ["prep-train", "run-train"]
if finetune:
self._keys.append("modify-train-script")
self.step_keys = {}
ii = "prep-train"
self.step_keys[ii] = "--".join(["%s" % self.inputs.parameters["block_id"], ii])
ii = "run-train"
self.step_keys[ii] = "--".join(
["%s" % self.inputs.parameters["block_id"], ii + "-{{item}}"]
)
ii = "modify-train-script"
self.step_keys[ii] = "--".join(["%s" % self.inputs.parameters["block_id"], ii])

self = _prep_run_dp_train(
self,
self.step_keys,
prep_train_op,
run_train_op,
modify_train_script_op,
prep_config=prep_config,
run_config=run_config,
upload_python_packages=upload_python_packages,
finetune=finetune,
valid_data=valid_data,
)

Expand Down Expand Up @@ -224,11 +147,9 @@ def _prep_run_dp_train(
step_keys,
prep_train_op: Type[OP],
run_train_op: Type[RunDPTrain],
modify_train_script_op: Type[OP],
prep_config: dict = normalize_step_dict({}),
run_config: dict = normalize_step_dict({}),
upload_python_packages: Optional[List[os.PathLike]] = None,
finetune: bool = False,
valid_data: Optional[S3Artifact] = None,
):
prep_config = deepcopy(prep_config)
Expand Down Expand Up @@ -297,34 +218,9 @@ def _prep_run_dp_train(
)
train_steps.add(run_train)

if finetune:
modify_train_script = Step(
"modify-train-script",
template=PythonOPTemplate(
modify_train_script_op,
python_packages=upload_python_packages,
**prep_template_config,
),
parameters={
"numb_models": train_steps.inputs.parameters["numb_models"],
},
artifacts={
"scripts": run_train.outputs.artifacts["script"],
},
key=step_keys["modify-train-script"],
executor=prep_executor,
**prep_config,
)
train_steps.add(modify_train_script)
train_steps.outputs.parameters[
"template_script"
].value_from_parameter = modify_train_script.outputs.parameters[
"template_script"
]
else:
train_steps.outputs.parameters[
"template_script"
].value_from_parameter = train_steps.inputs.parameters["template_script"]
train_steps.outputs.parameters[
"template_script"
].value_from_parameter = train_steps.inputs.parameters["template_script"]
train_steps.outputs.artifacts["scripts"]._from = run_train.outputs.artifacts[
"script"
]
Expand Down
11 changes: 2 additions & 9 deletions dpgen2/utils/dflow_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def matched_step_key(
if (
re.match(f"iter-[0-9]*--{jj}-[0-9]*", kk)
or re.match(f"iter-[0-9]*--{jj}", kk)
or re.match(f"finetune--{jj}-[0-9]*", kk)
or re.match(f"finetune--{jj}", kk)
or re.match(f"init--{jj}", kk)
):
ret.append(kk)
Expand Down Expand Up @@ -119,16 +117,11 @@ def find_slice_ranges(
status = "not-found"
for idx, ii in enumerate(keys):
if status == "not-found":
if re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii) or re.match(
f"finetune--{sliced_subkey}-[0-9]*", ii
):
if re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii):
status = "found"
tmp_range.append(idx)
elif status == "found":
if not (
re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii)
or re.match(f"finetune--{sliced_subkey}-[0-9]*", ii)
):
if not re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii):
status = "not-found"
tmp_range.append(idx)
found_range.append(tmp_range)
Expand Down

0 comments on commit 3a45d4c

Please sign in to comment.