Skip to content

Commit 7d0ede7

Browse files
authored
Support of implicit fallback (#61)
ghstack-source-id: f0db91b Pull Request resolved: #49
1 parent 22e663f commit 7d0ede7

File tree

6 files changed

+511
-26
lines changed

6 files changed

+511
-26
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# functions to expose
7+
from .utils import (
8+
get_op_strategy,
9+
op_strategy_context,
10+
replicate_op_strategy,
11+
with_implicit_strategies,
12+
)
13+
14+
__all__ = [
15+
"replicate_op_strategy",
16+
"get_op_strategy",
17+
"with_implicit_strategies",
18+
"op_strategy_context",
19+
]

autoparallel/dtensor_util/utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from contextlib import ExitStack, contextmanager
8+
9+
import torch
10+
from torch.distributed.tensor import DTensor
11+
from torch.distributed.tensor._op_schema import OpSchema, StrategyType
12+
from torch.distributed.tensor._ops.utils import register_op_strategy
13+
14+
logger = logging.getLogger(__name__)
15+
16+
aten = torch.ops.aten
17+
18+
# reference to existing sharding_propagator DTensor upstream
19+
propagator = DTensor._op_dispatcher.sharding_propagator
20+
21+
enable_implicit_replication = False
22+
_current_stack = None
23+
24+
replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy
25+
26+
27+
# TODO: remove and refer to
28+
# https://github.com/pytorch/pytorch/blob/9c107606629de6383f55e3b48b42e594d23407b1/test/distributed/tensor/test_op_strategy.py#L446
29+
# once the function is moved outside of the test folder in upstream
30+
@contextmanager
31+
def op_strategy_context(op_overload, strategy_func, schema_info=None):
32+
"""
33+
Context manager for setting and clearing op strategies.
34+
Args:
35+
op_overload: The operator overload to set or clear the strategy for.
36+
strategy_func: The strategy function to set for the operator overload.
37+
schema_info: Optional schema information for the operator overload.
38+
Yields:
39+
None
40+
"""
41+
propagator = DTensor._op_dispatcher.sharding_propagator
42+
try:
43+
# register the op strategy
44+
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
45+
yield
46+
finally:
47+
# clear this op strategy cache
48+
if op_overload in propagator.op_strategy_funcs:
49+
del propagator.op_strategy_funcs[op_overload]
50+
if op_overload in propagator.op_to_schema_info:
51+
del propagator.op_to_schema_info[op_overload]
52+
propagator.propagate_op_sharding.cache.cache_clear()
53+
54+
55+
def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType:
56+
global enable_implicit_replication, _current_stack
57+
58+
if op not in propagator.op_strategy_funcs:
59+
if not enable_implicit_replication:
60+
raise NotImplementedError(
61+
f"Operator {op} does not have a sharding strategy registered."
62+
)
63+
else:
64+
# Use the current stack if available
65+
if _current_stack is not None:
66+
_current_stack.enter_context(
67+
op_strategy_context(op, replicate_op_strategy)
68+
)
69+
else:
70+
# No stack available, just register permanently
71+
register_op_strategy(op)(replicate_op_strategy)
72+
logger.warning(
73+
f"implicitly registering `{op}` with `{replicate_op_strategy.__name__}`"
74+
)
75+
return propagator.op_strategy_funcs[op](op_schema)
76+
77+
78+
@contextmanager
79+
def with_implicit_strategies():
80+
"""Context manager to enable implicit replication and clean up strategies."""
81+
global enable_implicit_replication, _current_stack
82+
83+
# Create a fresh ExitStack for this context
84+
with ExitStack() as local_stack:
85+
# Store the stack as a global variable
86+
old_stack = _current_stack
87+
_current_stack = local_stack
88+
89+
# Enable implicit replication
90+
old_value = enable_implicit_replication
91+
enable_implicit_replication = True
92+
try:
93+
yield
94+
finally:
95+
# Restore the original values
96+
_current_stack = old_stack
97+
enable_implicit_replication = old_value
98+
99+
100+
# TODO: automatic generate redistribute cost for strategies. There exists a
101+
# `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack
102+
# to generate redistribute cost given input specs, and only tested on
103+
# certain ops. We can potentially make an improvement.
104+
def fill_missing_redistribute_cost(op: torch._ops.OpOverload, op_schema: OpSchema):
105+
"""
106+
Fill missing redistribute cost for strategies.
107+
"""
108+
...

autoparallel/propagation_rules.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
)
4040
from torch.distributed.tensor.placement_types import Replicate, Shard
4141

42+
from .dtensor_util import get_op_strategy
43+
4244
# TODO: move this to PyTorch
4345
dim_maps[torch.t] = lambda input: dim_transpose(input.ndim, -2, -1)
4446

@@ -654,11 +656,7 @@ def index_rule(mesh, op_schema):
654656

655657

656658
def sdpa_rule(op, mesh, op_schema):
657-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
658-
op
659-
](
660-
op_schema
661-
)
659+
out_strat = get_op_strategy(op, op_schema)
662660
# remove wrong context-parallel strategy
663661
# https://github.com/pytorch/pytorch/pull/131351#discussion_r1716164659
664662
new_strats = []
@@ -687,11 +685,7 @@ def _(mesh, op_schema):
687685
@register_opschema_rule(torch.ops.aten.reshape.default)
688686
def reshape_rule(mesh, op_schema):
689687
op = torch.ops.aten.reshape.default
690-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
691-
op
692-
](
693-
op_schema
694-
)
688+
out_strat = get_op_strategy(op, op_schema)
695689
if mesh.ndim == 1:
696690
# remove duplicate strategy
697691
# TODO: hack, fixme
@@ -717,11 +711,7 @@ def expand_rule(mesh, op_schema_):
717711
]
718712
if len(expand_dim) != 1:
719713
assert len(expand_dim) == 0
720-
return torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
721-
op
722-
](
723-
op_schema
724-
)
714+
return get_op_strategy(op, op_schema)
725715
assert len(expand_dim) == 1, f"{expand_dim}"
726716
expand_dim = expand_dim[0]
727717
to_remove = []
@@ -735,11 +725,7 @@ def expand_rule(mesh, op_schema_):
735725
removed = []
736726
for i in reversed(to_remove):
737727
removed.append(input_strat.strategies.pop(i))
738-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
739-
op
740-
](
741-
op_schema
742-
)
728+
out_strat = get_op_strategy(op, op_schema)
743729
for i, ss in enumerate(out_strat.strategies):
744730
for remov in to_remove:
745731
ss.redistribute_cost[0].insert(remov, math.inf)

autoparallel/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch.distributed.tensor.placement_types import Replicate
1818
from torch.utils._pytree import tree_flatten, tree_map_only
1919

20+
from .dtensor_util import get_op_strategy
2021
from .propagation_rules import (
2122
TENSOR_FACTORY_OPS,
2223
_op_partial_rules,
@@ -162,11 +163,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs):
162163
if op in _op_partial_rules:
163164
out_strat = _op_partial_rules[op](mesh, op_schema)
164165
else:
165-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
166-
op
167-
](
168-
op_schema
169-
)
166+
out_strat = get_op_strategy(op, op_schema)
170167

171168
propagate_tensor_meta(op, user_args, user_kwargs, out_strat)
172169
fill_missing_redistribute_cost(op, specs, out_strat)

requirements-test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
torch >= 2.7.0
22
numpy
33
pulp
4-
pytest
4+
pytest >= 8.1
5+
expecttest
56

67
black == 22.3.0
78
flake8 == 6.1.0

0 commit comments

Comments
 (0)