Skip to content

Commit 2c4cd82

Browse files
committed
Update on "Support of implicit fallback"
(Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
1 parent 78ed7e0 commit 2c4cd82

File tree

2 files changed

+58
-52
lines changed

2 files changed

+58
-52
lines changed

autoparallel/dtensor_util/utils.py

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,13 @@
55

66

77
import logging
8-
from contextlib import contextmanager
8+
from contextlib import ExitStack, contextmanager
99
from typing import Callable, TypeVar
1010

1111
import torch
1212
from torch.distributed.tensor import DTensor
13-
from torch.distributed.tensor._op_schema import (
14-
OpSchema,
15-
OutputSharding,
16-
RuntimeSchemaInfo,
17-
StrategyType,
18-
)
13+
from torch.distributed.tensor._op_schema import OpSchema, OutputSharding, StrategyType
14+
from torch.distributed.tensor._ops.utils import register_op_strategy
1915
from typing_extensions import ParamSpec
2016

2117
logger = logging.getLogger(__name__)
@@ -26,6 +22,34 @@
2622
_P = ParamSpec("_P")
2723

2824

25+
# TODO: remove and refer to
26+
# https://github.com/pytorch/pytorch/blob/9c107606629de6383f55e3b48b42e594d23407b1/test/distributed/tensor/test_op_strategy.py#L446
27+
# once the function is moved outside of the test folder in upstream
28+
@contextmanager
29+
def op_strategy_context(op_overload, strategy_func, schema_info=None):
30+
"""
31+
Context manager for setting and clearing op strategies.
32+
Args:
33+
op_overload: The operator overload to set or clear the strategy for.
34+
strategy_func: The strategy function to set for the operator overload.
35+
schema_info: Optional schema information for the operator overload.
36+
Yields:
37+
None
38+
"""
39+
propagator = DTensor._op_dispatcher.sharding_propagator
40+
try:
41+
# register the op strategy
42+
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
43+
yield
44+
finally:
45+
# clear this op strategy cache
46+
if op_overload in propagator.op_strategy_funcs:
47+
del propagator.op_strategy_funcs[op_overload]
48+
if op_overload in propagator.op_to_schema_info:
49+
del propagator.op_to_schema_info[op_overload]
50+
propagator.propagate_op_sharding.cache.cache_clear()
51+
52+
2953
# -------------define universal op strategy-------------
3054
replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy
3155

@@ -44,9 +68,8 @@ def __init__(self) -> None:
4468
self.op_to_schema_info = (
4569
DTensor._op_dispatcher.sharding_propagator.op_to_schema_info
4670
)
47-
4871
self.enable_implicit_replication: bool = False
49-
self.implicit_strategy_op_tracker: list[torch._ops.OpOverload] = []
72+
self._current_stack = None
5073

5174
def get_op_strategy(
5275
self, op: torch._ops.OpOverload, op_schema: OpSchema
@@ -57,54 +80,37 @@ def get_op_strategy(
5780
f"Operator {op} does not have a sharding strategy registered."
5881
)
5982
else:
60-
self.implicit_strategy_op_tracker.append(op)
83+
# Use the instance's current stack if available
84+
if self._current_stack is not None:
85+
self._current_stack.enter_context(
86+
op_strategy_context(op, replicate_op_strategy)
87+
)
88+
else:
89+
# No stack available, just register permanently
90+
register_op_strategy(op)(replicate_op_strategy)
6191
logger.warning(
62-
f"implicitly register sharding strategy op {op.name()} using {replicate_op_strategy.__name__}"
92+
f"implicitly registering `{op}` with `{replicate_op_strategy.__name__}`"
6393
)
64-
self.register_op_strategy(op)(replicate_op_strategy)
6594
return self.op_strategy_funcs[op](op_schema)
6695

67-
def register_op_strategy(
68-
self,
69-
op: torch._ops.OpOverload,
70-
schema_info=RuntimeSchemaInfo(needs_pytree=True),
71-
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
72-
# pyre-fixme[2]: Parameter must be annotated.
73-
# always enable pytree as dispatching overhead is not a concern in AP.
74-
def wrapper(impl):
75-
if isinstance(op, list):
76-
overloads = op
77-
else:
78-
overloads = [op]
79-
80-
for overload in overloads:
81-
self.op_strategy_funcs[overload] = impl
82-
self.op_to_schema_info[overload] = schema_info
83-
return impl
84-
85-
return wrapper
86-
8796
@contextmanager
88-
def replicate_for_unsupported_operators(self):
89-
"""
90-
Context manager for setting and clearing implicit strategy.
91-
"""
92-
try:
93-
if self.enable_implicit_replication:
94-
raise RuntimeError(
95-
"Implicit strategy is already enabled. Cannot enable it again."
96-
)
97+
def with_implicit_strategies(self):
98+
"""Context manager to enable implicit replication and clean up strategies."""
99+
# Create a fresh ExitStack for this context
100+
with ExitStack() as local_stack:
101+
# Store the stack as an instance attribute
102+
old_stack = self._current_stack
103+
self._current_stack = local_stack
104+
105+
# Enable implicit replication
106+
old_value = self.enable_implicit_replication
97107
self.enable_implicit_replication = True
98-
yield
99-
finally:
100-
self.enable_implicit_replication = False
101-
op_to_remove = self.implicit_strategy_op_tracker
102-
for op_overload in op_to_remove:
103-
if op_overload in self.op_strategy_funcs:
104-
del self.op_strategy_funcs[op_overload]
105-
if op_overload in self.op_to_schema_info:
106-
del self.op_to_schema_info[op_overload]
107-
self.implicit_strategy_op_tracker.clear()
108+
try:
109+
yield
110+
finally:
111+
# Restore the original values
112+
self._current_stack = old_stack
113+
self.enable_implicit_replication = old_value
108114

109115
# TODO: automatic generate redistribute cost for strategies. There exists a
110116
# `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack

tests/test_dtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def test_implicit_registration(self):
383383
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
384384

385385
# 2. test_op strategy implicitly registered under context manager
386-
with strategy_pool.replicate_for_unsupported_operators():
386+
with strategy_pool.with_implicit_strategies():
387387
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
388388

389389
# 3. remove registration after exiting the context manager

0 commit comments

Comments
 (0)