Skip to content

Start on composable FSDP + DDP #1981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

Start on composable FSDP + DDP #1981

wants to merge 4 commits into from

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Apr 22, 2025

Work in progress on composable FSDP + DDP.

Wants to fix #1980

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 23, 2025

The goal is to get this to work:

import torch, thunder
import torch.distributed
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform
from thunder.distributed.transforms.ddp_v2 import DDPTransform

mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (2, 2), mesh_dim_names=("ddp", "fsdp"))

with torch.device("cuda"):
    m = torch.nn.Sequential(torch.nn.Linear(256, 256), torch.nn.ReLU(), torch.nn.Linear(256, 256))
    inp = torch.randn(4, 256)

jm = thunder.jit(
    m,
    transforms=[
        FSDPTransform(process_group=mesh["fsdp"].get_group()),
        DDPTransform(mesh["ddp"].get_group(), broadcast_from=0, bucket_size_in_mb=25.0),
    ],
)

res = jm(inp)
res.sum().backward()

torch.distributed.destroy_process_group()

Currently, this hits a problem in the optimize_allreduce_in_ddp_backward called by apply_bucketing_to_grad_allreduce on the forward (which is puzzling to me because of the name, not sure yet if it would be as easy as skipping the call when we have a forward trace - @crcrpar if you have any idea).

@crcrpar
Copy link
Collaborator

crcrpar commented Apr 23, 2025

Would a problem be

for key in output_tensor_to_index_and_prod_bsym._dict:
_, bsym = output_tensor_to_index_and_prod_bsym.get_by_name(key)
if bsym.sym.id == dist_prims.PrimIDs.WAIT:
bsym_of_allreduce: BoundSymbol = producers[bsym.flat_proxy_args[0]]
utils.check(
bsym_of_allreduce.sym.id,
dist_prims.PrimIDs.ALL_REDUCE,
lambda: f"{bsym.sym.id=}, {bsym_of_allreduce.sym.id=}",
)
grad_before_after_allreduce[bsym.flat_proxy_outs[0]] = bsym_of_allreduce.flat_proxy_args[0]
by chance?
I didn't think about hybrid sharded data parallel at all when writing it. I think it'd be worth replacing utils.check(...) with if bsym_of_allreduce.sym.id is dist_prims.PrimIDs.ALL_REDUCE

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 25, 2025

Thanks! So I have some changes elsewhere, too.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 29, 2025

On my branch, the trouble is that the bucketing drops the WAIT for FSDP, but it should be conditioned on the process group instead(?). If I comment out / skip this:

visit_callable = BatchAllReduceVisitor(
process_group=compile_data.process_group_for_ddp,
flat_backward_trace_output=flat_backward_trace_output,
backward_trace_output_spec=backward_trace_output_spec,
gradient_buckets=gradient_buckets,
prims_to_filter={dist_prims.PrimIDs.ALL_REDUCE, dist_prims.PrimIDs.WAIT},
)
updated_bwd_trace = visitor_transform(
backward_trace,
visit_callable,
provenance="Batching all_reduce calls",
)

