Skip to content

Commit 8df62c4

Browse files
authored
Support of implicit fallback (#49)
* Support of explicit fallback [ghstack-poisoned] * Update on "Support of implicit fallback" (Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned] * Update on "Support of implicit fallback" (Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
1 parent b53ad10 commit 8df62c4

File tree

6 files changed

+511
-26
lines changed

6 files changed

+511
-26
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# functions to expose
7+
from .utils import (
8+
get_op_strategy,
9+
op_strategy_context,
10+
replicate_op_strategy,
11+
with_implicit_strategies,
12+
)
13+
14+
__all__ = [
15+
"replicate_op_strategy",
16+
"get_op_strategy",
17+
"with_implicit_strategies",
18+
"op_strategy_context",
19+
]

autoparallel/dtensor_util/utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from contextlib import ExitStack, contextmanager
8+
9+
import torch
10+
from torch.distributed.tensor import DTensor
11+
from torch.distributed.tensor._op_schema import OpSchema, StrategyType
12+
from torch.distributed.tensor._ops.utils import register_op_strategy
13+
14+
logger = logging.getLogger(__name__)
15+
16+
aten = torch.ops.aten
17+
18+
# reference to existing sharding_propagator DTensor upstream
19+
propagator = DTensor._op_dispatcher.sharding_propagator
20+
21+
enable_implicit_replication = False
22+
_current_stack = None
23+
24+
replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy
25+
26+
27+
# TODO: remove and refer to
28+
# https://github.com/pytorch/pytorch/blob/9c107606629de6383f55e3b48b42e594d23407b1/test/distributed/tensor/test_op_strategy.py#L446
29+
# once the function is moved outside of the test folder in upstream
30+
@contextmanager
31+
def op_strategy_context(op_overload, strategy_func, schema_info=None):
32+
"""
33+
Context manager for setting and clearing op strategies.
34+
Args:
35+
op_overload: The operator overload to set or clear the strategy for.
36+
strategy_func: The strategy function to set for the operator overload.
37+
schema_info: Optional schema information for the operator overload.
38+
Yields:
39+
None
40+
"""
41+
propagator = DTensor._op_dispatcher.sharding_propagator
42+
try:
43+
# register the op strategy
44+
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
45+
yield
46+
finally:
47+
# clear this op strategy cache
48+
if op_overload in propagator.op_strategy_funcs:
49+
del propagator.op_strategy_funcs[op_overload]
50+
if op_overload in propagator.op_to_schema_info:
51+
del propagator.op_to_schema_info[op_overload]
52+
propagator.propagate_op_sharding.cache.cache_clear()
53+
54+
55+
def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType:
56+
global enable_implicit_replication, _current_stack
57+
58+
if op not in propagator.op_strategy_funcs:
59+
if not enable_implicit_replication:
60+
raise NotImplementedError(
61+
f"Operator {op} does not have a sharding strategy registered."
62+
)
63+
else:
64+
# Use the current stack if available
65+
if _current_stack is not None:
66+
_current_stack.enter_context(
67+
op_strategy_context(op, replicate_op_strategy)
68+
)
69+
else:
70+
# No stack available, just register permanently
71+
register_op_strategy(op)(replicate_op_strategy)
72+
logger.warning(
73+
f"implicitly registering `{op}` with `{replicate_op_strategy.__name__}`"
74+
)
75+
return propagator.op_strategy_funcs[op](op_schema)
76+
77+
78+
@contextmanager
79+
def with_implicit_strategies():
80+
"""Context manager to enable implicit replication and clean up strategies."""
81+
global enable_implicit_replication, _current_stack
82+
83+
# Create a fresh ExitStack for this context
84+
with ExitStack() as local_stack:
85+
# Store the stack as a global variable
86+
old_stack = _current_stack
87+
_current_stack = local_stack
88+
89+
# Enable implicit replication
90+
old_value = enable_implicit_replication
91+
enable_implicit_replication = True
92+
try:
93+
yield
94+
finally:
95+
# Restore the original values
96+
_current_stack = old_stack
97+
enable_implicit_replication = old_value
98+
99+
100+
# TODO: automatic generate redistribute cost for strategies. There exists a
101+
# `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack
102+
# to generate redistribute cost given input specs, and only tested on
103+
# certain ops. We can potentially make an improvement.
104+
def fill_missing_redistribute_cost(op: torch._ops.OpOverload, op_schema: OpSchema):
105+
"""
106+
Fill missing redistribute cost for strategies.
107+
"""
108+
...

autoparallel/propagation_rules.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
3737
from torch.distributed.tensor.placement_types import Replicate, Shard
3838

39+
from .dtensor_util import get_op_strategy
40+
3941
# TODO: move this to PyTorch
4042
dim_maps[torch.t] = lambda input: dim_transpose(input.ndim, -2, -1)
4143

@@ -582,11 +584,7 @@ def index_rule(mesh, op_schema):
582584
@register_opschema_rule(torch.ops.aten._scaled_dot_product_efficient_attention.default)
583585
def sdpa_rule(mesh, op_schema):
584586
op = torch.ops.aten._scaled_dot_product_efficient_attention.default
585-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
586-
op
587-
](
588-
op_schema
589-
)
587+
out_strat = get_op_strategy(op, op_schema)
590588
# remove wrong context-parallel strategy
591589
# https://github.com/pytorch/pytorch/pull/131351#discussion_r1716164659
592590
new_strats = []
@@ -613,11 +611,7 @@ def sdpa_rule(mesh, op_schema):
613611
@register_opschema_rule(torch.ops.aten.reshape.default)
614612
def reshape_rule(mesh, op_schema):
615613
op = torch.ops.aten.reshape.default
616-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
617-
op
618-
](
619-
op_schema
620-
)
614+
out_strat = get_op_strategy(op, op_schema)
621615
if mesh.ndim == 1:
622616
# remove duplicate strategy
623617
# TODO: hack, fixme
@@ -643,11 +637,7 @@ def expand_rule(mesh, op_schema_):
643637
]
644638
if len(expand_dim) != 1:
645639
assert len(expand_dim) == 0
646-
return torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
647-
op
648-
](
649-
op_schema
650-
)
640+
return get_op_strategy(op, op_schema)
651641
assert len(expand_dim) == 1, f"{expand_dim}"
652642
expand_dim = expand_dim[0]
653643
to_remove = []
@@ -661,11 +651,7 @@ def expand_rule(mesh, op_schema_):
661651
removed = []
662652
for i in reversed(to_remove):
663653
removed.append(input_strat.strategies.pop(i))
664-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
665-
op
666-
](
667-
op_schema
668-
)
654+
out_strat = get_op_strategy(op, op_schema)
669655
for i, ss in enumerate(out_strat.strategies):
670656
for remov in to_remove:
671657
ss.redistribute_cost[0].insert(remov, math.inf)

autoparallel/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
1111
from torch.utils._pytree import tree_flatten, tree_map_only
1212

13+
from .dtensor_util import get_op_strategy
1314
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
1415

1516

@@ -110,11 +111,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs):
110111
if op in _op_partial_rules:
111112
out_strat = _op_partial_rules[op](mesh, op_schema)
112113
else:
113-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
114-
op
115-
](
116-
op_schema
117-
)
114+
out_strat = get_op_strategy(op, op_schema)
118115

119116
propagate_tensor_meta(op, user_args, user_kwargs, out_strat)
120117
fill_missing_redistribute_cost(op, specs, out_strat)

requirements-test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
torch >= 2.7.0
22
numpy
33
pulp
4-
pytest
4+
pytest >= 8.1
5+
expecttest
56

67
black == 22.3.0
78
flake8 == 6.1.0

0 commit comments

Comments
 (0)