Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade PyTorch Lightning to 1.0.2 #7852

Merged
merged 13 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None:
def generic_train(
model: BaseTransformer,
args: argparse.Namespace,
early_stopping_callback=False,
early_stopping_callback=None,
logger=True, # can pass WandbLogger() here
extra_callbacks=[],
checkpoint_callback=None,
Expand All @@ -355,6 +355,8 @@ def generic_train(
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
)
if early_stopping_callback:
extra_callbacks.append(early_stopping_callback)
if logging_callback is None:
logging_callback = LoggingCallback()

Expand All @@ -376,7 +378,6 @@ def generic_train(
callbacks=[logging_callback] + extra_callbacks,
logger=logger,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stopping_callback,
**train_params,
)

Expand Down
2 changes: 1 addition & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ psutil
sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==0.9.0
pytorch-lightning==1.0.2
matplotlib
git-python==1.0.3
faiss-cpu
Expand Down
1 change: 0 additions & 1 deletion examples/seq2seq/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
monitor=f"val_{metric}",
mode="min" if "loss" in metric else "max",
save_top_k=save_top_k,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback

Expand Down
2 changes: 1 addition & 1 deletion examples/text-classification/run_pl_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def main():

# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
model = model.load_from_checkpoint(checkpoints[-1])
return trainer.test(model)

Expand Down
6 changes: 3 additions & 3 deletions examples/token-classification/run_pl_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def add_model_specific_args(parser, root_dir):

if args.do_predict:
# See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint:
# pl use this default format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model)