Skip to content

Commit fbec8ab

Browse files
committed
introduce batch sharding strategy
ghstack-source-id: 6753840 Pull Request resolved: #50
1 parent f6a0212 commit fbec8ab

File tree

2 files changed

+251
-31
lines changed

2 files changed

+251
-31
lines changed

autoparallel/dtensor_util/utils.py

Lines changed: 141 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,33 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
6+
import itertools
77
import logging
88
from contextlib import ExitStack, contextmanager
9-
from typing import Callable, TypeVar
9+
from typing import Callable, Optional, TypeVar
1010

1111
import torch
1212
from torch.distributed.tensor import DTensor
13-
from torch.distributed.tensor._op_schema import OpSchema, OutputSharding, StrategyType
14-
from torch.distributed.tensor._ops.utils import register_op_strategy
13+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
14+
from torch.distributed.tensor._op_schema import (
15+
OpSchema,
16+
OpSpec,
17+
OpStrategy,
18+
OutputSharding,
19+
StrategyType,
20+
)
21+
from torch.distributed.tensor._ops.utils import (
22+
generate_redistribute_costs,
23+
is_tensor_shardable,
24+
register_op_strategy,
25+
)
26+
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
27+
28+
try:
29+
from torch.utils._cxx_pytree import tree_leaves
30+
except ImportError:
31+
from torch.utils._pytree import tree_leaves # type: ignore[no-redef]
32+
1533
from typing_extensions import ParamSpec
1634

1735
logger = logging.getLogger(__name__)
@@ -54,6 +72,125 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
5472
replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy
5573

5674

75+
def batch_shard_strategy(
76+
op_schema: OpSchema,
77+
input_shard_dim: list[Optional[int]],
78+
output_shard_dim: list[Optional[int]],
79+
enable_shard_batch_dim_over_multiple_axis: bool = False,
80+
) -> OpStrategy:
81+
"""
82+
Shard the input tensor over the specified dimensions. The strategy will map
83+
batch dim of input/output tensors to the same device mesh axis (or same
84+
multiple device axes). All input must either have one specified batch dim or
85+
no batch dim. If an input doesn't have batch dim, the strategy will assume
86+
the tensor will be broadcasted to batch dim and processed by the operator.
87+
For inputs specified with a batch dim, user need to make sure the batch dim
88+
size are all the same. Output should always have a batch dim.
89+
90+
Args:
91+
op_schema (OpSchema): the op schema.
92+
93+
input_shard_dim (list[Optional[int]]): the list of shard dimensions to
94+
consider for each input tensor argument. Use `None` if no batch dim of
95+
the input arg. If an arg is List[Tenor], we flatten it first and then
96+
match with input_shard_dim. Since the dim is not specific to the device
97+
mesh axis, it can be a combination of any device axes. Example 1: input
98+
tensor A[1024,64,8], B[1024,64,16], with input_shard_dim = [1,1], it can
99+
shard A's dim 0 over device axis X, and shard B's dim 0 over device axis
100+
X. X can be any of device axes. The output follow the same sharding as
101+
input. Example 2: input tensor A[64,8], B[64,16,1024], C[64,8], with
102+
input_shard_dim = [None,2,None], it will Replicate A,C over all device
103+
dim and only shard B's dim 2 over the device mesh. Assume the device
104+
mesh has 3 axis, then tensor B's placement can be (Shard(2), Shard(2),
105+
Replicate()), (Shard(2), Replicate(), Shard(2)), (Replicate(), Shard(2),
106+
Shard(2)).
107+
108+
output_shard_dim (list[Optional[int]]): the list of shard dimensions to
109+
consider for each output tensor argument. Use `None` if no batch dim of
110+
the output arg. For example, if the output is a single tensor and is
111+
sharded on dim 0, pass in [0] then.
112+
113+
enable_shard_batch_dim_over_multiple_axis (bool): if True, the strategy
114+
will try also map batch dim to multiple device axis. Default is False.
115+
116+
Note: It is the user's responsibility to make sure the sharded tensor for
117+
processing is correct in shape.
118+
"""
119+
output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
120+
# TODO(zpcore): Confirm if view op can be handle properly or not. Prevent
121+
# handling view ops until confirmed.
122+
if op_schema.op.is_view:
123+
raise RuntimeError(
124+
"fallback strategy is unable to handle view ops until confirmed"
125+
)
126+
if "List[Tensor]" in output_type:
127+
raise RuntimeError(
128+
"fallback strategy is unable to handle ops with List[Tensor] output "
129+
"because size of the list may depend on the op's input value"
130+
)
131+
inputs_strategy = tree_leaves(op_schema.args_strategy)
132+
assert len(inputs_strategy) == len(input_shard_dim)
133+
output_strategy = OpStrategy([])
134+
mesh = inputs_strategy[0].mesh
135+
device_axis = list(range(mesh.ndim))
136+
use_how_many_axis = (
137+
[i + 1 for i in range(mesh.ndim)]
138+
if enable_shard_batch_dim_over_multiple_axis
139+
else [1]
140+
)
141+
# number of device axes to shard on for the batch dim
142+
for num_axis in use_how_many_axis:
143+
device_combinations = list(itertools.combinations(device_axis, num_axis))
144+
# e.g., if num_axis == 2, device_combinations = [(0,1), (0,2), (1,2),
145+
# ...]. Then One feasible strategy is to shard tensor dim on both axis
146+
# (0,1). We check all combinations in device_combinations below.
147+
for comb in device_combinations:
148+
input_specs_list: list[DTensorSpec] = []
149+
output_specs_list: list[DTensorSpec] = []
150+
is_shardable = True
151+
for op_stratgy, dim in zip(inputs_strategy, input_shard_dim):
152+
# create a new list of shard_dim_option
153+
new_placements: list[Placement] = [Replicate()] * mesh.ndim
154+
for axis in comb:
155+
new_placements[axis] = Shard(dim) if dim else Replicate()
156+
tensor_meta = op_stratgy.strategies[0].output_spec.tensor_meta
157+
new_input_spec = DTensorSpec(
158+
mesh,
159+
tuple(new_placements),
160+
tensor_meta=op_stratgy.strategies[0].output_spec.tensor_meta,
161+
)
162+
if not is_tensor_shardable(tensor_meta.shape, new_input_spec):
163+
is_shardable = False
164+
break
165+
input_specs_list.append(new_input_spec)
166+
if not is_shardable:
167+
continue
168+
for dim in output_shard_dim:
169+
new_placements = [Replicate()] * mesh.ndim
170+
for axis in comb:
171+
new_placements[axis] = Shard(dim) if dim else Replicate()
172+
output_spec = DTensorSpec(
173+
mesh,
174+
tuple(new_placements),
175+
)
176+
output_specs_list.append(output_spec)
177+
178+
output_specs = (
179+
output_specs_list[0]
180+
if len(output_specs_list) == 1
181+
else tuple(output_specs_list)
182+
)
183+
input_specs = input_specs_list
184+
redistribute_cost = [
185+
generate_redistribute_costs(strat, input_specs_list[i])
186+
for i, strat in enumerate(inputs_strategy)
187+
]
188+
output_strategy.strategies.append(
189+
OpSpec(output_specs, input_specs, redistribute_cost) # type: ignore
190+
)
191+
return output_strategy
192+
193+
57194
class StrategyPool:
58195
def __init__(self) -> None:
59196
# reference to existing strategy from the DTensor upstream

