-
Notifications
You must be signed in to change notification settings - Fork 93
Start on composable FSDP + DDP #1981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The goal is to get this to work: import torch, thunder
import torch.distributed
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform
from thunder.distributed.transforms.ddp_v2 import DDPTransform
mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (2, 2), mesh_dim_names=("ddp", "fsdp"))
with torch.device("cuda"):
m = torch.nn.Sequential(torch.nn.Linear(256, 256), torch.nn.ReLU(), torch.nn.Linear(256, 256))
inp = torch.randn(4, 256)
jm = thunder.jit(
m,
transforms=[
FSDPTransform(process_group=mesh["fsdp"].get_group()),
DDPTransform(mesh["ddp"].get_group(), broadcast_from=0, bucket_size_in_mb=25.0),
],
)
res = jm(inp)
res.sum().backward()
torch.distributed.destroy_process_group() Currently, this hits a problem in the |
Would a problem be lightning-thunder/thunder/distributed/transforms/ddp.py Lines 293 to 302 in 3aa706a
I didn't think about hybrid sharded data parallel at all when writing it. I think it'd be worth replacing utils.check(...) with if bsym_of_allreduce.sym.id is dist_prims.PrimIDs.ALL_REDUCE
|
Thanks! So I have some changes elsewhere, too. |
On my branch, the trouble is that the bucketing drops the WAIT for FSDP, but it should be conditioned on the process group instead(?). If I comment out / skip this: lightning-thunder/thunder/distributed/transforms/ddp.py Lines 227 to 238 in 5f17876
I get the following for a two-layer MLP (do you have comments short of the lack of bucketing)? # Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(input, t_0_bias, t_0_weight, t_2_bias, t_2_weight):
# input: "cuda:0 f32[4, 256]"
# t_0_bias: "cuda:0 f32[128]"
# t_0_weight: "cuda:0 f32[128, 256]"
# t_2_bias: "cuda:0 f32[128]"
# t_2_weight: "cuda:0 f32[128, 256]"
ft31 = torch_all_gather_prim_impl(t_0_bias, _torch_distributed_distributed_c10d_ProcessGroup_8, True, None) # ft31: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
ft33 = torch_all_gather_prim_impl(t_0_weight, _torch_distributed_distributed_c10d_ProcessGroup_8, True, None) # ft33: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
ft36 = torch_all_gather_prim_impl(t_2_bias, _torch_distributed_distributed_c10d_ProcessGroup_8, True, None) # ft36: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
ft45 = torch_all_gather_prim_impl(t_2_weight, _torch_distributed_distributed_c10d_ProcessGroup_8, True, None) # ft45: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
t32 = torch_wait_prim_impl(ft31) # t32: "cuda:0 f32[256]"
del ft31
t34 = torch_wait_prim_impl(ft33) # t34: "cuda:0 f32[256, 256]"
del ft33
# /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias)
t35 = torch.nn.functional.linear(input, t34, t32) # t35: "cuda:0 f32[4, 256]"
# t35 = ltorch.linear(input, t34, t32) # t35: "cuda:0 f32[4, 256]"
# t35 = prims.linear(input, t34, t32) # t35: "cuda:0 f32[4, 256]"
del t34, t32
[t24, t6] = nvFusion0(t35)
# t24 = prims.gt(t35, 0.0) # t24: "cuda:0 b8[4, 256]"
# t6 = prims.where(t24, t35, 0.0) # t6: "cuda:0 f32[4, 256]"
del t35
t44 = torch_wait_prim_impl(ft36) # t44: "cuda:0 f32[256]"
del ft36
t47 = torch_wait_prim_impl(ft45) # t47: "cuda:0 f32[256, 256]"
del ft45
# /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias)
t49 = torch.nn.functional.linear(t6, t47, t44) # t49: "cuda:0 f32[4, 256]"
# t49 = ltorch.linear(t6, t47, t44) # t49: "cuda:0 f32[4, 256]"
# t49 = prims.linear(t6, t47, t44) # t49: "cuda:0 f32[4, 256]"
del t44
return {'output': (t49,), 'flat_args': [input, t_0_bias, t_0_weight, t_2_bias, t_2_weight], 'flat_output': (t49,)}, ((input, t24, t47, t6), ()) # Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
# C0: "Collection"
# None
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t0, = cotangents
# t0: "cuda:0 f32[4, 256]"
clear_mutable_collection(cotangents)
del cotangents
input, t24, t47, t6, = C0
# input: "cuda:0 f32[4, 256]"
# t24: "cuda:0 b8[4, 256]"
# t47: "cuda:0 f32[256, 256]"
# t6: "cuda:0 f32[4, 256]"
clear_mutable_collection(C0)
del C0
bw_t95 = torch.reshape(t0, (-1, 256)) # bw_t95: "cuda:0 f32[4, 256]"
# bw_t95 = ltorch.reshape(t0, (-1, 256)) # bw_t95: "cuda:0 f32[4, 256]"
# bw_t95 = prims.reshape(t0, (4, 256)) # bw_t95: "cuda:0 f32[4, 256]"
bw_t76 = torch.matmul(bw_t95, t47) # bw_t76: "cuda:0 f32[4, 256]"
# bw_t76 = ltorch.matmul(bw_t95, t47) # bw_t76: "cuda:0 f32[4, 256]"
# bw_t76 = prims.matmul(bw_t95, t47) # bw_t76: "cuda:0 f32[4, 256]"
del t47
bw_t96 = torch.permute(bw_t95, (1, 0)) # bw_t96: "cuda:0 f32[256, 4]"
# bw_t96 = ltorch.permute(bw_t95, (1, 0)) # bw_t96: "cuda:0 f32[256, 4]"
# bw_t96 = prims.transpose(bw_t95, (1, 0)) # bw_t96: "cuda:0 f32[256, 4]"
del bw_t95
bw_t97 = torch.reshape(t6, (-1, 256)) # bw_t97: "cuda:0 f32[4, 256]"
# bw_t97 = ltorch.reshape(t6, (-1, 256)) # bw_t97: "cuda:0 f32[4, 256]"
# bw_t97 = prims.reshape(t6, (4, 256)) # bw_t97: "cuda:0 f32[4, 256]"
del t6
bw_t77 = torch.matmul(bw_t96, bw_t97) # bw_t77: "cuda:0 f32[256, 256]"
# bw_t77 = ltorch.matmul(bw_t96, bw_t97) # bw_t77: "cuda:0 f32[256, 256]"
# bw_t77 = prims.matmul(bw_t96, bw_t97) # bw_t77: "cuda:0 f32[256, 256]"
del bw_t96, bw_t97
[bw_t42, bw_t43] = nvFusion0(t0, bw_t77)
# bw_t42 = prims.sum(t0, (0,)) # bw_t42: "cuda:0 f32[256]"
# bw_t43 = prims.div(bw_t77, 2.0) # bw_t43: "cuda:0 f32[256, 256]"
del t0, bw_t77
ft78 = torch_all_reduce_prim_impl(bw_t43, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_6, True, False) # ft78: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
del bw_t43
[bw_t46] = nvFusion1(bw_t42)
# bw_t46 = prims.div(bw_t42, 2.0) # bw_t46: "cuda:0 f32[256]"
del bw_t42
ft80 = torch_all_reduce_prim_impl(bw_t46, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_6, True, False) # ft80: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
del bw_t46
[bw_t54, bw_t55] = nvFusion2(t24, bw_t76, input)
# bw_t50 = prims.where(t24, bw_t76, 0.0) # bw_t50: "cuda:0 f32[4, 256]"
# bw_t53 = prims.reshape(bw_t50, (4, 256)) # bw_t53: "cuda:0 f32[4, 256]"
# bw_t54 = prims.transpose(bw_t53, (1, 0)) # bw_t54: "cuda:0 f32[256, 4]"
# bw_t55 = prims.reshape(input, (4, 256)) # bw_t55: "cuda:0 f32[4, 256]"
del input
bw_t82 = torch.matmul(bw_t54, bw_t55) # bw_t82: "cuda:0 f32[256, 256]"
# bw_t82 = ltorch.matmul(bw_t54, bw_t55) # bw_t82: "cuda:0 f32[256, 256]"
# bw_t82 = prims.matmul(bw_t54, bw_t55) # bw_t82: "cuda:0 f32[256, 256]"
del bw_t54, bw_t55
[bw_t58, bw_t57] = nvFusion3(t24, bw_t76, bw_t82)
# bw_t50 = prims.where(t24, bw_t76, 0.0) # bw_t50: "cuda:0 f32[4, 256]"
# bw_t58 = prims.div(bw_t82, 2.0) # bw_t58: "cuda:0 f32[256, 256]"
# bw_t57 = prims.sum(bw_t50, (0,)) # bw_t57: "cuda:0 f32[256]"
del t24, bw_t76, bw_t82
ft83 = torch_all_reduce_prim_impl(bw_t58, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_6, True, False) # ft83: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
del bw_t58
[bw_t61] = nvFusion4(bw_t57)
# bw_t61 = prims.div(bw_t57, 2.0) # bw_t61: "cuda:0 f32[256]"
del bw_t57
ft85 = torch_all_reduce_prim_impl(bw_t61, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_6, True, False) # ft85: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
del bw_t61
bw_t79 = torch_wait_prim_impl(ft78) # bw_t79: "cuda:0 f32[256, 256]"
del ft78
[bw_t64] = nvFusion5(bw_t79)
# bw_t64 = prims.div(bw_t79, 2.0) # bw_t64: "cuda:0 f32[256, 256]"
del bw_t79
ft87 = torch_all_reduce_prim_impl(bw_t64, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_7, True, False) # ft87: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
del bw_t64
bw_t81 = torch_wait_prim_impl(ft80) # bw_t81: "cuda:0 f32[256]"
del ft80
[bw_t67] = nvFusion6(bw_t81)
# bw_t67 = prims.div(bw_t81, 2.0) # bw_t67: "cuda:0 f32[256]"
del bw_t81
ft89 = torch_all_reduce_prim_impl(bw_t67, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_7, True, False) # ft89: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
del bw_t67
bw_t84 = torch_wait_prim_impl(ft83) # bw_t84: "cuda:0 f32[256, 256]"
del ft83
[bw_t70] = nvFusion7(bw_t84)
# bw_t70 = prims.div(bw_t84, 2.0) # bw_t70: "cuda:0 f32[256, 256]"
del bw_t84
ft91 = torch_all_reduce_prim_impl(bw_t70, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_7, True, False) # ft91: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256, 256]"
del bw_t70
bw_t86 = torch_wait_prim_impl(ft85) # bw_t86: "cuda:0 f32[256]"
del ft85
[bw_t73] = nvFusion8(bw_t86)
# bw_t73 = prims.div(bw_t86, 2.0) # bw_t73: "cuda:0 f32[256]"
del bw_t86
ft93 = torch_all_reduce_prim_impl(bw_t73, _DistributedReduceOps_4, _torch_distributed_distributed_c10d_ProcessGroup_7, True, False) # ft93: "FUTURE thunder.devices.Device(type='cuda', index=0) f32[256]"
del bw_t73
bw_t88 = torch_wait_prim_impl(ft87) # bw_t88: "cuda:0 f32[256, 256]"
del ft87
bw_t90 = torch_wait_prim_impl(ft89) # bw_t90: "cuda:0 f32[256]"
del ft89
bw_t92 = torch_wait_prim_impl(ft91) # bw_t92: "cuda:0 f32[256, 256]"
del ft91
bw_t94 = torch_wait_prim_impl(ft93) # bw_t94: "cuda:0 f32[256]"
del ft93
return (None, bw_t94, bw_t92, bw_t90, bw_t88) |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
So the following works on a 4 GPU machine except for the tracing: import torch, thunder
import torch.distributed
from torch.testing import assert_close
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform
from thunder.distributed.transforms.ddp_v2 import DDPTransform
torch.manual_seed(1337)
mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (2, 2), mesh_dim_names=("ddp", "fsdp"))
global_rank = mesh.get_rank()
fsdp_rank = mesh.get_local_rank('fsdp')
print(f"{global_rank=}, {fsdp_rank=}")
print(mesh)
print(mesh.get_coordinate())
with torch.device("cuda"):
m = torch.nn.Sequential(torch.nn.Linear(256, 256), torch.nn.ReLU(), torch.nn.Linear(256, 256))
inp = torch.randn(4, 256)
jm = thunder.jit(
m,
transforms=[
FSDPTransform(process_group=mesh["fsdp"].get_group()),
DDPTransform(mesh["ddp"].get_group(), broadcast_from=0, bucket_size_in_mb=25.0),
],
)
inp_sharded = inp[global_rank:global_rank + 1]
res = jm(inp)
go = torch.randn_like(res)
grads = torch.autograd.grad(res, jm.parameters(), go)
ref = m(inp)
ref_grads = torch.autograd.grad(ref, m.parameters(), go)
assert_close(res, ref)
for g, rg in zip(grads, ref_grads):
slice_size = rg.size(0) // 2
assert_close(g, rg[slice_size * fsdp_rank: slice_size * (fsdp_rank + 1)])
torch.distributed.destroy_process_group()
print("Worked!") I currently need to patch out the bucketing, which I still need to fix (cope with waits not being in the bucketing and do the bucketing by process group etc.)
|
Work in progress on composable FSDP + DDP.
Wants to fix #1980