Skip to content
2 changes: 1 addition & 1 deletion autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def build_model_graph(self):
_replace_view_mm_view_with_einsum(gm)
# now add aliases nodes to the graph to
# give more room for optimizations
_add_alias(gm, version="v1")
_add_alias(gm, version="v2")
trace_structured(
"artifact",
metadata_fn=lambda: {
Expand Down
50 changes: 44 additions & 6 deletions autoparallel/collective_runtime_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch.distributed.tensor._dtensor_spec as dtensor_spec
from torch.distributed.tensor._collective_utils import (
MeshTopoInfo,
Expand All @@ -13,21 +15,24 @@
)
from torch.distributed.tensor.placement_types import Partial, Shard

from .compute_estimation import _get_device_gmem_bandwidth


def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
num_hops = num_devices_on_mesh_dim**2
num_hops = num_devices_on_mesh_dim - 1
# base latency + comm latency
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s
return latency + bw * 1e6 # rescale to us
total_time = latency + bw * 1e6 # rescale to us
# FIXME: this is a hack, we need to spend some more effort on the cost model
total_time *= 5
return total_time


# this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py
# with iteration order introduced
# TODO: this should be improved, as we just really use the non-canonical order for
# PP->S(0)S(0) for now
def redistribute_cost(
current_spec: "dtensor_spec.DTensorSpec",
target_spec: "dtensor_spec.DTensorSpec",
Expand All @@ -50,44 +55,77 @@ def redistribute_cost(
if current_spec.is_replicated():
# short-cut:
# comm cost is 0 if current spec is already full replication
# except if output is partial, which doesn't make sense for us
if any(p.is_partial() for p in target_spec.placements):
return float("inf")
return 0.0

mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
cost = 0.0
comm_bytes_gb = (
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
)
gpu_memory_bandwidth = _get_device_gmem_bandwidth() / 1024**3 # GB/s
# Transformation that considered for redistribute cost:
# 1. allgather 2. alltoall
# 3. allreduce 4. reduce_scatter
curr_placements = [current_spec.placements[i] for i in order]
tgt_placements = [target_spec.placements[i] for i in order]

# suppose 70% efficiency for the non-collective operators
read_write_efficiency = 0.70
kernel_launch_overhead = 7 # us
for i, current, target in zip(order, curr_placements, tgt_placements):
if current == target:
continue
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
if current.is_shard() and target.is_replicate():
current = cast(Shard, current)
# allgather gives larger comm bytes
comm_bytes_gb *= num_devices_on_mesh_dim
# add up allgather comm cost
cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
if current.dim != 0:
# penalize cases like S(1) -> R as there are additional compute cost
# which corresponds to reshuffling the whole output tensor
# we multiply the cost by 2 because we need to count input and output
# reads for the reshuffle
compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
compute_cost = max(
compute_cost / read_write_efficiency, kernel_launch_overhead
)
cost += compute_cost
elif current.is_shard() and target.is_shard():
# should be alltoall comm, since we haven't implement it yet, add penalty
# to favor allgather instead
# cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i)
cost += allgather_cost(comm_bytes_gb, mesh_topo, i) * 4.0
cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us
elif current.is_partial() and target.is_replicate():
# add up allreduce comm cost
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
elif current.is_partial() and target.is_shard():
target = cast(Shard, target)
# add up reduce_scatter comm cost
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
if target.dim != 0:
# penalize cases like P -> S(1) as there are additional compute cost
# which corresponds to reshuffling the whole input tensor
# we multiply the cost by 2 because we need to count input and output
# reads for the reshuffle
compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
compute_cost = max(
compute_cost / read_write_efficiency, kernel_launch_overhead
)
cost += compute_cost
# after reduce_scatter the comm bytes for further collectives halved.
comm_bytes_gb /= num_devices_on_mesh_dim
elif current.is_shard() and target.is_partial():
# ban shard -> partial as it does not make sense to perform
# this redistribute
return float("inf")
elif current.is_replicate() and target.is_partial():
# ban replicate -> partial as it does not make sense to perform
# this redistribute in our case
return float("inf")

return cost

Expand Down
55 changes: 0 additions & 55 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,61 +531,6 @@ def add_default_constraints(self):
self.add_output_input_consistent_constraint()
self.add_inf_cost_constraint()

self.penalize_inefficient_collectives()

def penalize_inefficient_collectives(self):
"""
EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like
non-batch dimension shard-to-replicate conversions and forbid invalid transitions.

- Shard(dim≠0) → Replicate: multiply cost by 4
- Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
- Partial → Shard(dim≠0): multiply cost by 4

When performing shard_{n} -> replicate (for n != 0), there is additional
computation cost associated. Let's penalize it here while we don't add
the computation cost together in the comm cost
"""
# return
for s_i, node in enumerate(self.graph.nodes):
if node.op != "call_function":
continue
tgt_op_strat = self.strats[node]
for counter, parent in enumerate(node.all_input_nodes):
curr_op_strat = self.strats[parent]

for oi, tgt_strat in enumerate(tgt_op_strat.strategies):
spec = tgt_strat.input_specs[counter]
if not isinstance(spec, DTensorSpec):
# TODO: check if this is correct
continue

for ii, curr_strat in enumerate(curr_op_strat.strategies):
curr_spec = curr_strat.output_specs
if not isinstance(curr_spec, DTensorSpec):
continue
for tgt_plc, curr_plc in zip(
spec.placements, curr_spec.placements
):
if (
tgt_plc.is_replicate()
and curr_plc.is_shard()
and curr_plc.dim != 0
):
# penalize case S(1) -> R as there are additional compute cost
# TODO: add proper compute cost in the optimization objective
self.ds[(s_i, counter, oi, ii)]["cost"] *= 4
elif tgt_plc.is_partial() and curr_plc.is_replicate():
# forbit R -> P case as this doesn't make sense for us
self.prob += self.ds[(s_i, counter, oi, ii)]["va"] == 0
elif (
tgt_plc.is_shard()
and tgt_plc.dim != 0
and curr_plc.is_partial()
):
# penalize case P -> S(1) as there are additional compute cost
self.ds[(s_i, counter, oi, ii)]["cost"] *= 4

def get_violated_constraints_log(self):
violated_constraints = [
(k, c) for k, c in self.prob.constraints.items() if not c.valid()
Expand Down
5 changes: 4 additions & 1 deletion examples/example_autoparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,14 @@ def input_fn():
if "getitem" in str(n.target):
# getitem nodes are tagged same as their parent
expected = policy_fn(None, n.args[0].target, (), ())
elif "alias" in str(n.target) and "getitem" in str(n.args[0].target):
# alias nodes that depend on getitem are tagged same as their parent
expected = policy_fn(None, n.args[0].args[0].target, (), ())
else:
expected = policy_fn(None, n.target, (), ())
actual = n.meta.get("recompute")
# NOTE: this assert only supports policy_fns on op alone
assert actual == expected
assert actual == expected, f"{n} {actual} {expected}"
seqs.add(n.meta["seq_nr"])
else:
# fwd counterpart should have already populated seqs
Expand Down