tests/test_dtensor.py

Lines changed: 110 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from contextlib import contextmanager
6+
import functools
77

88
import numpy as np
99
import torch
10-
from torch.distributed.tensor import DTensor, Shard, distribute_tensor, init_device_mesh
10+
from torch.distributed.tensor import (
11+
DTensor,
12+
Replicate,
13+
Shard,
14+
distribute_tensor,
15+
init_device_mesh,
16+
)
1117
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1218
from torch.distributed.tensor._op_schema import (
1319
OpInfo,
@@ -16,6 +22,7 @@
1622
OpStrategy,
1723
OutputSharding,
1824
OutputSpecType,
25+
RuntimeSchemaInfo,
1926
TupleStrategy,
2027
)
2128
from torch.testing._internal.common_utils import run_tests
@@ -25,6 +32,7 @@
2532
)
2633

2734
from autoparallel.dtensor_util import strategy_pool
35+
from autoparallel.dtensor_util.utils import batch_shard_strategy, op_strategy_context
2836

2937
aten = torch.ops.aten
3038

@@ -94,31 +102,6 @@ def _fw_tuple(x, y, z):
94102
)
95103

96104

97-
@contextmanager
98-
def op_strategy_context(op_overload, strategy_func, schema_info=None):
99-
"""
100-
Context manager for setting and clearing op strategies in unit tests.
101-
Args:
102-
op_overload: The operator overload to set or clear the strategy for.
103-
strategy_func: The strategy function to set for the operator overload.
104-
schema_info: Optional schema information for the operator overload.
105-
Yields:
106-
None
107-
"""
108-
try:
109-
# register the op strategy
110-
strategy_pool.register_op_strategy(op_overload, schema_info=schema_info)(
111-
strategy_func
112-
)
113-
yield
114-
finally:
115-
# clear this op strategy cache
116-
if op_overload in strategy_pool.op_strategy_funcs:
117-
del strategy_pool.op_strategy_funcs[op_overload]
118-
if op_overload in strategy_pool.op_to_schema_info:
119-
del strategy_pool.op_to_schema_info[op_overload]
120-
121-
122105
# Overwrite upstream `_op_dispatcher.sharding_propagator` with customized
123106
# sharding_propagator. This is for testing purpose under eager mode and
124107
# AutoParallel won't use the propagate function. The main changes are 1) Skip
@@ -394,5 +377,105 @@ def test_implicit_registration(self):
394377
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
395378

