55import os
66import pkgutil
77import traceback
8- from functools import partial
8+ import types
99from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple , Union
1010
1111import torch
1919class TracingHook (abc .ABC ):
2020 """钩子的抽象基类"""
2121
22- def __init__ (self , serializer : "ConfigSerializer" , level : int ):
22+ def __init__ (
23+ self ,
24+ serializer : "ConfigSerializer" ,
25+ level : int ,
26+ dialect : Optional ["FrameworkDialect" ] = None ,
27+ ):
2328 self .serializer = serializer
2429 self .level = level
30+ self .dialect = dialect
2531
2632 @abc .abstractmethod
2733 def install (self ):
@@ -40,8 +46,7 @@ def __init__(
4046 level : int ,
4147 dialect : "FrameworkDialect" ,
4248 ):
43- super ().__init__ (serializer , level )
44- self .dialect = dialect
49+ super ().__init__ (serializer , level , dialect )
4550 self ._original_apis : Dict [str , Any ] = {}
4651 self ._module_cache : Dict [str , Any ] = {}
4752
@@ -85,9 +90,12 @@ def wrapper(*args, **kwargs):
8590 return wrapper
8691
8792 def install (self ):
93+ if self .dialect is None :
94+ return
8895 api_list = self .dialect .discover_apis () + self .dialect .discover_custom_ops ()
8996
90- # with open(os.path.join(os.path.dirname(__file__), "trace_output", "api_list.yaml"), "w") as f:
97+ # api_path = os.path.join(os.path.dirname(__file__), "trace_output/api_list.yaml")
98+ # with open(api_path, "w") as f:
9199 # yaml.dump(api_list, f)
92100
93101 print (f"[SetattrHook] Attempting to patch { len (api_list )} APIs..." )
@@ -106,29 +114,40 @@ def install(self):
106114 original_api = getattr (parent_obj , func_name )
107115 wrapper = None
108116
109- if isinstance (original_api , property ):
110- if original_api .fget and original_api .fset :
111- wrapped_getter = self ._create_wrapper (
112- f"{ api_name } .fget" ,
113- original_api .fget ,
114- self .serializer ,
115- self .level ,
116- )
117- wrapper = property (
118- wrapped_getter ,
119- original_api .fset ,
120- original_api .fdel ,
121- original_api .__doc__ ,
122- )
117+ if isinstance (
118+ original_api ,
119+ (
120+ types .FunctionType ,
121+ types .BuiltinFunctionType ,
122+ types .MethodType ,
123+ types .BuiltinMethodType ,
124+ ),
125+ ):
126+ wrapped_func = self ._create_wrapper (
127+ api_name , original_api , self .serializer , self .level
128+ )
123129 elif isinstance (original_api , (classmethod , staticmethod )):
124130 original_func = original_api .__func__
125131 wrapped_func = self ._create_wrapper (
126132 api_name , original_func , self .serializer , self .level
127133 )
128134 wrapper = type (original_api )(wrapped_func )
129- elif callable (original_api ):
130- wrapper = self ._create_wrapper (
131- api_name , original_api , self .serializer , self .level
135+ elif (
136+ isinstance (original_api , property )
137+ and original_api .fget
138+ and original_api .fset
139+ ):
140+ wrapped_getter = self ._create_wrapper (
141+ f"{ api_name } .fget" ,
142+ original_api .fget ,
143+ self .serializer ,
144+ self .level ,
145+ )
146+ wrapper = property (
147+ wrapped_getter ,
148+ original_api .fset ,
149+ original_api .fdel ,
150+ original_api .__doc__ ,
132151 )
133152
134153 if wrapper :
@@ -153,7 +172,8 @@ def install(self):
153172 f"[SetattrHook] Patched { patched_count } APIs. Skipped { skipped_count } non-writable APIs."
154173 )
155174
156- # with open(os.path.join(os.path.dirname(__file__), "trace_output", "api_list_wrap.yaml"), "w") as f:
175+ # api_path = os.path.join(os.path.dirname(__file__), "trace_output/api_list_wrap.yaml")
176+ # with open(api_path, "w") as f:
157177 # yaml.dump(list(self._original_apis.keys()), f)
158178
159179 def uninstall (self ):
@@ -171,9 +191,13 @@ def uninstall(self):
171191
172192
173193class TorchFunctionModeTracer (torch .overrides .TorchFunctionMode ):
174- def __init__ (self , serializer : "ConfigSerializer" , level : int ):
194+ def __init__ (
195+ self , serializer : "ConfigSerializer" , level : int , dialect : "FrameworkDialect"
196+ ):
175197 self .serializer = serializer
176198 self .level = level
199+ self .disable_torch_api_list = getattr (dialect , "disable_torch_api_list" , False )
200+ self .target_apis = getattr (dialect , "target_apis" , [])
177201
178202 # skip these for duplicate property access of paddle.Tensor in SetattrHook
179203 # (SetattrHook and TorchFunctionHook are installed at the same time)
@@ -209,9 +233,11 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
209233
210234
211235class TorchFunctionHook (TracingHook ):
212- def __init__ (self , serializer : "ConfigSerializer" , level : int ):
213- super ().__init__ (serializer , level )
214- self .tracing_mode = TorchFunctionModeTracer (serializer , level )
236+ def __init__ (
237+ self , serializer : "ConfigSerializer" , level : int , dialect : "FrameworkDialect"
238+ ):
239+ super ().__init__ (serializer , level , dialect )
240+ self .tracing_mode = TorchFunctionModeTracer (serializer , level , dialect )
215241
216242 def install (self ):
217243 print (f"[TorchFunctionHook] Enabling __torch_function__ tracing mode..." )
@@ -225,9 +251,13 @@ def uninstall(self):
225251
226252
227253class TorchDispatchModeTracer (TorchDispatchMode ):
228- def __init__ (self , serializer : "ConfigSerializer" , level : int ):
254+ def __init__ (
255+ self , serializer : "ConfigSerializer" , level : int , dialect : "FrameworkDialect"
256+ ):
229257 self .serializer = serializer
230258 self .level = level
259+ self .disable_torch_api_list = getattr (dialect , "disable_torch_api_list" , False )
260+ self .target_apis = getattr (dialect , "target_apis" , [])
231261
232262 def __torch_dispatch__ (self , func , types , args = (), kwargs = None ):
233263 kwargs = kwargs or {}
@@ -240,9 +270,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
240270
241271
242272class TorchDispatchHook (TracingHook ):
243- def __init__ (self , serializer : "ConfigSerializer" , level : int ):
244- super ().__init__ (serializer , level )
245- self .tracing_mode = TorchDispatchModeTracer (serializer , level )
273+ def __init__ (
274+ self , serializer : "ConfigSerializer" , level : int , dialect : "FrameworkDialect"
275+ ):
276+ super ().__init__ (serializer , level , dialect )
277+ self .tracing_mode = TorchDispatchModeTracer (serializer , level , dialect )
246278
247279 def install (self ):
248280 print (f"[TorchDispatchHook] Enabling __torch_dispatch__ tracing mode..." )
@@ -394,13 +426,14 @@ class PyTorchDialect(FrameworkDialect):
394426 # recommended to skip
395427 "__call__" ,
396428 "__format__" ,
397- "__instancecheck__" ,
398429 "__iter__" ,
399430 "__repr__" ,
400431 "__str__" ,
432+ "__instancecheck__" ,
401433 "__subclasscheck__" ,
402434 "__subclasshook__" ,
403- # optional to skip
435+ "__getstate__" ,
436+ "__setstate__" ,
404437 "__enter__" ,
405438 "__exit__" ,
406439 }
@@ -416,10 +449,11 @@ class PyTorchDialect(FrameworkDialect):
416449 "torch.TypedStorage" ,
417450 # methods
418451 "torch.autograd.function._is_setup_context_defined" ,
452+ "torch.distributed.reduce_op" ,
419453 "torch.fx.experimental.unification.multipledispatch.dispatcher.str_signature" ,
420454 "torch.nn.functional.handle_torch_function" ,
421455 "torch.nn.functional.has_torch_function_unary" ,
422- "torch.distributed.reduce_op " ,
456+ "torch.optim.Optimizer.profile_hook_step " ,
423457 }
424458
425459 def get_framework_name (self ) -> str :
@@ -480,7 +514,17 @@ def discover_apis(self) -> List[str]:
480514 continue
481515 if full_name in self .IGNORE_CLASSES_OR_METHODS :
482516 continue
483- if callable (obj ) and not inspect .isclass (obj ):
517+ if isinstance (
518+ obj ,
519+ (
520+ types .FunctionType ,
521+ types .BuiltinFunctionType ,
522+ types .MethodType ,
523+ types .BuiltinMethodType ,
524+ staticmethod ,
525+ classmethod ,
526+ ),
527+ ):
484528 api_set .add (full_name )
485529 elif inspect .isclass (obj ):
486530 # custom op class should be skip
@@ -490,23 +534,22 @@ def discover_apis(self) -> List[str]:
490534 if cls_member_name in self .IGNORE_ATTRIBUTES :
491535 continue
492536 full_cls_name = f"{ full_name } .{ cls_member_name } "
493- if inspect .ismethod (cls_member ) or inspect .isfunction (
494- cls_member
537+ if isinstance (
538+ cls_member ,
539+ (
540+ types .FunctionType ,
541+ types .BuiltinFunctionType ,
542+ types .MethodType ,
543+ types .BuiltinMethodType ,
544+ staticmethod ,
545+ classmethod ,
546+ ),
495547 ):
496548 api_set .add (full_cls_name )
497- elif isinstance (cls_member , (staticmethod , classmethod )):
498- api_set .add (full_cls_name )
499- elif isinstance (cls_member , property ):
500- if cls_member .fget and cls_member .fset :
501- api_set .add (full_cls_name )
502- elif isinstance (cls_member , partial ):
503- if hasattr (
504- cls_member .func , "__module__"
505- ) and cls_member .func .__module__ .startswith ("torch" ):
506- api_set .add (full_cls_name )
507549 elif (
508- hasattr (cls_member , "__isabstractmethod__" )
509- and cls_member .__isabstractmethod__
550+ isinstance (cls_member , property )
551+ and cls_member .fget
552+ and cls_member .fset
510553 ):
511554 api_set .add (full_cls_name )
512555 except Exception as e :
@@ -584,10 +627,7 @@ def get_hooks(self, serializer, levels: List[int], **kwargs) -> List[TracingHook
584627 for level in levels :
585628 hook_class = hook_map .get (level )
586629 if hook_class :
587- if level == 0 :
588- hooks .append (hook_class (serializer , level , self ))
589- else :
590- hooks .append (hook_class (serializer , level ))
630+ hooks .append (hook_class (serializer , level , self ))
591631 else :
592632 raise ValueError (f"Invalid level: { level } " )
593633 return hooks
0 commit comments