Skip to content

Commit

Permalink
Adding ability to resume run in Simple API (#1205)
Browse files Browse the repository at this point in the history
* Adding continue
  • Loading branch information
zphang authored Oct 21, 2020
1 parent f4bca4e commit 0b3dff5
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 151 deletions.
60 changes: 51 additions & 9 deletions guides/tutorials/quick_start_main.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,66 @@ python jiant/proj/main/scripts/configurator.py \
--task_config_base_path ${EXP_DIR}/tasks/configs \
--task_cache_base_path ${EXP_DIR}/cache/${MODEL_TYPE} \
--epochs 3 \
--train_batch_size 4 \
--train_batch_size 16 \
--eval_batch_multiplier 2 \
--do_train --do_val
```

5. Finally, we train our model.
```bash
python jiant/proj/main/runscript.py \
run_with_continue \
run \
--ZZsrc ${EXP_DIR}/models/${MODEL_TYPE}/config.json \
--jiant_task_container_config_path ${EXP_DIR}/runconfigs/${MODEL_TYPE}/${TASK}.json \
--model_load_mode from_transformers \
--learning_rate 1e-5 \
--eval_every_steps 1000 \
--no_improvements_for_n_evals 30 \
--save_checkpoint_every_steps 1000 \
--delete_checkpoint_if_done \
--do_save --do_train --do_val \
--force_overwrite \
--do_train --do_val \
--do_save --force_overwrite \
--output_dir ${EXP_DIR}/runs/${MODEL_TYPE}/${TASK}
```
```


## Additional Options

### Checkpointing

To allow checkpointing (allowing you to resume runs that get interrupted), use the `run_with_continue` mode and set the `save_checkpoint_every_steps` argument. For example:

```bash
python jiant/proj/main/runscript.py \
run_with_continue \
...
...
--save_checkpoint_every_steps 500 \
--delete_checkpoint_if_done
```

This will save a checkpoint to disk every 500 training steps. The checkpoint will be saved to a `checkpoint.p` file. If the process gets killed, you can rerun the exact same command and it will continue training from the latest checkpoint.

Note that checkpoints are for resuming training, not for saving snapshots of model weights at different points in training. Checkpoints also include additional run metadata, as well as the optimizer states. To save regular snapshots of model weights, see [Model Snapshots](#model-snapshots)

We also set the `delete_checkpoint_if_done` flag to delete the checkpoint after training is complete.

### Model Snapshots

To save snapshots of model weights at regular intervals, use the `--save_every_steps` argument. For example:

```
--save_every_steps 500
```

will save a pickle of model weights every 500 training steps.

### Early Stopping

To do early stopping, we can perform validation evaluation at regular intervals over the course of training, and select the best model weights based on validation performance. For expedience, we often do not want to evaluate on the whole validation set, but only a subset. To do early stopping, use the following arguments as an example:

```
--eval_every_steps 1000
--no_improvements_for_n_evals 30
--eval_subset_num 500
```

* `--eval_every_steps 1000` indicates that we will evaluate the model on a validation subset every 1000 training steps.
* `--no_improvements_for_n_evals 30` indicates that if the validation performance does not improve for 30 consecutive validation evaluations, we will end the training phase
* `--eval_subset_num 500` indicates that we will evaluate on the first 500 validation examples for early stopping. This value is `500` by default.
53 changes: 50 additions & 3 deletions guides/tutorials/quick_start_simple.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,79 @@ We will assume that `jiant` and its dependencies have already be installed.
## Workflow

First, let us assume that we will be working with the following working directory:
```
```bash
EXP_DIR=/path/to/exp
```

For this training example we'll use the RTE task from GLUE, and the RoBERTa-base model.

1. We'll get the data using `jiant`'s download script
```
```bash
python jiant/scripts/download_data/runscript.py \
download \
--tasks rte \
--output_path ${EXP_DIR}/tasks
```

2. Now that the data is ready, we can use `jiant`'s "Simple" CLI to perform training with a single command:
```
```bash
python jiant/proj/simple/runscript.py \
run \
--run_name simple \
--exp_dir ${EXP_DIR} \
--data_dir ${EXP_DIR}/tasks \
--model_type roberta-base \
--tasks rte \
--learning_rate 1e-5 \
--train_batch_size 16 \
--num_train_epochs 3 \
--do_save
```

The "Simple" CLI subsumes several steps under the hood, including downloading the `roberta-base` model, tokenizing and caching the data, writing a [run-configuration](../general/in_depth_into.md#write-run-config), and performing the training.


## Additional Options

### Checkpointing

To allow checkpointing (allowing you to resume runs that get interrupted), use the `run_with_continue` mode and set the `save_checkpoint_every_steps` argument. For example:

```bash
python jiant/proj/simple/runscript.py \
run_with_continue \
...
...
--save_checkpoint_every_steps 500 \
--delete_checkpoint_if_done
```

This will save a checkpoint to disk every 500 training steps. The checkpoint will be saved to a `checkpoint.p` file. If the process gets killed, you can rerun the exact same command and it will continue training from the latest checkpoint.

Note that checkpoints are for resuming training, not for saving snapshots of model weights at different points in training. Checkpoints also include additional run metadata, as well as the optimizer states. To save regular snapshots of model weights, see [Model Snapshots](#model-snapshots)

We also set the `delete_checkpoint_if_done` flag to delete the checkpoint after training is complete.

### Model Snapshots

To save snapshots of model weights at regular intervals, use the `--save_every_steps` argument. For example:

```
--save_every_steps 500
```

will save a pickle of model weights every 500 training steps.

### Early Stopping

To do early stopping, we can perform validation evaluation at regular intervals over the course of training, and select the best model weights based on validation performance. For expedience, we often do not want to evaluate on the whole validation set, but only a subset. To do early stopping, use the following arguments as an example:

```
--eval_every_steps 1000
--no_improvements_for_n_evals 30
--eval_subset_num 500
```

* `--eval_every_steps 1000` indicates that we will evaluate the model on a validation subset every 1000 training steps.
* `--no_improvements_for_n_evals 30` indicates that if the validation performance does not improve for 30 consecutive validation evaluations, we will end the training phase
* `--eval_subset_num 500` indicates that we will evaluate on the first 500 validation examples for early stopping. This value is `500` by default.
158 changes: 19 additions & 139 deletions jiant/proj/simple/runscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class RunConfiguration(zconf.RunConfig):
max_seq_length = zconf.attr(type=int, default=256)
num_train_epochs = zconf.attr(type=float, default=3)
train_examples_cap = zconf.attr(type=int, default=None)
dry_run = zconf.attr(action="store_true")
create_config = zconf.attr(action="store_true")

# === Running Setup === #
Expand Down Expand Up @@ -98,7 +97,7 @@ def create_and_write_task_configs(task_name_list, data_dir, task_config_base_pat
return task_config_path_dict


def run_simple(args: RunConfiguration):
def run_simple(args: RunConfiguration, with_continue: bool = False):

model_cache_path = replace_none(
args.model_cache_path, default=os.path.join(args.exp_dir, "models")
Expand Down Expand Up @@ -195,8 +194,18 @@ def run_simple(args: RunConfiguration):
model_cache_path, args.model_type, "model", f"{args.model_type}.p"
)
run_output_dir = os.path.join(args.exp_dir, "runs", args.run_name)
runscript.run_loop(
runscript.RunConfiguration(

if (
args.save_checkpoint_every_steps
and os.path.exists(os.path.join(run_output_dir, "checkpoint.p"))
and with_continue
):
print("Resuming")
checkpoint = torch.load(os.path.join(run_output_dir, "checkpoint.p"))
run_args = runscript.RunConfiguration.from_dict(checkpoint["metadata"]["args"])
else:
print("Running from start")
run_args = runscript.RunConfiguration(
# === Required parameters === #
jiant_task_container_config_path=jiant_task_container_config_path,
output_dir=run_output_dir,
Expand Down Expand Up @@ -234,148 +243,19 @@ def run_simple(args: RunConfiguration):
server_ip=args.server_ip,
server_port=args.server_port,
)
)
py_io.write_file(args.to_json(), os.path.join(run_output_dir, "simple_run_config.json"))


def dry_run(args: RunConfiguration):

model_cache_path = replace_none(
args.model_cache_path, default=os.path.join(args.exp_dir, "models")
)

print("\n# === Step 1: Write task configs based on templates === #")
full_task_name_list = sorted(list(set(args.train_tasks + args.val_tasks + args.test_tasks)))
for task_name in full_task_name_list:
print(
f"""
python jiant/proj/main/write_configs.py \\
--task_name {task_name} \\
--task_data_dir {os.path.join(args.data_dir, task_name)} \\
--task_config_path {os.path.join(args.data_dir, "configs", f"{task_name}_config.json")}
""".strip()
)

print("\n# === Step 2: Download models === #")
print(
f"""
python jiant/proj/main/export_model.py \\
--model_type {args.model_type} \\
--output_base_path {os.path.join(model_cache_path, args.model_type)}
""".strip()
)

print("\n# === Step 3: Tokenize and cache === #")
phase_task_dict = {
"train": args.train_tasks,
"val": args.val_tasks,
"test": args.test_tasks,
}
for task_name in full_task_name_list:
phases_to_do = []
for phase, phase_task_list in phase_task_dict.items():
if task_name in phase_task_list:
phases_to_do.append(phase)
print(
f"""
python jiant/proj/main/tokenize_and_cache.py \\
--task_config_path {os.path.join(args.data_dir, "configs", f"{task_name}_config.json")} \\
--model_type {args.model_type} \\
--model_tokenizer_path {os.path.join(model_cache_path, args.model_type, "tokenizer")} \\
--output_dir {os.path.join(args.exp_dir, "cache", task_name)} \\
--phases {",".join(phases_to_do)} \\
--max_seq_length {args.max_seq_length} \\
--smart_truncate \\
--do_iter
""".strip()
)
checkpoint = None

print("\n# === Step 4: Generate jiant_task_container_config === #")
s = f"""
python jiant/proj/main/scripts/configurator.py \\
SimpleAPIMultiTaskConfigurator \\
{os.path.join(args.exp_dir, "run_configs", f"{args.run_name}_config.json")} \\
--task_config_base_path {os.path.join(args.data_dir, "configs")} \\
--task_cache_base_path {os.path.join(args.exp_dir, "cache")} \\
--train_task_name_list {",".join(args.train_tasks)} \\
--val_task_name_list {",".join(args.val_tasks)} \\
--test_task_name_list {",".join(args.test_tasks)} \\
--train_batch_size {args.train_batch_size} \\
--eval_batch_multiplier 2 \\
--epochs {args.num_train_epochs} \\
--num_gpus {torch.cuda.device_count()}
""".strip()
if args.train_examples_cap:
s += f" \\\n --train_examples_cap {args.train_examples_cap}"
print(s.strip())

print("\n# === Step 5: Train/Eval! === #")
if args.model_weights_path:
model_load_mode = "partial"
model_weights_path = args.model_weights_path
else:
# From Transformers
if any(task_name.startswith("mlm_") for task_name in full_task_name_list):
model_load_mode = "from_transformers_with_mlm"
else:
model_load_mode = "from_transformers"
model_weights_path = os.path.join(
model_cache_path, args.model_type, "model", f"{args.model_type}.p"
)
s = f"""
python jiant/proj/main/runscript.py \\
run \\
--jiant_task_container_config_path \
{os.path.join(args.exp_dir, "run_configs", f"{args.run_name}_config.json")} \\
--output_dir {os.path.join(args.exp_dir, "runs", args.run_name)} \\
--model_type {args.model_type} \\
--model_path {model_weights_path} \\
--model_config_path \
{os.path.join(model_cache_path, args.model_type, "model", f"{args.model_type}.json")} \\
--model_tokenizer_path {os.path.join(model_cache_path, args.model_type, "tokenizer")} \\
--model_load_mode {model_load_mode}
""".strip()
if args.train_tasks:
s += " \\\n --do_train"
if args.val_tasks:
s += " \\\n --do_val"
covered_attrs = [
"jiant_task_container_config_path",
"output_dir",
"model_type",
"model_path",
"model_config_path",
"model_tokenizer_path",
"model_load_mode",
]
for attr in runscript.RunConfiguration.__attrs_attrs__:
if attr.name in covered_attrs:
continue
if not hasattr(args, attr.name):
continue
args_attr = getattr(args, attr.name)
if attr.default == args_attr:
continue
if attr.default is None and args_attr is None:
continue
if (
"argparse_kwargs" in attr.metadata
and "action" in attr.metadata["argparse_kwargs"]
and attr.metadata["argparse_kwargs"]["action"] == "store_true"
):
s += f" \\\n --{attr.name}"
else:
s += f" \\\n --{attr.name} {args_attr}"
print(s.strip())
runscript.run_loop(args=run_args, checkpoint=checkpoint)
py_io.write_file(args.to_json(), os.path.join(run_output_dir, "simple_run_config.json"))


def main():
mode, cl_args = zconf.get_mode_and_cl_args()
args = RunConfiguration.default_run_cli(cl_args=cl_args)
if mode == "run":
run_simple(args)
elif mode == "dry_run":
dry_run(args)
run_simple(args, with_continue=False)
if mode == "run_with_continue":
run_simple(args, with_continue=True)
else:
raise zconf.ModeLookupError(mode)

Expand Down

0 comments on commit 0b3dff5

Please sign in to comment.