Skip to content

Commit cc28a95

Browse files
committed
Introduce the batch sharding strategy
1 parent 7ed4feb commit cc28a95

File tree

3 files changed

+271
-8
lines changed

3 files changed

+271
-8
lines changed

autoparallel/dtensor_util/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# functions to expose
77
from .utils import (
8+
batch_shard_strategy,
89
get_op_strategy,
910
op_strategy_context,
1011
replicate_op_strategy,
@@ -13,6 +14,7 @@
1314

1415
__all__ = [
1516
"replicate_op_strategy",
17+
"batch_shard_strategy",
1618
"get_op_strategy",
1719
"with_implicit_strategies",
1820
"op_strategy_context",

autoparallel/dtensor_util/utils.py

Lines changed: 159 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,32 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import itertools
67
import logging
78
from contextlib import ExitStack, contextmanager
9+
from typing import Optional
810

911
import torch
1012
from torch.distributed.tensor import DTensor
11-
from torch.distributed.tensor._op_schema import OpSchema, StrategyType
12-
from torch.distributed.tensor._ops.utils import register_op_strategy
13+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
14+
from torch.distributed.tensor._op_schema import (
15+
OpSchema,
16+
OpSpec,
17+
OpStrategy,
18+
StrategyType,
19+
)
20+
from torch.distributed.tensor._ops.utils import (
21+
generate_redistribute_costs,
22+
is_tensor_shardable,
23+
register_op_strategy,
24+
)
25+
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
26+
27+
try:
28+
from torch.utils._cxx_pytree import tree_leaves
29+
except ImportError:
30+
from torch.utils._pytree import tree_leaves # type: ignore[no-redef]
31+
1332

1433
logger = logging.getLogger(__name__)
1534

@@ -39,19 +58,153 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
3958
None
4059
"""
4160
propagator = DTensor._op_dispatcher.sharding_propagator
61+
_origin_op_strategy_funcs = None
62+
_origin_op_strategy_schema = None
4263
try:
4364
# register the op strategy
44-
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
45-
yield
46-
finally:
47-
# clear this op strategy cache
4865
if op_overload in propagator.op_strategy_funcs:
66+
_origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload]
4967
del propagator.op_strategy_funcs[op_overload]
5068
if op_overload in propagator.op_to_schema_info:
69+
_origin_op_strategy_schema = propagator.op_to_schema_info[op_overload]
5170
del propagator.op_to_schema_info[op_overload]
71+
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
72+
yield
73+
finally:
74+
# clear this op strategy cache
75+
if _origin_op_strategy_funcs is None:
76+
if op_overload in propagator.op_strategy_funcs:
77+
del propagator.op_strategy_funcs[op_overload]
78+
else:
79+
propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs
80+
if _origin_op_strategy_schema is None:
81+
if op_overload in propagator.op_to_schema_info:
82+
del propagator.op_to_schema_info[op_overload]
83+
else:
84+
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
5285
propagator.propagate_op_sharding.cache.cache_clear()
5386

5487

88+
# -------------define universal op strategy-------------
89+
def batch_shard_strategy(
90+
op_schema: OpSchema,
91+
input_shard_dim: list[Optional[int]],
92+
output_shard_dim: list[Optional[int]],
93+
enable_shard_batch_dim_over_multiple_axis: bool = False,
94+
) -> OpStrategy:
95+
"""
96+
Shard the input tensor over the specified dimensions. The strategy will map
97+
batch dim of input/output tensors to the same device mesh axis (or same
98+
multiple device axes). All input must either have one specified batch dim or
99+
no batch dim. If an input doesn't have batch dim, the strategy will assume
100+
the tensor will be broadcasted to batch dim and processed by the operator.
101+
For inputs specified with a batch dim, user need to make sure the batch dim
102+
size are all the same. Output should always have a batch dim.
103+
104+
Args:
105+
op_schema (OpSchema): the op schema.
106+
107+
input_shard_dim (list[Optional[int]]): the list of shard dimensions to
108+
consider for each input tensor argument. Use `None` if no batch dim of
109+
the input arg. If an arg is List[Tenor], we flatten it first and then
110+
match with input_shard_dim. Since the dim is not specific to the device
111+
mesh axis, it can be a combination of any device axes. Example 1: input
112+
tensor A[1024,64,8], B[1024,64,16], with input_shard_dim = [1,1], it can
113+
shard A's dim 0 over device axis X, and shard B's dim 0 over device axis
114+
X. X can be any of device axes. The output follow the same sharding as
115+
input. Example 2: input tensor A[64,8], B[64,16,1024], C[64,8], with
116+
input_shard_dim = [None,2,None], it will Replicate A,C over all device
117+
dim and only shard B's dim 2 over the device mesh. Assume the device
118+
mesh has 3 axis, then tensor B's placement can be (Shard(2), Shard(2),
119+
Replicate()), (Shard(2), Replicate(), Shard(2)), (Replicate(), Shard(2),
120+
Shard(2)).
121+
122+
output_shard_dim (list[Optional[int]]): the list of shard dimensions to
123+
consider for each output tensor argument. Use `None` if no batch dim of
124+
the output arg. For example, if the output is a single tensor and is
125+
sharded on dim 0, pass in [0] then.
126+
127+
enable_shard_batch_dim_over_multiple_axis (bool): if True, the strategy
128+
will try also map batch dim to multiple device axis. Default is False.
129+
130+
Note: It is the user's responsibility to make sure the sharded tensor for
131+
processing is correct in shape.
132+
"""
133+
output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
134+
# TODO(zpcore): Confirm if view op can be handle properly or not. Prevent
135+
# handling view ops until confirmed.
136+
if op_schema.op.is_view:
137+
raise RuntimeError(
138+
"fallback strategy is unable to handle view ops until confirmed"
139+
)
140+
if "List[Tensor]" in output_type:
141+
raise RuntimeError(
142+
"fallback strategy is unable to handle ops with List[Tensor] output "
143+
"because size of the list may depend on the op's input value"
144+
)
145+
inputs_strategy = tree_leaves(op_schema.args_strategy)
146+
assert len(inputs_strategy) == len(input_shard_dim)
147+
output_strategy = OpStrategy([])
148+
mesh = inputs_strategy[0].mesh
149+
device_axis = list(range(mesh.ndim))
150+
use_how_many_axis = (
151+
[i + 1 for i in range(mesh.ndim)]
152+
if enable_shard_batch_dim_over_multiple_axis
153+
else [1]
154+
)
155+
# number of device axes to shard on for the batch dim
156+
for num_axis in use_how_many_axis:
157+
device_combinations = list(itertools.combinations(device_axis, num_axis))
158+
# e.g., if num_axis == 2, device_combinations = [(0,1), (0,2), (1,2),
159+
# ...]. Then One feasible strategy is to shard tensor dim on both axis
160+
# (0,1). We check all combinations in device_combinations below.
161+
for comb in device_combinations:
162+
input_specs_list: list[DTensorSpec] = []
163+
output_specs_list: list[DTensorSpec] = []
164+
is_shardable = True
165+
for op_stratgy, dim in zip(inputs_strategy, input_shard_dim):
166+
# create a new list of shard_dim_option
167+
new_placements: list[Placement] = [Replicate()] * mesh.ndim
168+
for axis in comb:
169+
new_placements[axis] = Shard(dim) if dim else Replicate()
170+
tensor_meta = op_stratgy.strategies[0].output_spec.tensor_meta
171+
new_input_spec = DTensorSpec(
172+
mesh,
173+
tuple(new_placements),
174+
tensor_meta=op_stratgy.strategies[0].output_spec.tensor_meta,
175+
)
176+
if not is_tensor_shardable(tensor_meta.shape, new_input_spec):
177+
is_shardable = False
178+
break
179+
input_specs_list.append(new_input_spec)
180+
if not is_shardable:
181+
continue
182+
for dim in output_shard_dim:
183+
new_placements = [Replicate()] * mesh.ndim
184+
for axis in comb:
185+
new_placements[axis] = Shard(dim) if dim else Replicate()
186+
output_spec = DTensorSpec(
187+
mesh,
188+
tuple(new_placements),
189+
)
190+
output_specs_list.append(output_spec)
191+
192+
output_specs = (
193+
output_specs_list[0]
194+
if len(output_specs_list) == 1
195+
else tuple(output_specs_list)
196+
)
197+
input_specs = input_specs_list
198+
redistribute_cost = [
199+
generate_redistribute_costs(strat, input_specs_list[i])
200+
for i, strat in enumerate(inputs_strategy)
201+
]
202+
output_strategy.strategies.append(
203+
OpSpec(output_specs, input_specs, redistribute_cost) # type: ignore
204+
)
205+
return output_strategy
206+
207+
55208
def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType:
56209
global enable_implicit_replication, _current_stack
57210

tests/test_dtensor.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import functools
7+
68
import numpy as np
79
import torch
810
from torch.distributed.device_mesh import init_device_mesh
9-
from torch.distributed.tensor import DTensor, Shard, distribute_tensor
11+
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
1012
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1113
from torch.distributed.tensor._op_schema import (
1214
OpInfo,
@@ -15,6 +17,7 @@
1517
OpStrategy,
1618
OutputSharding,
1719
OutputSpecType,
20+
RuntimeSchemaInfo,
1821
TupleStrategy,
1922
)
2023
from torch.testing._internal.common_utils import run_tests
@@ -23,7 +26,12 @@
2326
with_comms,
2427
)
2528

26-
from autoparallel.dtensor_util import get_op_strategy, with_implicit_strategies
29+
from autoparallel.dtensor_util import (
30+
batch_shard_strategy,
31+
get_op_strategy,
32+
op_strategy_context,
33+
with_implicit_strategies,
34+
)
2735

2836
propagator = DTensor._op_dispatcher.sharding_propagator
2937

@@ -370,5 +378,105 @@ def test_implicit_registration(self):
370378
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
371379

372380

381+
class DimShardingTest(DTensorTestBase):
382+
@with_comms
383+
def test_simple_batch_sharding(self):
384+
# both input tensors batch on dim 0
385+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
386+
test_op = torch.ops.mylib.numpy_sin.default
387+
388+
# 1. strategy that will try shard dim 0 into one devices axis.
389+
shard_first_dim_strategy = functools.partial(
390+
batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0]
391+
)
392+
with op_strategy_context(test_op, shard_first_dim_strategy):
393+
# dim 0 is the batch dim. Here we shard 16 over one device axis
394+
input_x = torch.randn([16, 1, 4], device=self.device_type)
395+
input_y = torch.randn([16, 1, 4], device=self.device_type)
396+
# any sharding below should work
397+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
398+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
399+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
400+
401+
# 2. strategy that will try shard dim 0 into multiple devices.
402+
shard_first_dim_to_multiple_devices_strategy = functools.partial(
403+
batch_shard_strategy,
404+
input_shard_dim=[0, 0],
405+
output_shard_dim=[0],
406+
enable_shard_batch_dim_over_multiple_axis=True,
407+
)
408+
with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy):
409+
# dim 0 is the batch dim. Here we potentially shard 16 over multiple device axes
410+
input_x = torch.randn([16, 1, 4], device=self.device_type)
411+
input_y = torch.randn([16, 1, 4], device=self.device_type)
412+
# any sharding below should work
413+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
414+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
415+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
416+
417+
@with_comms
418+
def test_broadcast_batch_sharding(self):
419+
# Not recommended, user need to make sure the op supports input with
420+
# broadcast first. If not supported, try unsqueeze inputs first to match
421+
# each other's dimensions and and use the example in the
422+
# test_simple_batch_sharding test.
423+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
424+
test_op = torch.ops.mylib.numpy_sin.default
425+
shard_on_first_dim_strategy = functools.partial(
426+
batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]
427+
)
428+
with op_strategy_context(test_op, shard_on_first_dim_strategy):
429+
input_x = torch.randn([1, 4], device=self.device_type)
430+
# input_y's 16 locates on the batch dim
431+
input_y = torch.randn([16, 1, 4], device=self.device_type)
432+
# any sharding below should work as long as the tensor dim it is shardable
433+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
434+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
435+
436+
output_dt = test_op(input_x_dt, input_y_dt)
437+
438+
# split the batch dim to test correctness
439+
input_y_chucks = torch.chunk(input_y, 4, dim=0)
440+
output = torch.cat(
441+
[test_op(input_x, input_y_part) for input_y_part in input_y_chucks]
442+
)
443+
self.assertEqual(output_dt.full_tensor(), output)
444+
445+
# or we can test directly since the op support broadcast.
446+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
447+
448+
@with_comms
449+
def test_simple_tuple_batch_sharding(self):
450+
# both input tensors batch on dim 0
451+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
452+
test_op = torch.ops.mylib.numpy_tuple_sin.default
453+
454+
# 1. strategy that will try shard dim 0 into one devices axis.
455+
shard_first_dim_to_multiple_devices_strategy = functools.partial(
456+
batch_shard_strategy,
457+
input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y
458+
output_shard_dim=[0],
459+
enable_shard_batch_dim_over_multiple_axis=True,
460+
)
461+
with op_strategy_context(
462+
test_op,
463+
shard_first_dim_to_multiple_devices_strategy,
464+
schema_info=RuntimeSchemaInfo(needs_pytree=True),
465+
):
466+
# dim 0 is the batch dim. Here we shard 16 over one device axis
467+
input_x = torch.randn([16, 8, 4], device=self.device_type)
468+
input_y = [
469+
torch.randn([16, 8, 4], device=self.device_type) for _ in range(3)
470+
]
471+
input_z = torch.randn([16, 8, 4], device=self.device_type)
472+
# any sharding below should work as long as the tensor dim it is shardable
473+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)])
474+
input_y_dt = [
475+
distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y
476+
]
477+
input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)])
478+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt)
479+
480+
373481
if __name__ == "__main__":
374482
run_tests()

0 commit comments

Comments
 (0)