@@ -21,6 +21,73 @@ def minifier_dir():
2121    return  f"/tmp/minifier_{ getpass .getuser ()}  
2222
2323
24+ class  NNModuleToString :
25+     safe_reprs  =  [
26+         torch .nn .Linear ,
27+         torch .nn .Conv1d ,
28+         torch .nn .Conv2d ,
29+         torch .nn .Conv3d ,
30+         torch .nn .BatchNorm1d ,
31+         torch .nn .BatchNorm2d ,
32+         torch .nn .BatchNorm3d ,
33+         torch .nn .LayerNorm ,
34+         torch .nn .Dropout ,
35+         torch .nn .Softmax ,
36+     ]
37+ 
38+     @staticmethod  
39+     def  can_convert_to_string (gm ):
40+         cant_convert  =  set ()
41+         for  _ , module  in  gm .named_children ():
42+             if  type (module ) not  in NNModuleToString .safe_reprs :
43+                 cant_convert .update (type (module ))
44+ 
45+         if  len (cant_convert ) >  0 :
46+             logging .warn (
47+                 f"Was not able to save the following children modules as reprs { cant_convert }  
48+             )
49+             return  False 
50+         return  True 
51+ 
52+     @staticmethod  
53+     def  convert (gm ):
54+         from  torch .nn .modules .module  import  _addindent 
55+ 
56+         tab  =  " "  *  4 
57+ 
58+         model_str  =  textwrap .dedent (
59+             """ 
60+             from torch.nn import * 
61+             class Repro(torch.nn.Module): 
62+                 def __init__(self): 
63+                     super().__init__() 
64+             """ 
65+         )
66+ 
67+         for  module_name , module  in  gm .named_children ():
68+             module_str  =  f"{ module .__repr__ ()}  
69+             model_str  +=  f"{ tab * 2 } { module_name } { module_str } \n " 
70+ 
71+         for  buffer_name , buffer  in  gm ._buffers .items ():
72+             if  buffer  is  None :
73+                 continue 
74+             model_str  +=  f"{ tab * 2 } { buffer_name } { list (buffer .shape )} { buffer .dtype } \n " 
75+ 
76+         for  param_name , param  in  gm ._parameters .items ():
77+             if  param  is  None :
78+                 continue 
79+             model_str  +=  f"{ tab * 2 } { param_name } { list (param .shape )} { param .dtype } \n " 
80+ 
81+         attrs  =  dir (gm )
82+         for  attr  in  attrs :
83+             if  "_tensor_constant"  in  attr :
84+                 val  =  getattr (gm , attr )
85+                 model_str  +=  f"    { attr } { val !r} \n " 
86+ 
87+         model_str  +=  f"{ _addindent (gm .code , 4 )} \n " 
88+         return  model_str 
89+ 
90+ 
2491@functools .lru_cache (None )  # subprocess is expensive  
2592def  _cuda_system_info_comment ():
2693    if  not  torch .cuda .is_available ():
@@ -48,7 +115,7 @@ def _cuda_system_info_comment():
48115    return  model_str 
49116
50117
51- def  generate_repro_string (gm , args ):
118+ def  generate_post_aot_repro_string (gm , args ):
52119    model_str  =  textwrap .dedent (
53120        """ 
54121        import torch 
@@ -65,14 +132,7 @@ def generate_repro_string(gm, args):
65132    model_str  +=  f"# torch git version: { torch .version .git_version } \n \n \n " 
66133    model_str  +=  _cuda_system_info_comment ()
67134
68-     model_str  +=  "class Repro(torch.nn.Module):\n " 
69-     attrs  =  dir (gm )
70-     for  attr  in  attrs :
71-         if  "_tensor_constant"  in  attr :
72-             val  =  getattr (gm , attr )
73-             model_str  +=  f"    { attr } { val !r} \n " 
74-     model_str  +=  textwrap .indent (gm .code , "    " )
75-     model_str  +=  "\n " 
135+     model_str  +=  NNModuleToString .convert (gm )
76136
77137    model_str  +=  f"args = { [(tuple (arg .shape ), tuple (arg .stride ()), arg .dtype , arg .device .type ) for  arg  in  args ]!r} \n " 
78138    model_str  +=  "args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]\n " 
@@ -108,7 +168,7 @@ def dump_post_aot_graph_state(gm, args, compiler_name):
108168
109169
110170def  save_graph_repro (fd , gm , args , compiler_name ):
111-     fd .write (generate_repro_string (gm , args ))
171+     fd .write (generate_post_aot_repro_string (gm , args ))
112172    fd .write (COMPILER_REPRO_OPTIONS [compiler_name ][0 ])
113173    fd .write (
114174        textwrap .dedent (
@@ -128,7 +188,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None):
128188        os .makedirs (subdir , exist_ok = True )
129189    file_name  =  os .path .join (subdir , f"{ str (uuid .uuid4 ())[:5 ]}  )
130190    with  open (file_name , "w" ) as  fd :
131-         fd .write (generate_repro_string (fx_g , args ))
191+         fd .write (generate_post_aot_repro_string (fx_g , args ))
132192        fail_fn  =  COMPILER_REPRO_OPTIONS [compiler_name ][2 ]
133193        fd .write (
134194            textwrap .dedent (
@@ -208,7 +268,7 @@ def dump_to_minify(gm, args, compiler_name: str):
208268    with  open (
209269        os .path .join (torchdynamo .config .base_dir , "minifier_launcher.py" ), "w" 
210270    ) as  fd :
211-         fd .write (generate_repro_string (gm , args ))
271+         fd .write (generate_post_aot_repro_string (gm , args ))
212272        fd .write ("\n " )
213273        fd .write (
214274            textwrap .dedent (
@@ -288,7 +348,7 @@ def run_fwd_maybe_bwd(gm, args):
288348        return  out 
289349
290350
291- def  generate_backend_repro_string ( gm , args , compiler_name ):
351+ def  generate_dynamo_fx_repro_string ( model_str , args , compiler_name ):
292352    """ 
293353    Generate a repro string for backend-agnostic minified version. 
294354    """ 
@@ -297,7 +357,10 @@ def generate_backend_repro_string(gm, args, compiler_name):
297357        """ 
298358        import torch 
299359        import torchdynamo 
360+         from torch import tensor, device 
361+         import torch.fx as fx 
300362        from torchdynamo.testing import rand_strided 
363+         from math import inf 
301364        from torchdynamo.debug_utils import run_fwd_maybe_bwd 
302365
303366        """ 
@@ -313,8 +376,7 @@ def generate_backend_repro_string(gm, args, compiler_name):
313376
314377    setup_module  =  textwrap .dedent (
315378        f""" 
316-         import module 
317-         mod = module.ReproModule().cuda() 
379+         mod = Repro().cuda() 
318380        opt_mod = torchdynamo.optimize("{ compiler_name }  
319381
320382        """ 
@@ -329,16 +391,33 @@ def generate_backend_repro_string(gm, args, compiler_name):
329391        """ 
330392    )
331393
332-     return  imports  +  prep_inputs  +  setup_module  +  run_module 
394+     return  imports  +  model_str  +  setup_module   +   prep_inputs  +  run_module 
333395
334396
335- def  dump_dynamo_gm_state (gm , args , compiler_name ):
397+ def  dump_dynamo_gm_to_file (gm , args , compiler_name ):
336398    """ 
337-     Saves the graph module and accompany it with a repro.py script. This is not 
338-     as clean as the wrap_post_aot_debug, where the repro is limited to one file. 
339-     This is because post_aot counterpart works on the functionalized graphs with 
340-     params lifted as graph inputs. Here, we have TorchDynamo produced Fx graphs, 
341-     which we save using to_folder utility. 
399+     Saves the repro to a repro.py file 
400+     """ 
401+     subdir  =  f"{ minifier_dir ()}  
402+     if  not  os .path .exists (subdir ):
403+         os .makedirs (subdir , exist_ok = True )
404+     file_name  =  os .path .join (subdir , f"{ len (gm .graph .nodes )}  )
405+     print (f"Writing checkpoint with { len (gm .graph .nodes )} { file_name }  )
406+ 
407+     model_str  =  NNModuleToString .convert (gm )
408+     with  open (file_name , "w" ) as  fd :
409+         fd .write (generate_dynamo_fx_repro_string (model_str , args , compiler_name ))
410+     repro_path  =  os .path .join (torchdynamo .config .base_dir , "repro.py" )
411+     shutil .copyfile (file_name , repro_path )
412+ 
413+ 
414+ def  dump_dynamo_gm_as_tarfile (gm , args , compiler_name ):
415+     """ 
416+     Saves the repro in repro.tar.gz, as opposed to a file. This is used for 
417+     cases, where we can't convert a Fx GraphModule to a string, and therefore 
418+     fallback to to_folder for serialization. We accompany this with a repro.py 
419+     script that imports the saved module, sets it up and runs the model to repro 
420+     the error. 
342421    """ 
343422    import  tarfile 
344423
@@ -355,7 +434,6 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
355434    gm_dir  =  os .path .join (tmp_dir , "module" )
356435    if  not  os .path .exists (gm_dir ):
357436        os .makedirs (gm_dir , exist_ok = True )
358-     print (f"Writing checkpoint with { len (gm .graph .nodes )} { file_name }  )
359437    for  node  in  gm .graph .nodes :
360438        new_kwargs  =  {}
361439        for  k , v  in  node .kwargs .items ():
@@ -365,19 +443,39 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
365443        node .kwargs  =  new_kwargs 
366444    gm .recompile ()
367445
446+     print (f"Writing checkpoint with { len (gm .graph .nodes )} { file_name }  )
368447    with  open (file_name , "w" ) as  fd :
369448        # TODO - Add the readable version of to_folder when available 
370-         gm .to_folder (gm_dir , "ReproModule" )
371-         fd .write (generate_backend_repro_string (gm , args , compiler_name ))
449+         gm .to_folder (gm_dir , "Repro" )
450+         fd .write (
451+             generate_dynamo_fx_repro_string (
452+                 "from module import Repro" , args , compiler_name 
453+             )
454+         )
455+ 
372456    local_dir  =  os .path .join (torchdynamo .config .base_dir , "repro" )
373457    if  os .path .exists (local_dir ):
374458        shutil .rmtree (local_dir )
375459    shutil .copytree (tmp_dir , local_dir )
376460    local_tar_file  =  os .path .join (torchdynamo .config .base_dir , "repro.tar.gz" )
461+     print (f"Writing checkpoint with { len (gm .graph .nodes )} { local_tar_file }  )
377462    with  tarfile .open (local_tar_file , "w:gz" ) as  tar :
378463        tar .add (local_dir , arcname = os .path .basename (local_dir ))
379464
380465
466+ def  dump_dynamo_gm_state (gm , args , compiler_name ):
467+     """ 
468+     Dumps the dynamo graph to repro the issue. 
469+     1) It tries to convert Fx GraphModule to a string. If we can, it writes to a 
470+     repro.py file. 
471+     2) If we can't convert Fx GraphModule to a string, we use to_folder to save 
472+     the module and save a tar file. 
473+     """ 
474+     if  NNModuleToString .can_convert_to_string (gm ):
475+         return  dump_dynamo_gm_to_file (gm , args , compiler_name )
476+     return  dump_dynamo_gm_as_tarfile (gm , args , compiler_name )
477+ 
478+ 
381479def  backend_fails (gm , example_inputs , compiler_fn , orig_failure ):
382480    """ 
383481    Minifier uses this function to identify if the minified graph module fails 
0 commit comments