diff --git a/pytorch_translate/train.py b/pytorch_translate/train.py index b82b0cf9..01d2ce79 100644 --- a/pytorch_translate/train.py +++ b/pytorch_translate/train.py @@ -228,11 +228,11 @@ def setup_training_state(args, trainer, task): # ignore previous directory args and just use the absolute path as is. checkpoint_path = os.path.join(args.save_dir, args.restore_file) restore_state = True - if os.path.exists(checkpoint_path): + if os.path.isfile(checkpoint_path): print( f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}." ) - elif args.pretrained_checkpoint_file and os.path.exists( + elif args.pretrained_checkpoint_file and os.path.isfile( args.pretrained_checkpoint_file ): checkpoint_path = args.pretrained_checkpoint_file