I get the following for a two-layer MLP (do you have comments short of the lack of bucketing)?

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(input, t_0_bias, t_0_weight, t_2_bias, t_2_weight):
  # input: "cuda:0 f32[4, 256]"
  # t_0_bias: "cuda:0 f32[128]"
  # t_0_weight: "cuda:0 f32[128, 256]"
  # t_2_bias: "cuda:0 f32[128]"
  # t_2_weight: "cuda:0 f32[128, 256]"
  ft31 = torch_all_gather_prim_impl(t_0_bias, _torch_distributed_distributed_c10d_ProcessGroup_8, True, None)  # ft31: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
  ft33 = torch_all_gather_prim_impl(t_0_weight, _torch_distributed_distributed_c10d_ProcessGroup_8, True, None)  # ft33: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
  ft36 = torch_all_gather_prim_impl(t_2_bias, _torch_distributed_distributed_c10d_ProcessGroup_8, True, None)  # ft36: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
  ft45 = torch_all_gather_prim_impl(t_2_weight, _torch_distributed_distributed_c10d_ProcessGroup_8, True, None)  # ft45: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
  t32 = torch_wait_prim_impl(ft31)  # t32: "cuda:0 f32[256]"
  del ft31
  t34 = torch_wait_prim_impl(ft33)  # t34: "cuda:0 f32[256, 256]"
  del ft33

  # /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/linear.py:125:          return F.linear(input, self.weight, self.bias)
  t35 = torch.nn.functional.linear(input, t34, t32)  # t35: "cuda:0 f32[4, 256]"
    # t35 = ltorch.linear(input, t34, t32)  # t35: "cuda:0 f32[4, 256]"
      # t35 = prims.linear(input, t34, t32)  # t35: "cuda:0 f32[4, 256]"
  del t34, t32
  [t24, t6] = nvFusion0(t35)
    # t24 = prims.gt(t35, 0.0)  # t24: "cuda:0 b8[4, 256]"
    # t6 = prims.where(t24, t35, 0.0)  # t6: "cuda:0 f32[4, 256]"
  del t35
  t44 = torch_wait_prim_impl(ft36)  # t44: "cuda:0 f32[256]"
  del ft36
  t47 = torch_wait_prim_impl(ft45)  # t47: "cuda:0 f32[256, 256]"
  del ft45

  # /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/linear.py:125:          return F.linear(input, self.weight, self.bias)
  t49 = torch.nn.functional.linear(t6, t47, t44)  # t49: "cuda:0 f32[4, 256]"
    # t49 = ltorch.linear(t6, t47, t44)  # t49: "cuda:0 f32[4, 256]"
      # t49 = prims.linear(t6, t47, t44)  # t49: "cuda:0 f32[4, 256]"
  del t44
  return {'output': (t49,), 'flat_args': [input, t_0_bias, t_0_weight, t_2_bias, t_2_weight], 'flat_output': (t49,)}, ((input, t24, t47, t6), ())
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  # C0: "Collection"
  # None
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  # t0: "cuda:0 f32[4, 256]"
  clear_mutable_collection(cotangents)
  del cotangents
  input, t24, t47, t6, = C0
  # input: "cuda:0 f32[4, 256]"
  # t24: "cuda:0 b8[4, 256]"
  # t47: "cuda:0 f32[256, 256]"
  # t6: "cuda:0 f32[4, 256]"
  clear_mutable_collection(C0)
  del C0
  bw_t95 = torch.reshape(t0, (-1, 256))  # bw_t95: "cuda:0 f32[4, 256]"
    # bw_t95 = ltorch.reshape(t0, (-1, 256))  # bw_t95: "cuda:0 f32[4, 256]"
      # bw_t95 = prims.reshape(t0, (4, 256))  # bw_t95: "cuda:0 f32[4, 256]"
  bw_t76 = torch.matmul(bw_t95, t47)  # bw_t76: "cuda:0 f32[4, 256]"
    # bw_t76 = ltorch.matmul(bw_t95, t47)  # bw_t76: "cuda:0 f32[4, 256]"
      # bw_t76 = prims.matmul(bw_t95, t47)  # bw_t76: "cuda:0 f32[4, 256]"
  del t47
  bw_t96 = torch.permute(bw_t95, (1, 0))  # bw_t96: "cuda:0 f32[256, 4]"
    # bw_t96 = ltorch.permute(bw_t95, (1, 0))  # bw_t96: "cuda:0 f32[256, 4]"
      # bw_t96 = prims.transpose(bw_t95, (1, 0))  # bw_t96: "cuda:0 f32[256, 4]"
  del bw_t95
  bw_t97 = torch.reshape(t6, (-1, 256))  # bw_t97: "cuda:0 f32[4, 256]"
    # bw_t97 = ltorch.reshape(t6, (-1, 256))  # bw_t97: "cuda:0 f32[4, 256]"
      # bw_t97 = prims.reshape(t6, (4, 256))  # bw_t97: "cuda:0 f32[4, 256]"
  del t6
  bw_t77 = torch.matmul(bw_t96, bw_t97)  # bw_t77: "cuda:0 f32[256, 256]"
    # bw_t77 = ltorch.matmul(bw_t96, bw_t97)  # bw_t77: "cuda:0 f32[256, 256]"
      # bw_t77 = prims.matmul(bw_t96, bw_t97)  # bw_t77: "cuda:0 f32[256, 256]"
  del bw_t96, bw_t97
  [bw_t42, bw_t43] = nvFusion0(t0, bw_t77)
    # bw_t42 = prims.sum(t0, (0,))  # bw_t42: "cuda:0 f32[256]"
    # bw_t43 = prims.div(bw_t77, 2.0)  # bw_t43: "cuda:0 f32[256, 256]"
  del t0, bw_t77
  ft78 = torch_all_reduce_prim_impl(bw_t43, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_6, True, False)  # ft78: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
  del bw_t43
  [bw_t46] = nvFusion1(bw_t42)
    # bw_t46 = prims.div(bw_t42, 2.0)  # bw_t46: "cuda:0 f32[256]"
  del bw_t42
  ft80 = torch_all_reduce_prim_impl(bw_t46, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_6, True, False)  # ft80: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
  del bw_t46
  [bw_t54, bw_t55] = nvFusion2(t24, bw_t76, input)
    # bw_t50 = prims.where(t24, bw_t76, 0.0)  # bw_t50: "cuda:0 f32[4, 256]"
    # bw_t53 = prims.reshape(bw_t50, (4, 256))  # bw_t53: "cuda:0 f32[4, 256]"
    # bw_t54 = prims.transpose(bw_t53, (1, 0))  # bw_t54: "cuda:0 f32[256, 4]"
    # bw_t55 = prims.reshape(input, (4, 256))  # bw_t55: "cuda:0 f32[4, 256]"
  del input
  bw_t82 = torch.matmul(bw_t54, bw_t55)  # bw_t82: "cuda:0 f32[256, 256]"
    # bw_t82 = ltorch.matmul(bw_t54, bw_t55)  # bw_t82: "cuda:0 f32[256, 256]"
      # bw_t82 = prims.matmul(bw_t54, bw_t55)  # bw_t82: "cuda:0 f32[256, 256]"
  del bw_t54, bw_t55
  [bw_t58, bw_t57] = nvFusion3(t24, bw_t76, bw_t82)
    # bw_t50 = prims.where(t24, bw_t76, 0.0)  # bw_t50: "cuda:0 f32[4, 256]"
    # bw_t58 = prims.div(bw_t82, 2.0)  # bw_t58: "cuda:0 f32[256, 256]"
    # bw_t57 = prims.sum(bw_t50, (0,))  # bw_t57: "cuda:0 f32[256]"
  del t24, bw_t76, bw_t82
  ft83 = torch_all_reduce_prim_impl(bw_t58, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_6, True, False)  # ft83: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
  del bw_t58
  [bw_t61] = nvFusion4(bw_t57)
    # bw_t61 = prims.div(bw_t57, 2.0)  # bw_t61: "cuda:0 f32[256]"
  del bw_t57
  ft85 = torch_all_reduce_prim_impl(bw_t61, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_6, True, False)  # ft85: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
  del bw_t61
  bw_t79 = torch_wait_prim_impl(ft78)  # bw_t79: "cuda:0 f32[256, 256]"
  del ft78
  [bw_t64] = nvFusion5(bw_t79)
    # bw_t64 = prims.div(bw_t79, 2.0)  # bw_t64: "cuda:0 f32[256, 256]"
  del bw_t79
  ft87 = torch_all_reduce_prim_impl(bw_t64, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_7, True, False)  # ft87: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
  del bw_t64
  bw_t81 = torch_wait_prim_impl(ft80)  # bw_t81: "cuda:0 f32[256]"
  del ft80
  [bw_t67] = nvFusion6(bw_t81)
    # bw_t67 = prims.div(bw_t81, 2.0)  # bw_t67: "cuda:0 f32[256]"
  del bw_t81
  ft89 = torch_all_reduce_prim_impl(bw_t67, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_7, True, False)  # ft89: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
  del bw_t67
  bw_t84 = torch_wait_prim_impl(ft83)  # bw_t84: "cuda:0 f32[256, 256]"
  del ft83
  [bw_t70] = nvFusion7(bw_t84)
    # bw_t70 = prims.div(bw_t84, 2.0)  # bw_t70: "cuda:0 f32[256, 256]"
  del bw_t84
  ft91 = torch_all_reduce_prim_impl(bw_t70, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_7, True, False)  # ft91: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
  del bw_t70
  bw_t86 = torch_wait_prim_impl(ft85)  # bw_t86: "cuda:0 f32[256]"
  del ft85
  [bw_t73] = nvFusion8(bw_t86)
    # bw_t73 = prims.div(bw_t86, 2.0)  # bw_t73: "cuda:0 f32[256]"
  del bw_t86
  ft93 = torch_all_reduce_prim_impl(bw_t73, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_7, True, False)  # ft93: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
  del bw_t73
  bw_t88 = torch_wait_prim_impl(ft87)  # bw_t88: "cuda:0 f32[256, 256]"
  del ft87
  bw_t90 = torch_wait_prim_impl(ft89)  # bw_t90: "cuda:0 f32[256]"
  del ft89
  bw_t92 = torch_wait_prim_impl(ft91)  # bw_t92: "cuda:0 f32[256, 256]"
  del ft91
  bw_t94 = torch_wait_prim_impl(ft93)  # bw_t94: "cuda:0 f32[256]"
  del ft93
  return (None, bw_t94, bw_t92, bw_t90, bw_t88)

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@t-vi
Copy link
Collaborator Author

