Skip to content

Commit cd27579

Browse files
authored
Account for compute cost in collectives during redistribution (#125)
* Account for compute cost in collectives during redistribution This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement. * Add better cost model to compute part * Tweak a2a cost for now This needs to be improved * Switch to using alias_v2 * Adapt example with new alias policy
1 parent bb3252b commit cd27579

File tree

4 files changed

+49
-63
lines changed

4 files changed

+49
-63
lines changed

autoparallel/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def build_model_graph(self):
277277
_replace_view_mm_view_with_einsum(gm)
278278
# now add aliases nodes to the graph to
279279
# give more room for optimizations
280-
_add_alias(gm, version="v1")
280+
_add_alias(gm, version="v2")
281281
trace_structured(
282282
"artifact",
283283
metadata_fn=lambda: {

autoparallel/collective_runtime_estimation.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import cast
7+
68
import torch.distributed.tensor._dtensor_spec as dtensor_spec
79
from torch.distributed.tensor._collective_utils import (
810
MeshTopoInfo,
@@ -13,21 +15,24 @@
1315
)
1416
from torch.distributed.tensor.placement_types import Partial, Shard
1517

18+
from .compute_estimation import _get_device_gmem_bandwidth
19+
1620

1721
def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
1822
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
1923
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
20-
num_hops = num_devices_on_mesh_dim**2
24+
num_hops = num_devices_on_mesh_dim - 1
2125
# base latency + comm latency
2226
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us
2327
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s
24-
return latency + bw * 1e6 # rescale to us
28+
total_time = latency + bw * 1e6 # rescale to us
29+
# FIXME: this is a hack, we need to spend some more effort on the cost model
30+
total_time *= 5
31+
return total_time
2532

2633

2734
# this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py
2835
# with iteration order introduced
29-
# TODO: this should be improved, as we just really use the non-canonical order for
30-
# PP->S(0)S(0) for now
3136
def redistribute_cost(
3237
current_spec: "dtensor_spec.DTensorSpec",
3338
target_spec: "dtensor_spec.DTensorSpec",
@@ -50,44 +55,77 @@ def redistribute_cost(
5055
if current_spec.is_replicated():
5156
# short-cut:
5257
# comm cost is 0 if current spec is already full replication
58+
# except if output is partial, which doesn't make sense for us
59+
if any(p.is_partial() for p in target_spec.placements):
60+
return float("inf")
5361
return 0.0
5462

5563
mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
5664
cost = 0.0
5765
comm_bytes_gb = (
5866
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
5967
)
68+
gpu_memory_bandwidth = _get_device_gmem_bandwidth() / 1024**3 # GB/s
6069
# Transformation that considered for redistribute cost:
6170
# 1. allgather 2. alltoall
6271
# 3. allreduce 4. reduce_scatter
6372
curr_placements = [current_spec.placements[i] for i in order]
6473
tgt_placements = [target_spec.placements[i] for i in order]
74+
75+
# suppose 70% efficiency for the non-collective operators
76+
read_write_efficiency = 0.70
77+
kernel_launch_overhead = 7 # us
6578
for i, current, target in zip(order, curr_placements, tgt_placements):
6679
if current == target:
6780
continue
6881
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
6982
if current.is_shard() and target.is_replicate():
83+
current = cast(Shard, current)
7084
# allgather gives larger comm bytes
7185
comm_bytes_gb *= num_devices_on_mesh_dim
7286
# add up allgather comm cost
7387
cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
88+
if current.dim != 0:
89+
# penalize cases like S(1) -> R as there are additional compute cost
90+
# which corresponds to reshuffling the whole output tensor
91+
# we multiply the cost by 2 because we need to count input and output
92+
# reads for the reshuffle
93+
compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
94+
compute_cost = max(
95+
compute_cost / read_write_efficiency, kernel_launch_overhead
96+
)
97+
cost += compute_cost
7498
elif current.is_shard() and target.is_shard():
7599
# should be alltoall comm, since we haven't implement it yet, add penalty
76100
# to favor allgather instead
77-
# cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i)
78-
cost += allgather_cost(comm_bytes_gb, mesh_topo, i) * 4.0
101+
cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us
79102
elif current.is_partial() and target.is_replicate():
80103
# add up allreduce comm cost
81104
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
82105
elif current.is_partial() and target.is_shard():
106+
target = cast(Shard, target)
83107
# add up reduce_scatter comm cost
84108
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
109+
if target.dim != 0:
110+
# penalize cases like P -> S(1) as there are additional compute cost
111+
# which corresponds to reshuffling the whole input tensor
112+
# we multiply the cost by 2 because we need to count input and output
113+
# reads for the reshuffle
114+
compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
115+
compute_cost = max(
116+
compute_cost / read_write_efficiency, kernel_launch_overhead
117+
)
118+
cost += compute_cost
85119
# after reduce_scatter the comm bytes for further collectives halved.
86120
comm_bytes_gb /= num_devices_on_mesh_dim
87121
elif current.is_shard() and target.is_partial():
88122
# ban shard -> partial as it does not make sense to perform
89123
# this redistribute
90124
return float("inf")
125+
elif current.is_replicate() and target.is_partial():
126+
# ban replicate -> partial as it does not make sense to perform
127+
# this redistribute in our case
128+
return float("inf")
91129

92130
return cost
93131

autoparallel/optimize_sharding.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -531,61 +531,6 @@ def add_default_constraints(self):
531531
self.add_output_input_consistent_constraint()
532532
self.add_inf_cost_constraint()
533533

534-
self.penalize_inefficient_collectives()
535-
536-
def penalize_inefficient_collectives(self):
537-
"""
538-
EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like
539-
non-batch dimension shard-to-replicate conversions and forbid invalid transitions.
540-
541-
- Shard(dim≠0) → Replicate: multiply cost by 4
542-
- Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
543-
- Partial → Shard(dim≠0): multiply cost by 4
544-
545-
When performing shard_{n} -> replicate (for n != 0), there is additional
546-
computation cost associated. Let's penalize it here while we don't add
547-
the computation cost together in the comm cost
548-
"""
549-
# return
550-
for s_i, node in enumerate(self.graph.nodes):
551-
if node.op != "call_function":
552-
continue
553-
tgt_op_strat = self.strats[node]
554-
for counter, parent in enumerate(node.all_input_nodes):
555-
curr_op_strat = self.strats[parent]
556-
557-
for oi, tgt_strat in enumerate(tgt_op_strat.strategies):
558-
spec = tgt_strat.input_specs[counter]
559-
if not isinstance(spec, DTensorSpec):
560-
# TODO: check if this is correct
561-
continue
562-
563-
for ii, curr_strat in enumerate(curr_op_strat.strategies):
564-
curr_spec = curr_strat.output_specs
565-
if not isinstance(curr_spec, DTensorSpec):
566-
continue
567-
for tgt_plc, curr_plc in zip(
568-
spec.placements, curr_spec.placements
569-
):
570-
if (
571-
tgt_plc.is_replicate()
572-
and curr_plc.is_shard()
573-
and curr_plc.dim != 0
574-
):
575-
# penalize case S(1) -> R as there are additional compute cost
576-
# TODO: add proper compute cost in the optimization objective
577-
self.ds[(s_i, counter, oi, ii)]["cost"] *= 4
578-
elif tgt_plc.is_partial() and curr_plc.is_replicate():
579-
# forbit R -> P case as this doesn't make sense for us
580-
self.prob += self.ds[(s_i, counter, oi, ii)]["va"] == 0
581-
elif (
582-
tgt_plc.is_shard()
583-
and tgt_plc.dim != 0
584-
and curr_plc.is_partial()
585-
):
586-
# penalize case P -> S(1) as there are additional compute cost
587-
self.ds[(s_i, counter, oi, ii)]["cost"] *= 4
588-
589534
def get_violated_constraints_log(self):
590535
violated_constraints = [
591536
(k, c) for k, c in self.prob.constraints.items() if not c.valid()

examples/example_autoparallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,14 @@ def input_fn():
154154
if "getitem" in str(n.target):
155155
# getitem nodes are tagged same as their parent
156156
expected = policy_fn(None, n.args[0].target, (), ())
157+
elif "alias" in str(n.target) and "getitem" in str(n.args[0].target):
158+
# alias nodes that depend on getitem are tagged same as their parent
159+
expected = policy_fn(None, n.args[0].args[0].target, (), ())
157160
else:
158161
expected = policy_fn(None, n.target, (), ())
159162
actual = n.meta.get("recompute")
160163
# NOTE: this assert only supports policy_fns on op alone
161-
assert actual == expected
164+
assert actual == expected, f"{n} {actual} {expected}"
162165
seqs.add(n.meta["seq_nr"])
163166
else:
164167
# fwd counterpart should have already populated seqs

0 commit comments

Comments
 (0)