Skip to content

Commit

Permalink
Avoid circular imports when lightning-habana or lightning-graphcore i…
Browse files Browse the repository at this point in the history
…s installed (#18226)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored Aug 3, 2023
1 parent 28c401c commit 0aeeb60
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 31 deletions.
4 changes: 2 additions & 2 deletions src/lightning/pytorch/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE
from lightning.pytorch.utilities.imports import _lightning_graphcore_available
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -123,7 +123,7 @@ def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None:
datahook_selector = trainer._data_connector._datahook_selector
assert datahook_selector is not None
for hook in batch_transfer_hooks:
if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

# TODO: This code could be done in a hook in the IPUAccelerator as it's a simple error check
Expand Down
26 changes: 13 additions & 13 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
from lightning.pytorch.utilities.imports import (
_LIGHTNING_BAGUA_AVAILABLE,
_LIGHTNING_COLOSSALAI_AVAILABLE,
_LIGHTNING_GRAPHCORE_AVAILABLE,
_LIGHTNING_HABANA_AVAILABLE,
_lightning_graphcore_available,
_lightning_habana_available,
)
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn

Expand Down Expand Up @@ -338,12 +338,12 @@ def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability."""
if XLAAccelerator.is_available():
return "tpu"
if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

if IPUAccelerator.is_available():
return "ipu"
if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

if HPUAccelerator.is_available():
Expand Down Expand Up @@ -411,7 +411,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:

def _choose_strategy(self) -> Union[Strategy, str]:
if self._accelerator_flag == "ipu":
if not _LIGHTNING_GRAPHCORE_AVAILABLE:
if not _lightning_graphcore_available():
raise ImportError(
"You have passed `accelerator='ipu'` but the IPU integration is not installed."
" Please run `pip install lightning-graphcore` or check out"
Expand All @@ -421,7 +421,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:

return IPUStrategy.strategy_name
if self._accelerator_flag == "hpu":
if not _LIGHTNING_HABANA_AVAILABLE:
if not _lightning_habana_available():
raise ImportError(
"You have asked for HPU but you miss install related integration."
" Please run `pip install lightning-habana` or see for further instructions"
Expand Down Expand Up @@ -490,7 +490,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self._precision_plugin_flag, PrecisionPlugin):
return self._precision_plugin_flag

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator, IPUPrecision

# TODO: For the strategies that have a fixed precision class, we don't really need this logic
Expand All @@ -500,7 +500,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self.accelerator, IPUAccelerator):
return IPUPrecision(self._precision_flag)

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator, HPUPrecisionPlugin

if isinstance(self.accelerator, HPUAccelerator):
Expand Down Expand Up @@ -567,7 +567,7 @@ def _validate_precision_choice(self) -> None:
f"The `XLAAccelerator` can only be used with a `XLAPrecisionPlugin`,"
f" found: {self._precision_plugin_flag}."
)
if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

if isinstance(self.accelerator, HPUAccelerator) and self._precision_flag not in (
Expand Down Expand Up @@ -622,7 +622,7 @@ def _lazy_init_strategy(self) -> None:
f" found {self.strategy.__class__.__name__}."
)

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator, HPUParallelStrategy, SingleHPUStrategy

if isinstance(self.accelerator, HPUAccelerator) and not isinstance(
Expand All @@ -641,7 +641,7 @@ def is_distributed(self) -> bool:
DeepSpeedStrategy,
XLAStrategy,
]
if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUParallelStrategy

distributed_strategies.append(HPUParallelStrategy)
Expand Down Expand Up @@ -694,7 +694,7 @@ def _register_external_accelerators_and_strategies() -> None:
if "bagua" not in StrategyRegistry:
BaguaStrategy.register_strategies(StrategyRegistry)

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator, HPUParallelStrategy, SingleHPUStrategy

# TODO: Prevent registering multiple times
Expand All @@ -705,7 +705,7 @@ def _register_external_accelerators_and_strategies() -> None:
if "hpu_single" not in StrategyRegistry:
SingleHPUStrategy.register_strategies(StrategyRegistry)

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator, IPUStrategy

# TODO: Prevent registering multiple times
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE
from lightning.pytorch.utilities.imports import _lightning_graphcore_available
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -165,7 +165,7 @@ def attach_datamodule(
datamodule.trainer = trainer

def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

# `DistributedSampler` is never used with `poptorch.DataLoader`
Expand All @@ -190,7 +190,7 @@ def _prepare_dataloader(self, dataloader: object, shuffle: bool, mode: RunningSt
if not isinstance(dataloader, DataLoader):
return dataloader

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

# IPUs use a custom `poptorch.DataLoader` which we might need to convert to
Expand Down
10 changes: 5 additions & 5 deletions src/lightning/pytorch/trainer/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
XLAProfiler,
)
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE, _LIGHTNING_HABANA_AVAILABLE
from lightning.pytorch.utilities.imports import _lightning_graphcore_available, _lightning_habana_available
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn


Expand Down Expand Up @@ -158,7 +158,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores")

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

num_ipus = trainer.num_devices if isinstance(trainer.accelerator, IPUAccelerator) else 0
Expand All @@ -168,7 +168,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
ipu_available = False
rank_zero_info(f"IPU available: {ipu_available}, using: {num_ipus} IPUs")

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

num_hpus = trainer.num_devices if isinstance(trainer.accelerator, HPUAccelerator) else 0
Expand All @@ -192,13 +192,13 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
if XLAAccelerator.is_available() and not isinstance(trainer.accelerator, XLAAccelerator):
rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.")

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator

if IPUAccelerator.is_available() and not isinstance(trainer.accelerator, IPUAccelerator):
rank_zero_warn("IPU available but not used. You can set it by doing `Trainer(accelerator='ipu')`.")

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
Expand Down
14 changes: 12 additions & 2 deletions src/lightning/pytorch/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,15 @@ def _try_import_module(module_name: str) -> bool:
return False


_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore") and _try_import_module("lightning_graphcore")
_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana") and _try_import_module("lightning_habana")
@functools.lru_cache(maxsize=1)
def _lightning_graphcore_available() -> bool:
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_graphcore`
# also imports Lightning
return bool(RequirementCache("lightning-graphcore")) and _try_import_module("lightning_graphcore")


@functools.lru_cache(maxsize=1)
def _lightning_habana_available() -> bool:
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_habana`
# also imports Lightning
return bool(RequirementCache("lightning-habana")) and _try_import_module("lightning_habana")
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector, _set_torch_flags
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE, _LIGHTNING_HABANA_AVAILABLE
from lightning.pytorch.utilities.imports import _lightning_graphcore_available, _lightning_habana_available
from tests_pytorch.conftest import mock_cuda_count, mock_mps_count, mock_tpu_available, mock_xla_available
from tests_pytorch.helpers.runif import RunIf

if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
from lightning_graphcore import IPUAccelerator, IPUStrategy

if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
from lightning_habana import HPUAccelerator, SingleHPUStrategy


Expand Down Expand Up @@ -935,7 +935,7 @@ def _mock_tpu_available(value):
assert connector.strategy.launcher.is_interactive_compatible

# Single/Multi IPU: strategy is the same
if _LIGHTNING_GRAPHCORE_AVAILABLE:
if _lightning_graphcore_available():
with monkeypatch.context():
mock_cuda_count(monkeypatch, 0)
mock_mps_count(monkeypatch, 0)
Expand All @@ -949,7 +949,7 @@ def _mock_tpu_available(value):
assert connector.strategy.launcher is None

# Single HPU
if _LIGHTNING_HABANA_AVAILABLE:
if _lightning_habana_available():
import lightning_habana

with monkeypatch.context():
Expand All @@ -967,7 +967,7 @@ def _mock_tpu_available(value):
monkeypatch.undo() # for some reason `.context()` is not working properly
_mock_interactive()

if not is_interactive and _LIGHTNING_HABANA_AVAILABLE: # HPU does not support interactive environments
if not is_interactive and _lightning_habana_available(): # HPU does not support interactive environments
from lightning_habana import HPUParallelStrategy

# Multi HPU
Expand Down

0 comments on commit 0aeeb60

Please sign in to comment.