Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add debug flag to TPU Training Plugins (PT_XLA_DEBUG) #7219

Merged
merged 6 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))


- Added `debug` flag to TPU Training Plugins ([#7219](https://github.com/PyTorchLightning/pytorch-lightning/pull/7219))



### Changed

Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Union

from torch.optim import Optimizer
Expand Down Expand Up @@ -51,7 +52,8 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
return super().setup(trainer, model)

def teardown(self) -> None:
pass
if "PT_XLA_DEBUG" in os.environ:
del os.environ["PT_XLA_DEBUG"]

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import torch

from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
Expand All @@ -24,11 +26,12 @@
class SingleTPUPlugin(SingleDevicePlugin):
""" Plugin for training on a single TPU device. """

def __init__(self, device: int):
def __init__(self, device: int, debug: bool = False):

device = xm.xla_device(device)
super().__init__(device)

self.debug = debug
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0

Expand All @@ -47,6 +50,9 @@ def pre_dispatch(self) -> None:
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)

if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@
class TPUSpawnPlugin(DDPSpawnPlugin):
""" Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method. """

def __init__(self, parallel_devices: Optional[List[int]] = None, **_: Any) -> None:
def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None:
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
self.debug = debug
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0
self.start_method = None
Expand Down Expand Up @@ -104,6 +105,10 @@ def connect(self, model: 'pl.LightningModule') -> None:
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
return super().connect(model)

def pre_dispatch(self):
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

def setup(self, model: Module) -> Module:
self.create_mp_queue()
return self.model
Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,31 @@ def test_sync_dist(rank):
assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"

xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_debug_mode(tmpdir):
"""Test if debug mode works on TPU."""

class DebugModel(BoringModel):

def on_train_start(self):
assert os.environ.get("PT_XLA_DEBUG") == str(1), "PT_XLA_DEBUG was not set in environment variables"

def teardown(self, stage):
assert "PT_XLA_DEBUG" not in os.environ

tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=8,
limit_train_batches=0.4,
limit_val_batches=0.4,
plugins=TPUSpawnPlugin(debug=True),
)

model = DebugModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)