Skip to content

Commit 0208656

Browse files
committed
Fix other callsites
1 parent afa0050 commit 0208656

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

autoparallel/compute_estimation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _get_device_tflops(dtype):
155155
return device_limit.gemm_tflops[dtype]
156156

157157

158-
def _get_sharded_shape(spec):
158+
def _get_sharded_shape_stride(spec):
159159
mesh = spec.mesh
160160
tensor_shape = spec.tensor_meta.shape
161161
# TODO: take dtype into account as well
@@ -196,7 +196,7 @@ def estimate_strategy_runtime_cost(node, strategy):
196196
for k, v in kwargs.items():
197197
assert not isinstance(v, torch.Tensor), f"{node} {v}"
198198
args_sizes_strides = tuple(
199-
_get_sharded_shape(spec) for spec in strategy.input_specs
199+
_get_sharded_shape_stride(spec) for spec in strategy.input_specs
200200
)
201201

202202
counter = 0

autoparallel/optimize_sharding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from torch.distributed.tensor.placement_types import Replicate, Shard
1212
from torch.utils._pytree import tree_flatten, tree_map_only
1313

14-
from .compute_estimation import _get_sharded_shape, estimate_strategy_runtime_cost
14+
from .compute_estimation import (
15+
_get_sharded_shape_stride,
16+
estimate_strategy_runtime_cost,
17+
)
1518
from .propagation_rules import _create_all_options
1619
from .utils import get_placement_options
1720

@@ -488,7 +491,7 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high)
488491
data = self.ds[(s_i, 0, ii, 0)]
489492
spec = data["inp_strat"]
490493
tensor_shape = spec.tensor_meta.shape
491-
new_tensor_shape = _get_sharded_shape(spec)
494+
new_tensor_shape, _ = _get_sharded_shape_stride(spec)
492495
new_size = math.prod(new_tensor_shape)
493496
old_size = math.prod(tensor_shape)
494497
elms.append(data["va"] * new_size / old_size)

0 commit comments

Comments
 (0)