Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Take a separate argument for loading checkpoints, take paths for pret…
Browse files Browse the repository at this point in the history
…rained checkpoints (#379)

Summary:
Pull Request resolved: #379

There are now two arguments for checkpoints -
- `checkpoint_folder`: This is where checkpoints are saved to. Checkpoints are saved inside a subdirectory (the flow id) within this folder
- `checkpoint_load_path`: This is where checkpoints are loaded from. This can be a file or a directory.

Note that this means that training will not automatically resume by picking up a checkpoint from the `checkpoint_folder` if the user doesn't specify a `checkpoint_load_path`, but I think explicit is better than implicit in this scenario.

Also, renamed `pretrained_checkpoint_folder` to `pretrained_checkpoint_path` and updated the help string to reflect the fact that this doesn't need to be a folder anymore.

Updated the Getting started tutorial to reflect the changes.

Reviewed By: vreis

Differential Revision: D19760255

fbshipit-source-id: 7c66c7b66c0d4a192dcc6b403852668d01f7932b
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Feb 7, 2020
1 parent 9d18f50 commit 8575416
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 19 deletions.
7 changes: 3 additions & 4 deletions classy_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,13 @@ def main(args, config):

task = build_task(config)

# Load checkpoint, if available. This automatically resumes from an
# existing checkpoint, in case training is being restarted.
checkpoint = load_checkpoint(args.checkpoint_folder)
# Load checkpoint, if available.
checkpoint = load_checkpoint(args.checkpoint_load_path)
task.set_checkpoint(checkpoint)

# Load a checkpoint contraining a pre-trained model. This is how we
# implement fine-tuning of existing models.
pretrained_checkpoint = load_checkpoint(args.pretrained_checkpoint_folder)
pretrained_checkpoint = load_checkpoint(args.pretrained_checkpoint_path)
if pretrained_checkpoint is not None:
assert isinstance(
task, FineTuningTask
Expand Down
22 changes: 15 additions & 7 deletions classy_vision/generic/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,27 @@ def add_generic_args(parser):
"--checkpoint_folder",
default="",
type=str,
help="""folder to use for checkpoints:
help="""folder to use for saving checkpoints:
epochal checkpoints are stored as model_<epoch>.torch,
latest epoch checkpoint is at checkpoint.torch""",
)
parser.add_argument(
"--pretrained_checkpoint_folder",
"--checkpoint_load_path",
default="",
type=str,
help="""folder to use for pre-trained checkpoints:
epochal checkpoints are stored as model_<epoch>.torch,
latest epoch checkpoint is at checkpoint.torch,
checkpoint is used for fine-tuning task, and it will
not resume training from the checkpoint""",
help="""path to load a checkpoint from, which can be a file or a directory:
If the path is a directory, the checkpoint file is assumed to be
checkpoint.torch""",
)
parser.add_argument(
"--pretrained_checkpoint_path",
default="",
type=str,
help="""path to load a pre-trained checkpoints from, which can be a file or a
directory:
If the path is a directory, the checkpoint file is assumed to be
checkpoint.torch. This checkpoint is only used for fine-tuning
tasks, and training will not resume from this checkpoint.""",
)
parser.add_argument(
"--checkpoint_period",
Expand Down
3 changes: 2 additions & 1 deletion classy_vision/templates/synthetic/hydra_configs/args.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ checkpoint_folder: ""
checkpoint_period: 1
log_freq: 5
num_workers: 4
pretrained_checkpoint_folder: ""
checkpoint_load_path: ""
pretrained_checkpoint_path: ""
profiler: False
skip_tensorboard: False
show_progress: False
Expand Down
12 changes: 5 additions & 7 deletions tutorials/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@
"source": [
"## 4. Loading checkpoints\n",
"\n",
"Now that we've run `classy_train.py`, let's see how to load the resulting model. At the end of execution, `classy_train.py` will print the checkpoint directory used for that run. Each run will output to a different directory, typically named `output_<timestamp>/checkpoints`."
"Now that we've run `classy_train.py`, let's see how to load the resulting model. At the end of execution, `classy_train.py` will print the checkpoint directory used for that run. Each run will output to a different directory, typically named `output_<timestamp>/checkpoints`. This can be configured by passing the `--checkpoint_folder` argument to `classy_train.py`"
]
},
{
Expand Down Expand Up @@ -269,18 +269,16 @@
"\n",
"## 5. Resuming from checkpoints\n",
"\n",
"Resuming from a checkpoint is as simple as training: `classy_train.py` takes a `--checkpoint_folder` argument, which specifies the checkpoint to resume from:"
"Resuming from a checkpoint is as simple as training: `classy_train.py` takes a `--checkpoint_load_path` argument, which specifies the checkpoint path to resume from:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"! ./classy_train.py --config configs/template_config.json --checkpoint_folder ./output_<timestamp>/checkpoints"
"! ./classy_train.py --config configs/template_config.json --checkpoint_load_path ./output_<timestamp>/checkpoints"
]
},
{
Expand Down Expand Up @@ -554,7 +552,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5+"
"version": "3.7.3"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 8575416

Please sign in to comment.