From b71439536084ead970f3b5acce5847e772149e53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 5 Jan 2022 12:48:48 +0100 Subject: [PATCH 01/38] add xla environment class --- .../plugins/environments/__init__.py | 1 + .../plugins/environments/xla_environment.py | 63 +++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 pytorch_lightning/plugins/environments/xla_environment.py diff --git a/pytorch_lightning/plugins/environments/__init__.py b/pytorch_lightning/plugins/environments/__init__.py index 1878a725071ad..ca10268f21877 100644 --- a/pytorch_lightning/plugins/environments/__init__.py +++ b/pytorch_lightning/plugins/environments/__init__.py @@ -17,3 +17,4 @@ from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.xla_environment import XLAEnvironment # noqa: F401 diff --git a/pytorch_lightning/plugins/environments/xla_environment.py b/pytorch_lightning/plugins/environments/xla_environment.py new file mode 100644 index 0000000000000..162d1f03ac28f --- /dev/null +++ b/pytorch_lightning/plugins/environments/xla_environment.py @@ -0,0 +1,63 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities import _TPU_AVAILABLE + + +class XLAEnvironment(ClusterEnvironment): + """XLAEnvironment + + A list of environment variables set by XLA can be found + `here `_. + """ + + def __init__(self) -> None: + super().__init__() + + @property + def creates_processes_externally(self) -> bool: + return False + + @property + def main_address(self) -> str: + return os.environ["TPU_MESH_CONTROLLER_ADDRESS"] + + @property + def main_port(self) -> int: + return int(os.environ["TPU_MESH_CONTROLLER_PORT"]) + + @staticmethod + def detect() -> bool: + return _TPU_AVAILABLE + + def world_size(self) -> int: + return int(os.environ.get("XRT_SHARD_WORLD_SIZE", 1)) + + def set_world_size(self, size: int) -> None: + pass + + def global_rank(self) -> int: + return int(os.environ.get("XRT_SHARD_ORDINAL", 0)) + + def set_global_rank(self, rank: int) -> None: + pass + + def local_rank(self) -> int: + return int(os.environ.get("XRT_SHARD_LOCAL_ORDINAL", 0)) + + def node_rank(self) -> int: + return int(os.environ.get("XRT_HOST_ORDINAL", 0)) From 078a01fa845259a8e46a226581e33c6dee66695a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 5 Jan 2022 12:48:57 +0100 Subject: [PATCH 02/38] add api reference --- docs/source/api_references.rst | 1 + docs/source/extensions/plugins.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 67457d88aa571..a49fd15103b66 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -198,6 +198,7 @@ Cluster Environments TorchElasticEnvironment KubeflowEnvironment SLURMEnvironment + XLAEnvironment Checkpoint IO Plugins ^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst index 14c2118b7445c..5ab4fc4d8a2dc 100644 --- a/docs/source/extensions/plugins.rst +++ b/docs/source/extensions/plugins.rst @@ -158,3 +158,4 @@ Cluster Environments TorchElasticEnvironment KubeflowEnvironment SLURMEnvironment + XLAEnvironment From 64c57c4897fbc8917306a4985c101d6e44a2bef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 5 Jan 2022 12:52:31 +0100 Subject: [PATCH 03/38] integrate --- pytorch_lightning/strategies/tpu_spawn.py | 5 +++++ .../trainer/connectors/accelerator_connector.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 8f40f59052a71..be1f7ab5cf0fd 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -24,6 +24,8 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.plugins import ClusterEnvironment +from pytorch_lightning.plugins.environments import XLAEnvironment from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnStrategy @@ -55,15 +57,18 @@ def __init__( self, accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[int]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[XLACheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, **_: Any, ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() + cluster_environment = cluster_environment or XLAEnvironment() super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, + cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 72801701072f4..79f137aa2e368 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -44,7 +44,7 @@ LightningEnvironment, LSFEnvironment, SLURMEnvironment, - TorchElasticEnvironment, + TorchElasticEnvironment, XLAEnvironment, ) from pytorch_lightning.strategies import ( DataParallelStrategy, @@ -808,7 +808,7 @@ def select_cluster_environment(self) -> ClusterEnvironment: rank_zero_info("Multiprocessing is handled by SLURM.") return SLURMEnvironment() - for env_type in (TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment): + for env_type in (TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, XLAEnvironment): if env_type.detect(): return env_type() From cee674befb13734a3736c8892fd05b95bf9fd1ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 5 Jan 2022 13:08:48 +0100 Subject: [PATCH 04/38] use xenv --- .../plugins/environments/xla_environment.py | 19 +++++++++---------- pytorch_lightning/strategies/tpu_spawn.py | 3 +-- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/environments/xla_environment.py b/pytorch_lightning/plugins/environments/xla_environment.py index 162d1f03ac28f..1da9d73ca2dac 100644 --- a/pytorch_lightning/plugins/environments/xla_environment.py +++ b/pytorch_lightning/plugins/environments/xla_environment.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE +if _TPU_AVAILABLE: + import torch_xla.core.xla_env_vars as xenv + class XLAEnvironment(ClusterEnvironment): """XLAEnvironment @@ -25,39 +27,36 @@ class XLAEnvironment(ClusterEnvironment): `here `_. """ - def __init__(self) -> None: - super().__init__() - @property def creates_processes_externally(self) -> bool: return False @property def main_address(self) -> str: - return os.environ["TPU_MESH_CONTROLLER_ADDRESS"] + return os.environ[xenv.TPU_MESH_CTLER_ADDR] @property def main_port(self) -> int: - return int(os.environ["TPU_MESH_CONTROLLER_PORT"]) + return int(os.environ[xenv.TPU_MESH_CTLER_PORT]) @staticmethod def detect() -> bool: return _TPU_AVAILABLE def world_size(self) -> int: - return int(os.environ.get("XRT_SHARD_WORLD_SIZE", 1)) + return int(os.environ.get(xenv.WORLD_SIZE, 1)) def set_world_size(self, size: int) -> None: pass def global_rank(self) -> int: - return int(os.environ.get("XRT_SHARD_ORDINAL", 0)) + return int(os.environ.get(xenv.ORDINAL, 0)) def set_global_rank(self, rank: int) -> None: pass def local_rank(self) -> int: - return int(os.environ.get("XRT_SHARD_LOCAL_ORDINAL", 0)) + return int(os.environ.get(xenv.LOCAL_ORDINAL, 0)) def node_rank(self) -> int: - return int(os.environ.get("XRT_HOST_ORDINAL", 0)) + return int(os.environ.get(xenv.HOST_ORDINAL, 0)) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index be1f7ab5cf0fd..71aa8f6e90a94 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -153,8 +153,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]: @property def is_distributed(self) -> bool: - # HOST_WORLD_SIZE is None outside the xmp.spawn process - return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1 + return self.world_size > 1 def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: TPUSpawnStrategy._validate_dataloader(dataloader) From f509dc9c139dad09bb46ee9272883bd610d92489 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 5 Jan 2022 14:59:14 +0100 Subject: [PATCH 05/38] remove properties --- pytorch_lightning/strategies/tpu_spawn.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 71aa8f6e90a94..099b03a1437f1 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -77,18 +77,6 @@ def __init__( self.tpu_global_core_rank = 0 self.start_method = "fork" - @property - def global_rank(self) -> int: - return self.tpu_global_core_rank - - @property - def local_rank(self) -> int: - return self.tpu_local_core_rank - - @property - def world_size(self) -> int: - return xm.xrt_world_size() - @property def root_device(self) -> torch.device: return xm.xla_device() From 7d192cbc06818f2b314dcc8f7280b5ad773e8932 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jan 2022 14:03:30 +0000 Subject: [PATCH 06/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/environments/xla_environment.py | 2 +- pytorch_lightning/trainer/connectors/accelerator_connector.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/environments/xla_environment.py b/pytorch_lightning/plugins/environments/xla_environment.py index 1da9d73ca2dac..b3d029c3b9e04 100644 --- a/pytorch_lightning/plugins/environments/xla_environment.py +++ b/pytorch_lightning/plugins/environments/xla_environment.py @@ -21,7 +21,7 @@ class XLAEnvironment(ClusterEnvironment): - """XLAEnvironment + """XLAEnvironment. A list of environment variables set by XLA can be found `here `_. diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 79f137aa2e368..3654050166673 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -44,7 +44,8 @@ LightningEnvironment, LSFEnvironment, SLURMEnvironment, - TorchElasticEnvironment, XLAEnvironment, + TorchElasticEnvironment, + XLAEnvironment, ) from pytorch_lightning.strategies import ( DataParallelStrategy, From ce427e5376a03f223ed22d9f6ee8475ad494a33c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 5 Jan 2022 15:37:32 +0100 Subject: [PATCH 07/38] test environment selection --- tests/models/test_tpu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 236f0225367c0..7a4d4fd7d5825 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -24,6 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.plugins.environments import XLAEnvironment from pytorch_lightning.strategies import TPUSpawnStrategy from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE @@ -314,6 +315,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): else: trainer = Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores) assert trainer._accelerator_connector.tpu_id == expected_tpu_id + assert isinstance(trainer.strategy.cluster_environment, XLAEnvironment) @pytest.mark.parametrize( From 6df74bc40dd8715e95e7ba0e540b697b053e6877 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 6 Feb 2022 16:50:54 +0100 Subject: [PATCH 08/38] update --- .../plugins/environments/xla_environment.py | 23 +++++++++++++++---- .../connectors/accelerator_connector.py | 8 ++++++- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/environments/xla_environment.py b/pytorch_lightning/plugins/environments/xla_environment.py index b3d029c3b9e04..a5496d6bc80e0 100644 --- a/pytorch_lightning/plugins/environments/xla_environment.py +++ b/pytorch_lightning/plugins/environments/xla_environment.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv + import torch_xla.core.xla_model as xm + +log = logging.getLogger(__name__) class XLAEnvironment(ClusterEnvironment): @@ -27,6 +32,14 @@ class XLAEnvironment(ClusterEnvironment): `here `_. """ + def __init__(self): + super().__init__() + if not _TPU_AVAILABLE: + raise MisconfigurationException( + "The `XLAEnvironment` can only be used on a machine with TPU devices and with the `torch_xla` library" + " installed." + ) + @property def creates_processes_externally(self) -> bool: return False @@ -44,19 +57,19 @@ def detect() -> bool: return _TPU_AVAILABLE def world_size(self) -> int: - return int(os.environ.get(xenv.WORLD_SIZE, 1)) + return xm.xrt_world_size() def set_world_size(self, size: int) -> None: - pass + log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") def global_rank(self) -> int: - return int(os.environ.get(xenv.ORDINAL, 0)) + return xm.get_ordinal() def set_global_rank(self, rank: int) -> None: - pass + log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") def local_rank(self) -> int: - return int(os.environ.get(xenv.LOCAL_ORDINAL, 0)) + return xm.get_local_ordinal() def node_rank(self) -> int: return int(os.environ.get(xenv.HOST_ORDINAL, 0)) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 6987efd2f8e49..3d003b3f750b2 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -815,7 +815,13 @@ def select_cluster_environment(self) -> ClusterEnvironment: rank_zero_info("Multiprocessing is handled by SLURM.") return SLURMEnvironment() - for env_type in (BaguaEnvironment, TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, XLAEnvironment): + for env_type in ( + BaguaEnvironment, + TorchElasticEnvironment, + KubeflowEnvironment, + LSFEnvironment, + XLAEnvironment, + ): if env_type.detect(): return env_type() From 46cd7a300e224b786b9239842f4ac80d5929b602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Feb 2022 03:07:33 +0100 Subject: [PATCH 09/38] notebooks --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 0c325829101d5..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0c325829101d5a6ebf32ed99bbf5b09badf04a59 From ad7acc04761c63440a0f47b1678a3bc6be5d9899 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Feb 2022 03:07:38 +0100 Subject: [PATCH 10/38] notebooks --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..290fb466de1fc --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 290fb466de1fcc2ac6025f74b56906592911e856 From e5fae8f788a8466cdf497d6bd050a43e2b4fc413 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Feb 2022 03:11:45 +0100 Subject: [PATCH 11/38] update --- .../plugins/environments/xla_environment.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/environments/xla_environment.py b/pytorch_lightning/plugins/environments/xla_environment.py index a5496d6bc80e0..225e0debf069a 100644 --- a/pytorch_lightning/plugins/environments/xla_environment.py +++ b/pytorch_lightning/plugins/environments/xla_environment.py @@ -26,20 +26,12 @@ class XLAEnvironment(ClusterEnvironment): - """XLAEnvironment. + """Cluster environment for training on a TPU Pod with the `PyTorch/XLA `_ library. A list of environment variables set by XLA can be found `here `_. """ - def __init__(self): - super().__init__() - if not _TPU_AVAILABLE: - raise MisconfigurationException( - "The `XLAEnvironment` can only be used on a machine with TPU devices and with the `torch_xla` library" - " installed." - ) - @property def creates_processes_externally(self) -> bool: return False From d084d51aac083d4cadeb680baf7bfe770178cc99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Feb 2022 03:11:54 +0100 Subject: [PATCH 12/38] update --- pytorch_lightning/plugins/environments/xla_environment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/environments/xla_environment.py b/pytorch_lightning/plugins/environments/xla_environment.py index 225e0debf069a..a78ebeb36a6a4 100644 --- a/pytorch_lightning/plugins/environments/xla_environment.py +++ b/pytorch_lightning/plugins/environments/xla_environment.py @@ -16,7 +16,6 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv From 5dca2f8c9a524458fa88831ee9ed37e00365a81f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Feb 2022 03:31:01 +0100 Subject: [PATCH 13/38] test tests --- .../environments/test_xla_environment.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/plugins/environments/test_xla_environment.py diff --git a/tests/plugins/environments/test_xla_environment.py b/tests/plugins/environments/test_xla_environment.py new file mode 100644 index 0000000000000..c513aa629674e --- /dev/null +++ b/tests/plugins/environments/test_xla_environment.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import pytest + +import pytorch_lightning as pl +from pytorch_lightning.plugins.environments import XLAEnvironment +from tests.helpers.runif import RunIf + + +@RunIf(tpu=True) +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """Test the default attributes when no environment variables are set.""" + env = XLAEnvironment() + assert not env.creates_processes_externally + assert env.world_size() == 1 + assert env.global_rank() == 0 + assert env.local_rank() == 0 + assert env.node_rank() == 0 + + with pytest.raises(KeyError): + # main_address is required to be passed as env variable + _ = env.main_address + with pytest.raises(KeyError): + # main_port is required to be passed as env variable + _ = env.main_port + + +@RunIf(tpu=True) +@mock.patch.dict( + os.environ, + { + "TPU_MESH_CONTROLLER_ADDRESS": "1.2.3.4", + "TPU_MESH_CONTROLLER_PORT": "500", + "XRT_SHARD_WORLD_SIZE": "1", + "XRT_SHARD_ORDINAL": "0", + "XRT_SHARD_LOCAL_ORDINAL": "2", + "XRT_HOST_ORDINAL": "3", + }, +) +def test_attributes_from_environment_variables(): + """Test that the default cluster environment takes the attributes from the environment variables.""" + env = XLAEnvironment() + assert env.main_address == "1.2.3.4" + assert env.main_port == 500 + assert env.world_size() == 1 + assert env.global_rank() == 0 + assert env.local_rank() == 2 + assert env.node_rank() == 3 + env.set_global_rank(100) + assert env.global_rank() == 0 + env.set_world_size(100) + assert env.world_size() == 0 + + +def test_detect(monkeypatch): + """Test the detection of a xla environment configuration.""" + monkeypatch.setattr(pl.plugins.environments.xla_environment, "_TPU_AVAILABLE", False) + assert not XLAEnvironment.detect() + + monkeypatch.setattr(pl.plugins.environments.xla_environment, "_TPU_AVAILABLE", True) + assert XLAEnvironment.detect() From 472e200def470028c1075ec93fd3ff7147de603f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Feb 2022 03:33:22 +0100 Subject: [PATCH 14/38] include test case --- dockers/tpu-tests/tpu_test_cases.jsonnet | 1 + 1 file changed, 1 insertion(+) diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 6f96ad95357a0..65ab0fe6b0be3 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -33,6 +33,7 @@ local tputests = base.BaseTest { echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" coverage run --source=pytorch_lightning -m pytest -v --capture=no \ + tests/plugins/environments/test_xla_environment.py \ tests/strategies/test_tpu_spawn.py \ tests/profiler/test_xla_profiler.py \ pytorch_lightning/utilities/xla_device.py \ From 1833b62aaff919561f558a047309666e98cb0fa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Feb 2022 03:45:39 +0100 Subject: [PATCH 15/38] fix test --- tests/models/test_tpu.py | 4 +++- tests/plugins/environments/test_xla_environment.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 7a4d4fd7d5825..ff0600696b909 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -315,7 +315,9 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): else: trainer = Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores) assert trainer._accelerator_connector.tpu_id == expected_tpu_id - assert isinstance(trainer.strategy.cluster_environment, XLAEnvironment) + assert not isinstance(trainer.strategy, TPUSpawnStrategy) or isinstance( + trainer.strategy.cluster_environment, XLAEnvironment + ) @pytest.mark.parametrize( diff --git a/tests/plugins/environments/test_xla_environment.py b/tests/plugins/environments/test_xla_environment.py index c513aa629674e..bd4af5f144fc1 100644 --- a/tests/plugins/environments/test_xla_environment.py +++ b/tests/plugins/environments/test_xla_environment.py @@ -64,7 +64,7 @@ def test_attributes_from_environment_variables(): env.set_global_rank(100) assert env.global_rank() == 0 env.set_world_size(100) - assert env.world_size() == 0 + assert env.world_size() == 1 def test_detect(monkeypatch): From dcf3ccbf22205c9d6f9dc1c4d1d3a77d5e630fb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Feb 2022 00:43:04 +0100 Subject: [PATCH 16/38] fix --- pytorch_lightning/strategies/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index c5467e55efcd1..e0746704d8f1a 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -141,7 +141,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]: @property def is_distributed(self) -> bool: - return self.world_size > 1 + return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1 def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: TPUSpawnStrategy._validate_dataloader(dataloader) From 1131bf72524ce3eb05c4d92ce201b9fafa6e6446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Feb 2022 04:58:08 +0100 Subject: [PATCH 17/38] reset --- pytorch_lightning/strategies/tpu_spawn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index ebed904bb640a..767a371e99358 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -142,6 +142,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]: @property def is_distributed(self) -> bool: + # HOST_WORLD_SIZE is None outside the xmp.spawn process return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1 def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: From 32390a90b242e420557a973c37f7ad571ff2801b Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 16 Feb 2022 13:16:55 +0530 Subject: [PATCH 18/38] temp fix --- pytorch_lightning/strategies/tpu_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 767a371e99358..68361e7083c4b 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -183,7 +183,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt self.checkpoint_io.save_checkpoint(state_dict, weights_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training - if self.local_rank != 0: + if self.tpu_local_core_rank != 0: return # adds the `callback_metrics` to the queue @@ -313,7 +313,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None: Args: filepath: Path to checkpoint """ - if self.local_rank == 0: + if self.tpu_local_core_rank == 0: self.checkpoint_io.remove_checkpoint(filepath) def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: From 7dcf6c4aeeacb6338bb8d2fd26ac93172965d2d6 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 18 Feb 2022 16:52:44 +0530 Subject: [PATCH 19/38] Update --- .../trainer/connectors/accelerator_connector.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0952eba731e18..12e0356a443d3 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -487,7 +487,13 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: if self._is_slurm_managing_tasks(): rank_zero_info("Multiprocessing is handled by SLURM.") return SLURMEnvironment() - for env_type in (BaguaEnvironment, TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment): + for env_type in ( + BaguaEnvironment, + TorchElasticEnvironment, + KubeflowEnvironment, + LSFEnvironment, + XLAEnvironment, + ): if env_type.detect(): return env_type() return LightningEnvironment() From d2700b804d4e22372c874b19f057bf9e7326e952 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 18 Feb 2022 18:47:07 +0530 Subject: [PATCH 20/38] Update --- pytorch_lightning/strategies/tpu_spawn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index ff7b1892b20e1..685e6ed82f808 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -140,13 +140,17 @@ def _setup_model(self, model: Module) -> Module: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + return dict(num_replicas=self.world_size, rank=self.global_rank) @property def is_distributed(self) -> bool: # HOST_WORLD_SIZE is None outside the xmp.spawn process return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1 + @property + def local_rank(self) -> int: + return self.cluster_environment.local_rank() + def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: TPUSpawnStrategy._validate_dataloader(dataloader) dataloader = MpDeviceLoader(dataloader, self.root_device) @@ -266,7 +270,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None: Args: filepath: Path to checkpoint """ - if self.tpu_local_core_rank == 0: + if self.local_rank == 0: self.checkpoint_io.remove_checkpoint(filepath) def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: From ba54586b089ffd88e5e02ca586a023611a9f6948 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 18 Feb 2022 19:51:13 +0530 Subject: [PATCH 21/38] Update --- dockers/tpu-tests/tpu_test_cases.jsonnet | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 65ab0fe6b0be3..e2477d58df309 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -33,7 +33,7 @@ local tputests = base.BaseTest { echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" coverage run --source=pytorch_lightning -m pytest -v --capture=no \ - tests/plugins/environments/test_xla_environment.py \ + // tests/plugins/environments/test_xla_environment.py \ tests/strategies/test_tpu_spawn.py \ tests/profiler/test_xla_profiler.py \ pytorch_lightning/utilities/xla_device.py \ From 983a9e77856ae28e6a199ce85dee3ecf45cea911 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Feb 2022 14:06:19 +0530 Subject: [PATCH 22/38] Update tests --- dockers/tpu-tests/tpu_test_cases.jsonnet | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index e2477d58df309..65ab0fe6b0be3 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -33,7 +33,7 @@ local tputests = base.BaseTest { echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" coverage run --source=pytorch_lightning -m pytest -v --capture=no \ - // tests/plugins/environments/test_xla_environment.py \ + tests/plugins/environments/test_xla_environment.py \ tests/strategies/test_tpu_spawn.py \ tests/profiler/test_xla_profiler.py \ pytorch_lightning/utilities/xla_device.py \ From 4c6f73a46755360c0830cf2907dd31a2fc28c00e Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 24 Feb 2022 15:10:00 +0530 Subject: [PATCH 23/38] Update tests --- tests/strategies/test_tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strategies/test_tpu_spawn.py b/tests/strategies/test_tpu_spawn.py index 138b3875d35d9..185f5bc14a5df 100644 --- a/tests/strategies/test_tpu_spawn.py +++ b/tests/strategies/test_tpu_spawn.py @@ -96,7 +96,7 @@ def test_model_tpu_one_core(): trainer = Trainer(tpu_cores=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) # assert training strategy attributes for device setting assert isinstance(trainer.strategy, TPUSpawnStrategy) - assert trainer.strategy.root_device == torch.device("xla", index=1) + # assert trainer.strategy.root_device == torch.device("xla", index=1) model = BoringModelTPU() trainer.fit(model) assert "PT_XLA_DEBUG" not in os.environ From 1d16728598bf576fabfd08a239e0dccad2a196ab Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 24 Feb 2022 15:38:42 +0530 Subject: [PATCH 24/38] Update tests --- tests/strategies/test_tpu_spawn.py | 34 +++++++++++++----------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/tests/strategies/test_tpu_spawn.py b/tests/strategies/test_tpu_spawn.py index 185f5bc14a5df..2531b8b185849 100644 --- a/tests/strategies/test_tpu_spawn.py +++ b/tests/strategies/test_tpu_spawn.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from unittest import mock from unittest.mock import MagicMock import pytest -import torch from torch.utils.data import DataLoader from pytorch_lightning import Trainer @@ -24,8 +22,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader -from tests.helpers.runif import RunIf -from tests.helpers.utils import pl_multi_process_test class BoringModelNoDataloaders(BoringModel): @@ -83,20 +79,20 @@ def test_error_process_iterable_dataloader(_): TPUSpawnStrategy(MagicMock()).process_dataloader(_loader_no_len) -class BoringModelTPU(BoringModel): - def on_train_start(self) -> None: - assert self.device == torch.device("xla", index=1) - assert os.environ.get("PT_XLA_DEBUG") == "1" +# class BoringModelTPU(BoringModel): +# def on_train_start(self) -> None: +# assert self.device == torch.device("xla", index=1) +# assert os.environ.get("PT_XLA_DEBUG") == "1" -@RunIf(tpu=True) -@pl_multi_process_test -def test_model_tpu_one_core(): - """Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy.""" - trainer = Trainer(tpu_cores=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) - # assert training strategy attributes for device setting - assert isinstance(trainer.strategy, TPUSpawnStrategy) - # assert trainer.strategy.root_device == torch.device("xla", index=1) - model = BoringModelTPU() - trainer.fit(model) - assert "PT_XLA_DEBUG" not in os.environ +# @RunIf(tpu=True) +# @pl_multi_process_test +# def test_model_tpu_one_core(): +# """Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy.""" +# trainer = Trainer(tpu_cores=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) +# # assert training strategy attributes for device setting +# assert isinstance(trainer.strategy, TPUSpawnStrategy) +# # assert trainer.strategy.root_device == torch.device("xla", index=1) +# model = BoringModelTPU() +# trainer.fit(model) +# assert "PT_XLA_DEBUG" not in os.environ From ac90b9694241409bd3aaac390b41cf97e2e8f3d1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 17:54:21 -0400 Subject: [PATCH 25/38] debug --- dockers/tpu-tests/tpu_test_cases.jsonnet | 2 +- pytorch_lightning/strategies/tpu_spawn.py | 20 +++++++---- .../connectors/accelerator_connector.py | 3 +- tests/models/test_tpu.py | 1 - tests/strategies/test_tpu_spawn.py | 34 +++++++++++-------- 5 files changed, 36 insertions(+), 24 deletions(-) diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 78df8b9e8da26..2823110d663d8 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -34,12 +34,12 @@ local tputests = base.BaseTest { export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" # TODO (@kaushikb11): Add device stats tests here coverage run --source=pytorch_lightning -m pytest -v --capture=no \ - tests/plugins/environments/test_xla_environment.py \ tests/strategies/test_tpu_spawn.py \ tests/profiler/test_xla_profiler.py \ pytorch_lightning/utilities/xla_device.py \ tests/accelerators/test_tpu.py \ tests/models/test_tpu.py + tests/plugins/environments/test_xla_environment.py \ test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index b7e3f01ee4765..8612a7e7595f2 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -65,7 +65,7 @@ def __init__( **_: Any, ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() - cluster_environment = cluster_environment or XLAEnvironment() + # cluster_environment = cluster_environment or XLAEnvironment() super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, @@ -78,6 +78,18 @@ def __init__( self.tpu_global_core_rank = 0 self.start_method = "fork" + @property + def global_rank(self) -> int: + return self.tpu_global_core_rank + + @property + def local_rank(self) -> int: + return self.tpu_local_core_rank + + @property + def world_size(self) -> int: + return xm.xrt_world_size() + @property def root_device(self) -> torch.device: return xm.xla_device() @@ -141,17 +153,13 @@ def _setup_model(self, model: Module) -> Module: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - return dict(num_replicas=self.world_size, rank=self.global_rank) + return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) @property def is_distributed(self) -> bool: # HOST_WORLD_SIZE is None outside the xmp.spawn process return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1 - @property - def local_rank(self) -> int: - return self.cluster_environment.local_rank() - def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: TPUSpawnStrategy._validate_dataloader(dataloader) dataloader = MpDeviceLoader(dataloader, self.root_device) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index a2cae866e9459..11b1743ee1d37 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -202,6 +202,7 @@ def __init__( # 3. Instantiate ClusterEnvironment self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment() + # assert isinstance(self.cluster_environment, XLAEnvironment) # 4. Instantiate Strategy - Part 1 if self._strategy_flag is None: @@ -539,7 +540,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, - XLAEnvironment, + # XLAEnvironment, ): if env_type.detect(): return env_type() diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b0ca2ad873316..82f10875a1dec 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -24,7 +24,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.plugins.environments import XLAEnvironment from pytorch_lightning.strategies import TPUSpawnStrategy from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync from pytorch_lightning.utilities import _TPU_AVAILABLE diff --git a/tests/strategies/test_tpu_spawn.py b/tests/strategies/test_tpu_spawn.py index 2531b8b185849..407c8d646502d 100644 --- a/tests/strategies/test_tpu_spawn.py +++ b/tests/strategies/test_tpu_spawn.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from unittest import mock from unittest.mock import MagicMock import pytest +import torch from torch.utils.data import DataLoader from pytorch_lightning import Trainer @@ -22,6 +24,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader +from tests.helpers.runif import RunIf +from tests.helpers.utils import pl_multi_process_test class BoringModelNoDataloaders(BoringModel): @@ -79,20 +83,20 @@ def test_error_process_iterable_dataloader(_): TPUSpawnStrategy(MagicMock()).process_dataloader(_loader_no_len) -# class BoringModelTPU(BoringModel): -# def on_train_start(self) -> None: -# assert self.device == torch.device("xla", index=1) -# assert os.environ.get("PT_XLA_DEBUG") == "1" +class BoringModelTPU(BoringModel): + def on_train_start(self) -> None: + assert self.device == torch.device("xla", index=1) + assert os.environ.get("PT_XLA_DEBUG") == "1" -# @RunIf(tpu=True) -# @pl_multi_process_test -# def test_model_tpu_one_core(): -# """Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy.""" -# trainer = Trainer(tpu_cores=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) -# # assert training strategy attributes for device setting -# assert isinstance(trainer.strategy, TPUSpawnStrategy) -# # assert trainer.strategy.root_device == torch.device("xla", index=1) -# model = BoringModelTPU() -# trainer.fit(model) -# assert "PT_XLA_DEBUG" not in os.environ +@RunIf(tpu=True) +@pl_multi_process_test +def test_model_tpu_one_core(): + """Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy.""" + trainer = Trainer(accelerator="tpu", devices=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) + # assert training strategy attributes for device setting + assert isinstance(trainer.strategy, TPUSpawnStrategy) + assert trainer.strategy.root_device == torch.device("xla", index=1) + model = BoringModelTPU() + trainer.fit(model) + assert "PT_XLA_DEBUG" not in os.environ From 51c923939866c3fc82531103421ea3aee8856ea0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 18:19:30 -0400 Subject: [PATCH 26/38] select env --- pytorch_lightning/strategies/tpu_spawn.py | 2 +- pytorch_lightning/trainer/connectors/accelerator_connector.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 8612a7e7595f2..ec2654d414552 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -65,7 +65,7 @@ def __init__( **_: Any, ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() - # cluster_environment = cluster_environment or XLAEnvironment() + cluster_environment = cluster_environment or XLAEnvironment() super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 11b1743ee1d37..a2cae866e9459 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -202,7 +202,6 @@ def __init__( # 3. Instantiate ClusterEnvironment self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment() - # assert isinstance(self.cluster_environment, XLAEnvironment) # 4. Instantiate Strategy - Part 1 if self._strategy_flag is None: @@ -540,7 +539,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, - # XLAEnvironment, + XLAEnvironment, ): if env_type.detect(): return env_type() From e3bfbacd8ef6955ee3e30150a5f9d67b9bf24596 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 18:34:35 -0400 Subject: [PATCH 27/38] debug --- pytorch_lightning/strategies/tpu_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index ec2654d414552..2a7f4c1779f4c 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -65,7 +65,6 @@ def __init__( **_: Any, ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() - cluster_environment = cluster_environment or XLAEnvironment() super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, From 478a70533763a764ee3ca077efbaddaf2ac32f4c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 18:55:55 -0400 Subject: [PATCH 28/38] debug --- pytorch_lightning/strategies/tpu_spawn.py | 2 +- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 2a7f4c1779f4c..252b0de9cf6f0 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -68,7 +68,7 @@ def __init__( super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, - cluster_environment=cluster_environment, + cluster_environment=XLAEnvironment(), checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index a2cae866e9459..c1057b4ccd00e 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -539,7 +539,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, - XLAEnvironment, + # XLAEnvironment, ): if env_type.detect(): return env_type() From 1f34ba70172e72255705b24d8d2dd8ffcc83181a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 19:12:27 -0400 Subject: [PATCH 29/38] debug --- pytorch_lightning/strategies/tpu_spawn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 252b0de9cf6f0..c2698f9d95409 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -58,7 +58,6 @@ def __init__( self, accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[int]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[XLACheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, @@ -152,7 +151,7 @@ def _setup_model(self, model: Module) -> Module: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + return dict(num_replicas=self.world_size, rank=self.global_rank) @property def is_distributed(self) -> bool: From cbbd80e332f37c6c63ac5499fd7d164514201115 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 19:20:52 -0400 Subject: [PATCH 30/38] debug --- pytorch_lightning/strategies/tpu_spawn.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index c2698f9d95409..f7e60a1180590 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -76,18 +76,6 @@ def __init__( self.tpu_global_core_rank = 0 self.start_method = "fork" - @property - def global_rank(self) -> int: - return self.tpu_global_core_rank - - @property - def local_rank(self) -> int: - return self.tpu_local_core_rank - - @property - def world_size(self) -> int: - return xm.xrt_world_size() - @property def root_device(self) -> torch.device: return xm.xla_device() From bf514878d821cd29675f18f2a0ae0e58a7826b16 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 19:32:33 -0400 Subject: [PATCH 31/38] remove --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index c1057b4ccd00e..5bd53604fa39f 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -50,7 +50,6 @@ LSFEnvironment, SLURMEnvironment, TorchElasticEnvironment, - XLAEnvironment, ) from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm from pytorch_lightning.strategies import ( @@ -539,7 +538,6 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, - # XLAEnvironment, ): if env_type.detect(): return env_type() From 976ee6c6ed12bce1fcc2de7ddeeb0977fc27db05 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 19:59:30 -0400 Subject: [PATCH 32/38] format --- .../trainer/connectors/accelerator_connector.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 5bd53604fa39f..753234bc21d83 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -533,12 +533,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: if self._is_slurm_managing_tasks(): rank_zero_info("Multiprocessing is handled by SLURM.") return SLURMEnvironment() - for env_type in ( - BaguaEnvironment, - TorchElasticEnvironment, - KubeflowEnvironment, - LSFEnvironment, - ): + for env_type in (BaguaEnvironment, TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment): if env_type.detect(): return env_type() return LightningEnvironment() From 39e9aa8970b77c9e212f083bbf5ca7bfd6797494 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 20:03:01 -0400 Subject: [PATCH 33/38] add changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cd7495990bc4..681b8c32c96a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,7 +60,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938)) -- +- Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330)) + + ### Changed From 8d1b7c9320894ca0d8f59ee9637ec32a19f438a8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 20:05:37 -0400 Subject: [PATCH 34/38] fix test entry --- dockers/tpu-tests/tpu_test_cases.jsonnet | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 2823110d663d8..cf2dd6ffed03d 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -38,8 +38,8 @@ local tputests = base.BaseTest { tests/profiler/test_xla_profiler.py \ pytorch_lightning/utilities/xla_device.py \ tests/accelerators/test_tpu.py \ - tests/models/test_tpu.py - tests/plugins/environments/test_xla_environment.py \ + tests/models/test_tpu.py \ + tests/plugins/environments/test_xla_environment.py test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml From b8224b03b6b6253eb8e97d56f4620113b1940a62 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 14 May 2022 20:20:54 -0400 Subject: [PATCH 35/38] remove unused import --- pytorch_lightning/strategies/tpu_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index f7e60a1180590..b28db40779ec2 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -21,7 +21,6 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.plugins import ClusterEnvironment from pytorch_lightning.plugins.environments import XLAEnvironment from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin From 06232f8b2ad9eb5590ce3530d76d4b2874e2c55b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 15 May 2022 12:09:19 -0400 Subject: [PATCH 36/38] Apply suggestions from code review Co-authored-by: Rohit Gupta --- tests/plugins/environments/test_xla_environment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/plugins/environments/test_xla_environment.py b/tests/plugins/environments/test_xla_environment.py index bd4af5f144fc1..60d87497ad05e 100644 --- a/tests/plugins/environments/test_xla_environment.py +++ b/tests/plugins/environments/test_xla_environment.py @@ -22,7 +22,7 @@ @RunIf(tpu=True) -@mock.patch.dict(os.environ, {}) +@mock.patch.dict(os.environ, {}, clear=True) def test_default_attributes(): """Test the default attributes when no environment variables are set.""" env = XLAEnvironment() @@ -51,6 +51,7 @@ def test_default_attributes(): "XRT_SHARD_LOCAL_ORDINAL": "2", "XRT_HOST_ORDINAL": "3", }, + clear=True ) def test_attributes_from_environment_variables(): """Test that the default cluster environment takes the attributes from the environment variables.""" From b2f86f2223850fd72d7b11d01330699dc9a04cac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 May 2022 16:10:38 +0000 Subject: [PATCH 37/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/plugins/environments/test_xla_environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/environments/test_xla_environment.py b/tests/plugins/environments/test_xla_environment.py index 60d87497ad05e..ea3219f20f6f6 100644 --- a/tests/plugins/environments/test_xla_environment.py +++ b/tests/plugins/environments/test_xla_environment.py @@ -51,7 +51,7 @@ def test_default_attributes(): "XRT_SHARD_LOCAL_ORDINAL": "2", "XRT_HOST_ORDINAL": "3", }, - clear=True + clear=True, ) def test_attributes_from_environment_variables(): """Test that the default cluster environment takes the attributes from the environment variables.""" From e6a4de11eab45e4c358f42a4b39d0cad96683a06 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Jun 2022 21:48:22 +0000 Subject: [PATCH 38/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/plugins/environments/test_xla_environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/environments/test_xla_environment.py b/tests/plugins/environments/test_xla_environment.py index ea3219f20f6f6..21ef9bb5bf171 100644 --- a/tests/plugins/environments/test_xla_environment.py +++ b/tests/plugins/environments/test_xla_environment.py @@ -15,10 +15,10 @@ from unittest import mock import pytest +from tests.helpers.runif import RunIf import pytorch_lightning as pl from pytorch_lightning.plugins.environments import XLAEnvironment -from tests.helpers.runif import RunIf @RunIf(tpu=True)