diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index ac8922adb417..fdd4be80d181 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -38,7 +38,7 @@ def create_from_context() -> Optional["_XlaDistModel"]: return _XlaDistModel() @staticmethod - def create_from_backend(backend: str = "xla-tpu", **kwargs) -> "_XlaDistModel": + def create_from_backend(backend: str = XLA_TPU, **kwargs) -> "_XlaDistModel": return _XlaDistModel(backend=backend, **kwargs) def __init__(self, backend=None, **kwargs): @@ -57,7 +57,7 @@ def _create_from_backend(self, backend, **kwargs): self._setup_attrs() def _init_from_context(self): - self._backend = "xla-tpu" + self._backend = XLA_TPU self._setup_attrs() def _compute_nproc_per_node(self): @@ -110,7 +110,7 @@ def spawn( nproc_per_node: int = 1, nnodes: int = 1, node_rank: int = 0, - backend: str = "xla-tpu", + backend: str = XLA_TPU, **kwargs ): import os