Skip to content

Commit cecebbd

Browse files
committed
Deprecate TrainingTypePlugin.on_save and Accelerator.on_save
1 parent e1442d2 commit cecebbd

File tree

5 files changed

+2
-12
lines changed

5 files changed

+2
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
111111
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))
112112

113113

114-
-
114+
- Deprecated `TrainingTypePlugin.on_save` and `Accelerator.on_save` ([#8987](https://github.com/PyTorchLightning/pytorch-lightning/pull/8987))
115115

116116

117117
-

pytorch_lightning/accelerators/accelerator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,6 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
371371
"""
372372
return self.training_type_plugin.lightning_module_state_dict()
373373

374-
def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]:
375-
return self.training_type_plugin.on_save(checkpoint)
376-
377374
def barrier(self, name: Optional[str] = None) -> None:
378375
self.training_type_plugin.barrier(name=name)
379376

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
301301
last_path = None
302302
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
303303
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
304-
atomic_save(self.on_save(state_dict), last_path)
304+
atomic_save(state_dict, last_path)
305305

306306
# todo, pass complete checkpoint as state dictionary
307307
self.mp_queue.put(best_model_path)

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,6 @@ def validation_step_end(self, output):
201201
def test_step_end(self, output):
202202
return output
203203

204-
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
205-
return checkpoint
206-
207204
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
208205
"""Wraps the dataloader if necessary
209206
@@ -273,8 +270,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
273270
checkpoint: dict containing model and trainer state
274271
filepath: write-target file's path
275272
"""
276-
# dump states as a checkpoint dictionary object
277-
checkpoint = self.on_save(checkpoint)
278273
if self.should_rank_save_checkpoint:
279274
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
280275

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,6 @@ def hpc_save(self, folderpath: str, logger):
294294

295295
model.on_hpc_save(checkpoint)
296296

297-
checkpoint = self.trainer.accelerator.on_save(checkpoint)
298-
299297
# do the actual save
300298
# TODO: fix for anything with multiprocess DP, DDP, DDP2
301299
try:

0 commit comments

Comments
 (0)