Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XLAEnvironment plugin #11330

Merged
merged 45 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
b714395
add xla environment class
awaelchli Jan 5, 2022
078a01f
add api reference
awaelchli Jan 5, 2022
64c57c4
integrate
awaelchli Jan 5, 2022
cee674b
use xenv
awaelchli Jan 5, 2022
f509dc9
remove properties
awaelchli Jan 5, 2022
7d192cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2022
ce427e5
test environment selection
awaelchli Jan 5, 2022
460349f
Merge branch 'master' into feature/xla_environment
awaelchli Feb 6, 2022
6df74bc
update
awaelchli Feb 6, 2022
35a7474
Merge branch 'master' into feature/xla_environment
awaelchli Feb 7, 2022
46cd7a3
notebooks
awaelchli Feb 7, 2022
ad7acc0
notebooks
awaelchli Feb 7, 2022
e5fae8f
update
awaelchli Feb 7, 2022
d084d51
update
awaelchli Feb 7, 2022
5dca2f8
test tests
awaelchli Feb 7, 2022
472e200
include test case
awaelchli Feb 7, 2022
1833b62
fix test
awaelchli Feb 7, 2022
dcf3ccb
fix
awaelchli Feb 13, 2022
970c1b0
Merge branch 'master' into feature/xla_environment
awaelchli Feb 13, 2022
1131bf7
reset
awaelchli Feb 14, 2022
32390a9
temp fix
kaushikb11 Feb 16, 2022
2fcdb64
Merge branch 'master' into feature/xla_environment
kaushikb11 Feb 18, 2022
7dcf6c4
Update
kaushikb11 Feb 18, 2022
d2700b8
Update
kaushikb11 Feb 18, 2022
ba54586
Update
kaushikb11 Feb 18, 2022
983a9e7
Update tests
kaushikb11 Feb 23, 2022
4c6f73a
Update tests
kaushikb11 Feb 24, 2022
1d16728
Update tests
kaushikb11 Feb 24, 2022
f1d9cd9
Merge branch 'master' into feature/xla_environment
awaelchli Apr 5, 2022
d3fac36
Merge branch 'master' into feature/xla_environment
awaelchli May 14, 2022
ac90b96
debug
awaelchli May 14, 2022
51c9239
select env
awaelchli May 14, 2022
e3bfbac
debug
awaelchli May 14, 2022
478a705
debug
awaelchli May 14, 2022
1f34ba7
debug
awaelchli May 14, 2022
cbbd80e
debug
awaelchli May 14, 2022
bf51487
remove
awaelchli May 14, 2022
976ee6c
format
awaelchli May 14, 2022
39e9aa8
add changelog
awaelchli May 15, 2022
8d1b7c9
fix test entry
awaelchli May 15, 2022
b8224b0
remove unused import
awaelchli May 15, 2022
06232f8
Apply suggestions from code review
awaelchli May 15, 2022
b2f86f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
9be2f12
Merge branch 'master' into feature/xla_environment
Borda Jun 21, 2022
e6a4de1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ Cluster Environments
TorchElasticEnvironment
KubeflowEnvironment
SLURMEnvironment
XLAEnvironment

Checkpoint IO Plugins
^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ Cluster Environments
TorchElasticEnvironment
KubeflowEnvironment
SLURMEnvironment
XLAEnvironment
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.xla_environment import XLAEnvironment # noqa: F401
62 changes: 62 additions & 0 deletions pytorch_lightning/plugins/environments/xla_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv


class XLAEnvironment(ClusterEnvironment):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""XLAEnvironment.

A list of environment variables set by XLA can be found
`here <https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_env_vars.py>`_.
"""

@property
def creates_processes_externally(self) -> bool:
return False

@property
def main_address(self) -> str:
return os.environ[xenv.TPU_MESH_CTLER_ADDR]

@property
def main_port(self) -> int:
return int(os.environ[xenv.TPU_MESH_CTLER_PORT])

@staticmethod
def detect() -> bool:
return _TPU_AVAILABLE

def world_size(self) -> int:
return int(os.environ.get(xenv.WORLD_SIZE, 1))

def set_world_size(self, size: int) -> None:
pass

def global_rank(self) -> int:
return int(os.environ.get(xenv.ORDINAL, 0))

def set_global_rank(self, rank: int) -> None:
pass

def local_rank(self) -> int:
return int(os.environ.get(xenv.LOCAL_ORDINAL, 0))

def node_rank(self) -> int:
return int(os.environ.get(xenv.HOST_ORDINAL, 0))
20 changes: 6 additions & 14 deletions pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins import ClusterEnvironment
from pytorch_lightning.plugins.environments import XLAEnvironment
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnStrategy
Expand Down Expand Up @@ -55,15 +57,18 @@ def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[int]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[XLACheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
debug: bool = False,
**_: Any,
) -> None:
checkpoint_io = checkpoint_io or XLACheckpointIO()
cluster_environment = cluster_environment or XLAEnvironment()
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
Expand All @@ -72,18 +77,6 @@ def __init__(
self.tpu_global_core_rank = 0
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self.start_method = "fork"

@property
def global_rank(self) -> int:
return self.tpu_global_core_rank

@property
def local_rank(self) -> int:
return self.tpu_local_core_rank

@property
def world_size(self) -> int:
return xm.xrt_world_size()

@property
def root_device(self) -> torch.device:
return xm.xla_device()
Expand Down Expand Up @@ -148,8 +141,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]:

@property
def is_distributed(self) -> bool:
# HOST_WORLD_SIZE is None outside the xmp.spawn process
return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1
return self.world_size > 1
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
TPUSpawnStrategy._validate_dataloader(dataloader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
LSFEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
XLAEnvironment,
)
from pytorch_lightning.strategies import (
DataParallelStrategy,
Expand Down Expand Up @@ -808,7 +809,7 @@ def select_cluster_environment(self) -> ClusterEnvironment:
rank_zero_info("Multiprocessing is handled by SLURM.")
return SLURMEnvironment()

for env_type in (TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment):
for env_type in (TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, XLAEnvironment):
if env_type.detect():
return env_type()

Expand Down