Skip to content

Commit 78ed7e0

Browse files
committed
Support of explicit fallback
[ghstack-poisoned]
1 parent b53ad10 commit 78ed7e0

File tree

6 files changed

+536
-26
lines changed

6 files changed

+536
-26
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from . import utils
8+
9+
strategy_pool = utils.StrategyPool()

autoparallel/dtensor_util/utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import logging
8+
from contextlib import contextmanager
9+
from typing import Callable, TypeVar
10+
11+
import torch
12+
from torch.distributed.tensor import DTensor
13+
from torch.distributed.tensor._op_schema import (
14+
OpSchema,
15+
OutputSharding,
16+
RuntimeSchemaInfo,
17+
StrategyType,
18+
)
19+
from typing_extensions import ParamSpec
20+
21+
logger = logging.getLogger(__name__)
22+
23+
aten = torch.ops.aten
24+
25+
_T = TypeVar("_T")
26+
_P = ParamSpec("_P")
27+
28+
29+
# -------------define universal op strategy-------------
30+
replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy
31+
32+
33+
class StrategyPool:
34+
def __init__(self) -> None:
35+
# reference to existing strategy from the DTensor upstream
36+
self.op_strategy_funcs: dict[
37+
torch._ops.OpOverload, Callable[[OpSchema], StrategyType]
38+
] = DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs
39+
# reference to existing rules
40+
self.op_to_rules: dict[
41+
torch._ops.OpOverload, Callable[[OpSchema], OutputSharding]
42+
] = DTensor._op_dispatcher.sharding_propagator.op_to_rules
43+
# we probably don't need to care about existing op_to_schema_info for AP
44+
self.op_to_schema_info = (
45+
DTensor._op_dispatcher.sharding_propagator.op_to_schema_info
46+
)
47+
48+
self.enable_implicit_replication: bool = False
49+
self.implicit_strategy_op_tracker: list[torch._ops.OpOverload] = []
50+
51+
def get_op_strategy(
52+
self, op: torch._ops.OpOverload, op_schema: OpSchema
53+
) -> StrategyType:
54+
if op not in self.op_strategy_funcs:
55+
if not self.enable_implicit_replication:
56+
raise NotImplementedError(
57+
f"Operator {op} does not have a sharding strategy registered."
58+
)
59+
else:
60+
self.implicit_strategy_op_tracker.append(op)
61+
logger.warning(
62+
f"implicitly register sharding strategy op {op.name()} using {replicate_op_strategy.__name__}"
63+
)
64+
self.register_op_strategy(op)(replicate_op_strategy)
65+
return self.op_strategy_funcs[op](op_schema)
66+
67+
def register_op_strategy(
68+
self,
69+
op: torch._ops.OpOverload,
70+
schema_info=RuntimeSchemaInfo(needs_pytree=True),
71+
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
72+
# pyre-fixme[2]: Parameter must be annotated.
73+
# always enable pytree as dispatching overhead is not a concern in AP.
74+
def wrapper(impl):
75+
if isinstance(op, list):
76+
overloads = op
77+
else:
78+
overloads = [op]
79+
80+
for overload in overloads:
81+
self.op_strategy_funcs[overload] = impl
82+
self.op_to_schema_info[overload] = schema_info
83+
return impl
84+
85+
return wrapper
86+
87+
@contextmanager
88+
def replicate_for_unsupported_operators(self):
89+
"""
90+
Context manager for setting and clearing implicit strategy.
91+
"""
92+
try:
93+
if self.enable_implicit_replication:
94+
raise RuntimeError(
95+
"Implicit strategy is already enabled. Cannot enable it again."
96+
)
97+
self.enable_implicit_replication = True
98+
yield
99+
finally:
100+
self.enable_implicit_replication = False
101+
op_to_remove = self.implicit_strategy_op_tracker
102+
for op_overload in op_to_remove:
103+
if op_overload in self.op_strategy_funcs:
104+
del self.op_strategy_funcs[op_overload]
105+
if op_overload in self.op_to_schema_info:
106+
del self.op_to_schema_info[op_overload]
107+
self.implicit_strategy_op_tracker.clear()
108+
109+
# TODO: automatic generate redistribute cost for strategies. There exists a
110+
# `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack
111+
# to generate redistribute cost given input specs, and only tested on
112+
# certain ops. We can potentially make an improvement.
113+
def fill_missing_redistribute_cost(
114+
self, op: torch._ops.OpOverload, op_schema: OpSchema
115+
):
116+
"""
117+
Fill missing redistribute cost for strategies.
118+
"""
119+
...

autoparallel/propagation_rules.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
3737
from torch.distributed.tensor.placement_types import Replicate, Shard
3838

39+
from .dtensor_util import strategy_pool
40+
3941
# TODO: move this to PyTorch
4042
dim_maps[torch.t] = lambda input: dim_transpose(input.ndim, -2, -1)
4143

@@ -582,11 +584,7 @@ def index_rule(mesh, op_schema):
582584
@register_opschema_rule(torch.ops.aten._scaled_dot_product_efficient_attention.default)
583585
def sdpa_rule(mesh, op_schema):
584586
op = torch.ops.aten._scaled_dot_product_efficient_attention.default
585-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
586-
op
587-
](
588-
op_schema
589-
)
587+
out_strat = strategy_pool.get_op_strategy(op, op_schema)
590588
# remove wrong context-parallel strategy
591589
# https://github.com/pytorch/pytorch/pull/131351#discussion_r1716164659
592590
new_strats = []
@@ -613,11 +611,7 @@ def sdpa_rule(mesh, op_schema):
613611
@register_opschema_rule(torch.ops.aten.reshape.default)
614612
def reshape_rule(mesh, op_schema):
615613
op = torch.ops.aten.reshape.default
616-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
617-
op
618-
](
619-
op_schema
620-
)
614+
out_strat = strategy_pool.get_op_strategy(op, op_schema)
621615
if mesh.ndim == 1:
622616
# remove duplicate strategy
623617
# TODO: hack, fixme
@@ -643,11 +637,7 @@ def expand_rule(mesh, op_schema_):
643637
]
644638
if len(expand_dim) != 1:
645639
assert len(expand_dim) == 0
646-
return torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
647-
op
648-
](
649-
op_schema
650-
)
640+
return strategy_pool.get_op_strategy(op, op_schema)
651641
assert len(expand_dim) == 1, f"{expand_dim}"
652642
expand_dim = expand_dim[0]
653643
to_remove = []
@@ -661,11 +651,7 @@ def expand_rule(mesh, op_schema_):
661651
removed = []
662652
for i in reversed(to_remove):
663653
removed.append(input_strat.strategies.pop(i))
664-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
665-
op
666-
](
667-
op_schema
668-
)
654+
out_strat = strategy_pool.get_op_strategy(op, op_schema)
669655
for i, ss in enumerate(out_strat.strategies):
670656
for remov in to_remove:
671657
ss.redistribute_cost[0].insert(remov, math.inf)

autoparallel/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
1111
from torch.utils._pytree import tree_flatten, tree_map_only
1212

13+
from .dtensor_util import strategy_pool
1314
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
1415

1516

@@ -110,11 +111,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs):
110111
if op in _op_partial_rules:
111112
out_strat = _op_partial_rules[op](mesh, op_schema)
112113
else:
113-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
114-
op
115-
](
116-
op_schema
117-
)
114+
out_strat = strategy_pool.get_op_strategy(op, op_schema)
118115

119116
propagate_tensor_meta(op, user_args, user_kwargs, out_strat)
120117
fill_missing_redistribute_cost(op, specs, out_strat)

requirements-test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
torch >= 2.7.0
22
numpy
33
pulp
4-
pytest
4+
pytest >= 8.1
5+
expecttest
56

67
black == 22.3.0
78
flake8 == 6.1.0

0 commit comments

Comments
 (0)