396379

380+
class DimShardingTest(DTensorTestBase):
381+
@with_comms
382+
def test_simple_batch_sharding(self):
383+
# both input tensors batch on dim 0
384+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
385+
test_op = torch.ops.mylib.numpy_sin.default
386+
387+
# 1. strategy that will try shard dim 0 into one devices axis.
388+
shard_first_dim_strategy = functools.partial(
389+
batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0]
390+
)
391+
with op_strategy_context(test_op, shard_first_dim_strategy):
392+
# dim 0 is the batch dim. Here we shard 16 over one device axis
393+
input_x = torch.randn([16, 1, 4], device=self.device_type)
394+
input_y = torch.randn([16, 1, 4], device=self.device_type)
395+
# any sharding below should work
396+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
397+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
398+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
399+
400+
# 2. strategy that will try shard dim 0 into multiple devices.
401+
shard_first_dim_to_multiple_devices_strategy = functools.partial(
402+
batch_shard_strategy,
403+
input_shard_dim=[0, 0],
404+
output_shard_dim=[0],
405+
enable_shard_batch_dim_over_multiple_axis=True,
406+
)
407+
with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy):
408+
# dim 0 is the batch dim. Here we potentially shard 16 over multiple device axes
409+
input_x = torch.randn([16, 1, 4], device=self.device_type)
410+
input_y = torch.randn([16, 1, 4], device=self.device_type)
411+
# any sharding below should work
412+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
413+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
414+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
415+
416+
@with_comms
417+
def test_broadcast_batch_sharding(self):
418+
# Not recommended, user need to make sure the op supports input with
419+
# broadcast first. If not supported, try unsqueeze inputs first to match
420+
# each other's dimensions and and use the example in the
421+
# test_simple_batch_sharding test.
422+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
423+
test_op = torch.ops.mylib.numpy_sin.default
424+
shard_on_first_dim_strategy = functools.partial(
425+
batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]
426+
)
427+
with op_strategy_context(test_op, shard_on_first_dim_strategy):
428+
input_x = torch.randn([1, 4], device=self.device_type)
429+
# input_y's 16 locates on the batch dim
430+
input_y = torch.randn([16, 1, 4], device=self.device_type)
431+
# any sharding below should work as long as the tensor dim it is shardable
432+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
433+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
434+
435+
output_dt = test_op(input_x_dt, input_y_dt)
436+
437+
# split the batch dim to test correctness
438+
input_y_chucks = torch.chunk(input_y, 4, dim=0)
439+
output = torch.cat(
440+
[test_op(input_x, input_y_part) for input_y_part in input_y_chucks]
441+
)
442+
self.assertEqual(output_dt.full_tensor(), output)
443+
444+
# or we can test directly since the op support broadcast.
445+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
446+
447+
@with_comms
448+
def test_simple_tuple_batch_sharding(self):
449+
# both input tensors batch on dim 0
450+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
451+
test_op = torch.ops.mylib.numpy_tuple_sin.default
452+
453+
# 1. strategy that will try shard dim 0 into one devices axis.
454+
shard_first_dim_to_multiple_devices_strategy = functools.partial(
455+
batch_shard_strategy,
456+
input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y
457+
output_shard_dim=[0],
458+
enable_shard_batch_dim_over_multiple_axis=True,
459+
)
460+
with op_strategy_context(
461+
test_op,
462+
shard_first_dim_to_multiple_devices_strategy,
463+
schema_info=RuntimeSchemaInfo(needs_pytree=True),
464+
):
465+
# dim 0 is the batch dim. Here we shard 16 over one device axis
466+
input_x = torch.randn([16, 8, 4], device=self.device_type)
467+
input_y = [
468+
torch.randn([16, 8, 4], device=self.device_type) for _ in range(3)
469+
]
470+
input_z = torch.randn([16, 8, 4], device=self.device_type)
471+
# any sharding below should work as long as the tensor dim it is shardable
472+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)])
473+
input_y_dt = [
474+
distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y
475+
]
476+
input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)])
477+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt)
478+
479+
397480
if __name__ == "__main__":
398481
run_tests()

0 commit comments

Comments
 (0)