From afa005063b108b6f9045b79cb15312d1221d8a7e Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 9 Jul 2025 17:14:34 -0700 Subject: [PATCH 1/2] preserve tensor striding during compute estimation --- autoparallel/compute_estimation.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index e6e9bcb2..37513105 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -164,11 +164,15 @@ def _get_sharded_shape(spec): # TODO: find a better heuristic other than # running DTensor new_tensor_shape = list(tensor_shape) + new_tensor_stride = list(spec.tensor_meta.stride) for mesh_size, placement in zip(mesh.shape, placements): if placement.is_shard(): dim = placement.dim new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size - return new_tensor_shape + new_tensor_stride[dim] = ( + new_tensor_stride[dim] + mesh_size - 1 + ) // mesh_size + return new_tensor_shape, new_tensor_stride def estimate_strategy_runtime_cost(node, strategy): @@ -191,15 +195,18 @@ def estimate_strategy_runtime_cost(node, strategy): if len(kwargs) > 0: for k, v in kwargs.items(): assert not isinstance(v, torch.Tensor), f"{node} {v}" - args_shapes = tuple(_get_sharded_shape(spec) for spec in strategy.input_specs) + args_sizes_strides = tuple( + _get_sharded_shape(spec) for spec in strategy.input_specs + ) counter = 0 args = list(args) for i, arg in enumerate(args): if isinstance(arg, torch.Tensor): with fake_mode: - args[i] = torch.empty( - args_shapes[counter], device=arg.device, dtype=arg.dtype + sizes, strides = args_sizes_strides[counter] + args[i] = torch.empty_strided( + sizes, strides, device=arg.device, dtype=arg.dtype ) counter += 1 From 02086566ed11ee996e46d77fd7d57bf33a1492fc Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 17 Jul 2025 15:41:26 -0700 Subject: [PATCH 2/2] Fix other callsites --- autoparallel/compute_estimation.py | 4 ++-- autoparallel/optimize_sharding.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index 37513105..c104a42e 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -155,7 +155,7 @@ def _get_device_tflops(dtype): return device_limit.gemm_tflops[dtype] -def _get_sharded_shape(spec): +def _get_sharded_shape_stride(spec): mesh = spec.mesh tensor_shape = spec.tensor_meta.shape # TODO: take dtype into account as well @@ -196,7 +196,7 @@ def estimate_strategy_runtime_cost(node, strategy): for k, v in kwargs.items(): assert not isinstance(v, torch.Tensor), f"{node} {v}" args_sizes_strides = tuple( - _get_sharded_shape(spec) for spec in strategy.input_specs + _get_sharded_shape_stride(spec) for spec in strategy.input_specs ) counter = 0 diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 56051360..9eb025de 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -11,7 +11,10 @@ from torch.distributed.tensor.placement_types import Replicate, Shard from torch.utils._pytree import tree_flatten, tree_map_only -from .compute_estimation import _get_sharded_shape, estimate_strategy_runtime_cost +from .compute_estimation import ( + _get_sharded_shape_stride, + estimate_strategy_runtime_cost, +) from .propagation_rules import _create_all_options from .utils import get_placement_options @@ -488,7 +491,7 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high) data = self.ds[(s_i, 0, ii, 0)] spec = data["inp_strat"] tensor_shape = spec.tensor_meta.shape - new_tensor_shape = _get_sharded_shape(spec) + new_tensor_shape, _ = _get_sharded_shape_stride(spec) new_size = math.prod(new_tensor_shape) old_size = math.prod(tensor_shape) elms.append(data["va"] * new_size / old_size)