Skip to content

Commit

Permalink
Update TPUSpawnPlugin spawn methods (#10022)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Oct 19, 2021
1 parent e44921e commit e0c83ee
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- LightningLite:
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018), [#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
Expand Down Expand Up @@ -508,6 +508,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Remove deprecated `distributed_backend` from `Trainer` ([#10017](https://github.com/PyTorchLightning/pytorch-lightning/pull/10017))


- Removed `process_idx` from the `{DDPSpawnPlugin,TPUSpawnPlugin}.new_process` methods ([#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))


### Fixed


Expand Down
32 changes: 15 additions & 17 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import os
import re
import time
from typing import Any, Dict, List, Optional, Union
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -148,17 +149,9 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
def set_world_ranks(self, process_idx: int = 0) -> None:
pass

def new_process(self, process_idx: int, trainer, mp_queue) -> None:
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.mp_queue = mp_queue

reset_seed()

self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

# set warning rank
rank_zero_only.rank = self.global_rank

if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()

Expand Down Expand Up @@ -261,26 +254,31 @@ def _close_logger(self, trainer) -> None:
if trainer.logger is not None:
trainer.logger.finalize("success")

def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict:
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {
"args": (trainer, self.mp_queue),
"nprocs": len(self.parallel_devices),
"start_method": self.start_method,
}

def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None:
xmp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs())

def _worker_setup(self, process_idx: int):
reset_seed()
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
rank_zero_only.rank = self.global_rank

def start_training(self, trainer: "pl.Trainer") -> None:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
self._close_logger(trainer)
xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
return super().start_training(trainer)

def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._close_logger(trainer)
xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))

def start_predicting(self, trainer: "pl.Trainer") -> None:
xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
return super().start_evaluating(trainer)

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
Expand Down

0 comments on commit e0c83ee

Please sign in to comment.