Skip to content

Conversation

@sanketpurandare
Copy link
Contributor

@sanketpurandare sanketpurandare commented Nov 5, 2025

Based off #227, will rebase once it lands.
Test file: examples/example_pp_graph_passes.py

Made some modifications to Ivan’s pass (#201) as follows:

The split_fsdp_prefetch pass required all the inputs (params+buffers+microbatch) to be passed to the prefetch graph and required all the outputs of the prefetch graph to be passed to the fwd_graph. But we wanted something slightly different. We only want to pass the sharded_params to the prefetch_graph and obtain unsharded_params from it. The fwd_graph then takes in these unsharded_params (obtained from prefetch_graph)+buffers+microbatch. This is important because we want the prefetch graph to be independent of micro batch since the unshard action won’t have an associated micro batch.

Similarly, reduce_grad graph should only take in the unsharded_grads from the bwd_graph and produce sharded_grads. We cannot pass in all the outputs of the bwd_graph to the reduce_grad graph, since it would make it micro batch dependent and we want to call reduce_grad action only once after accumulating unsharded_grads across micro batches.

Brian’s pass (#232) worked perfectly fine, just integrated it differently in the end-to-end workflow.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 5, 2025
Copy link
Contributor

@IvanKobzarev IvanKobzarev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splti_fsdp_prefetch lgtm

gm: torch.fx.GraphModule,
num_params: int,
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
g = deepcopy(gm.graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why do you want to keep the original graph unchanged? Will it be further used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using its container graph module to initialize it with the two new graphs, I just thought it would be safer to this way, we can remove this later

g_ins = g.find_nodes(op="placeholder")
def split_fsdp_prefetch(
gm: torch.fx.GraphModule,
num_params: int,
Copy link
Contributor

@IvanKobzarev IvanKobzarev Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If we use export with descriptors, potentially num_params could be taken from metadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for now this is easily obtainable from graph meta. Same as above can be removed later.

@sanketpurandare sanketpurandare merged commit 7fb094d into meta-pytorch:main Nov 5, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants