-
Notifications
You must be signed in to change notification settings - Fork 8
Support tuple of tensors in estimate_strategy_runtime_cost #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Previously, if we had tuple of tensors as an argument to a function, we wouldn't apply any sharding on it. This is split from #26 , where I originally found this issue
| args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args) | ||
| kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs) | ||
|
|
||
| fake_mode = torch._guards.detect_fake_mode(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this bc we are already under a fake mode now, but we weren't in the initial autop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's right, now the whole AutoParallel is running under fake mode, so we can remove it
| args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args) | ||
| kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs) | ||
|
|
||
| fake_mode = torch._guards.detect_fake_mode(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is removed because we're already running this in a fake mode?
| _get_sharded_shape_stride(spec) for spec in strategy.input_specs | ||
| ) | ||
|
|
||
| flat_args, treespec = tree_flatten(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we just call tree_map_only(Tensor, torch.empty) here instead of doing the for loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to get the size from the args_sizes_strides (which comes from the spec), so I think we might need this indirection.
But if there is a cleaner way of doing this I'm happy to change the code!
Previously, if we had tuple of tensors as an argument to a function, we wouldn't apply any sharding on it. This is split from #26 , where I originally found this issue