|
3 | 3 | # This source code is licensed under the BSD license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
| 6 | +import itertools |
6 | 7 | import logging |
7 | 8 | from contextlib import ExitStack, contextmanager |
| 9 | +from typing import Optional |
8 | 10 |
|
9 | 11 | import torch |
10 | 12 | from torch.distributed.tensor import DTensor |
11 | | -from torch.distributed.tensor._op_schema import OpSchema, StrategyType |
12 | | -from torch.distributed.tensor._ops.utils import register_op_strategy |
| 13 | +from torch.distributed.tensor._dtensor_spec import DTensorSpec |
| 14 | +from torch.distributed.tensor._op_schema import ( |
| 15 | + OpSchema, |
| 16 | + OpSpec, |
| 17 | + OpStrategy, |
| 18 | + StrategyType, |
| 19 | +) |
| 20 | +from torch.distributed.tensor._ops.utils import ( |
| 21 | + generate_redistribute_costs, |
| 22 | + is_tensor_shardable, |
| 23 | + register_op_strategy, |
| 24 | +) |
| 25 | +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard |
| 26 | + |
| 27 | +try: |
| 28 | + from torch.utils._cxx_pytree import tree_leaves |
| 29 | +except ImportError: |
| 30 | + from torch.utils._pytree import tree_leaves # type: ignore[no-redef] |
| 31 | + |
13 | 32 |
|
14 | 33 | logger = logging.getLogger(__name__) |
15 | 34 |
|
@@ -39,19 +58,153 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): |
39 | 58 | None |
40 | 59 | """ |
41 | 60 | propagator = DTensor._op_dispatcher.sharding_propagator |
| 61 | + _origin_op_strategy_funcs = None |
| 62 | + _origin_op_strategy_schema = None |
42 | 63 | try: |
43 | 64 | # register the op strategy |
44 | | - register_op_strategy(op_overload, schema_info=schema_info)(strategy_func) |
45 | | - yield |
46 | | - finally: |
47 | | - # clear this op strategy cache |
48 | 65 | if op_overload in propagator.op_strategy_funcs: |
| 66 | + _origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload] |
49 | 67 | del propagator.op_strategy_funcs[op_overload] |
50 | 68 | if op_overload in propagator.op_to_schema_info: |
| 69 | + _origin_op_strategy_schema = propagator.op_to_schema_info[op_overload] |
51 | 70 | del propagator.op_to_schema_info[op_overload] |
| 71 | + register_op_strategy(op_overload, schema_info=schema_info)(strategy_func) |
| 72 | + yield |
| 73 | + finally: |
| 74 | + # clear this op strategy cache |
| 75 | + if _origin_op_strategy_funcs is None: |
| 76 | + if op_overload in propagator.op_strategy_funcs: |
| 77 | + del propagator.op_strategy_funcs[op_overload] |
| 78 | + else: |
| 79 | + propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs |
| 80 | + if _origin_op_strategy_schema is None: |
| 81 | + if op_overload in propagator.op_to_schema_info: |
| 82 | + del propagator.op_to_schema_info[op_overload] |
| 83 | + else: |
| 84 | + propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema |
52 | 85 | propagator.propagate_op_sharding.cache.cache_clear() |
53 | 86 |
|
54 | 87 |
|
| 88 | +# -------------define universal op strategy------------- |
| 89 | +def batch_shard_strategy( |
| 90 | + op_schema: OpSchema, |
| 91 | + input_shard_dim: list[Optional[int]], |
| 92 | + output_shard_dim: list[Optional[int]], |
| 93 | + enable_shard_batch_dim_over_multiple_axis: bool = False, |
| 94 | +) -> OpStrategy: |
| 95 | + """ |
| 96 | + Shard the input tensor over the specified dimensions. The strategy will map |
| 97 | + batch dim of input/output tensors to the same device mesh axis (or same |
| 98 | + multiple device axes). All input must either have one specified batch dim or |
| 99 | + no batch dim. If an input doesn't have batch dim, the strategy will assume |
| 100 | + the tensor will be broadcasted to batch dim and processed by the operator. |
| 101 | + For inputs specified with a batch dim, user need to make sure the batch dim |
| 102 | + size are all the same. Output should always have a batch dim. |
| 103 | +
|
| 104 | + Args: |
| 105 | + op_schema (OpSchema): the op schema. |
| 106 | +
|
| 107 | + input_shard_dim (list[Optional[int]]): the list of shard dimensions to |
| 108 | + consider for each input tensor argument. Use `None` if no batch dim of |
| 109 | + the input arg. If an arg is List[Tenor], we flatten it first and then |
| 110 | + match with input_shard_dim. Since the dim is not specific to the device |
| 111 | + mesh axis, it can be a combination of any device axes. Example 1: input |
| 112 | + tensor A[1024,64,8], B[1024,64,16], with input_shard_dim = [1,1], it can |
| 113 | + shard A's dim 0 over device axis X, and shard B's dim 0 over device axis |
| 114 | + X. X can be any of device axes. The output follow the same sharding as |
| 115 | + input. Example 2: input tensor A[64,8], B[64,16,1024], C[64,8], with |
| 116 | + input_shard_dim = [None,2,None], it will Replicate A,C over all device |
| 117 | + dim and only shard B's dim 2 over the device mesh. Assume the device |
| 118 | + mesh has 3 axis, then tensor B's placement can be (Shard(2), Shard(2), |
| 119 | + Replicate()), (Shard(2), Replicate(), Shard(2)), (Replicate(), Shard(2), |
| 120 | + Shard(2)). |
| 121 | +
|
| 122 | + output_shard_dim (list[Optional[int]]): the list of shard dimensions to |
| 123 | + consider for each output tensor argument. Use `None` if no batch dim of |
| 124 | + the output arg. For example, if the output is a single tensor and is |
| 125 | + sharded on dim 0, pass in [0] then. |
| 126 | +
|
| 127 | + enable_shard_batch_dim_over_multiple_axis (bool): if True, the strategy |
| 128 | + will try also map batch dim to multiple device axis. Default is False. |
| 129 | +
|
| 130 | + Note: It is the user's responsibility to make sure the sharded tensor for |
| 131 | + processing is correct in shape. |
| 132 | + """ |
| 133 | + output_type = [str(ret.type) for ret in op_schema.op._schema.returns] |
| 134 | + # TODO(zpcore): Confirm if view op can be handle properly or not. Prevent |
| 135 | + # handling view ops until confirmed. |
| 136 | + if op_schema.op.is_view: |
| 137 | + raise RuntimeError( |
| 138 | + "fallback strategy is unable to handle view ops until confirmed" |
| 139 | + ) |
| 140 | + if "List[Tensor]" in output_type: |
| 141 | + raise RuntimeError( |
| 142 | + "fallback strategy is unable to handle ops with List[Tensor] output " |
| 143 | + "because size of the list may depend on the op's input value" |
| 144 | + ) |
| 145 | + inputs_strategy = tree_leaves(op_schema.args_strategy) |
| 146 | + assert len(inputs_strategy) == len(input_shard_dim) |
| 147 | + output_strategy = OpStrategy([]) |
| 148 | + mesh = inputs_strategy[0].mesh |
| 149 | + device_axis = list(range(mesh.ndim)) |
| 150 | + use_how_many_axis = ( |
| 151 | + [i + 1 for i in range(mesh.ndim)] |
| 152 | + if enable_shard_batch_dim_over_multiple_axis |
| 153 | + else [1] |
| 154 | + ) |
| 155 | + # number of device axes to shard on for the batch dim |
| 156 | + for num_axis in use_how_many_axis: |
| 157 | + device_combinations = list(itertools.combinations(device_axis, num_axis)) |
| 158 | + # e.g., if num_axis == 2, device_combinations = [(0,1), (0,2), (1,2), |
| 159 | + # ...]. Then One feasible strategy is to shard tensor dim on both axis |
| 160 | + # (0,1). We check all combinations in device_combinations below. |
| 161 | + for comb in device_combinations: |
| 162 | + input_specs_list: list[DTensorSpec] = [] |
| 163 | + output_specs_list: list[DTensorSpec] = [] |
| 164 | + is_shardable = True |
| 165 | + for op_stratgy, dim in zip(inputs_strategy, input_shard_dim): |
| 166 | + # create a new list of shard_dim_option |
| 167 | + new_placements: list[Placement] = [Replicate()] * mesh.ndim |
| 168 | + for axis in comb: |
| 169 | + new_placements[axis] = Shard(dim) if dim else Replicate() |
| 170 | + tensor_meta = op_stratgy.strategies[0].output_spec.tensor_meta |
| 171 | + new_input_spec = DTensorSpec( |
| 172 | + mesh, |
| 173 | + tuple(new_placements), |
| 174 | + tensor_meta=op_stratgy.strategies[0].output_spec.tensor_meta, |
| 175 | + ) |
| 176 | + if not is_tensor_shardable(tensor_meta.shape, new_input_spec): |
| 177 | + is_shardable = False |
| 178 | + break |
| 179 | + input_specs_list.append(new_input_spec) |
| 180 | + if not is_shardable: |
| 181 | + continue |
| 182 | + for dim in output_shard_dim: |
| 183 | + new_placements = [Replicate()] * mesh.ndim |
| 184 | + for axis in comb: |
| 185 | + new_placements[axis] = Shard(dim) if dim else Replicate() |
| 186 | + output_spec = DTensorSpec( |
| 187 | + mesh, |
| 188 | + tuple(new_placements), |
| 189 | + ) |
| 190 | + output_specs_list.append(output_spec) |
| 191 | + |
| 192 | + output_specs = ( |
| 193 | + output_specs_list[0] |
| 194 | + if len(output_specs_list) == 1 |
| 195 | + else tuple(output_specs_list) |
| 196 | + ) |
| 197 | + input_specs = input_specs_list |
| 198 | + redistribute_cost = [ |
| 199 | + generate_redistribute_costs(strat, input_specs_list[i]) |
| 200 | + for i, strat in enumerate(inputs_strategy) |
| 201 | + ] |
| 202 | + output_strategy.strategies.append( |
| 203 | + OpSpec(output_specs, input_specs, redistribute_cost) # type: ignore |
| 204 | + ) |
| 205 | + return output_strategy |
| 206 | + |
| 207 | + |
55 | 208 | def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType: |
56 | 209 | global enable_implicit_replication, _current_stack |
57 | 210 |
|
|
0 commit comments