Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/n] Directly call TrainingTypePlugin APIs instead of going through the Accelerator #9901

Merged
merged 7 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896))


- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))


### Deprecated

- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))
Expand Down
237 changes: 222 additions & 15 deletions pytorch_lightning/accelerators/accelerator.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
void(*args, **kwargs)

dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
dl_max_batches = self._max_batches[dataloader_idx]

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def on_run_start(self) -> None:
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Predicts one entire dataloader."""
void(*args, **kwargs)
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
dl_max_batches = self.max_batches[self.current_dataloader_idx]

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def on_advance_start(self) -> None:

def advance(self) -> None:
"""Runs one whole epoch."""
dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
Expand Down Expand Up @@ -234,7 +234,7 @@ def on_run_end(self) -> None:
self.trainer.call_hook("on_train_end")

# give accelerators a chance to finish
self.trainer.accelerator.on_train_end()
self.trainer.training_type_plugin.on_train_end()

def teardown(self) -> None:
self.epoch_loop.teardown()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
lightning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
self.trainer.training_type_plugin.post_training_step()

del step_kwargs

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
lightning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
self.trainer.training_type_plugin.post_training_step()

del step_kwargs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
weights_only: saving model weights only
"""
_checkpoint = self.dump_checkpoint(weights_only)
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)
self.trainer.training_type_plugin.save_checkpoint(_checkpoint, filepath)

def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metrics = (
Expand All @@ -478,7 +478,7 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metric.persistent(True)
metric.sync()

state_dict = self.trainer.accelerator.lightning_module_state_dict()
state_dict = self.trainer.training_type_plugin.lightning_module_state_dict()

for metric in metrics:
# sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check
Expand Down
34 changes: 19 additions & 15 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,13 +1020,13 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.callback_connector.attach_model_logging_functions(model)

# attach model to the training type plugin
self.accelerator.connect(model)
self.training_type_plugin.connect(model)

# hook
self.data_connector.prepare_data()
self.callback_connector._attach_model_callbacks()

if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self._load_checkpoint_weights()

# ----------------------------
Expand All @@ -1037,7 +1037,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment

# check if we should delay restoring checkpoint till later
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self.checkpoint_connector.resume_start()
self._restore_modules_and_callbacks()

Expand All @@ -1055,9 +1055,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
| ||
{self._dispatch} ||
| || LIGHTNING
{self.accelerator.start_training} ||
or {self.accelerator.start_evaluating} ||
or {self.accelerator.start_predicting} || FLOW
{self.training_type_plugin.start_training} ||
or {self.training_type_plugin.start_evaluating} ||
or {self.training_type_plugin.start_predicting} || FLOW
| ||
{self.run_stage} ||
| || DIRECTION
Expand Down Expand Up @@ -1087,7 +1087,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
# plugin will setup fitting (e.g. ddp will launch child processes)
self._pre_dispatch()

if self.accelerator.restore_checkpoint_after_pre_dispatch:
if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()

Expand Down Expand Up @@ -1119,7 +1119,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.state.status = TrainerStatus.FINISHED
self.state.stage = None

return self.accelerator.results
return self.training_type_plugin.results

def _pre_dispatch(self):
self.accelerator.pre_dispatch(self)
Expand Down Expand Up @@ -1173,11 +1173,11 @@ def _post_dispatch(self):

def _dispatch(self):
if self.evaluating:
self.accelerator.start_evaluating(self)
self.training_type_plugin.start_evaluating(self)
elif self.predicting:
self.accelerator.start_predicting(self)
self.training_type_plugin.start_predicting(self)
else:
self.accelerator.start_training(self)
self.training_type_plugin.start_training(self)

def run_stage(self):
self.accelerator.dispatch(self)
Expand Down Expand Up @@ -1509,22 +1509,26 @@ def precision_plugin(self) -> PrecisionPlugin:

@property
def global_rank(self) -> int:
return self.accelerator.training_type_plugin.global_rank
return self.training_type_plugin.global_rank

@property
def local_rank(self) -> int:
# some training types define a local rank
return getattr(self.accelerator.training_type_plugin, "local_rank", 0)
return getattr(self.training_type_plugin, "local_rank", 0)

@property
def node_rank(self) -> int:
# some training types define a local rank
return getattr(self.accelerator.training_type_plugin, "node_rank", 0)
return getattr(self.training_type_plugin, "node_rank", 0)

@property
def world_size(self) -> int:
# some training types define a world size
return getattr(self.accelerator.training_type_plugin, "world_size", 1)
return getattr(self.training_type_plugin, "world_size", 1)

@property
def should_rank_save_checkpoint(self) -> bool:
return self.training_type_plugin.should_rank_save_checkpoint

@property
def _distrib_type(self) -> DistributedType:
Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_restore_checkpoint_after_pre_dispatch_default():
"""Assert default for restore_checkpoint_after_pre_dispatch is False."""
plugin = SingleDevicePlugin(torch.device("cpu"))
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())
assert not accelerator.restore_checkpoint_after_pre_dispatch
assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch
assert not plugin.restore_checkpoint_after_pre_dispatch


Expand Down Expand Up @@ -77,7 +77,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO())
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())

assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

trainer = Trainer(
Expand Down
61 changes: 60 additions & 1 deletion tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import():
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401


def test_v1_6_0_deprecated_accelerator_collective():
def test_v1_6_0_deprecated_accelerator_pass_through_functions():
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type import SingleDevicePlugin

Expand All @@ -347,3 +347,62 @@ def test_v1_6_0_deprecated_accelerator_collective():
with pytest.deprecated_call(match="will be removed in v1.6"):
tensor = torch.rand(2, 2, requires_grad=True)
accelerator.all_gather(tensor)

with pytest.deprecated_call(match="will be removed in v1.6"):
model = BoringModel()
accelerator.connect(model)

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.post_training_step()

with pytest.deprecated_call(match="will be removed in v1.6"):
tensor = torch.rand(2, 2, requires_grad=True)
accelerator.training_step_end(tensor)

with pytest.deprecated_call(match="will be removed in v1.6"):
tensor = torch.rand(2, 2, requires_grad=True)
accelerator.test_step_end(tensor)

with pytest.deprecated_call(match="will be removed in v1.6"):
tensor = torch.rand(2, 2, requires_grad=True)
accelerator.validation_step_end(tensor)

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.lightning_module_state_dict()

with pytest.deprecated_call(match="will be removed in v1.6"):
dl = model.train_dataloader()
accelerator.process_dataloader(dl)

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.results

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.setup_optimizers_in_pre_dispatch

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.restore_checkpoint_after_pre_dispatch

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_validation_start()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_test_start()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_predict_start()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_validation_end()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_test_end()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_predict_end()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_train_end()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_train_batch_start(batch=None, batch_idx=0)
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_fully_sharded_plugin_checkpoint_multi_gpus(tmpdir):

def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.accelerator.lightning_module_state_dict()
model_state_dict = trainer.training_type_plugin.lightning_module_state_dict()

if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_ddp_configure_ddp():
)
# test wrap the model if fitting
trainer.state.fn = TrainerFn.FITTING
trainer.accelerator.connect(model)
trainer.training_type_plugin.connect(model)
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
Expand All @@ -122,7 +122,7 @@ def test_ddp_configure_ddp():
plugins=[ddp_plugin],
)
# test do not wrap the model if trainerFN is not fitting
trainer.accelerator.connect(model)
trainer.training_type_plugin.connect(model)
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
Expand Down