-
Notifications
You must be signed in to change notification settings - Fork 8
Support tuple of tensors in estimate_strategy_runtime_cost #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -213,25 +213,23 @@ def estimate_strategy_runtime_cost(node, strategy): | |
| args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args) | ||
| kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs) | ||
|
|
||
| fake_mode = torch._guards.detect_fake_mode(args) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is removed because we're already running this in a fake mode? |
||
|
|
||
| if len(kwargs) > 0: | ||
| for k, v in kwargs.items(): | ||
| assert not isinstance(v, torch.Tensor), f"{node} {v}" | ||
| args_sizes_strides = tuple( | ||
| _get_sharded_shape_stride(spec) for spec in strategy.input_specs | ||
| ) | ||
|
|
||
| flat_args, treespec = tree_flatten(args) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't we just call
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to get the size from the But if there is a cleaner way of doing this I'm happy to change the code! |
||
| new_flat_args = [] | ||
| counter = 0 | ||
| args = list(args) | ||
| for i, arg in enumerate(args): | ||
| if isinstance(arg, torch.Tensor): | ||
| with fake_mode: | ||
| sizes, strides = args_sizes_strides[counter] | ||
| args[i] = torch.empty_strided( | ||
| sizes, strides, device=arg.device, dtype=arg.dtype | ||
| ) | ||
| for x in flat_args: | ||
| if isinstance(x, torch.Tensor): | ||
| sizes, strides = args_sizes_strides[counter] | ||
| x = torch.empty_strided(sizes, strides, device=x.device, dtype=x.dtype) | ||
| counter += 1 | ||
| new_flat_args.append(x) | ||
| args = treespec.unflatten(new_flat_args) | ||
|
|
||
| # TODO: maybe cache the flop_counter to avoid recreating it | ||
| # all the time | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this bc we are already under a fake mode now, but we weren't in the initial autop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's right, now the whole AutoParallel is running under fake mode, so we can remove it