Skip to content

Commit 2a94f37

Browse files
committed
introduce batch sharding strategy
ghstack-source-id: dfdc089 Pull Request resolved: #50
1 parent f37b5ec commit 2a94f37

File tree

2 files changed

+245
-3
lines changed

2 files changed

+245
-3
lines changed

autoparallel/dtensor_util/utils.py

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,33 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
6+
import itertools
77
import logging
88
from contextlib import contextmanager
9-
from typing import Callable, TypeVar
9+
from typing import Callable, Optional, TypeVar
1010

1111
import torch
1212
from torch.distributed.tensor import DTensor
13+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
1314
from torch.distributed.tensor._op_schema import (
1415
OpSchema,
16+
OpSpec,
17+
OpStrategy,
1518
OutputSharding,
1619
RuntimeSchemaInfo,
1720
StrategyType,
1821
)
22+
from torch.distributed.tensor._ops.utils import (
23+
generate_redistribute_costs,
24+
is_tensor_shardable,
25+
)
26+
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
27+
28+
try:
29+
from torch.utils._cxx_pytree import tree_leaves
30+
except ImportError:
31+
from torch.utils._pytree import tree_leaves # type: ignore[no-redef]
32+
1933
from typing_extensions import ParamSpec
2034

2135
logger = logging.getLogger(__name__)
@@ -30,6 +44,125 @@
3044
replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy
3145

3246

