Skip to content

Commit 65250eb

Browse files
committed
Update base for Update on "introduce batch sharding strategy"
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
1 parent 88b5db9 commit 65250eb

File tree

5 files changed

+89
-120
lines changed

5 files changed

+89
-120
lines changed

autoparallel/dtensor_util/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,17 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
# functions to expose
7+
from .utils import (
8+
get_op_strategy,
9+
op_strategy_context,
10+
replicate_op_strategy,
11+
with_implicit_strategies,
12+
)
613

7-
from . import utils
8-
9-
strategy_pool = utils.StrategyPool()
14+
__all__ = [
15+
"replicate_op_strategy",
16+
"get_op_strategy",
17+
"with_implicit_strategies",
18+
"op_strategy_context",
19+
]

autoparallel/dtensor_util/utils.py

Lines changed: 59 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,25 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
76
import logging
87
from contextlib import ExitStack, contextmanager
9-
from typing import Callable, TypeVar
108

119
import torch
1210
from torch.distributed.tensor import DTensor
13-
from torch.distributed.tensor._op_schema import OpSchema, OutputSharding, StrategyType
11+
from torch.distributed.tensor._op_schema import OpSchema, StrategyType
1412
from torch.distributed.tensor._ops.utils import register_op_strategy
15-
from typing_extensions import ParamSpec
1613

1714
logger = logging.getLogger(__name__)
1815

1916
aten = torch.ops.aten
2017

21-
_T = TypeVar("_T")
22-
_P = ParamSpec("_P")
18+
# reference to existing sharding_propagator DTensor upstream
19+
propagator = DTensor._op_dispatcher.sharding_propagator
20+
21+
enable_implicit_replication = False
22+
_current_stack = None
23+
24+
replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy
2325

2426

2527
# TODO: remove and refer to
@@ -50,76 +52,57 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
5052
propagator.propagate_op_sharding.cache.cache_clear()
5153

5254

53-
# -------------define universal op strategy-------------
54-
replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy
55-
55+
def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType:
56+
global enable_implicit_replication, _current_stack
5657

