Skip to content

Commit

Permalink
adding mark_step_on_freeing as a temp workaround to pytorch#3455
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed Apr 28, 2022
1 parent 4d77c4a commit 3b33247
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class XlaFullyShardedDataParallel(nn.Module):
if ``True``, use PyTorch XLA 1.10's all_gather implementation,
which performs all_gather via padding and all_reduce and avoids
the GRPC error (see https://github.com/pytorch/xla/issues/3423).
mark_step_on_freeing (bool, Optional):
if ``True``, call `xm.mark_step` upon freeing full parameters.
This is a temporary and inefficient workaround to avoid XLA compiler
fusion that breaks parameter freeing in nested FSDP. It is useful
only when ``reshard_after_forward`` is ``True``. See details in
https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513.
"""

def __init__(
Expand All @@ -134,6 +140,7 @@ def __init__(
flatten_parameters: bool = True,
execute_sharding_on_init: bool = True,
use_all_gather_via_all_reduce: bool = True,
mark_step_on_freeing: bool = False,
):
is_forward_defined = (
hasattr(module, "forward") and hasattr(module.forward, "__func__") and
Expand Down Expand Up @@ -232,8 +239,16 @@ def __init__(

if execute_sharding_on_init:
# Execute the parameter sharding immediately and free up the memory
xm.mark_step()
gc.collect()
xm.mark_step()

# TODO (ronghanghu): remove when https://github.com/pytorch/xla/issues/3455 is resolved
# This is a temporary workaround before after we have a mature solution
# to avoid undesired fusion with XLA compiler optimization barrier (see
# https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513
# for details). This workaround notably increases the execution time and
# may trigger more compilation, so we need a permanent solution to #3455.
self._mark_step_on_freeing = mark_step_on_freeing

def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
Expand Down Expand Up @@ -864,6 +879,11 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
# free the original full parameter
p.data = self._dummy_data_placeholder
p._has_full_param = False
# immediately execute the parameter freeing as a workaround to undesired XLA fusion
# see https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513 for details
# TODO (ronghanghu): remove when https://github.com/pytorch/xla/issues/3455 is resolved
if self._mark_step_on_freeing:
xm.mark_step()

def assert_state(self, state: Union[TrainingState,
List[TrainingState]]) -> None:
Expand Down

0 comments on commit 3b33247

Please sign in to comment.