t-vi commented May 1, 2025

So the following works on a 4 GPU machine except for the tracing:

import torch, thunder
import torch.distributed
from torch.testing import assert_close
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform
from thunder.distributed.transforms.ddp_v2 import DDPTransform

torch.manual_seed(1337)

mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (2, 2), mesh_dim_names=("ddp", "fsdp"))
global_rank = mesh.get_rank()
fsdp_rank = mesh.get_local_rank('fsdp')
print(f"{global_rank=}, {fsdp_rank=}")
print(mesh)
print(mesh.get_coordinate())

with torch.device("cuda"):
    m = torch.nn.Sequential(torch.nn.Linear(256, 256), torch.nn.ReLU(), torch.nn.Linear(256, 256))
    inp = torch.randn(4, 256)

jm = thunder.jit(
    m,
    transforms=[
        FSDPTransform(process_group=mesh["fsdp"].get_group()),
        DDPTransform(mesh["ddp"].get_group(), broadcast_from=0, bucket_size_in_mb=25.0),
    ],
)

inp_sharded = inp[global_rank:global_rank + 1]
res = jm(inp)
go = torch.randn_like(res)
grads = torch.autograd.grad(res, jm.parameters(), go)
ref = m(inp)
ref_grads = torch.autograd.grad(ref, m.parameters(), go)
assert_close(res, ref)
for g, rg in zip(grads, ref_grads):
    slice_size = rg.size(0) // 2
    assert_close(g, rg[slice_size * fsdp_rank: slice_size * (fsdp_rank + 1)])

