11from dataclasses import dataclass
2- from typing import Any , Callable , Dict , Type
2+ from typing import Any , Callable , Dict , Optional , Type , Union
33import torch
44import logging
55
88
99
1010@dataclass (frozen = True )
11- class ModuleReplacement :
11+ class Substitution :
1212 """Class to store key functionality for module replacement"""
1313
1414 # torch.ops.___ name for replacement function for module
1515 new_operator : torch ._ops .OpOverload
1616
17- # Function taking a containing graph, a submodule, and a 'call_module' node and returning
18- # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
17+ # Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
18+ # and returning a replacement node, with type 'call_function', or raising an Error if
19+ # incompatibility is detected
1920 # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
2021 subgraph_insertion_fn : Callable [
21- [torch .fx .GraphModule , torch .nn . Module , torch .fx . Node ], torch .fx .Node
22+ [torch .fx .GraphModule , torch .fx . Node , Optional [ torch .nn . Module ] ], torch .fx .Node
2223 ]
2324
2425
25- # Dictionary mapping module to ModuleReplacement instance
26- MODULE_SUBSTITUTION_REGISTRY : Dict [Type [torch .nn .Module ], ModuleReplacement ] = dict ()
26+ # Dictionary mapping module to Substitution instance
27+ SUBSTITUTION_REGISTRY : Dict [
28+ Union [Type [torch .nn .Module ], Callable ], Substitution
29+ ] = dict ()
2730
2831
29- def module_substitution (
30- module_to_replace : Type [torch .nn .Module ],
32+ def register_substitution (
33+ module_or_function_to_replace : Union [ Type [torch .nn .Module ], Callable ],
3134 new_operator : torch ._ops .OpOverload ,
3235 enabled : bool = True ,
3336) -> Callable [[Any ], Any ]:
3437 """Decorator to register subgraph insertion functions
3538
3639 Args:
37- module_to_replace : nn.Module to replace
40+ module_or_function_to_replace : nn.Module or node target Callable to replace
3841 new_operator: Custom torch operator to replace with
3942 enabled: Whether the substitution is enabled or disabled
4043 Returns:
4144 torch.fx.GraphModule
4245 """
4346
44- def register_substitution (subgraph_insertion_fn ):
47+ def enable_substitution (subgraph_insertion_fn ):
4548 """Function for use if substitution is enabled"""
46- module_replacement = ModuleReplacement (
49+ replacement = Substitution (
4750 new_operator = new_operator , subgraph_insertion_fn = subgraph_insertion_fn
4851 )
49- MODULE_SUBSTITUTION_REGISTRY [ module_to_replace ] = module_replacement
52+ SUBSTITUTION_REGISTRY [ module_or_function_to_replace ] = replacement
5053 return subgraph_insertion_fn
5154
5255 def disable_substitution (subgraph_insertion_fn ):
5356 """Function for use if substitution is disabled"""
5457 return subgraph_insertion_fn
5558
56- return register_substitution if enabled else disable_substitution
59+ return enable_substitution if enabled else disable_substitution
5760
5861
59- def pre_aot_module_replacement (gm : torch .fx .GraphModule ):
60- """Perform module-level graph replacement prior to AOT tracing
62+ def pre_aot_substitutions (gm : torch .fx .GraphModule ):
63+ """Perform graph substitutions prior to AOT tracing
6164
6265 Args:
63- gm: FX GraphModule to perform module replacement on
66+ gm: FX GraphModule to perform substitution on
6467 Returns:
6568 torch.fx.GraphModule
6669
@@ -71,48 +74,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
7174
7275 # Iterate over graph nodes, extracting module calls, to check for interceptions
7376 for n in gm .graph .nodes :
77+ exists_in_registry = False
78+ to_replace = None
79+
7480 if n .op == "call_module" :
75- # Extract submodule from graph
81+ # Extract submodule from graph, validate in registry
7682 submodule = gm .get_submodule (n .target )
77-
78- # If submodule is a member of the substitution registry, replace it
79- if type (submodule ) in MODULE_SUBSTITUTION_REGISTRY :
80-
81- try :
82- replacement = MODULE_SUBSTITUTION_REGISTRY [type (submodule )]
83- op , insertion_fn = (
84- replacement .new_operator ,
85- replacement .subgraph_insertion_fn ,
86- )
87- logger .debug (
88- f"Replacing module of type { type (submodule )} with { op } "
83+ to_replace = type (submodule )
84+ exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
85+ elif n .op == "call_function" :
86+ # Extract function from graph, validate in registry
87+ to_replace = n .target
88+ exists_in_registry = n .target in SUBSTITUTION_REGISTRY
89+
90+ # If submodule/function is a member of the substitution registry, replace it
91+ if exists_in_registry :
92+ try :
93+ replacement = SUBSTITUTION_REGISTRY [to_replace ]
94+ op , insertion_fn = (
95+ replacement .new_operator ,
96+ replacement .subgraph_insertion_fn ,
97+ )
98+ logger .debug (f"Replacing node of type { to_replace } with { op } " )
99+
100+ # Insert new node prior to older node
101+ with gm .graph .inserting_before (n ):
102+ new_node = insertion_fn (
103+ gm , n , submodule if n .op == "call_module" else None
89104 )
90105
91- # Insert new node prior to older node
92- with gm .graph .inserting_before (n ):
93- new_node = insertion_fn (gm , submodule , n )
94-
95- # If submodule is not a native torch.nn module, it must be manually excluded
96- # from Dynamo tracing
97- if not type (submodule ).__module__ .startswith ("torch.nn" ):
98- torch ._dynamo .allowed_functions ._allowed_function_ids .add (
99- id (type (submodule ))
100- )
101-
102- # Replace all original node uses and clean up graph
103- n .replace_all_uses_with (new_node )
104- gm .graph .eliminate_dead_code ()
105- gm .graph .lint ()
106- gm .recompile ()
107-
108- # A module replacement can fail in the event that the specific instance of the submodule cannot
109- # be replaced
110- except Exception :
111- logger .debug (
112- f"Encountered error while replacing { type (submodule )} " ,
113- exc_info = True ,
106+ # If submodule is not a native torch.nn module, it must be manually excluded
107+ # from Dynamo tracing
108+ if n .op == "call_module" and not type (submodule ).__module__ .startswith (
109+ "torch.nn"
110+ ):
111+ torch ._dynamo .allowed_functions ._allowed_function_ids .add (
112+ id (to_replace )
114113 )
115- continue
114+
115+ # Replace all original node uses and clean up graph
116+ n .replace_all_uses_with (new_node )
117+ gm .graph .eliminate_dead_code ()
118+ gm .graph .lint ()
119+ gm .recompile ()
120+
121+ # A replacement can fail in the event that the specific instance of the submodule/function
122+ # cannot be replaced
123+ except Exception :
124+ logger .debug (
125+ f"Encountered error while replacing { to_replace } " ,
126+ exc_info = True ,
127+ )
128+ continue
116129
117130 # Perform cleanup and recompilation before returning module
118131 gm .graph .eliminate_dead_code ()
0 commit comments