47+
def batch_shard_strategy(
48+
op_schema: OpSchema,
49+
input_shard_dim: list[Optional[int]],
50+
output_shard_dim: list[Optional[int]],
51+
enable_shard_batch_dim_over_multiple_axis: bool = False,
52+
) -> OpStrategy:
53+
"""
54+
Shard the input tensor over the specified dimensions. The strategy will map
55+
batch dim of input/output tensors to the same device mesh axis (or same
56+
multiple device axises). All input must either have one specified batch dim
57+
or no batch dim. If an input doesn't have batch dim, the strategy will
58+
assume the tensor will be broadcasted to batch dim and processed by the
59+
operator. For inputs specified with a batch dim, user need to make sure the
60+
batch dim size are all the same. Output should always have a batch dim.
61+
62+
Args:
63+
op_schema (OpSchema): the op schema.
64+
65+
input_shard_dim (list[Optional[int]]): the list of shard dimensions to
66+
consider for each input tensor argument. Use `None` if no batch dim of
67+
the input arg. If an arg is List[Tenor], we flatten it first and then
68+
match with input_shard_dim. Since the dim is not specific to the device
69+
mesh axis, it can be a combination of any device axises. Example 1:
70+
input tensor A[1024,64,8], B[1024,64,16], with input_shard_dim = [1,1],
71+
it can shard A's dim 0 over device axis X, and shard B's dim 0 over
72+
device axis X. X can be any of device axises. The output follow the same
73+
sharding as input. Example 2: input tensor A[64,8], B[64,16,1024],
74+
C[64,8], with input_shard_dim = [None,2,None], it will Replicate A,C
75+
over all device dim and only shard B's dim 2 over the device mesh.
76+
Assume the device mesh has 3 axis, then tensor B's placement can be
77+
(Shard(2), Shard(2), Replicate()), (Shard(2), Replicate(), Shard(2)),
78+
(Replicate(), Shard(2), Shard(2)).
79+
80+
output_shard_dim (list[Optional[int]]): the list of shard dimensions to
81+
consider for each output tensor argument. Use `None` if no batch dim of
82+
the output arg. For example, if the output is a single tensor and is
83+
sharded on dim 0, pass in [0] then.
84+
85+
enable_shard_batch_dim_over_multiple_axis (bool): if True, the strategy
86+
will try also map batch dim to multiple device axis. Default is False.
87+
88+
Note: It is the user's responsibility to make sure the sharded tensor for
89+
processing is correct in shape.
90+
"""
91+
output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
92+
# TODO(zpcore): Confirm if view op can be handle properly or not. Prevent
93+
# handling view ops until confirmed.
94+
if op_schema.op.is_view:
95+
raise RuntimeError(
96+
"fallback strategy is unable to handle view ops until confirmed"
97+
)
98+
if "List[Tensor]" in output_type:
99+
raise RuntimeError(
100+
"fallback strategy is unable to handle ops with List[Tensor] output "
101+
"because size of the list may depend on the op's input value"
102+
)
103+
inputs_strategy = tree_leaves(op_schema.args_strategy)
104+
assert len(inputs_strategy) == len(input_shard_dim)
105+
output_strategy = OpStrategy([])
106+
mesh = inputs_strategy[0].mesh
107+
device_axis = list(range(mesh.ndim))
108+
use_how_many_axis = (
109+
[i + 1 for i in range(mesh.ndim)]
110+
if enable_shard_batch_dim_over_multiple_axis
111+
else [1]
112+
)
113+
# number of device axises to shard on for the batch dim
114+
for num_axis in use_how_many_axis:
115+
device_combinations = list(itertools.combinations(device_axis, num_axis))
116+
# e.g., if num_axis == 2, device_combinations = [(0,1), (0,2), (1,2),
117+
# ...]. Then One feasible strategy is to shard tensor dim on both axis
118+
# (0,1). We check all combinations in device_combinations below.
119+
for comb in device_combinations:
120+
input_specs_list: list[DTensorSpec] = []
121+
output_specs_list: list[DTensorSpec] = []
122+
is_shardable = True
123+
for op_stratgy, dim in zip(inputs_strategy, input_shard_dim):
124+
# create a new list of shard_dim_option
125+
new_placements: list[Placement] = [Replicate()] * mesh.ndim
126+
for axis in comb:
127+
new_placements[axis] = Shard(dim) if dim else Replicate()
128+
tensor_meta = op_stratgy.strategies[0].output_spec.tensor_meta
129+
new_input_spec = DTensorSpec(
130+
mesh,
131+
tuple(new_placements),
132+
tensor_meta=op_stratgy.strategies[0].output_spec.tensor_meta,
133+
)
134+
if not is_tensor_shardable(tensor_meta.shape, new_input_spec):
135+
is_shardable = False
136+
break
137+
input_specs_list.append(new_input_spec)
138+
if not is_shardable:
139+
continue
140+
for dim in output_shard_dim:
141+
new_placements = [Replicate()] * mesh.ndim
142+
for axis in comb:
143+
new_placements[axis] = Shard(dim) if dim else Replicate()
144+
output_spec = DTensorSpec(
145+
mesh,
146+
tuple(new_placements),
147+
)
148+
output_specs_list.append(output_spec)
149+
150+
output_specs = (
151+
output_specs_list[0]
152+
if len(output_specs_list) == 1
153+
else tuple(output_specs_list)
154+
)
155+
input_specs = input_specs_list
156+
redistribute_cost = [
157+
generate_redistribute_costs(strat, input_specs_list[i])
158+
for i, strat in enumerate(inputs_strategy)
159+
]
160+
output_strategy.strategies.append(
161+
OpSpec(output_specs, input_specs, redistribute_cost) # type: ignore
162+
)
163+
return output_strategy
164+
165+
33166
class StrategyPool:
34167
def __init__(self) -> None:
35168
# reference to existing strategy from the DTensor upstream

tests/test_dtensor.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,18 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import functools
67
from contextlib import contextmanager
78

89
import numpy as np
910
import torch
10-
from torch.distributed.tensor import DTensor, Shard, distribute_tensor, init_device_mesh
11+
from torch.distributed.tensor import (
12+
DTensor,
13+
Replicate,
14+
Shard,
15+
distribute_tensor,
16+
init_device_mesh,
17+
)
1118
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1219
from torch.distributed.tensor._op_schema import (
1320
OpInfo,
@@ -16,6 +23,7 @@
1623
OpStrategy,
1724
OutputSharding,
1825
OutputSpecType,
26+
RuntimeSchemaInfo,
1927
TupleStrategy,
2028
)
2129
from torch.testing._internal.common_utils import run_tests
@@ -25,6 +33,7 @@
2533
)
2634

2735
from autoparallel.dtensor_util import strategy_pool
36+
from autoparallel.dtensor_util.utils import batch_shard_strategy
2837

2938
aten = torch.ops.aten
3039

@@ -394,5 +403,105 @@ def test_implicit_registration(self):
394403
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
395404

