Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XLAEnvironment plugin #11330

Merged
merged 45 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
b714395
add xla environment class
awaelchli Jan 5, 2022
078a01f
add api reference
awaelchli Jan 5, 2022
64c57c4
integrate
awaelchli Jan 5, 2022
cee674b
use xenv
awaelchli Jan 5, 2022
f509dc9
remove properties
awaelchli Jan 5, 2022
7d192cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2022
ce427e5
test environment selection
awaelchli Jan 5, 2022
460349f
Merge branch 'master' into feature/xla_environment
awaelchli Feb 6, 2022
6df74bc
update
awaelchli Feb 6, 2022
35a7474
Merge branch 'master' into feature/xla_environment
awaelchli Feb 7, 2022
46cd7a3
notebooks
awaelchli Feb 7, 2022
ad7acc0
notebooks
awaelchli Feb 7, 2022
e5fae8f
update
awaelchli Feb 7, 2022
d084d51
update
awaelchli Feb 7, 2022
5dca2f8
test tests
awaelchli Feb 7, 2022
472e200
include test case
awaelchli Feb 7, 2022
1833b62
fix test
awaelchli Feb 7, 2022
dcf3ccb
fix
awaelchli Feb 13, 2022
970c1b0
Merge branch 'master' into feature/xla_environment
awaelchli Feb 13, 2022
1131bf7
reset
awaelchli Feb 14, 2022
32390a9
temp fix
kaushikb11 Feb 16, 2022
2fcdb64
Merge branch 'master' into feature/xla_environment
kaushikb11 Feb 18, 2022
7dcf6c4
Update
kaushikb11 Feb 18, 2022
d2700b8
Update
kaushikb11 Feb 18, 2022
ba54586
Update
kaushikb11 Feb 18, 2022
983a9e7
Update tests
kaushikb11 Feb 23, 2022
4c6f73a
Update tests
kaushikb11 Feb 24, 2022
1d16728
Update tests
kaushikb11 Feb 24, 2022
f1d9cd9
Merge branch 'master' into feature/xla_environment
awaelchli Apr 5, 2022
d3fac36
Merge branch 'master' into feature/xla_environment
awaelchli May 14, 2022
ac90b96
debug
awaelchli May 14, 2022
51c9239
select env
awaelchli May 14, 2022
e3bfbac
debug
awaelchli May 14, 2022
478a705
debug
awaelchli May 14, 2022
1f34ba7
debug
awaelchli May 14, 2022
cbbd80e
debug
awaelchli May 14, 2022
bf51487
remove
awaelchli May 14, 2022
976ee6c
format
awaelchli May 14, 2022
39e9aa8
add changelog
awaelchli May 15, 2022
8d1b7c9
fix test entry
awaelchli May 15, 2022
b8224b0
remove unused import
awaelchli May 15, 2022
06232f8
Apply suggestions from code review
awaelchli May 15, 2022
b2f86f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
9be2f12
Merge branch 'master' into feature/xla_environment
Borda Jun 21, 2022
e6a4de1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Added `teardown()` method to `Accelerator` ([#11935](https://github.com/PyTorchLightning/pytorch-lightning/pull/11935))
-


- Added a `timeout` argument to `DDPStrategy`. ([#13244](https://github.com/PyTorchLightning/pytorch-lightning/pull/13244))
-


- Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330))



### Changed
Expand Down
3 changes: 2 additions & 1 deletion dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ local tputests = base.BaseTest {
strategies/test_tpu_spawn.py \
profiler/test_xla_profiler.py \
accelerators/test_tpu.py \
models/test_tpu.py
models/test_tpu.py \
plugins/environments/test_xla_environment.py
test_exit_code=$?
echo "\n||| END PYTEST LOGS |||\n"
coverage xml
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ environments
LSFEnvironment
SLURMEnvironment
TorchElasticEnvironment
XLAEnvironment

io
""
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,4 @@ You can define the interface of your own cluster environment based on the requir
LSFEnvironment
SLURMEnvironment
TorchElasticEnvironment
XLAEnvironment
1 change: 1 addition & 0 deletions src/pytorch_lightning/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.xla_environment import XLAEnvironment # noqa: F401
66 changes: 66 additions & 0 deletions src/pytorch_lightning/plugins/environments/xla_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm

log = logging.getLogger(__name__)


class XLAEnvironment(ClusterEnvironment):
"""Cluster environment for training on a TPU Pod with the `PyTorch/XLA <https://pytorch.org/xla>`_ library.

A list of environment variables set by XLA can be found
`here <https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_env_vars.py>`_.
"""

@property
def creates_processes_externally(self) -> bool:
return False

@property
def main_address(self) -> str:
return os.environ[xenv.TPU_MESH_CTLER_ADDR]

@property
def main_port(self) -> int:
return int(os.environ[xenv.TPU_MESH_CTLER_PORT])

@staticmethod
def detect() -> bool:
return _TPU_AVAILABLE

def world_size(self) -> int:
return xm.xrt_world_size()

def set_world_size(self, size: int) -> None:
log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return xm.get_ordinal()

def set_global_rank(self, rank: int) -> None:
log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
return xm.get_local_ordinal()

def node_rank(self) -> int:
return int(os.environ.get(xenv.HOST_ORDINAL, 0))
16 changes: 3 additions & 13 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.environments import XLAEnvironment
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=XLAEnvironment(),
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
Expand All @@ -75,18 +77,6 @@ def __init__(
self.tpu_global_core_rank = 0
self.start_method = "fork"

@property
def global_rank(self) -> int:
return self.tpu_global_core_rank

@property
def local_rank(self) -> int:
return self.tpu_local_core_rank

@property
def world_size(self) -> int:
return xm.xrt_world_size()

@property
def root_device(self) -> torch.device:
return xm.xla_device()
Expand Down Expand Up @@ -150,7 +140,7 @@ def _setup_model(self, model: Module) -> Module:

@property
def distributed_sampler_kwargs(self) -> Dict[str, int]:
return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
return dict(num_replicas=self.world_size, rank=self.global_rank)

@property
def is_distributed(self) -> bool:
Expand Down
77 changes: 77 additions & 0 deletions tests/plugins/environments/test_xla_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock

import pytest
from tests.helpers.runif import RunIf

import pytorch_lightning as pl
from pytorch_lightning.plugins.environments import XLAEnvironment


@RunIf(tpu=True)
@mock.patch.dict(os.environ, {}, clear=True)
def test_default_attributes():
"""Test the default attributes when no environment variables are set."""
env = XLAEnvironment()
assert not env.creates_processes_externally
assert env.world_size() == 1
assert env.global_rank() == 0
assert env.local_rank() == 0
assert env.node_rank() == 0

with pytest.raises(KeyError):
# main_address is required to be passed as env variable
_ = env.main_address
with pytest.raises(KeyError):
# main_port is required to be passed as env variable
_ = env.main_port


@RunIf(tpu=True)
@mock.patch.dict(
os.environ,
{
"TPU_MESH_CONTROLLER_ADDRESS": "1.2.3.4",
"TPU_MESH_CONTROLLER_PORT": "500",
"XRT_SHARD_WORLD_SIZE": "1",
"XRT_SHARD_ORDINAL": "0",
"XRT_SHARD_LOCAL_ORDINAL": "2",
"XRT_HOST_ORDINAL": "3",
},
clear=True,
)
def test_attributes_from_environment_variables():
"""Test that the default cluster environment takes the attributes from the environment variables."""
env = XLAEnvironment()
assert env.main_address == "1.2.3.4"
assert env.main_port == 500
assert env.world_size() == 1
assert env.global_rank() == 0
assert env.local_rank() == 2
assert env.node_rank() == 3
env.set_global_rank(100)
assert env.global_rank() == 0
env.set_world_size(100)
assert env.world_size() == 1


def test_detect(monkeypatch):
"""Test the detection of a xla environment configuration."""
monkeypatch.setattr(pl.plugins.environments.xla_environment, "_TPU_AVAILABLE", False)
assert not XLAEnvironment.detect()

monkeypatch.setattr(pl.plugins.environments.xla_environment, "_TPU_AVAILABLE", True)
assert XLAEnvironment.detect()