55
66
77import logging
8- from contextlib import contextmanager
8+ from contextlib import ExitStack , contextmanager
99from typing import Callable , TypeVar
1010
1111import torch
1212from torch .distributed .tensor import DTensor
13- from torch .distributed .tensor ._op_schema import (
14- OpSchema ,
15- OutputSharding ,
16- RuntimeSchemaInfo ,
17- StrategyType ,
18- )
13+ from torch .distributed .tensor ._op_schema import OpSchema , OutputSharding , StrategyType
14+ from torch .distributed .tensor ._ops .utils import register_op_strategy
1915from typing_extensions import ParamSpec
2016
2117logger = logging .getLogger (__name__ )
2622_P = ParamSpec ("_P" )
2723
2824
25+ # TODO: remove and refer to
26+ # https://github.com/pytorch/pytorch/blob/9c107606629de6383f55e3b48b42e594d23407b1/test/distributed/tensor/test_op_strategy.py#L446
27+ # once the function is moved outside of the test folder in upstream
28+ @contextmanager
29+ def op_strategy_context (op_overload , strategy_func , schema_info = None ):
30+ """
31+ Context manager for setting and clearing op strategies.
32+ Args:
33+ op_overload: The operator overload to set or clear the strategy for.
34+ strategy_func: The strategy function to set for the operator overload.
35+ schema_info: Optional schema information for the operator overload.
36+ Yields:
37+ None
38+ """
39+ propagator = DTensor ._op_dispatcher .sharding_propagator
40+ try :
41+ # register the op strategy
42+ register_op_strategy (op_overload , schema_info = schema_info )(strategy_func )
43+ yield
44+ finally :
45+ # clear this op strategy cache
46+ if op_overload in propagator .op_strategy_funcs :
47+ del propagator .op_strategy_funcs [op_overload ]
48+ if op_overload in propagator .op_to_schema_info :
49+ del propagator .op_to_schema_info [op_overload ]
50+ propagator .propagate_op_sharding .cache .cache_clear ()
51+
52+
2953# -------------define universal op strategy-------------
3054replicate_op_strategy = torch .distributed .tensor ._ops .utils .replicate_op_strategy
3155
@@ -44,9 +68,8 @@ def __init__(self) -> None:
4468 self .op_to_schema_info = (
4569 DTensor ._op_dispatcher .sharding_propagator .op_to_schema_info
4670 )
47-
4871 self .enable_implicit_replication : bool = False
49- self .implicit_strategy_op_tracker : list [ torch . _ops . OpOverload ] = []
72+ self ._current_stack = None
5073
5174 def get_op_strategy (
5275 self , op : torch ._ops .OpOverload , op_schema : OpSchema
@@ -57,54 +80,37 @@ def get_op_strategy(
5780 f"Operator { op } does not have a sharding strategy registered."
5881 )
5982 else :
60- self .implicit_strategy_op_tracker .append (op )
83+ # Use the instance's current stack if available
84+ if self ._current_stack is not None :
85+ self ._current_stack .enter_context (
86+ op_strategy_context (op , replicate_op_strategy )
87+ )
88+ else :
89+ # No stack available, just register permanently
90+ register_op_strategy (op )(replicate_op_strategy )
6191 logger .warning (
62- f"implicitly register sharding strategy op { op . name () } using { replicate_op_strategy .__name__ } "
92+ f"implicitly registering ` { op } ` with ` { replicate_op_strategy .__name__ } ` "
6393 )
64- self .register_op_strategy (op )(replicate_op_strategy )
6594 return self .op_strategy_funcs [op ](op_schema )
6695
67- def register_op_strategy (
68- self ,
69- op : torch ._ops .OpOverload ,
70- schema_info = RuntimeSchemaInfo (needs_pytree = True ),
71- ) -> Callable [[Callable [_P , _T ]], Callable [_P , _T ]]:
72- # pyre-fixme[2]: Parameter must be annotated.
73- # always enable pytree as dispatching overhead is not a concern in AP.
74- def wrapper (impl ):
75- if isinstance (op , list ):
76- overloads = op
77- else :
78- overloads = [op ]
79-
80- for overload in overloads :
81- self .op_strategy_funcs [overload ] = impl
82- self .op_to_schema_info [overload ] = schema_info
83- return impl
84-
85- return wrapper
86-
8796 @contextmanager
88- def replicate_for_unsupported_operators (self ):
89- """
90- Context manager for setting and clearing implicit strategy.
91- """
92- try :
93- if self .enable_implicit_replication :
94- raise RuntimeError (
95- "Implicit strategy is already enabled. Cannot enable it again."
96- )
97+ def with_implicit_strategies (self ):
98+ """Context manager to enable implicit replication and clean up strategies."""
99+ # Create a fresh ExitStack for this context
100+ with ExitStack () as local_stack :
101+ # Store the stack as an instance attribute
102+ old_stack = self ._current_stack
103+ self ._current_stack = local_stack
104+
105+ # Enable implicit replication
106+ old_value = self .enable_implicit_replication
97107 self .enable_implicit_replication = True
98- yield
99- finally :
100- self .enable_implicit_replication = False
101- op_to_remove = self .implicit_strategy_op_tracker
102- for op_overload in op_to_remove :
103- if op_overload in self .op_strategy_funcs :
104- del self .op_strategy_funcs [op_overload ]
105- if op_overload in self .op_to_schema_info :
106- del self .op_to_schema_info [op_overload ]
107- self .implicit_strategy_op_tracker .clear ()
108+ try :
109+ yield
110+ finally :
111+ # Restore the original values
112+ self ._current_stack = old_stack
113+ self .enable_implicit_replication = old_value
108114
109115 # TODO: automatic generate redistribute cost for strategies. There exists a
110116 # `fill_missing_redistribute_cost` in autoparallel/utils.py, which is a hack
0 commit comments