Skip to content

Commit 20e53b0

Browse files
authored
Support tuple of tensors in estimate_strategy_runtime_cost (#102)
* Support tuple of tensors in estimate_strategy_runtime_cost Previously, if we had tuple of tensors as an argument to a function, we wouldn't apply any sharding on it. This is split from #26 , where I originally found this issue * Fix bad copy-paste
1 parent ba73b2e commit 20e53b0

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

autoparallel/compute_estimation.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -213,25 +213,23 @@ def estimate_strategy_runtime_cost(node, strategy):
213213
args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args)
214214
kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs)
215215

216-
fake_mode = torch._guards.detect_fake_mode(args)
217-
218216
if len(kwargs) > 0:
219217
for k, v in kwargs.items():
220218
assert not isinstance(v, torch.Tensor), f"{node} {v}"
221219
args_sizes_strides = tuple(
222220
_get_sharded_shape_stride(spec) for spec in strategy.input_specs
223221
)
224222

223+
flat_args, treespec = tree_flatten(args)
224+
new_flat_args = []
225225
counter = 0
226-
args = list(args)
227-
for i, arg in enumerate(args):
228-
if isinstance(arg, torch.Tensor):
229-
with fake_mode:
230-
sizes, strides = args_sizes_strides[counter]
231-
args[i] = torch.empty_strided(
232-
sizes, strides, device=arg.device, dtype=arg.dtype
233-
)
226+
for x in flat_args:
227+
if isinstance(x, torch.Tensor):
228+
sizes, strides = args_sizes_strides[counter]
229+
x = torch.empty_strided(sizes, strides, device=x.device, dtype=x.dtype)
234230
counter += 1
231+
new_flat_args.append(x)
232+
args = treespec.unflatten(new_flat_args)
235233

236234
# TODO: maybe cache the flop_counter to avoid recreating it
237235
# all the time

0 commit comments

Comments
 (0)