33# This source code is licensed under the BSD license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ from typing import cast
7+
68import torch .distributed .tensor ._dtensor_spec as dtensor_spec
79from torch .distributed .tensor ._collective_utils import (
810 MeshTopoInfo ,
1315)
1416from torch .distributed .tensor .placement_types import Partial , Shard
1517
18+ from .compute_estimation import _get_device_gmem_bandwidth
19+
1620
1721def all_to_all_cost (bytes_gb : float , mesh_topo : MeshTopoInfo , mesh_dim : int ) -> float :
1822 num_devices_on_mesh_dim = mesh_topo .mesh_dim_devices [mesh_dim ]
1923 mesh_dim_bandwidth = mesh_topo .mesh_dim_bandwidth [mesh_dim ]
20- num_hops = num_devices_on_mesh_dim ** 2
24+ num_hops = num_devices_on_mesh_dim - 1
2125 # base latency + comm latency
2226 latency = 6.6 + num_hops * mesh_topo .mesh_dim_latency [mesh_dim ] # us
2327 bw = (bytes_gb * num_hops / num_devices_on_mesh_dim ) / mesh_dim_bandwidth # s
24- return latency + bw * 1e6 # rescale to us
28+ total_time = latency + bw * 1e6 # rescale to us
29+ # FIXME: this is a hack, we need to spend some more effort on the cost model
30+ total_time *= 5
31+ return total_time
2532
2633
2734# this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py
2835# with iteration order introduced
29- # TODO: this should be improved, as we just really use the non-canonical order for
30- # PP->S(0)S(0) for now
3136def redistribute_cost (
3237 current_spec : "dtensor_spec.DTensorSpec" ,
3338 target_spec : "dtensor_spec.DTensorSpec" ,
@@ -50,44 +55,77 @@ def redistribute_cost(
5055 if current_spec .is_replicated ():
5156 # short-cut:
5257 # comm cost is 0 if current spec is already full replication
58+ # except if output is partial, which doesn't make sense for us
59+ if any (p .is_partial () for p in target_spec .placements ):
60+ return float ("inf" )
5361 return 0.0
5462
5563 mesh_topo = MeshTopoInfo .build_from_mesh (current_spec .mesh )
5664 cost = 0.0
5765 comm_bytes_gb = (
5866 spec_to_bytes (current_spec ) / current_spec .num_shards / 1024 / 1024 / 1024
5967 )
68+ gpu_memory_bandwidth = _get_device_gmem_bandwidth () / 1024 ** 3 # GB/s
6069 # Transformation that considered for redistribute cost:
6170 # 1. allgather 2. alltoall
6271 # 3. allreduce 4. reduce_scatter
6372 curr_placements = [current_spec .placements [i ] for i in order ]
6473 tgt_placements = [target_spec .placements [i ] for i in order ]
74+
75+ # suppose 70% efficiency for the non-collective operators
76+ read_write_efficiency = 0.70
77+ kernel_launch_overhead = 7 # us
6578 for i , current , target in zip (order , curr_placements , tgt_placements ):
6679 if current == target :
6780 continue
6881 num_devices_on_mesh_dim = mesh_topo .mesh_dim_devices [i ]
6982 if current .is_shard () and target .is_replicate ():
83+ current = cast (Shard , current )
7084 # allgather gives larger comm bytes
7185 comm_bytes_gb *= num_devices_on_mesh_dim
7286 # add up allgather comm cost
7387 cost += allgather_cost (comm_bytes_gb , mesh_topo , i )
88+ if current .dim != 0 :
89+ # penalize cases like S(1) -> R as there are additional compute cost
90+ # which corresponds to reshuffling the whole output tensor
91+ # we multiply the cost by 2 because we need to count input and output
92+ # reads for the reshuffle
93+ compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
94+ compute_cost = max (
95+ compute_cost / read_write_efficiency , kernel_launch_overhead
96+ )
97+ cost += compute_cost
7498 elif current .is_shard () and target .is_shard ():
7599 # should be alltoall comm, since we haven't implement it yet, add penalty
76100 # to favor allgather instead
77- # cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i)
78- cost += allgather_cost (comm_bytes_gb , mesh_topo , i ) * 4.0
101+ cost += all_to_all_cost (comm_bytes_gb , mesh_topo , i ) # us
79102 elif current .is_partial () and target .is_replicate ():
80103 # add up allreduce comm cost
81104 cost += allreduce_cost (comm_bytes_gb , mesh_topo , i )
82105 elif current .is_partial () and target .is_shard ():
106+ target = cast (Shard , target )
83107 # add up reduce_scatter comm cost
84108 cost += reduce_scatter_cost (comm_bytes_gb , mesh_topo , i )
109+ if target .dim != 0 :
110+ # penalize cases like P -> S(1) as there are additional compute cost
111+ # which corresponds to reshuffling the whole input tensor
112+ # we multiply the cost by 2 because we need to count input and output
113+ # reads for the reshuffle
114+ compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
115+ compute_cost = max (
116+ compute_cost / read_write_efficiency , kernel_launch_overhead
117+ )
118+ cost += compute_cost
85119 # after reduce_scatter the comm bytes for further collectives halved.
86120 comm_bytes_gb /= num_devices_on_mesh_dim
87121 elif current .is_shard () and target .is_partial ():
88122 # ban shard -> partial as it does not make sense to perform
89123 # this redistribute
90124 return float ("inf" )
125+ elif current .is_replicate () and target .is_partial ():
126+ # ban replicate -> partial as it does not make sense to perform
127+ # this redistribute in our case
128+ return float ("inf" )
91129
92130 return cost
93131
0 commit comments