diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 7128d892391..a6d0c64911e 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -11,8 +11,6 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as spmd from torch_xla.distributed.fsdp.wrap import recursive_wrap -from torch_xla.distributed.fsdp._init_utils import _materialize_module -from torch_xla.distributed.fsdp.wrap import recursive_wrap from torch_xla.distributed.fsdp.xla_fully_sharded_data_parallel import _cast_floats_tensors, FLOAT_DTYPES @@ -53,18 +51,6 @@ class SpmdFullyShardedDataParallel(nn.Module): dtype for full parameters for computation. This defaults to ``torch.float32`` but can be set to ``torch.float16`` or ``torch.bfloat16``. The sharded parameters will always be in FP32. - - Note: - Support for TorchDistX initialization: - This implementation supports using TorchDistX's ``deferred_init`` for module initialization, - which can help save host memory. When using `deferred_init`, the module will be initialized - by a default initialization function that calls torchdistX's ``materialize_module``. - - Example: - >>> # With torchdistX - >>> module = deferred_init.deferred_init(MyModule, device="cuda") - >>> # Will initialize via deferred_init.materialize_module(). - >>> fsdp_model = FSDPv2(module) """ def __init__( @@ -132,12 +118,6 @@ def __init__( f"compute_dtype must be one of {FLOAT_DTYPES}, not {compute_dtype}") self.compute_dtype = compute_dtype or torch.float32 - _materialize_module( - module, - None, [], - deferred_init_check_fn=lambda k: not isinstance( - k, SpmdFullyShardedDataParallel)) - # Let's move the module to xla device in case it's not moved # by the caller already. self._orig_module = module.to(xm.xla_device())