diff --git a/CHANGELOG.md b/CHANGELOG.md index a83ef6a55d515..deb7041ac91bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,11 +67,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `teardown()` method to `Accelerator` ([#11935](https://github.com/PyTorchLightning/pytorch-lightning/pull/11935)) -- - Added a `timeout` argument to `DDPStrategy`. ([#13244](https://github.com/PyTorchLightning/pytorch-lightning/pull/13244)) -- + + +- Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330)) + ### Changed diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 952cdd74ad3d8..5ec110729bd93 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -38,7 +38,8 @@ local tputests = base.BaseTest { strategies/test_tpu_spawn.py \ profiler/test_xla_profiler.py \ accelerators/test_tpu.py \ - models/test_tpu.py + models/test_tpu.py \ + plugins/environments/test_xla_environment.py test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 7fe1e58661fbb..a147340d36df4 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -198,6 +198,7 @@ environments LSFEnvironment SLURMEnvironment TorchElasticEnvironment + XLAEnvironment io "" diff --git a/docs/source-pytorch/extensions/plugins.rst b/docs/source-pytorch/extensions/plugins.rst index 392a07219b45b..6ea8d42815f46 100644 --- a/docs/source-pytorch/extensions/plugins.rst +++ b/docs/source-pytorch/extensions/plugins.rst @@ -117,3 +117,4 @@ You can define the interface of your own cluster environment based on the requir LSFEnvironment SLURMEnvironment TorchElasticEnvironment + XLAEnvironment diff --git a/src/pytorch_lightning/plugins/environments/__init__.py b/src/pytorch_lightning/plugins/environments/__init__.py index eab64cfe2daf5..3417f6007041b 100644 --- a/src/pytorch_lightning/plugins/environments/__init__.py +++ b/src/pytorch_lightning/plugins/environments/__init__.py @@ -18,3 +18,4 @@ from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.xla_environment import XLAEnvironment # noqa: F401 diff --git a/src/pytorch_lightning/plugins/environments/xla_environment.py b/src/pytorch_lightning/plugins/environments/xla_environment.py new file mode 100644 index 0000000000000..a78ebeb36a6a4 --- /dev/null +++ b/src/pytorch_lightning/plugins/environments/xla_environment.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities import _TPU_AVAILABLE + +if _TPU_AVAILABLE: + import torch_xla.core.xla_env_vars as xenv + import torch_xla.core.xla_model as xm + +log = logging.getLogger(__name__) + + +class XLAEnvironment(ClusterEnvironment): + """Cluster environment for training on a TPU Pod with the `PyTorch/XLA `_ library. + + A list of environment variables set by XLA can be found + `here `_. + """ + + @property + def creates_processes_externally(self) -> bool: + return False + + @property + def main_address(self) -> str: + return os.environ[xenv.TPU_MESH_CTLER_ADDR] + + @property + def main_port(self) -> int: + return int(os.environ[xenv.TPU_MESH_CTLER_PORT]) + + @staticmethod + def detect() -> bool: + return _TPU_AVAILABLE + + def world_size(self) -> int: + return xm.xrt_world_size() + + def set_world_size(self, size: int) -> None: + log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return xm.get_ordinal() + + def set_global_rank(self, rank: int) -> None: + log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + + def local_rank(self) -> int: + return xm.get_local_ordinal() + + def node_rank(self) -> int: + return int(os.environ.get(xenv.HOST_ORDINAL, 0)) diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 9fd3796f89009..464eb6b57d4de 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -22,6 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.plugins.environments import XLAEnvironment from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy @@ -67,6 +68,7 @@ def __init__( super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, + cluster_environment=XLAEnvironment(), checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) @@ -75,18 +77,6 @@ def __init__( self.tpu_global_core_rank = 0 self.start_method = "fork" - @property - def global_rank(self) -> int: - return self.tpu_global_core_rank - - @property - def local_rank(self) -> int: - return self.tpu_local_core_rank - - @property - def world_size(self) -> int: - return xm.xrt_world_size() - @property def root_device(self) -> torch.device: return xm.xla_device() @@ -150,7 +140,7 @@ def _setup_model(self, model: Module) -> Module: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + return dict(num_replicas=self.world_size, rank=self.global_rank) @property def is_distributed(self) -> bool: diff --git a/tests/plugins/environments/test_xla_environment.py b/tests/plugins/environments/test_xla_environment.py new file mode 100644 index 0000000000000..21ef9bb5bf171 --- /dev/null +++ b/tests/plugins/environments/test_xla_environment.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import pytest +from tests.helpers.runif import RunIf + +import pytorch_lightning as pl +from pytorch_lightning.plugins.environments import XLAEnvironment + + +@RunIf(tpu=True) +@mock.patch.dict(os.environ, {}, clear=True) +def test_default_attributes(): + """Test the default attributes when no environment variables are set.""" + env = XLAEnvironment() + assert not env.creates_processes_externally + assert env.world_size() == 1 + assert env.global_rank() == 0 + assert env.local_rank() == 0 + assert env.node_rank() == 0 + + with pytest.raises(KeyError): + # main_address is required to be passed as env variable + _ = env.main_address + with pytest.raises(KeyError): + # main_port is required to be passed as env variable + _ = env.main_port + + +@RunIf(tpu=True) +@mock.patch.dict( + os.environ, + { + "TPU_MESH_CONTROLLER_ADDRESS": "1.2.3.4", + "TPU_MESH_CONTROLLER_PORT": "500", + "XRT_SHARD_WORLD_SIZE": "1", + "XRT_SHARD_ORDINAL": "0", + "XRT_SHARD_LOCAL_ORDINAL": "2", + "XRT_HOST_ORDINAL": "3", + }, + clear=True, +) +def test_attributes_from_environment_variables(): + """Test that the default cluster environment takes the attributes from the environment variables.""" + env = XLAEnvironment() + assert env.main_address == "1.2.3.4" + assert env.main_port == 500 + assert env.world_size() == 1 + assert env.global_rank() == 0 + assert env.local_rank() == 2 + assert env.node_rank() == 3 + env.set_global_rank(100) + assert env.global_rank() == 0 + env.set_world_size(100) + assert env.world_size() == 1 + + +def test_detect(monkeypatch): + """Test the detection of a xla environment configuration.""" + monkeypatch.setattr(pl.plugins.environments.xla_environment, "_TPU_AVAILABLE", False) + assert not XLAEnvironment.detect() + + monkeypatch.setattr(pl.plugins.environments.xla_environment, "_TPU_AVAILABLE", True) + assert XLAEnvironment.detect()