57-
class StrategyPool:
58-
def __init__(self) -> None:
59-
# reference to existing strategy from the DTensor upstream
60-
self.op_strategy_funcs: dict[
61-
torch._ops.OpOverload, Callable[[OpSchema], StrategyType]
62-
] = DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs
63-
# reference to existing rules
64-
self.op_to_rules: dict[
65-
torch._ops.OpOverload, Callable[[OpSchema], OutputSharding]
66-
] = DTensor._op_dispatcher.sharding_propagator.op_to_rules
67-
# we probably don't need to care about existing op_to_schema_info for AP
68-
self.op_to_schema_info = (
69-
DTensor._op_dispatcher.sharding_propagator.op_to_schema_info
70-
)
71-
self.enable_implicit_replication: bool = False
72-
self._current_stack = None
73-
74-
def get_op_strategy(
75-
self, op: torch._ops.OpOverload, op_schema: OpSchema
76-
) -> StrategyType:
77-
if op not in self.op_strategy_funcs:
78-
if not self.enable_implicit_replication:
79-
raise NotImplementedError(
80-
f"Operator {op} does not have a sharding strategy registered."
58+
if op not in propagator.op_strategy_funcs:
59+
if not enable_implicit_replication:
60+
raise NotImplementedError(
61+
f"Operator {op} does not have a sharding strategy registered."
62+
)
63+
else:
64+
# Use the current stack if available
65+
if _current_stack is not None:
66+
_current_stack.enter_context(
67+
op_strategy_context(op, replicate_op_strategy)
8168
)
8269
else:
83-
# Use the instance's current stack if available
84-
if self._current_stack is not None:
85-
self._current_stack.enter_context(
86-
op_strategy_context(op, replicate_op_strategy)
87-
)
88-
else:
89-
# No stack available, just register permanently
90-
register_op_strategy(op)(replicate_op_strategy)
91-
logger.warning(
92-
f"implicitly registering `{op}` with `{replicate_op_strategy.__name__}`"
93-
)
94-
return self.op_strategy_funcs[op](op_schema)
95-
96-
@contextmanager
97-
def with_implicit_strategies(self):
98-
"""Context manager to enable implicit replication and clean up strategies."""
99-
# Create a fresh ExitStack for this context
100-
with ExitStack() as local_stack:
101-
# Store the stack as an instance attribute
102-
old_stack = self._current_stack
103-
self._current_stack = local_stack
104-
105-
# Enable implicit replication
106-
old_value = self.enable_implicit_replication
107-
self.enable_implicit_replication = True
108-
try:
109-
yield
110-
finally:
111-
# Restore the original values
112-
self._current_stack = old_stack
113-
self.enable_implicit_replication = old_value
114-
115-
# TODO: automatic generate redistribute cost for strategies. There exists a
116-
# `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack
117-
# to generate redistribute cost given input specs, and only tested on
118-
# certain ops. We can potentially make an improvement.
119-
def fill_missing_redistribute_cost(
120-
self, op: torch._ops.OpOverload, op_schema: OpSchema
121-
):
122-
"""
123-
Fill missing redistribute cost for strategies.
124-
"""
125-
...
70+
# No stack available, just register permanently
71+
register_op_strategy(op)(replicate_op_strategy)
72+
logger.warning(
73+
f"implicitly registering `{op}` with `{replicate_op_strategy.__name__}`"
74+
)
75+
return propagator.op_strategy_funcs[op](op_schema)
76+
77+
78+
@contextmanager
79+
def with_implicit_strategies():
80+
"""Context manager to enable implicit replication and clean up strategies."""
81+
global enable_implicit_replication, _current_stack
82+
83+
# Create a fresh ExitStack for this context
84+
with ExitStack() as local_stack:
85+
# Store the stack as a global variable
86+
old_stack = _current_stack
87+
_current_stack = local_stack
88+
89+
# Enable implicit replication
90+
old_value = enable_implicit_replication
91+
enable_implicit_replication = True
92+
try:
93+
yield
94+
finally:
95+
# Restore the original values
96+
_current_stack = old_stack
97+
enable_implicit_replication = old_value
98+
99+
100+
# TODO: automatic generate redistribute cost for strategies. There exists a
101+
# `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack
102+
# to generate redistribute cost given input specs, and only tested on
103+
# certain ops. We can potentially make an improvement.
104+
def fill_missing_redistribute_cost(op: torch._ops.OpOverload, op_schema: OpSchema):
105+
"""
106+
Fill missing redistribute cost for strategies.
107+
"""
108+
...

autoparallel/propagation_rules.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
3737
from torch.distributed.tensor.placement_types import Replicate, Shard
3838

39-
from .dtensor_util import strategy_pool
39+
from .dtensor_util import get_op_strategy
4040

4141
# TODO: move this to PyTorch
4242
dim_maps[torch.t] = lambda input: dim_transpose(input.ndim, -2, -1)
@@ -584,7 +584,7 @@ def index_rule(mesh, op_schema):
584584
@register_opschema_rule(torch.ops.aten._scaled_dot_product_efficient_attention.default)
585585
def sdpa_rule(mesh, op_schema):
586586
op = torch.ops.aten._scaled_dot_product_efficient_attention.default
587-
out_strat = strategy_pool.get_op_strategy(op, op_schema)
587+
out_strat = get_op_strategy(op, op_schema)
588588
# remove wrong context-parallel strategy
589589
# https://github.com/pytorch/pytorch/pull/131351#discussion_r1716164659
590590
new_strats = []
@@ -611,7 +611,7 @@ def sdpa_rule(mesh, op_schema):
611611
@register_opschema_rule(torch.ops.aten.reshape.default)
612612
def reshape_rule(mesh, op_schema):
613613
op = torch.ops.aten.reshape.default
614-
out_strat = strategy_pool.get_op_strategy(op, op_schema)
614+
out_strat = get_op_strategy(op, op_schema)
615615
if mesh.ndim == 1:
616616
# remove duplicate strategy
617617
# TODO: hack, fixme
@@ -637,7 +637,7 @@ def expand_rule(mesh, op_schema_):
637637
]
638638
if len(expand_dim) != 1:
639639
assert len(expand_dim) == 0
640-
return strategy_pool.get_op_strategy(op, op_schema)
640+
return get_op_strategy(op, op_schema)
641641
assert len(expand_dim) == 1, f"{expand_dim}"
642642
expand_dim = expand_dim[0]
643643
to_remove = []
@@ -651,7 +651,7 @@ def expand_rule(mesh, op_schema_):
651651
removed = []
652652
for i in reversed(to_remove):
653653
removed.append(input_strat.strategies.pop(i))
654-
out_strat = strategy_pool.get_op_strategy(op, op_schema)
654+
out_strat = get_op_strategy(op, op_schema)
655655
for i, ss in enumerate(out_strat.strategies):
656656
for remov in to_remove:
657657
ss.redistribute_cost[0].insert(remov, math.inf)

