Skip to content

Commit ec2ea25

Browse files
committed
introduce batch sharding strategy
ghstack-source-id: 604f1b6 Pull Request resolved: #50
1 parent 84fd139 commit ec2ea25

File tree

3 files changed

+253
-4
lines changed

3 files changed

+253
-4
lines changed

autoparallel/dtensor_util/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# functions to expose
77
from .utils import (
8+
batch_shard_strategy,
89
get_op_strategy,
910
op_strategy_context,
1011
replicate_op_strategy,
@@ -13,6 +14,7 @@
1314

1415
__all__ = [
1516
"replicate_op_strategy",
17+
"batch_shard_strategy",
1618
"get_op_strategy",
1719
"with_implicit_strategies",
1820
"op_strategy_context",

autoparallel/dtensor_util/utils.py

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,32 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import itertools
67
import logging
78
from contextlib import ExitStack, contextmanager
9+
from typing import Optional
810

911
import torch
1012
from torch.distributed.tensor import DTensor
11-
from torch.distributed.tensor._op_schema import OpSchema, StrategyType
12-
from torch.distributed.tensor._ops.utils import register_op_strategy
13+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
14+
from torch.distributed.tensor._op_schema import (
15+
OpSchema,
16+
OpSpec,
17+
OpStrategy,
18+
StrategyType,
19+
)
20+
from torch.distributed.tensor._ops.utils import (
21+
generate_redistribute_costs,
22+
is_tensor_shardable,
23+
register_op_strategy,
24+
)
25+
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
26+
27+
try:
28+
from torch.utils._cxx_pytree import tree_leaves
29+
except ImportError:
30+
from torch.utils._pytree import tree_leaves # type: ignore[no-redef]
31+
1332

1433
logger = logging.getLogger(__name__)
1534

@@ -52,6 +71,126 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
5271
propagator.propagate_op_sharding.cache.cache_clear()
5372

5473

74+
# -------------define universal op strategy-------------
75+
def batch_shard_strategy(
76+
op_schema: OpSchema,
77+
input_shard_dim: list[Optional[int]],
78+
output_shard_dim: list[Optional[int]],
79+
enable_shard_batch_dim_over_multiple_axis: bool = False,
80+
) -> OpStrategy:
81+
"""
82+
Shard the input tensor over the specified dimensions. The strategy will map
83+
batch dim of input/output tensors to the same device mesh axis (or same
84+
multiple device axes). All input must either have one specified batch dim or
85+
no batch dim. If an input doesn't have batch dim, the strategy will assume
86+
the tensor will be broadcasted to batch dim and processed by the operator.
87+
For inputs specified with a batch dim, user need to make sure the batch dim
88+
size are all the same. Output should always have a batch dim.
89+
90+
Args:
91+
op_schema (OpSchema): the op schema.
92+
93+
input_shard_dim (list[Optional[int]]): the list of shard dimensions to
94+
consider for each input tensor argument. Use `None` if no batch dim of
95+
the input arg. If an arg is List[Tenor], we flatten it first and then
96+
match with input_shard_dim. Since the dim is not specific to the device
97+
mesh axis, it can be a combination of any device axes. Example 1: input
98+
tensor A[1024,64,8], B[1024,64,16], with input_shard_dim = [1,1], it can
99+
shard A's dim 0 over device axis X, and shard B's dim 0 over device axis
100+
X. X can be any of device axes. The output follow the same sharding as
101+
input. Example 2: input tensor A[64,8], B[64,16,1024], C[64,8], with
102+
input_shard_dim = [None,2,None], it will Replicate A,C over all device
103+
dim and only shard B's dim 2 over the device mesh. Assume the device
104+
mesh has 3 axis, then tensor B's placement can be (Shard(2), Shard(2),
105+
Replicate()), (Shard(2), Replicate(), Shard(2)), (Replicate(), Shard(2),
106+
Shard(2)).
107+
108+
output_shard_dim (list[Optional[int]]): the list of shard dimensions to
109+
consider for each output tensor argument. Use `None` if no batch dim of
110+
the output arg. For example, if the output is a single tensor and is
111+
sharded on dim 0, pass in [0] then.
112+
113+
enable_shard_batch_dim_over_multiple_axis (bool): if True, the strategy
114+
will try also map batch dim to multiple device axis. Default is False.
115+
116+
Note: It is the user's responsibility to make sure the sharded tensor for
117+
processing is correct in shape.
118+
"""
119+
output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
120+
# TODO(zpcore): Confirm if view op can be handle properly or not. Prevent
121+
# handling view ops until confirmed.
122+
if op_schema.op.is_view:
123+
raise RuntimeError(
124+
"fallback strategy is unable to handle view ops until confirmed"
125+
)
126+
if "List[Tensor]" in output_type:
127+
raise RuntimeError(
128+
"fallback strategy is unable to handle ops with List[Tensor] output "
129+
"because size of the list may depend on the op's input value"
130+
)
131+
inputs_strategy = tree_leaves(op_schema.args_strategy)
132+
assert len(inputs_strategy) == len(input_shard_dim)
133+
output_strategy = OpStrategy([])
134+
mesh = inputs_strategy[0].mesh
135+
device_axis = list(range(mesh.ndim))
136+
use_how_many_axis = (
137+
[i + 1 for i in range(mesh.ndim)]
138+
if enable_shard_batch_dim_over_multiple_axis
139+
else [1]
140+
)
141+
# number of device axes to shard on for the batch dim
142+
for num_axis in use_how_many_axis:
143+
device_combinations = list(itertools.combinations(device_axis, num_axis))
144+
# e.g., if num_axis == 2, device_combinations = [(0,1), (0,2), (1,2),
145+
# ...]. Then One feasible strategy is to shard tensor dim on both axis
146+
# (0,1). We check all combinations in device_combinations below.
147+
for comb in device_combinations:
148+
input_specs_list: list[DTensorSpec] = []
149+
output_specs_list: list[DTensorSpec] = []
150+
is_shardable = True
151+
for op_stratgy, dim in zip(inputs_strategy, input_shard_dim):
152+
# create a new list of shard_dim_option
153+
new_placements: list[Placement] = [Replicate()] * mesh.ndim
154+
for axis in comb:
155+
new_placements[axis] = Shard(dim) if dim else Replicate()
156+
tensor_meta = op_stratgy.strategies[0].output_spec.tensor_meta
157+
new_input_spec = DTensorSpec(
158+
mesh,
159+
tuple(new_placements),
160+
tensor_meta=op_stratgy.strategies[0].output_spec.tensor_meta,
161+
)
162+
if not is_tensor_shardable(tensor_meta.shape, new_input_spec):
163+
is_shardable = False
164+
break
165+
input_specs_list.append(new_input_spec)
166+
if not is_shardable:
167+
continue
168+
for dim in output_shard_dim:
169+
new_placements = [Replicate()] * mesh.ndim
170+
for axis in comb:
171+
new_placements[axis] = Shard(dim) if dim else Replicate()
172+
output_spec = DTensorSpec(
173+
mesh,
174+
tuple(new_placements),
175+
)
176+
output_specs_list.append(output_spec)
177+
178+
output_specs = (
179+
output_specs_list[0]
180+
if len(output_specs_list) == 1
181+
else tuple(output_specs_list)
182+
)
183+
input_specs = input_specs_list
184+
redistribute_cost = [
185+
generate_redistribute_costs(strat, input_specs_list[i])
186+
for i, strat in enumerate(inputs_strategy)
187+
]
188+
output_strategy.strategies.append(
189+
OpSpec(output_specs, input_specs, redistribute_cost) # type: ignore
190+
)
191+
return output_strategy
192+
193+
55194
def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType:
56195
global enable_implicit_replication, _current_stack
57196

tests/test_dtensor.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import functools
7+
68
import numpy as np
79
import torch
810
from torch.distributed.device_mesh import init_device_mesh
9-
from torch.distributed.tensor import DTensor, Shard, distribute_tensor
11+
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
1012
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1113
from torch.distributed.tensor._op_schema import (
1214
OpInfo,
@@ -15,6 +17,7 @@
1517
OpStrategy,
1618
OutputSharding,
1719
OutputSpecType,
20+
RuntimeSchemaInfo,
1821
TupleStrategy,
1922
)
2023
from torch.testing._internal.common_utils import run_tests
@@ -23,7 +26,12 @@
2326
with_comms,
2427
)
2528

26-
from autoparallel.dtensor_util import get_op_strategy, with_implicit_strategies
29+
from autoparallel.dtensor_util import (
30+
batch_shard_strategy,
31+
get_op_strategy,
32+
op_strategy_context,
33+
with_implicit_strategies,
34+
)
2735

2836
propagator = DTensor._op_dispatcher.sharding_propagator
2937

@@ -370,5 +378,105 @@ def test_implicit_registration(self):
370378
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
371379

372380

381+
class DimShardingTest(DTensorTestBase):
382+
@with_comms
383+
def test_simple_batch_sharding(self):
384+
# both input tensors batch on dim 0
385+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
386+
test_op = torch.ops.mylib.numpy_sin.default
387+
388+
# 1. strategy that will try shard dim 0 into one devices axis.
389+
shard_first_dim_strategy = functools.partial(
390+
batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0]
391+
)
392+
with op_strategy_context(test_op, shard_first_dim_strategy):
393+
# dim 0 is the batch dim. Here we shard 16 over one device axis
394+
input_x = torch.randn([16, 1, 4], device=self.device_type)
395+
input_y = torch.randn([16, 1, 4], device=self.device_type)
396+
# any sharding below should work
397+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
398+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
399+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
400+
401+
# 2. strategy that will try shard dim 0 into multiple devices.
402+
shard_first_dim_to_multiple_devices_strategy = functools.partial(
403+
batch_shard_strategy,
404+
input_shard_dim=[0, 0],
405+
output_shard_dim=[0],
406+
enable_shard_batch_dim_over_multiple_axis=True,
407+
)
408+
with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy):
409+
# dim 0 is the batch dim. Here we potentially shard 16 over multiple device axes
410+
input_x = torch.randn([16, 1, 4], device=self.device_type)
411+
input_y = torch.randn([16, 1, 4], device=self.device_type)
412+
# any sharding below should work
413+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
414+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
415+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
416+
417+
@with_comms
418+
def test_broadcast_batch_sharding(self):
419+
# Not recommended, user need to make sure the op supports input with
420+
# broadcast first. If not supported, try unsqueeze inputs first to match
421+
# each other's dimensions and and use the example in the
422+
# test_simple_batch_sharding test.
423+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
424+
test_op = torch.ops.mylib.numpy_sin.default
425+
shard_on_first_dim_strategy = functools.partial(
426+
batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]
427+
)
428+
with op_strategy_context(test_op, shard_on_first_dim_strategy):
429+
input_x = torch.randn([1, 4], device=self.device_type)
430+
# input_y's 16 locates on the batch dim
431+
input_y = torch.randn([16, 1, 4], device=self.device_type)
432+
# any sharding below should work as long as the tensor dim it is shardable
433+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
434+
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
435+
436+
output_dt = test_op(input_x_dt, input_y_dt)
437+
438+
# split the batch dim to test correctness
439+
input_y_chucks = torch.chunk(input_y, 4, dim=0)
440+
output = torch.cat(
441+
[test_op(input_x, input_y_part) for input_y_part in input_y_chucks]
442+
)
443+
self.assertEqual(output_dt.full_tensor(), output)
444+
445+
# or we can test directly since the op support broadcast.
446+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
447+
448+
@with_comms
449+
def test_simple_tuple_batch_sharding(self):
450+
# both input tensors batch on dim 0
451+
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
452+
test_op = torch.ops.mylib.numpy_tuple_sin.default
453+
454+
# 1. strategy that will try shard dim 0 into one devices axis.
455+
shard_first_dim_to_multiple_devices_strategy = functools.partial(
456+
batch_shard_strategy,
457+
input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y
458+
output_shard_dim=[0],
459+
enable_shard_batch_dim_over_multiple_axis=True,
460+
)
461+
with op_strategy_context(
462+
test_op,
463+
shard_first_dim_to_multiple_devices_strategy,
464+
schema_info=RuntimeSchemaInfo(needs_pytree=True),
465+
):
466+
# dim 0 is the batch dim. Here we shard 16 over one device axis
467+
input_x = torch.randn([16, 8, 4], device=self.device_type)
468+
input_y = [
469+
torch.randn([16, 8, 4], device=self.device_type) for _ in range(3)
470+
]
471+
input_z = torch.randn([16, 8, 4], device=self.device_type)
472+
# any sharding below should work as long as the tensor dim it is shardable
473+
input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)])
474+
input_y_dt = [
475+
distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y
476+
]
477+
input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)])
478+
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt)
479+
480+
373481
if __name__ == "__main__":
374482
run_tests()

0 commit comments

Comments
 (0)