@@ -33,14 +33,16 @@ class NNModuleToString:
3333 torch .nn .LayerNorm ,
3434 torch .nn .Dropout ,
3535 torch .nn .Softmax ,
36+ torch .nn .ReLU ,
37+ torch .nn .MaxPool2d ,
3638 ]
3739
3840 @staticmethod
3941 def can_convert_to_string (gm ):
4042 cant_convert = set ()
4143 for _ , module in gm .named_children ():
4244 if type (module ) not in NNModuleToString .safe_reprs :
43- cant_convert .update ( type ( module ) )
45+ cant_convert .add ( module )
4446
4547 if len (cant_convert ) > 0 :
4648 logging .warn (
@@ -117,7 +119,7 @@ def _cuda_system_info_comment():
117119 return model_str
118120
119121
120- def generate_post_aot_repro_string (gm , args ):
122+ def generate_compiler_repro_string (gm , args ):
121123 model_str = textwrap .dedent (
122124 """
123125 import torch
@@ -157,7 +159,7 @@ def generate_post_aot_repro_string(gm, args):
157159}
158160
159161
160- def dump_post_aot_graph_state (gm , args , compiler_name ):
162+ def dump_compiler_graph_state (gm , args , compiler_name ):
161163 subdir = f"{ minifier_dir ()} /checkpoints"
162164 if not os .path .exists (subdir ):
163165 os .makedirs (subdir , exist_ok = True )
@@ -170,7 +172,7 @@ def dump_post_aot_graph_state(gm, args, compiler_name):
170172
171173
172174def save_graph_repro (fd , gm , args , compiler_name ):
173- fd .write (generate_post_aot_repro_string (gm , args ))
175+ fd .write (generate_compiler_repro_string (gm , args ))
174176 fd .write (COMPILER_REPRO_OPTIONS [compiler_name ][0 ])
175177 fd .write (
176178 textwrap .dedent (
@@ -190,7 +192,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None):
190192 os .makedirs (subdir , exist_ok = True )
191193 file_name = os .path .join (subdir , f"{ str (uuid .uuid4 ())[:5 ]} .py" )
192194 with open (file_name , "w" ) as fd :
193- fd .write (generate_post_aot_repro_string (fx_g , args ))
195+ fd .write (generate_compiler_repro_string (fx_g , args ))
194196 fail_fn = COMPILER_REPRO_OPTIONS [compiler_name ][2 ]
195197 fd .write (
196198 textwrap .dedent (
@@ -270,15 +272,15 @@ def dump_to_minify(gm, args, compiler_name: str):
270272 with open (
271273 os .path .join (torchdynamo .config .base_dir , "minifier_launcher.py" ), "w"
272274 ) as fd :
273- fd .write (generate_post_aot_repro_string (gm , args ))
275+ fd .write (generate_compiler_repro_string (gm , args ))
274276 fd .write ("\n " )
275277 fd .write (
276278 textwrap .dedent (
277279 f"""
278280 from functools import partial
279281 from torchdynamo.debug_utils import (
280282 isolate_fails,
281- dump_post_aot_graph_state ,
283+ dump_compiler_graph_state ,
282284 )
283285 from functorch.compile import minifier
284286
@@ -288,15 +290,15 @@ def dump_to_minify(gm, args, compiler_name: str):
288290 mod,
289291 args,
290292 module_fails=partial(isolate_fails, env=env_variables, compiler_name="{ compiler_name } "),
291- dump_state=partial(dump_post_aot_graph_state , compiler_name="{ compiler_name } "),
293+ dump_state=partial(dump_compiler_graph_state , compiler_name="{ compiler_name } "),
292294 )
293295 """
294296 )
295297 )
296298 print ("wrote out to minifier_launcher.py" )
297299
298300
299- def wrap_post_aot_debug (compiler , compiler_name : str ):
301+ def wrap_compiler_debug (compiler , compiler_name : str ):
300302 """
301303 Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both
302304 forward and backward call separately with the backend compiler - like
@@ -319,7 +321,7 @@ def debug_wrapper(gm, example_inputs, **kwargs):
319321 compiled_fn (* example_inputs )
320322 except Exception as e :
321323 if config .repro_level == 1 :
322- dump_post_aot_graph_state (
324+ dump_compiler_graph_state (
323325 fx .GraphModule (gm , orig_graph ), example_inputs , compiler_name
324326 )
325327 elif config .repro_level == 2 :
@@ -384,7 +386,6 @@ def generate_dynamo_fx_repro_string(model_str, args, compiler_name):
384386 """
385387 )
386388
387- # TODO - Figure out the amp state
388389 run_module = textwrap .dedent (
389390 f"""
390391 with torch.cuda.amp.autocast(enabled={ torch .is_autocast_enabled ()} ):
@@ -396,7 +397,7 @@ def generate_dynamo_fx_repro_string(model_str, args, compiler_name):
396397 return imports + model_str + setup_module + prep_inputs + run_module
397398
398399
399- def dump_dynamo_gm_to_file (gm , args , compiler_name ):
400+ def dump_backend_repro_as_file (gm , args , compiler_name ):
400401 """
401402 Saves the repro to a repro.py file
402403 """
@@ -413,7 +414,7 @@ def dump_dynamo_gm_to_file(gm, args, compiler_name):
413414 shutil .copyfile (file_name , repro_path )
414415
415416
416- def dump_dynamo_gm_as_tarfile (gm , args , compiler_name ):
417+ def dump_backend_repro_as_tarfile (gm , args , compiler_name ):
417418 """
418419 Saves the repro in repro.tar.gz, as opposed to a file. This is used for
419420 cases, where we can't convert a Fx GraphModule to a string, and therefore
@@ -465,7 +466,7 @@ def dump_dynamo_gm_as_tarfile(gm, args, compiler_name):
465466 tar .add (local_dir , arcname = os .path .basename (local_dir ))
466467
467468
468- def dump_dynamo_gm_state (gm , args , compiler_name ):
469+ def dump_backend_state (gm , args , compiler_name ):
469470 """
470471 Dumps the dynamo graph to repro the issue.
471472 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a
@@ -474,8 +475,8 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
474475 the module and save a tar file.
475476 """
476477 if NNModuleToString .can_convert_to_string (gm ):
477- return dump_dynamo_gm_to_file (gm , args , compiler_name )
478- return dump_dynamo_gm_as_tarfile (gm , args , compiler_name )
478+ return dump_backend_repro_as_file (gm , args , compiler_name )
479+ return dump_backend_repro_as_tarfile (gm , args , compiler_name )
479480
480481
481482def backend_fails (gm , example_inputs , compiler_fn , orig_failure ):
@@ -503,10 +504,10 @@ def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
503504 return False
504505
505506
506- def wrap_dynamo_gm_debug (compiler_fn , compiler_name : str ):
507+ def wrap_backend_debug (compiler_fn , compiler_name : str ):
507508 """
508509 A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
509- As opposed to wrap_post_aot_debug , this wrapper intercepts at the
510+ As opposed to wrap_compiler_debug , this wrapper intercepts at the
510511 TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
511512 level, e.g., it is useful for minifying issues related to Aot Autograd
512513 tracing. If an error is found, we minify and save the minified repro in
@@ -517,7 +518,7 @@ def wrap_dynamo_gm_debug(compiler_fn, compiler_name: str):
517518 @functools .wraps (compiler_fn )
518519 def debug_wrapper (gm , example_inputs , ** kwargs ):
519520 compiled_gm = compiler_fn (gm , example_inputs , ** kwargs )
520- if config .dynamo_repro_level > 0 :
521+ if config .backend_repro_level > 0 :
521522 # Ensure that we fail when backend fails
522523 config .raise_on_backend_error = True
523524 try :
@@ -528,15 +529,15 @@ def debug_wrapper(gm, example_inputs, **kwargs):
528529 f"Compiled Fx GraphModule failed with { orig_failure } . Starting minifier."
529530 )
530531 dump_state_fn = functools .partial (
531- dump_dynamo_gm_state , compiler_name = compiler_name
532+ dump_backend_state , compiler_name = compiler_name
532533 )
533- if config .dynamo_repro_level == 1 :
534+ if config .backend_repro_level == 1 :
534535 dump_state_fn (
535536 fx .GraphModule (gm , copy .deepcopy (gm .graph )), example_inputs
536537 )
537538 else :
538539 # As opposed to using dump_to_minify, like we do in
539- # wrap_post_aot_debug , we directly run minifier here. This
540+ # wrap_compiler_debug , we directly run minifier here. This
540541 # is because we can't serialize compiler_fn here.
541542
542543 # The minified version uses
0 commit comments