396405

406+
class DimShardingTest(DTensorTestBase):
407+
@with_comms
408+
def test_simple_batch_sharding(self):
409+
# both input tensors batch on dim 0
410+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
411+
test_op = torch.ops.mylib.numpy_sin.default
412+
413+
# 1. strategy that will try shard dim 0 into one devices axis.
414+
shard_first_dim_strategy = functools.partial(
415+
batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0]
416+
)
417+
with op_strategy_context(test_op, shard_first_dim_strategy):
418+
# dim 0 is the batch dim. Here we shard 16 over one device axis
419+
input_x = torch.randn([16, 1, 4], device=self.device_type)
420+
input_y = torch.randn([16, 1, 4], device=self.device_type)
421+
# any sharding below should work
422+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
423+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
424+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
425+
426+
# 2. strategy that will try shard dim 0 into multiple devices.
427+
shard_first_dim_to_multiple_devices_strategy = functools.partial(
428+
batch_shard_strategy,
429+
input_shard_dim=[0, 0],
430+
output_shard_dim=[0],
431+
enable_shard_batch_dim_over_multiple_axis=True,
432+
)
433+
with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy):
434+
# dim 0 is the batch dim. Here we potentially shard 16 over multiple device axises
435+
input_x = torch.randn([16, 1, 4], device=self.device_type)
436+
input_y = torch.randn([16, 1, 4], device=self.device_type)
437+
# any sharding below should work
438+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
439+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
440+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
441+
442+
@with_comms
443+
def test_broadcast_batch_sharding(self):
444+
# Not recommended, user need to make sure the op supports input with
445+
# broadcast first. If not supported, try unsqueeze inputs first to match
446+
# each other's dimensions and and use the example in the
447+
# test_simple_batch_sharding test.
448+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
449+
test_op = torch.ops.mylib.numpy_sin.default
450+
shard_on_first_dim_strategy = functools.partial(
451+
batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]
452+
)
453+
with op_strategy_context(test_op, shard_on_first_dim_strategy):
454+
input_x = torch.randn([1, 4], device=self.device_type)
455+
# input_y's 16 locates on the batch dim
456+
input_y = torch.randn([16, 1, 4], device=self.device_type)
457+
# any sharding below should work as long as the tensor dim it is shardable
458+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
459+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
460+
461+
output_dt = test_op(input_x_dt, input_y_dt)
462+
463+
# split the batch dim to test correctness
464+
input_y_chucks = torch.chunk(input_y, 4, dim=0)
465+
output = torch.cat(
466+
[test_op(input_x, input_y_part) for input_y_part in input_y_chucks]
467+
)
468+
self.assertEqual(output_dt.full_tensor(), output)
469+
470+
# or we can test directly since the op support broadcast.
471+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
472+
473+
@with_comms
474+
def test_simple_tuple_batch_sharding(self):
475+
# both input tensors batch on dim 0
476+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
477+
test_op = torch.ops.mylib.numpy_tuple_sin.default
478+
479+
# 1. strategy that will try shard dim 0 into one devices axis.
480+
shard_first_dim_to_multiple_devices_strategy = functools.partial(
481+
batch_shard_strategy,
482+
input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y
483+
output_shard_dim=[0],
484+
enable_shard_batch_dim_over_multiple_axis=True,
485+
)
486+
with op_strategy_context(
487+
test_op,
488+
shard_first_dim_to_multiple_devices_strategy,
489+
schema_info=RuntimeSchemaInfo(needs_pytree=True),
490+
):
491+
# dim 0 is the batch dim. Here we shard 16 over one device axis
492+
input_x = torch.randn([16, 8, 4], device=self.device_type)
493+
input_y = [
494+
torch.randn([16, 8, 4], device=self.device_type) for _ in range(3)
495+
]
496+
input_z = torch.randn([16, 8, 4], device=self.device_type)
497+
# any sharding below should work as long as the tensor dim it is shardable
498+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)])
499+
input_y_dt = [
500+
distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y
501+
]
502+
input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)])
503+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt)
504+
505+
397506
if __name__ == "__main__":
398507
run_tests()

0 commit comments

Comments
 (0)