torch.distributed.destroy_process_group()

print("Worked!")

I currently need to patch out the bucketing, which I still need to fix (cope with waits not being in the bucketing and do the bucketing by process group etc.)

diff --git a/thunder/distributed/transforms/ddp.py b/thunder/distributed/transforms/ddp.py
index c78313e4..ad67cb2b 100644
--- a/thunder/distributed/transforms/ddp.py
+++ b/thunder/distributed/transforms/ddp.py
@@ -14,7 +14,7 @@ from thunder.core.proxies import variableify
 from thunder.core.pytree import tree_flatten
 from thunder.core.pytree import tree_unflatten
 from thunder.core.trace import from_trace
-from thunder.core.trace import TraceProvenance
+from thunder.core.trace import TraceProvenance, TraceTag
 from thunder.core.transforms import visitor_transform
 from thunder.core.transforms import VISIT_TYPE
 from thunder.core import utils
@@ -231,6 +231,8 @@ def optimize_allreduce_in_ddp_backward(
         gradient_buckets=gradient_buckets,
         prims_to_filter={dist_prims.PrimIDs.ALL_REDUCE, dist_prims.PrimIDs.WAIT},
     )
+    print("###evil hack###")
+    return backward_trace
     updated_bwd_trace = visitor_transform(
         backward_trace,
         visit_callable,
@@ -303,4 +305,6 @@ def apply_bucketing_to_grad_allreduce(trace: TraceCtx) -> TraceCtx:
     if len(grad_before_after_allreduce._dict) == 0:
         return trace
 
+    if TraceTag.AUGMENTED_FORWARD in trace.tags:
+        return trace
     return optimize_allreduce_in_ddp_backward(trace, compile_data=compile_data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Hybrid FSDP and DDP does not work (composable transforms)
2 participants