|
3 | 3 | # This source code is licensed under the BSD license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | | - |
7 | 6 | import logging |
8 | 7 | from contextlib import ExitStack, contextmanager |
9 | | -from typing import Callable, TypeVar |
10 | 8 |
|
11 | 9 | import torch |
12 | 10 | from torch.distributed.tensor import DTensor |
13 | | -from torch.distributed.tensor._op_schema import OpSchema, OutputSharding, StrategyType |
| 11 | +from torch.distributed.tensor._op_schema import OpSchema, StrategyType |
14 | 12 | from torch.distributed.tensor._ops.utils import register_op_strategy |
15 | | -from typing_extensions import ParamSpec |
16 | 13 |
|
17 | 14 | logger = logging.getLogger(__name__) |
18 | 15 |
|
19 | 16 | aten = torch.ops.aten |
20 | 17 |
|
21 | | -_T = TypeVar("_T") |
22 | | -_P = ParamSpec("_P") |
| 18 | +# reference to existing sharding_propagator DTensor upstream |
| 19 | +propagator = DTensor._op_dispatcher.sharding_propagator |
| 20 | + |
| 21 | +enable_implicit_replication = False |
| 22 | +_current_stack = None |
| 23 | + |
| 24 | +replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy |
23 | 25 |
|
24 | 26 |
|
25 | 27 | # TODO: remove and refer to |
@@ -50,76 +52,57 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): |
50 | 52 | propagator.propagate_op_sharding.cache.cache_clear() |
51 | 53 |
|
52 | 54 |
|
53 | | -# -------------define universal op strategy------------- |
54 | | -replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy |
55 | | - |
| 55 | +def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType: |
| 56 | + global enable_implicit_replication, _current_stack |
56 | 57 |
|
57 | | -class StrategyPool: |
58 | | - def __init__(self) -> None: |
59 | | - # reference to existing strategy from the DTensor upstream |
60 | | - self.op_strategy_funcs: dict[ |
61 | | - torch._ops.OpOverload, Callable[[OpSchema], StrategyType] |
62 | | - ] = DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs |
63 | | - # reference to existing rules |
64 | | - self.op_to_rules: dict[ |
65 | | - torch._ops.OpOverload, Callable[[OpSchema], OutputSharding] |
66 | | - ] = DTensor._op_dispatcher.sharding_propagator.op_to_rules |
67 | | - # we probably don't need to care about existing op_to_schema_info for AP |
68 | | - self.op_to_schema_info = ( |
69 | | - DTensor._op_dispatcher.sharding_propagator.op_to_schema_info |
70 | | - ) |
71 | | - self.enable_implicit_replication: bool = False |
72 | | - self._current_stack = None |
73 | | - |
74 | | - def get_op_strategy( |
75 | | - self, op: torch._ops.OpOverload, op_schema: OpSchema |
76 | | - ) -> StrategyType: |
77 | | - if op not in self.op_strategy_funcs: |
78 | | - if not self.enable_implicit_replication: |
79 | | - raise NotImplementedError( |
80 | | - f"Operator {op} does not have a sharding strategy registered." |
| 58 | + if op not in propagator.op_strategy_funcs: |
| 59 | + if not enable_implicit_replication: |
| 60 | + raise NotImplementedError( |
| 61 | + f"Operator {op} does not have a sharding strategy registered." |
| 62 | + ) |
| 63 | + else: |
| 64 | + # Use the current stack if available |
| 65 | + if _current_stack is not None: |
| 66 | + _current_stack.enter_context( |
| 67 | + op_strategy_context(op, replicate_op_strategy) |
81 | 68 | ) |
82 | 69 | else: |
83 | | - # Use the instance's current stack if available |
84 | | - if self._current_stack is not None: |
85 | | - self._current_stack.enter_context( |
86 | | - op_strategy_context(op, replicate_op_strategy) |
87 | | - ) |
88 | | - else: |
89 | | - # No stack available, just register permanently |
90 | | - register_op_strategy(op)(replicate_op_strategy) |
91 | | - logger.warning( |
92 | | - f"implicitly registering `{op}` with `{replicate_op_strategy.__name__}`" |
93 | | - ) |
94 | | - return self.op_strategy_funcs[op](op_schema) |
95 | | - |
96 | | - @contextmanager |
97 | | - def with_implicit_strategies(self): |
98 | | - """Context manager to enable implicit replication and clean up strategies.""" |
99 | | - # Create a fresh ExitStack for this context |
100 | | - with ExitStack() as local_stack: |
101 | | - # Store the stack as an instance attribute |
102 | | - old_stack = self._current_stack |
103 | | - self._current_stack = local_stack |
104 | | - |
105 | | - # Enable implicit replication |
106 | | - old_value = self.enable_implicit_replication |
107 | | - self.enable_implicit_replication = True |
108 | | - try: |
109 | | - yield |
110 | | - finally: |
111 | | - # Restore the original values |
112 | | - self._current_stack = old_stack |
113 | | - self.enable_implicit_replication = old_value |
114 | | - |
115 | | - # TODO: automatic generate redistribute cost for strategies. There exists a |
116 | | - # `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack |
117 | | - # to generate redistribute cost given input specs, and only tested on |
118 | | - # certain ops. We can potentially make an improvement. |
119 | | - def fill_missing_redistribute_cost( |
120 | | - self, op: torch._ops.OpOverload, op_schema: OpSchema |
121 | | - ): |
122 | | - """ |
123 | | - Fill missing redistribute cost for strategies. |
124 | | - """ |
125 | | - ... |
| 70 | + # No stack available, just register permanently |
| 71 | + register_op_strategy(op)(replicate_op_strategy) |
| 72 | + logger.warning( |
| 73 | + f"implicitly registering `{op}` with `{replicate_op_strategy.__name__}`" |
| 74 | + ) |
| 75 | + return propagator.op_strategy_funcs[op](op_schema) |
| 76 | + |
| 77 | + |
| 78 | +@contextmanager |
| 79 | +def with_implicit_strategies(): |
| 80 | + """Context manager to enable implicit replication and clean up strategies.""" |
| 81 | + global enable_implicit_replication, _current_stack |
| 82 | + |
| 83 | + # Create a fresh ExitStack for this context |
| 84 | + with ExitStack() as local_stack: |
| 85 | + # Store the stack as a global variable |
| 86 | + old_stack = _current_stack |
| 87 | + _current_stack = local_stack |
| 88 | + |
| 89 | + # Enable implicit replication |
| 90 | + old_value = enable_implicit_replication |
| 91 | + enable_implicit_replication = True |
| 92 | + try: |
| 93 | + yield |
| 94 | + finally: |
| 95 | + # Restore the original values |
| 96 | + _current_stack = old_stack |
| 97 | + enable_implicit_replication = old_value |
| 98 | + |
| 99 | + |
| 100 | +# TODO: automatic generate redistribute cost for strategies. There exists a |
| 101 | +# `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack |
| 102 | +# to generate redistribute cost given input specs, and only tested on |
| 103 | +# certain ops. We can potentially make an improvement. |
| 104 | +def fill_missing_redistribute_cost(op: torch._ops.OpOverload, op_schema: OpSchema): |
| 105 | + """ |
| 106 | + Fill missing redistribute cost for strategies. |
| 107 | + """ |
| 108 | + ... |
0 commit comments