From cc28a951d9c200c45ac8e672be249ad860d28f80 Mon Sep 17 00:00:00 2001 From: zpcore Date: Tue, 29 Jul 2025 11:21:40 -0700 Subject: [PATCH] Introduce the batch sharding strategy --- autoparallel/dtensor_util/__init__.py | 2 + autoparallel/dtensor_util/utils.py | 165 +++++++++++++++++++++++++- tests/test_dtensor.py | 112 ++++++++++++++++- 3 files changed, 271 insertions(+), 8 deletions(-) diff --git a/autoparallel/dtensor_util/__init__.py b/autoparallel/dtensor_util/__init__.py index f7cdbb63..d378a3d5 100644 --- a/autoparallel/dtensor_util/__init__.py +++ b/autoparallel/dtensor_util/__init__.py @@ -5,6 +5,7 @@ # functions to expose from .utils import ( + batch_shard_strategy, get_op_strategy, op_strategy_context, replicate_op_strategy, @@ -13,6 +14,7 @@ __all__ = [ "replicate_op_strategy", + "batch_shard_strategy", "get_op_strategy", "with_implicit_strategies", "op_strategy_context", diff --git a/autoparallel/dtensor_util/utils.py b/autoparallel/dtensor_util/utils.py index 3d0c5d83..3341e2e9 100644 --- a/autoparallel/dtensor_util/utils.py +++ b/autoparallel/dtensor_util/utils.py @@ -3,13 +3,32 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import itertools import logging from contextlib import ExitStack, contextmanager +from typing import Optional import torch from torch.distributed.tensor import DTensor -from torch.distributed.tensor._op_schema import OpSchema, StrategyType -from torch.distributed.tensor._ops.utils import register_op_strategy +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + StrategyType, +) +from torch.distributed.tensor._ops.utils import ( + generate_redistribute_costs, + is_tensor_shardable, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard + +try: + from torch.utils._cxx_pytree import tree_leaves +except ImportError: + from torch.utils._pytree import tree_leaves # type: ignore[no-redef] + logger = logging.getLogger(__name__) @@ -39,19 +58,153 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): None """ propagator = DTensor._op_dispatcher.sharding_propagator + _origin_op_strategy_funcs = None + _origin_op_strategy_schema = None try: # register the op strategy - register_op_strategy(op_overload, schema_info=schema_info)(strategy_func) - yield - finally: - # clear this op strategy cache if op_overload in propagator.op_strategy_funcs: + _origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload] del propagator.op_strategy_funcs[op_overload] if op_overload in propagator.op_to_schema_info: + _origin_op_strategy_schema = propagator.op_to_schema_info[op_overload] del propagator.op_to_schema_info[op_overload] + register_op_strategy(op_overload, schema_info=schema_info)(strategy_func) + yield + finally: + # clear this op strategy cache + if _origin_op_strategy_funcs is None: + if op_overload in propagator.op_strategy_funcs: + del propagator.op_strategy_funcs[op_overload] + else: + propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs + if _origin_op_strategy_schema is None: + if op_overload in propagator.op_to_schema_info: + del propagator.op_to_schema_info[op_overload] + else: + propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema propagator.propagate_op_sharding.cache.cache_clear() +# -------------define universal op strategy------------- +def batch_shard_strategy( + op_schema: OpSchema, + input_shard_dim: list[Optional[int]], + output_shard_dim: list[Optional[int]], + enable_shard_batch_dim_over_multiple_axis: bool = False, +) -> OpStrategy: + """ + Shard the input tensor over the specified dimensions. The strategy will map + batch dim of input/output tensors to the same device mesh axis (or same + multiple device axes). All input must either have one specified batch dim or + no batch dim. If an input doesn't have batch dim, the strategy will assume + the tensor will be broadcasted to batch dim and processed by the operator. + For inputs specified with a batch dim, user need to make sure the batch dim + size are all the same. Output should always have a batch dim. + + Args: + op_schema (OpSchema): the op schema. + + input_shard_dim (list[Optional[int]]): the list of shard dimensions to + consider for each input tensor argument. Use `None` if no batch dim of + the input arg. If an arg is List[Tenor], we flatten it first and then + match with input_shard_dim. Since the dim is not specific to the device + mesh axis, it can be a combination of any device axes. Example 1: input + tensor A[1024,64,8], B[1024,64,16], with input_shard_dim = [1,1], it can + shard A's dim 0 over device axis X, and shard B's dim 0 over device axis + X. X can be any of device axes. The output follow the same sharding as + input. Example 2: input tensor A[64,8], B[64,16,1024], C[64,8], with + input_shard_dim = [None,2,None], it will Replicate A,C over all device + dim and only shard B's dim 2 over the device mesh. Assume the device + mesh has 3 axis, then tensor B's placement can be (Shard(2), Shard(2), + Replicate()), (Shard(2), Replicate(), Shard(2)), (Replicate(), Shard(2), + Shard(2)). + + output_shard_dim (list[Optional[int]]): the list of shard dimensions to + consider for each output tensor argument. Use `None` if no batch dim of + the output arg. For example, if the output is a single tensor and is + sharded on dim 0, pass in [0] then. + + enable_shard_batch_dim_over_multiple_axis (bool): if True, the strategy + will try also map batch dim to multiple device axis. Default is False. + + Note: It is the user's responsibility to make sure the sharded tensor for + processing is correct in shape. + """ + output_type = [str(ret.type) for ret in op_schema.op._schema.returns] + # TODO(zpcore): Confirm if view op can be handle properly or not. Prevent + # handling view ops until confirmed. + if op_schema.op.is_view: + raise RuntimeError( + "fallback strategy is unable to handle view ops until confirmed" + ) + if "List[Tensor]" in output_type: + raise RuntimeError( + "fallback strategy is unable to handle ops with List[Tensor] output " + "because size of the list may depend on the op's input value" + ) + inputs_strategy = tree_leaves(op_schema.args_strategy) + assert len(inputs_strategy) == len(input_shard_dim) + output_strategy = OpStrategy([]) + mesh = inputs_strategy[0].mesh + device_axis = list(range(mesh.ndim)) + use_how_many_axis = ( + [i + 1 for i in range(mesh.ndim)] + if enable_shard_batch_dim_over_multiple_axis + else [1] + ) + # number of device axes to shard on for the batch dim + for num_axis in use_how_many_axis: + device_combinations = list(itertools.combinations(device_axis, num_axis)) + # e.g., if num_axis == 2, device_combinations = [(0,1), (0,2), (1,2), + # ...]. Then One feasible strategy is to shard tensor dim on both axis + # (0,1). We check all combinations in device_combinations below. + for comb in device_combinations: + input_specs_list: list[DTensorSpec] = [] + output_specs_list: list[DTensorSpec] = [] + is_shardable = True + for op_stratgy, dim in zip(inputs_strategy, input_shard_dim): + # create a new list of shard_dim_option + new_placements: list[Placement] = [Replicate()] * mesh.ndim + for axis in comb: + new_placements[axis] = Shard(dim) if dim else Replicate() + tensor_meta = op_stratgy.strategies[0].output_spec.tensor_meta + new_input_spec = DTensorSpec( + mesh, + tuple(new_placements), + tensor_meta=op_stratgy.strategies[0].output_spec.tensor_meta, + ) + if not is_tensor_shardable(tensor_meta.shape, new_input_spec): + is_shardable = False + break + input_specs_list.append(new_input_spec) + if not is_shardable: + continue + for dim in output_shard_dim: + new_placements = [Replicate()] * mesh.ndim + for axis in comb: + new_placements[axis] = Shard(dim) if dim else Replicate() + output_spec = DTensorSpec( + mesh, + tuple(new_placements), + ) + output_specs_list.append(output_spec) + + output_specs = ( + output_specs_list[0] + if len(output_specs_list) == 1 + else tuple(output_specs_list) + ) + input_specs = input_specs_list + redistribute_cost = [ + generate_redistribute_costs(strat, input_specs_list[i]) + for i, strat in enumerate(inputs_strategy) + ] + output_strategy.strategies.append( + OpSpec(output_specs, input_specs, redistribute_cost) # type: ignore + ) + return output_strategy + + def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType: global enable_implicit_replication, _current_stack diff --git a/tests/test_dtensor.py b/tests/test_dtensor.py index d44ef6de..cd568634 100644 --- a/tests/test_dtensor.py +++ b/tests/test_dtensor.py @@ -3,10 +3,12 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import functools + import numpy as np import torch from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor import DTensor, Shard, distribute_tensor +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpInfo, @@ -15,6 +17,7 @@ OpStrategy, OutputSharding, OutputSpecType, + RuntimeSchemaInfo, TupleStrategy, ) from torch.testing._internal.common_utils import run_tests @@ -23,7 +26,12 @@ with_comms, ) -from autoparallel.dtensor_util import get_op_strategy, with_implicit_strategies +from autoparallel.dtensor_util import ( + batch_shard_strategy, + get_op_strategy, + op_strategy_context, + with_implicit_strategies, +) propagator = DTensor._op_dispatcher.sharding_propagator @@ -370,5 +378,105 @@ def test_implicit_registration(self): self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) +class DimShardingTest(DTensorTestBase): + @with_comms + def test_simple_batch_sharding(self): + # both input tensors batch on dim 0 + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) + test_op = torch.ops.mylib.numpy_sin.default + + # 1. strategy that will try shard dim 0 into one devices axis. + shard_first_dim_strategy = functools.partial( + batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0] + ) + with op_strategy_context(test_op, shard_first_dim_strategy): + # dim 0 is the batch dim. Here we shard 16 over one device axis + input_x = torch.randn([16, 1, 4], device=self.device_type) + input_y = torch.randn([16, 1, 4], device=self.device_type) + # any sharding below should work + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) + + # 2. strategy that will try shard dim 0 into multiple devices. + shard_first_dim_to_multiple_devices_strategy = functools.partial( + batch_shard_strategy, + input_shard_dim=[0, 0], + output_shard_dim=[0], + enable_shard_batch_dim_over_multiple_axis=True, + ) + with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy): + # dim 0 is the batch dim. Here we potentially shard 16 over multiple device axes + input_x = torch.randn([16, 1, 4], device=self.device_type) + input_y = torch.randn([16, 1, 4], device=self.device_type) + # any sharding below should work + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) + + @with_comms + def test_broadcast_batch_sharding(self): + # Not recommended, user need to make sure the op supports input with + # broadcast first. If not supported, try unsqueeze inputs first to match + # each other's dimensions and and use the example in the + # test_simple_batch_sharding test. + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) + test_op = torch.ops.mylib.numpy_sin.default + shard_on_first_dim_strategy = functools.partial( + batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0] + ) + with op_strategy_context(test_op, shard_on_first_dim_strategy): + input_x = torch.randn([1, 4], device=self.device_type) + # input_y's 16 locates on the batch dim + input_y = torch.randn([16, 1, 4], device=self.device_type) + # any sharding below should work as long as the tensor dim it is shardable + input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()]) + input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)]) + + output_dt = test_op(input_x_dt, input_y_dt) + + # split the batch dim to test correctness + input_y_chucks = torch.chunk(input_y, 4, dim=0) + output = torch.cat( + [test_op(input_x, input_y_part) for input_y_part in input_y_chucks] + ) + self.assertEqual(output_dt.full_tensor(), output) + + # or we can test directly since the op support broadcast. + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt) + + @with_comms + def test_simple_tuple_batch_sharding(self): + # both input tensors batch on dim 0 + mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) + test_op = torch.ops.mylib.numpy_tuple_sin.default + + # 1. strategy that will try shard dim 0 into one devices axis. + shard_first_dim_to_multiple_devices_strategy = functools.partial( + batch_shard_strategy, + input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y + output_shard_dim=[0], + enable_shard_batch_dim_over_multiple_axis=True, + ) + with op_strategy_context( + test_op, + shard_first_dim_to_multiple_devices_strategy, + schema_info=RuntimeSchemaInfo(needs_pytree=True), + ): + # dim 0 is the batch dim. Here we shard 16 over one device axis + input_x = torch.randn([16, 8, 4], device=self.device_type) + input_y = [ + torch.randn([16, 8, 4], device=self.device_type) for _ in range(3) + ] + input_z = torch.randn([16, 8, 4], device=self.device_type) + # any sharding below should work as long as the tensor dim it is shardable + input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)]) + input_y_dt = [ + distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y + ] + input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)]) + self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt) + + if __name__ == "__main__": run_tests()