Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions autoparallel/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _get_device_tflops(dtype):
return device_limit.gemm_tflops[dtype]


def _get_sharded_shape(spec):
def _get_sharded_shape_stride(spec):
mesh = spec.mesh
tensor_shape = spec.tensor_meta.shape
# TODO: take dtype into account as well
Expand All @@ -164,11 +164,15 @@ def _get_sharded_shape(spec):
# TODO: find a better heuristic other than
# running DTensor
new_tensor_shape = list(tensor_shape)
new_tensor_stride = list(spec.tensor_meta.stride)
for mesh_size, placement in zip(mesh.shape, placements):
if placement.is_shard():
dim = placement.dim
new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size
return new_tensor_shape
new_tensor_stride[dim] = (
new_tensor_stride[dim] + mesh_size - 1
) // mesh_size
return new_tensor_shape, new_tensor_stride


def estimate_strategy_runtime_cost(node, strategy):
Expand All @@ -191,15 +195,18 @@ def estimate_strategy_runtime_cost(node, strategy):
if len(kwargs) > 0:
for k, v in kwargs.items():
assert not isinstance(v, torch.Tensor), f"{node} {v}"
args_shapes = tuple(_get_sharded_shape(spec) for spec in strategy.input_specs)
args_sizes_strides = tuple(
_get_sharded_shape_stride(spec) for spec in strategy.input_specs
)

counter = 0
args = list(args)
for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
with fake_mode:
args[i] = torch.empty(
args_shapes[counter], device=arg.device, dtype=arg.dtype
sizes, strides = args_sizes_strides[counter]
args[i] = torch.empty_strided(
sizes, strides, device=arg.device, dtype=arg.dtype
)
counter += 1

Expand Down
7 changes: 5 additions & 2 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.utils._pytree import tree_flatten, tree_map_only

from .compute_estimation import _get_sharded_shape, estimate_strategy_runtime_cost
from .compute_estimation import (
_get_sharded_shape_stride,
estimate_strategy_runtime_cost,
)
from .propagation_rules import _create_all_options
from .utils import get_placement_options

Expand Down Expand Up @@ -488,7 +491,7 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high)
data = self.ds[(s_i, 0, ii, 0)]
spec = data["inp_strat"]
tensor_shape = spec.tensor_meta.shape
new_tensor_shape = _get_sharded_shape(spec)
new_tensor_shape, _ = _get_sharded_shape_stride(spec)
new_size = math.prod(new_tensor_shape)
old_size = math.prod(tensor_shape)
elms.append(data["va"] * new_size / old_size)
Expand Down