Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autoparallel/dtensor_util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# functions to expose
from .utils import (
batch_shard_strategy,
get_op_strategy,
op_strategy_context,
replicate_op_strategy,
Expand All @@ -13,6 +14,7 @@

__all__ = [
"replicate_op_strategy",
"batch_shard_strategy",
"get_op_strategy",
"with_implicit_strategies",
"op_strategy_context",
Expand Down
165 changes: 159 additions & 6 deletions autoparallel/dtensor_util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,32 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import logging
from contextlib import ExitStack, contextmanager
from typing import Optional

import torch
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._op_schema import OpSchema, StrategyType
from torch.distributed.tensor._ops.utils import register_op_strategy
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
OpSchema,
OpSpec,
OpStrategy,
StrategyType,
)
from torch.distributed.tensor._ops.utils import (
generate_redistribute_costs,
is_tensor_shardable,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard

try:
from torch.utils._cxx_pytree import tree_leaves
except ImportError:
from torch.utils._pytree import tree_leaves # type: ignore[no-redef]


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,19 +58,153 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
None
"""
propagator = DTensor._op_dispatcher.sharding_propagator
_origin_op_strategy_funcs = None
_origin_op_strategy_schema = None
try:
# register the op strategy
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
yield
finally:
# clear this op strategy cache
if op_overload in propagator.op_strategy_funcs:
_origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload]
del propagator.op_strategy_funcs[op_overload]
if op_overload in propagator.op_to_schema_info:
_origin_op_strategy_schema = propagator.op_to_schema_info[op_overload]
del propagator.op_to_schema_info[op_overload]
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
yield
finally:
# clear this op strategy cache
if _origin_op_strategy_funcs is None:
if op_overload in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[op_overload]
else:
propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs
if _origin_op_strategy_schema is None:
if op_overload in propagator.op_to_schema_info:
del propagator.op_to_schema_info[op_overload]
else:
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
propagator.propagate_op_sharding.cache.cache_clear()


# -------------define universal op strategy-------------
def batch_shard_strategy(
op_schema: OpSchema,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this takes an OpSchema and not an OpStrategy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type definition of OpSchema is very unclear in DTensor. Here the OpSchema is a collection of OpStrategy for all input args. While for the output, most of the time we just have one output tensor, so OpStrategy is sufficient for output.

input_shard_dim: list[Optional[int]],
output_shard_dim: list[Optional[int]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The terminology is confusing here. Is the shard dim the batch dim? Or is it something else? If it is the batch dim, the comments below seem to imply there is only one batch dim, why is this a list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's batch dim. There can be multiple tensor input. Each element in input_shard_dim maps to one input tensor. Same to output tensor(s).

enable_shard_batch_dim_over_multiple_axis: bool = False,
) -> OpStrategy:
"""
Shard the input tensor over the specified dimensions. The strategy will map
batch dim of input/output tensors to the same device mesh axis (or same
multiple device axes). All input must either have one specified batch dim or
no batch dim. If an input doesn't have batch dim, the strategy will assume
the tensor will be broadcasted to batch dim and processed by the operator.
For inputs specified with a batch dim, user need to make sure the batch dim
size are all the same. Output should always have a batch dim.

Args:
op_schema (OpSchema): the op schema.

input_shard_dim (list[Optional[int]]): the list of shard dimensions to
consider for each input tensor argument. Use `None` if no batch dim of
the input arg. If an arg is List[Tenor], we flatten it first and then
match with input_shard_dim. Since the dim is not specific to the device
mesh axis, it can be a combination of any device axes. Example 1: input
tensor A[1024,64,8], B[1024,64,16], with input_shard_dim = [1,1], it can
shard A's dim 0 over device axis X, and shard B's dim 0 over device axis
X. X can be any of device axes. The output follow the same sharding as
input. Example 2: input tensor A[64,8], B[64,16,1024], C[64,8], with
input_shard_dim = [None,2,None], it will Replicate A,C over all device
dim and only shard B's dim 2 over the device mesh. Assume the device
mesh has 3 axis, then tensor B's placement can be (Shard(2), Shard(2),
Replicate()), (Shard(2), Replicate(), Shard(2)), (Replicate(), Shard(2),
Shard(2)).

output_shard_dim (list[Optional[int]]): the list of shard dimensions to
consider for each output tensor argument. Use `None` if no batch dim of
the output arg. For example, if the output is a single tensor and is
sharded on dim 0, pass in [0] then.

enable_shard_batch_dim_over_multiple_axis (bool): if True, the strategy
will try also map batch dim to multiple device axis. Default is False.

Note: It is the user's responsibility to make sure the sharded tensor for
processing is correct in shape.
"""
output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running str on the type, very suspicious!

Copy link
Contributor Author

@zpcore zpcore Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the recommended way to check the tensor output type?

# TODO(zpcore): Confirm if view op can be handle properly or not. Prevent
# handling view ops until confirmed.
if op_schema.op.is_view:
raise RuntimeError(
"fallback strategy is unable to handle view ops until confirmed"
)
if "List[Tensor]" in output_type:
raise RuntimeError(
"fallback strategy is unable to handle ops with List[Tensor] output "
"because size of the list may depend on the op's input value"
)
inputs_strategy = tree_leaves(op_schema.args_strategy)
assert len(inputs_strategy) == len(input_shard_dim)
output_strategy = OpStrategy([])
mesh = inputs_strategy[0].mesh
device_axis = list(range(mesh.ndim))
use_how_many_axis = (
[i + 1 for i in range(mesh.ndim)]
if enable_shard_batch_dim_over_multiple_axis
else [1]
)
# number of device axes to shard on for the batch dim
for num_axis in use_how_many_axis:
device_combinations = list(itertools.combinations(device_axis, num_axis))
# e.g., if num_axis == 2, device_combinations = [(0,1), (0,2), (1,2),
# ...]. Then One feasible strategy is to shard tensor dim on both axis
# (0,1). We check all combinations in device_combinations below.
for comb in device_combinations:
input_specs_list: list[DTensorSpec] = []
output_specs_list: list[DTensorSpec] = []
is_shardable = True
for op_stratgy, dim in zip(inputs_strategy, input_shard_dim):
# create a new list of shard_dim_option
new_placements: list[Placement] = [Replicate()] * mesh.ndim
for axis in comb:
new_placements[axis] = Shard(dim) if dim else Replicate()
tensor_meta = op_stratgy.strategies[0].output_spec.tensor_meta
new_input_spec = DTensorSpec(
mesh,
tuple(new_placements),
tensor_meta=op_stratgy.strategies[0].output_spec.tensor_meta,
)
if not is_tensor_shardable(tensor_meta.shape, new_input_spec):
is_shardable = False
break
input_specs_list.append(new_input_spec)
if not is_shardable:
continue
for dim in output_shard_dim:
new_placements = [Replicate()] * mesh.ndim
for axis in comb:
new_placements[axis] = Shard(dim) if dim else Replicate()
output_spec = DTensorSpec(
mesh,
tuple(new_placements),
)
output_specs_list.append(output_spec)

output_specs = (
output_specs_list[0]
if len(output_specs_list) == 1
else tuple(output_specs_list)
)
input_specs = input_specs_list
redistribute_cost = [
generate_redistribute_costs(strat, input_specs_list[i])
for i, strat in enumerate(inputs_strategy)
]
output_strategy.strategies.append(
OpSpec(output_specs, input_specs, redistribute_cost) # type: ignore
)
return output_strategy


def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType:
global enable_implicit_replication, _current_stack

Expand Down
112 changes: 110 additions & 2 deletions tests/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools

import numpy as np
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Shard, distribute_tensor
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
OpInfo,
Expand All @@ -15,6 +17,7 @@
OpStrategy,
OutputSharding,
OutputSpecType,
RuntimeSchemaInfo,
TupleStrategy,
)
from torch.testing._internal.common_utils import run_tests
Expand All @@ -23,7 +26,12 @@
with_comms,
)

from autoparallel.dtensor_util import get_op_strategy, with_implicit_strategies
from autoparallel.dtensor_util import (
batch_shard_strategy,
get_op_strategy,
op_strategy_context,
with_implicit_strategies,
)

propagator = DTensor._op_dispatcher.sharding_propagator

Expand Down Expand Up @@ -370,5 +378,105 @@ def test_implicit_registration(self):
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)


class DimShardingTest(DTensorTestBase):
@with_comms
def test_simple_batch_sharding(self):
# both input tensors batch on dim 0
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
test_op = torch.ops.mylib.numpy_sin.default

# 1. strategy that will try shard dim 0 into one devices axis.
shard_first_dim_strategy = functools.partial(
batch_shard_strategy, input_shard_dim=[0, 0], output_shard_dim=[0]
)
with op_strategy_context(test_op, shard_first_dim_strategy):
# dim 0 is the batch dim. Here we shard 16 over one device axis
input_x = torch.randn([16, 1, 4], device=self.device_type)
input_y = torch.randn([16, 1, 4], device=self.device_type)
# any sharding below should work
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)

# 2. strategy that will try shard dim 0 into multiple devices.
shard_first_dim_to_multiple_devices_strategy = functools.partial(
batch_shard_strategy,
input_shard_dim=[0, 0],
output_shard_dim=[0],
enable_shard_batch_dim_over_multiple_axis=True,
)
with op_strategy_context(test_op, shard_first_dim_to_multiple_devices_strategy):
# dim 0 is the batch dim. Here we potentially shard 16 over multiple device axes
input_x = torch.randn([16, 1, 4], device=self.device_type)
input_y = torch.randn([16, 1, 4], device=self.device_type)
# any sharding below should work
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)

@with_comms
def test_broadcast_batch_sharding(self):
# Not recommended, user need to make sure the op supports input with
# broadcast first. If not supported, try unsqueeze inputs first to match
# each other's dimensions and and use the example in the
# test_simple_batch_sharding test.
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
test_op = torch.ops.mylib.numpy_sin.default
shard_on_first_dim_strategy = functools.partial(
batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]
)
with op_strategy_context(test_op, shard_on_first_dim_strategy):
input_x = torch.randn([1, 4], device=self.device_type)
# input_y's 16 locates on the batch dim
input_y = torch.randn([16, 1, 4], device=self.device_type)
# any sharding below should work as long as the tensor dim it is shardable
input_x_dt = distribute_tensor(input_x, mesh, [Shard(1), Replicate()])
input_y_dt = distribute_tensor(input_y, mesh, [Replicate(), Shard(0)])

output_dt = test_op(input_x_dt, input_y_dt)

# split the batch dim to test correctness
input_y_chucks = torch.chunk(input_y, 4, dim=0)
output = torch.cat(
[test_op(input_x, input_y_part) for input_y_part in input_y_chucks]
)
self.assertEqual(output_dt.full_tensor(), output)

# or we can test directly since the op support broadcast.
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)

@with_comms
def test_simple_tuple_batch_sharding(self):
# both input tensors batch on dim 0
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
test_op = torch.ops.mylib.numpy_tuple_sin.default

# 1. strategy that will try shard dim 0 into one devices axis.
shard_first_dim_to_multiple_devices_strategy = functools.partial(
batch_shard_strategy,
input_shard_dim=[0, 0, 0, 0, 0], # flatten input_y
output_shard_dim=[0],
enable_shard_batch_dim_over_multiple_axis=True,
)
with op_strategy_context(
test_op,
shard_first_dim_to_multiple_devices_strategy,
schema_info=RuntimeSchemaInfo(needs_pytree=True),
):
# dim 0 is the batch dim. Here we shard 16 over one device axis
input_x = torch.randn([16, 8, 4], device=self.device_type)
input_y = [
torch.randn([16, 8, 4], device=self.device_type) for _ in range(3)
]
input_z = torch.randn([16, 8, 4], device=self.device_type)
# any sharding below should work as long as the tensor dim it is shardable
input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)])
input_y_dt = [
distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y
]
input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)])
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt, input_z_dt)


if __name__ == "__main__":
run_tests()