-
Notifications
You must be signed in to change notification settings - Fork 601
Closed
Copy link
Description
Bug description
When training Qwen3 model and set ETP=1 + torch.compile + AC
Error stack
[rank0]:[rank0]:V1007 03:32:05.001000 307436 site-packages/torch/_dynamo/symbolic_convert.py:4419] [0/0] FAILED INLINING <code object forward at 0x7f5699c88190, file "/data05/ziyizhang.zzy/torchtitan/torchtitan/experiments/qwen3/model/model.py", line 307>
[rank0]:[rank0]:I1007 03:32:05.001000 307436 site-packages/torch/_dynamo/variables/higher_order_ops.py:1193] [0/0] speculate_subgraph: while introspecting torch.utils.checkpoint.checkpoint, we were unable to trace function `UnspecializedNNModuleVariable` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[rank0]:[rank0]:I1007 03:32:05.001000 307436 site-packages/torch/_dynamo/variables/higher_order_ops.py:1194] [0/0] HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)
[rank0]:[rank0]:I1007 03:32:05.001000 307436 site-packages/torch/_dynamo/variables/higher_order_ops.py:1194] [0/0] Explanation: This is not supported.
[rank0]:[rank0]:I1007 03:32:05.001000 307436 site-packages/torch/_dynamo/variables/higher_order_ops.py:1194] [0/0]
[rank0]:[rank0]:I1007 03:32:05.001000 307436 site-packages/torch/_dynamo/variables/higher_order_ops.py:1194] [0/0]
[rank0]:[rank0]:I1007 03:32:05.001000 307436 site-packages/torch/_dynamo/variables/higher_order_ops.py:1194] [0/0] Developer debug context:
[rank0]:[rank0]:I1007 03:32:05.001000 307436 site-packages/torch/_dynamo/variables/higher_order_ops.py:1194] [0/0]
[rank0]:[rank0]:I1007 03:32:05.001000 307436 site-packages/torch/_dynamo/variables/higher_order_ops.py:1194] [0/0] For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0067.html
[rank0]:[rank0]:V1007 03:32:05.002000 307436 site-packages/torch/_dynamo/symbolic_convert.py:605] [0/0] [__graph_breaks] Graph break in user code at /data05/ziyizhang.zzy/torchtitan/torchtitan/distributed/expert_parallel.py:301
[rank0]:[rank0]:V1007 03:32:05.002000 307436 site-packages/torch/_dynamo/symbolic_convert.py:605] [0/0] [__graph_breaks] Graph Break Reason: HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)
[rank0]:[rank0]:V1007 03:32:05.002000 307436 site-packages/torch/_dynamo/symbolic_convert.py:605] [0/0] [__graph_breaks] Explanation: This is not supported.
This is because torch.utils.checkpoint.checkpoint is HOPs in compile mode, which disallow mutation under the HOP. However, in ReordererSequenceParallel there is a mutation to self.num_tokens, which triggers the tracing failure and fallback to eager.
The self.num_tokens was introduced by #1586, where the top_scores.shape[0] shape is different in inputs and outputs, so storing the value as an attribute. To avoid this mutation, probably we can make topk as class attribute and inferring the num_tokens by top_scores.shape[0] / self.topk?
Versions
main branch
xmfan and wmhst7tianyu-l