diff --git a/autoparallel/api.py b/autoparallel/api.py index 1310ce20..3b766f45 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -277,7 +277,7 @@ def build_model_graph(self): _replace_view_mm_view_with_einsum(gm) # now add aliases nodes to the graph to # give more room for optimizations - _add_alias(gm, version="v1") + _add_alias(gm, version="v2") trace_structured( "artifact", metadata_fn=lambda: { diff --git a/autoparallel/collective_runtime_estimation.py b/autoparallel/collective_runtime_estimation.py index 2023c12b..901a883a 100644 --- a/autoparallel/collective_runtime_estimation.py +++ b/autoparallel/collective_runtime_estimation.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from typing import cast + import torch.distributed.tensor._dtensor_spec as dtensor_spec from torch.distributed.tensor._collective_utils import ( MeshTopoInfo, @@ -13,21 +15,24 @@ ) from torch.distributed.tensor.placement_types import Partial, Shard +from .compute_estimation import _get_device_gmem_bandwidth + def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] - num_hops = num_devices_on_mesh_dim**2 + num_hops = num_devices_on_mesh_dim - 1 # base latency + comm latency latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s - return latency + bw * 1e6 # rescale to us + total_time = latency + bw * 1e6 # rescale to us + # FIXME: this is a hack, we need to spend some more effort on the cost model + total_time *= 5 + return total_time # this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py # with iteration order introduced -# TODO: this should be improved, as we just really use the non-canonical order for -# PP->S(0)S(0) for now def redistribute_cost( current_spec: "dtensor_spec.DTensorSpec", target_spec: "dtensor_spec.DTensorSpec", @@ -50,6 +55,9 @@ def redistribute_cost( if current_spec.is_replicated(): # short-cut: # comm cost is 0 if current spec is already full replication + # except if output is partial, which doesn't make sense for us + if any(p.is_partial() for p in target_spec.placements): + return float("inf") return 0.0 mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) @@ -57,37 +65,67 @@ def redistribute_cost( comm_bytes_gb = ( spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 ) + gpu_memory_bandwidth = _get_device_gmem_bandwidth() / 1024**3 # GB/s # Transformation that considered for redistribute cost: # 1. allgather 2. alltoall # 3. allreduce 4. reduce_scatter curr_placements = [current_spec.placements[i] for i in order] tgt_placements = [target_spec.placements[i] for i in order] + + # suppose 70% efficiency for the non-collective operators + read_write_efficiency = 0.70 + kernel_launch_overhead = 7 # us for i, current, target in zip(order, curr_placements, tgt_placements): if current == target: continue num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] if current.is_shard() and target.is_replicate(): + current = cast(Shard, current) # allgather gives larger comm bytes comm_bytes_gb *= num_devices_on_mesh_dim # add up allgather comm cost cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + if current.dim != 0: + # penalize cases like S(1) -> R as there are additional compute cost + # which corresponds to reshuffling the whole output tensor + # we multiply the cost by 2 because we need to count input and output + # reads for the reshuffle + compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us + compute_cost = max( + compute_cost / read_write_efficiency, kernel_launch_overhead + ) + cost += compute_cost elif current.is_shard() and target.is_shard(): # should be alltoall comm, since we haven't implement it yet, add penalty # to favor allgather instead - # cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) - cost += allgather_cost(comm_bytes_gb, mesh_topo, i) * 4.0 + cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us elif current.is_partial() and target.is_replicate(): # add up allreduce comm cost cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) elif current.is_partial() and target.is_shard(): + target = cast(Shard, target) # add up reduce_scatter comm cost cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) + if target.dim != 0: + # penalize cases like P -> S(1) as there are additional compute cost + # which corresponds to reshuffling the whole input tensor + # we multiply the cost by 2 because we need to count input and output + # reads for the reshuffle + compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us + compute_cost = max( + compute_cost / read_write_efficiency, kernel_launch_overhead + ) + cost += compute_cost # after reduce_scatter the comm bytes for further collectives halved. comm_bytes_gb /= num_devices_on_mesh_dim elif current.is_shard() and target.is_partial(): # ban shard -> partial as it does not make sense to perform # this redistribute return float("inf") + elif current.is_replicate() and target.is_partial(): + # ban replicate -> partial as it does not make sense to perform + # this redistribute in our case + return float("inf") return cost diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 3aa97dbd..dd828434 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -531,61 +531,6 @@ def add_default_constraints(self): self.add_output_input_consistent_constraint() self.add_inf_cost_constraint() - self.penalize_inefficient_collectives() - - def penalize_inefficient_collectives(self): - """ - EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like - non-batch dimension shard-to-replicate conversions and forbid invalid transitions. - - - Shard(dim≠0) → Replicate: multiply cost by 4 - - Replicate → Partial: x_{i,a,o,j} = 0 (forbidden) - - Partial → Shard(dim≠0): multiply cost by 4 - - When performing shard_{n} -> replicate (for n != 0), there is additional - computation cost associated. Let's penalize it here while we don't add - the computation cost together in the comm cost - """ - # return - for s_i, node in enumerate(self.graph.nodes): - if node.op != "call_function": - continue - tgt_op_strat = self.strats[node] - for counter, parent in enumerate(node.all_input_nodes): - curr_op_strat = self.strats[parent] - - for oi, tgt_strat in enumerate(tgt_op_strat.strategies): - spec = tgt_strat.input_specs[counter] - if not isinstance(spec, DTensorSpec): - # TODO: check if this is correct - continue - - for ii, curr_strat in enumerate(curr_op_strat.strategies): - curr_spec = curr_strat.output_specs - if not isinstance(curr_spec, DTensorSpec): - continue - for tgt_plc, curr_plc in zip( - spec.placements, curr_spec.placements - ): - if ( - tgt_plc.is_replicate() - and curr_plc.is_shard() - and curr_plc.dim != 0 - ): - # penalize case S(1) -> R as there are additional compute cost - # TODO: add proper compute cost in the optimization objective - self.ds[(s_i, counter, oi, ii)]["cost"] *= 4 - elif tgt_plc.is_partial() and curr_plc.is_replicate(): - # forbit R -> P case as this doesn't make sense for us - self.prob += self.ds[(s_i, counter, oi, ii)]["va"] == 0 - elif ( - tgt_plc.is_shard() - and tgt_plc.dim != 0 - and curr_plc.is_partial() - ): - # penalize case P -> S(1) as there are additional compute cost - self.ds[(s_i, counter, oi, ii)]["cost"] *= 4 - def get_violated_constraints_log(self): violated_constraints = [ (k, c) for k, c in self.prob.constraints.items() if not c.valid() diff --git a/examples/example_autoparallel.py b/examples/example_autoparallel.py index 5826b17e..9932ca1b 100644 --- a/examples/example_autoparallel.py +++ b/examples/example_autoparallel.py @@ -154,11 +154,14 @@ def input_fn(): if "getitem" in str(n.target): # getitem nodes are tagged same as their parent expected = policy_fn(None, n.args[0].target, (), ()) + elif "alias" in str(n.target) and "getitem" in str(n.args[0].target): + # alias nodes that depend on getitem are tagged same as their parent + expected = policy_fn(None, n.args[0].args[0].target, (), ()) else: expected = policy_fn(None, n.target, (), ()) actual = n.meta.get("recompute") # NOTE: this assert only supports policy_fns on op alone - assert actual == expected + assert actual == expected, f"{n} {actual} {expected}" seqs.add(n.meta["seq_nr"]) else: # fwd counterpart should have already populated seqs