Skip to content

Commit

Permalink
fix: remove torchdistx support
Browse files Browse the repository at this point in the history
  • Loading branch information
lausannel authored Sep 27, 2024
1 parent 4d2ee13 commit b7977f5
Showing 1 changed file with 0 additions and 20 deletions.
20 changes: 0 additions & 20 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as spmd
from torch_xla.distributed.fsdp.wrap import recursive_wrap
from torch_xla.distributed.fsdp._init_utils import _materialize_module
from torch_xla.distributed.fsdp.wrap import recursive_wrap
from torch_xla.distributed.fsdp.xla_fully_sharded_data_parallel import _cast_floats_tensors, FLOAT_DTYPES


Expand Down Expand Up @@ -53,18 +51,6 @@ class SpmdFullyShardedDataParallel(nn.Module):
dtype for full parameters for computation. This defaults to
``torch.float32`` but can be set to ``torch.float16`` or
``torch.bfloat16``. The sharded parameters will always be in FP32.
Note:
Support for TorchDistX initialization:
This implementation supports using TorchDistX's ``deferred_init`` for module initialization,
which can help save host memory. When using `deferred_init`, the module will be initialized
by a default initialization function that calls torchdistX's ``materialize_module``.
Example:
>>> # With torchdistX
>>> module = deferred_init.deferred_init(MyModule, device="cuda")
>>> # Will initialize via deferred_init.materialize_module().
>>> fsdp_model = FSDPv2(module)
"""

def __init__(
Expand Down Expand Up @@ -132,12 +118,6 @@ def __init__(
f"compute_dtype must be one of {FLOAT_DTYPES}, not {compute_dtype}")
self.compute_dtype = compute_dtype or torch.float32

_materialize_module(
module,
None, [],
deferred_init_check_fn=lambda k: not isinstance(
k, SpmdFullyShardedDataParallel))

# Let's move the module to xla device in case it's not moved
# by the caller already.
self._orig_module = module.to(xm.xla_device())
Expand Down

0 comments on commit b7977f5

Please sign in to comment.