autoparallel/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
1111
from torch.utils._pytree import tree_flatten, tree_map_only
1212

13-
from .dtensor_util import strategy_pool
13+
from .dtensor_util import get_op_strategy
1414
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
1515

1616

@@ -111,7 +111,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs):
111111
if op in _op_partial_rules:
112112
out_strat = _op_partial_rules[op](mesh, op_schema)
113113
else:
114-
out_strat = strategy_pool.get_op_strategy(op, op_schema)
114+
out_strat = get_op_strategy(op, op_schema)
115115

116116
propagate_tensor_meta(op, user_args, user_kwargs, out_strat)
117117
fill_missing_redistribute_cost(op, specs, out_strat)

tests/test_dtensor.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from contextlib import contextmanager
7-
86
import numpy as np
97
import torch
10-
from torch.distributed.tensor import DTensor, Shard, distribute_tensor, init_device_mesh
8+
from torch.distributed.device_mesh import init_device_mesh
9+
from torch.distributed.tensor import DTensor, Shard, distribute_tensor
1110
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1211
from torch.distributed.tensor._op_schema import (
1312
OpInfo,
@@ -24,7 +23,9 @@
2423
with_comms,
2524
)
2625

27-
from autoparallel.dtensor_util import strategy_pool
26+
from autoparallel.dtensor_util import get_op_strategy, with_implicit_strategies
27+
28+
propagator = DTensor._op_dispatcher.sharding_propagator
2829

2930
aten = torch.ops.aten
3031

@@ -94,31 +95,6 @@ def _fw_tuple(x, y, z):
9495
)
9596

9697

97-
@contextmanager
98-
def op_strategy_context(op_overload, strategy_func, schema_info=None):
99-
"""
100-
Context manager for setting and clearing op strategies in unit tests.
101-
Args:
102-
op_overload: The operator overload to set or clear the strategy for.
103-
strategy_func: The strategy function to set for the operator overload.
104-
schema_info: Optional schema information for the operator overload.
105-
Yields:
106-
None
107-
"""
108-
try:
109-
# register the op strategy
110-
strategy_pool.register_op_strategy(op_overload, schema_info=schema_info)(
111-
strategy_func
112-
)
113-
yield
114-
finally:
115-
# clear this op strategy cache
116-
if op_overload in strategy_pool.op_strategy_funcs:
117-
del strategy_pool.op_strategy_funcs[op_overload]
118-
if op_overload in strategy_pool.op_to_schema_info:
119-
del strategy_pool.op_to_schema_info[op_overload]
120-
121-
12298
# Overwrite upstream `_op_dispatcher.sharding_propagator` with customized
12399
# sharding_propagator. This is for testing purpose under eager mode and
124100
# AutoParallel won't use the propagate function. The main changes are 1) Skip
@@ -132,9 +108,9 @@ class CustomShardingPropagator(
132108
def __init__(self):
133109
super().__init__()
134110
self.propagate_op_sharding.cache.cache_clear()
135-
self.op_to_rules = strategy_pool.op_to_rules
136-
self.op_strategy_funcs = strategy_pool.op_strategy_funcs
137-
self.op_to_schema_info = strategy_pool.op_to_schema_info
111+
self.op_to_rules = propagator.op_to_rules
112+
self.op_strategy_funcs = propagator.op_strategy_funcs
113+
self.op_to_schema_info = propagator.op_to_schema_info
138114

139115
def propagate(self, op_info: OpInfo) -> None:
140116
op_info.output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
@@ -199,7 +175,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin
199175
strategy_schema.schema_info = op_schema.schema_info
200176

201177
# assign implicit strategy if enabled
202-
strategy_pool.get_op_strategy(strategy_schema.op, strategy_schema)
178+
get_op_strategy(strategy_schema.op, strategy_schema)
203179

204180
# run sharding strategy propagation/generation
205181
op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema)
@@ -383,7 +359,7 @@ def test_implicit_registration(self):
383359
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
384360

385361
# 2. test_op strategy implicitly registered under context manager
386-
with strategy_pool.with_implicit_strategies():
362+
with with_implicit_strategies():
387363
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
388364

389365
# 3. remove registration after exiting the context manager

0 commit comments

Comments
 (0)