Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: checkpoint connector methods 4/n #3474

Merged
merged 9 commits into from
Sep 12, 2020
1 change: 1 addition & 0 deletions .pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"pytorch_lightning/trainer/distrib_data_parallel.py",
"pytorch_lightning/trainer/lr_scheduler_connector.py",
"pytorch_lightning/trainer/training_loop_temp.py",
"pytorch_lightning/trainer/connectors/checkpoint_connector.py",
"pytorch_lightning/tuner"
],

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
amp.load_state_dict(checkpoint['amp_scaling_state'])

# load training state (affects trainer only)
self.trainer.restore_training_state(checkpoint)
self.restore_training_state(checkpoint)

def restore_training_state(self, checkpoint):
"""
Expand Down Expand Up @@ -187,7 +187,7 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):

# if hpc weights exist restore model
if len(hpc_weight_paths) > 0:
self.trainer.hpc_load(folderpath, self.trainer.on_gpu)
self.hpc_load(folderpath, self.trainer.on_gpu)
did_restore = True
return did_restore

Expand Down Expand Up @@ -321,7 +321,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
if self.trainer.amp_backend == AMPType.NATIVE and not self.trainer.use_tpu and self.trainer.scaler is not None:
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = self.trainer.state_dict()
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the module_arguments and state_dict from the model
model = self.trainer.get_model()
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ class TrainerTrainingTricksMixin(ABC):
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def fit(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def print_nan_gradients(self) -> None:
model = self.get_model()
for param in model.parameters():
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def scale_batch_size(trainer,
log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')

# Restore initial state of model
trainer.restore(str(save_path), on_gpu=trainer.on_gpu)
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
os.remove(save_path)

# Finish by resetting variables so trainer is ready to fit model
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def lr_find(
lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose

# Reset model state
trainer.restore(str(save_path), on_gpu=trainer.on_gpu)
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
os.remove(save_path)

# Finish by resetting variables so trainer is ready to fit model
Expand Down
4 changes: 2 additions & 2 deletions tests/base/develop_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
trainer.init_optimizers(pretrained_model)

# test HPC loading / saving
trainer.hpc_save(save_dir, logger)
trainer.hpc_load(save_dir, on_gpu=on_gpu)
trainer.checkpoint_connector.hpc_save(save_dir, logger)
trainer.checkpoint_connector.hpc_load(save_dir, on_gpu=on_gpu)


def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def run_test_from_config(trainer_options):
run_prediction(dataloader, pretrained_model)

# test HPC loading / saving
trainer.hpc_save(ckpt_path, trainer.logger)
trainer.hpc_load(ckpt_path, on_gpu=args.on_gpu)
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
trainer.checkpoint_connector.hpc_load(ckpt_path, on_gpu=args.on_gpu)

if args.on_gpu:
trainer = Trainer(gpus=1, distributed_backend='horovod', max_epochs=1)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_cpu_slurm_save_load(tmpdir):

# test HPC saving
# simulate snapshot on slurm
saved_filepath = trainer.hpc_save(trainer.weights_save_path, logger)
saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger)
assert os.path.exists(saved_filepath)

# new logger file to get meta
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_dp_resume(tmpdir):
# HPC LOAD/SAVE
# ---------------------------
# save
trainer.hpc_save(tmpdir, logger)
trainer.checkpoint_connector.hpc_save(tmpdir, logger)

# init new trainer
new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def test_model_checkpoint_only_weights(tmpdir):

# assert restoring train state fails
with pytest.raises(KeyError, match='checkpoint contains only the model'):
trainer.restore_training_state(checkpoint)
trainer.checkpoint_connector.restore_training_state(checkpoint)


def test_model_freeze_unfreeze():
Expand Down