@@ -21,6 +21,75 @@ def minifier_dir():
2121 return f"/tmp/minifier_{ getpass .getuser ()} "
2222
2323
24+ class NNModuleToString :
25+ safe_reprs = [
26+ torch .nn .Linear ,
27+ torch .nn .Conv1d ,
28+ torch .nn .Conv2d ,
29+ torch .nn .Conv3d ,
30+ torch .nn .BatchNorm1d ,
31+ torch .nn .BatchNorm2d ,
32+ torch .nn .BatchNorm3d ,
33+ torch .nn .LayerNorm ,
34+ torch .nn .Dropout ,
35+ torch .nn .Softmax ,
36+ ]
37+
38+ @staticmethod
39+ def can_convert_to_string (gm ):
40+ cant_convert = set ()
41+ for _ , module in gm .named_children ():
42+ if type (module ) not in NNModuleToString .safe_reprs :
43+ cant_convert .update (type (module ))
44+
45+ if len (cant_convert ) > 0 :
46+ logging .warn (
47+ f"Was not able to save the following children modules as reprs { cant_convert } "
48+ )
49+ return False
50+ return True
51+
52+ @staticmethod
53+ def convert (gm ):
54+ from torch .nn .modules .module import _addindent
55+
56+ tab = " " * 4
57+
58+ model_str = textwrap .dedent (
59+ """
60+ from torch.nn import *
61+ class Repro(torch.nn.Module):
62+ def __init__(self):
63+ super().__init__()
64+ """
65+ )
66+
67+ for module_name , module in gm .named_children ():
68+ module_str = f"{ module .__repr__ ()} "
69+ model_str += f"{ tab * 2 } self.{ module_name } = { module_str } \n "
70+
71+ for buffer_name , buffer in gm ._buffers .items ():
72+ if buffer is None :
73+ continue
74+ tensor_str = f"torch.randn({ list (buffer .shape )} , dtype={ buffer .dtype } )"
75+ model_str += f"{ tab * 2 } self.register_buffer('{ buffer_name } ', { tensor_str } )\n "
76+
77+ for param_name , param in gm ._parameters .items ():
78+ if param is None :
79+ continue
80+ tensor_str = f"torch.nn.Parameter(torch.randn({ list (param .shape )} , dtype={ param .dtype } ))"
81+ model_str += f"{ tab * 2 } self.{ param_name } = { tensor_str } \n "
82+
83+ attrs = dir (gm )
84+ for attr in attrs :
85+ if "_tensor_constant" in attr :
86+ val = getattr (gm , attr )
87+ model_str += f" { attr } = { val !r} \n "
88+
89+ model_str += f"{ _addindent (gm .code , 4 )} \n "
90+ return model_str
91+
92+
2493@functools .lru_cache (None ) # subprocess is expensive
2594def _cuda_system_info_comment ():
2695 if not torch .cuda .is_available ():
@@ -48,7 +117,7 @@ def _cuda_system_info_comment():
48117 return model_str
49118
50119
51- def generate_repro_string (gm , args ):
120+ def generate_post_aot_repro_string (gm , args ):
52121 model_str = textwrap .dedent (
53122 """
54123 import torch
@@ -65,14 +134,7 @@ def generate_repro_string(gm, args):
65134 model_str += f"# torch git version: { torch .version .git_version } \n \n \n "
66135 model_str += _cuda_system_info_comment ()
67136
68- model_str += "class Repro(torch.nn.Module):\n "
69- attrs = dir (gm )
70- for attr in attrs :
71- if "_tensor_constant" in attr :
72- val = getattr (gm , attr )
73- model_str += f" { attr } = { val !r} \n "
74- model_str += textwrap .indent (gm .code , " " )
75- model_str += "\n "
137+ model_str += NNModuleToString .convert (gm )
76138
77139 model_str += f"args = { [(tuple (arg .shape ), tuple (arg .stride ()), arg .dtype , arg .device .type ) for arg in args ]!r} \n "
78140 model_str += "args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]\n "
@@ -108,7 +170,7 @@ def dump_post_aot_graph_state(gm, args, compiler_name):
108170
109171
110172def save_graph_repro (fd , gm , args , compiler_name ):
111- fd .write (generate_repro_string (gm , args ))
173+ fd .write (generate_post_aot_repro_string (gm , args ))
112174 fd .write (COMPILER_REPRO_OPTIONS [compiler_name ][0 ])
113175 fd .write (
114176 textwrap .dedent (
@@ -128,7 +190,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None):
128190 os .makedirs (subdir , exist_ok = True )
129191 file_name = os .path .join (subdir , f"{ str (uuid .uuid4 ())[:5 ]} .py" )
130192 with open (file_name , "w" ) as fd :
131- fd .write (generate_repro_string (fx_g , args ))
193+ fd .write (generate_post_aot_repro_string (fx_g , args ))
132194 fail_fn = COMPILER_REPRO_OPTIONS [compiler_name ][2 ]
133195 fd .write (
134196 textwrap .dedent (
@@ -208,7 +270,7 @@ def dump_to_minify(gm, args, compiler_name: str):
208270 with open (
209271 os .path .join (torchdynamo .config .base_dir , "minifier_launcher.py" ), "w"
210272 ) as fd :
211- fd .write (generate_repro_string (gm , args ))
273+ fd .write (generate_post_aot_repro_string (gm , args ))
212274 fd .write ("\n " )
213275 fd .write (
214276 textwrap .dedent (
@@ -288,7 +350,7 @@ def run_fwd_maybe_bwd(gm, args):
288350 return out
289351
290352
291- def generate_backend_repro_string ( gm , args , compiler_name ):
353+ def generate_dynamo_fx_repro_string ( model_str , args , compiler_name ):
292354 """
293355 Generate a repro string for backend-agnostic minified version.
294356 """
@@ -297,7 +359,10 @@ def generate_backend_repro_string(gm, args, compiler_name):
297359 """
298360 import torch
299361 import torchdynamo
362+ from torch import tensor, device
363+ import torch.fx as fx
300364 from torchdynamo.testing import rand_strided
365+ from math import inf
301366 from torchdynamo.debug_utils import run_fwd_maybe_bwd
302367
303368 """
@@ -313,8 +378,7 @@ def generate_backend_repro_string(gm, args, compiler_name):
313378
314379 setup_module = textwrap .dedent (
315380 f"""
316- import module
317- mod = module.ReproModule().cuda()
381+ mod = Repro().cuda()
318382 opt_mod = torchdynamo.optimize("{ compiler_name } ")(mod)
319383
320384 """
@@ -329,16 +393,33 @@ def generate_backend_repro_string(gm, args, compiler_name):
329393 """
330394 )
331395
332- return imports + prep_inputs + setup_module + run_module
396+ return imports + model_str + setup_module + prep_inputs + run_module
333397
334398
335- def dump_dynamo_gm_state (gm , args , compiler_name ):
399+ def dump_dynamo_gm_to_file (gm , args , compiler_name ):
336400 """
337- Saves the graph module and accompany it with a repro.py script. This is not
338- as clean as the wrap_post_aot_debug, where the repro is limited to one file.
339- This is because post_aot counterpart works on the functionalized graphs with
340- params lifted as graph inputs. Here, we have TorchDynamo produced Fx graphs,
341- which we save using to_folder utility.
401+ Saves the repro to a repro.py file
402+ """
403+ subdir = f"{ minifier_dir ()} /checkpoints"
404+ if not os .path .exists (subdir ):
405+ os .makedirs (subdir , exist_ok = True )
406+ file_name = os .path .join (subdir , f"{ len (gm .graph .nodes )} .py" )
407+ print (f"Writing checkpoint with { len (gm .graph .nodes )} nodes to { file_name } " )
408+
409+ model_str = NNModuleToString .convert (gm )
410+ with open (file_name , "w" ) as fd :
411+ fd .write (generate_dynamo_fx_repro_string (model_str , args , compiler_name ))
412+ repro_path = os .path .join (torchdynamo .config .base_dir , "repro.py" )
413+ shutil .copyfile (file_name , repro_path )
414+
415+
416+ def dump_dynamo_gm_as_tarfile (gm , args , compiler_name ):
417+ """
418+ Saves the repro in repro.tar.gz, as opposed to a file. This is used for
419+ cases, where we can't convert a Fx GraphModule to a string, and therefore
420+ fallback to to_folder for serialization. We accompany this with a repro.py
421+ script that imports the saved module, sets it up and runs the model to repro
422+ the error.
342423 """
343424 import tarfile
344425
@@ -355,7 +436,6 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
355436 gm_dir = os .path .join (tmp_dir , "module" )
356437 if not os .path .exists (gm_dir ):
357438 os .makedirs (gm_dir , exist_ok = True )
358- print (f"Writing checkpoint with { len (gm .graph .nodes )} nodes to { file_name } " )
359439 for node in gm .graph .nodes :
360440 new_kwargs = {}
361441 for k , v in node .kwargs .items ():
@@ -365,19 +445,39 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
365445 node .kwargs = new_kwargs
366446 gm .recompile ()
367447
448+ print (f"Writing checkpoint with { len (gm .graph .nodes )} nodes to { file_name } " )
368449 with open (file_name , "w" ) as fd :
369450 # TODO - Add the readable version of to_folder when available
370- gm .to_folder (gm_dir , "ReproModule" )
371- fd .write (generate_backend_repro_string (gm , args , compiler_name ))
451+ gm .to_folder (gm_dir , "Repro" )
452+ fd .write (
453+ generate_dynamo_fx_repro_string (
454+ "from module import Repro" , args , compiler_name
455+ )
456+ )
457+
372458 local_dir = os .path .join (torchdynamo .config .base_dir , "repro" )
373459 if os .path .exists (local_dir ):
374460 shutil .rmtree (local_dir )
375461 shutil .copytree (tmp_dir , local_dir )
376462 local_tar_file = os .path .join (torchdynamo .config .base_dir , "repro.tar.gz" )
463+ print (f"Writing checkpoint with { len (gm .graph .nodes )} locally to { local_tar_file } " )
377464 with tarfile .open (local_tar_file , "w:gz" ) as tar :
378465 tar .add (local_dir , arcname = os .path .basename (local_dir ))
379466
380467
468+ def dump_dynamo_gm_state (gm , args , compiler_name ):
469+ """
470+ Dumps the dynamo graph to repro the issue.
471+ 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a
472+ repro.py file.
473+ 2) If we can't convert Fx GraphModule to a string, we use to_folder to save
474+ the module and save a tar file.
475+ """
476+ if NNModuleToString .can_convert_to_string (gm ):
477+ return dump_dynamo_gm_to_file (gm , args , compiler_name )
478+ return dump_dynamo_gm_as_tarfile (gm , args , compiler_name )
479+
480+
381481def backend_fails (gm , example_inputs , compiler_fn , orig_failure ):
382482 """
383483 Minifier uses this function to identify if the minified graph module fails
0 commit comments