Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 06b9e42

Browse files
committed
Write to a file instead of tar.gz
1 parent cbfa91e commit 06b9e42

File tree

1 file changed

+125
-25
lines changed

1 file changed

+125
-25
lines changed

torchdynamo/debug_utils.py

Lines changed: 125 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,75 @@ def minifier_dir():
2121
return f"/tmp/minifier_{getpass.getuser()}"
2222

2323

24+
class NNModuleToString:
25+
safe_reprs = [
26+
torch.nn.Linear,
27+
torch.nn.Conv1d,
28+
torch.nn.Conv2d,
29+
torch.nn.Conv3d,
30+
torch.nn.BatchNorm1d,
31+
torch.nn.BatchNorm2d,
32+
torch.nn.BatchNorm3d,
33+
torch.nn.LayerNorm,
34+
torch.nn.Dropout,
35+
torch.nn.Softmax,
36+
]
37+
38+
@staticmethod
39+
def can_convert_to_string(gm):
40+
cant_convert = set()
41+
for _, module in gm.named_children():
42+
if type(module) not in NNModuleToString.safe_reprs:
43+
cant_convert.update(type(module))
44+
45+
if len(cant_convert) > 0:
46+
logging.warn(
47+
f"Was not able to save the following children modules as reprs {cant_convert}"
48+
)
49+
return False
50+
return True
51+
52+
@staticmethod
53+
def convert(gm):
54+
from torch.nn.modules.module import _addindent
55+
56+
tab = " " * 4
57+
58+
model_str = textwrap.dedent(
59+
"""
60+
from torch.nn import *
61+
class Repro(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
"""
65+
)
66+
67+
for module_name, module in gm.named_children():
68+
module_str = f"{module.__repr__()}"
69+
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
70+
71+
for buffer_name, buffer in gm._buffers.items():
72+
if buffer is None:
73+
continue
74+
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
75+
model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
76+
77+
for param_name, param in gm._parameters.items():
78+
if param is None:
79+
continue
80+
tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))"
81+
model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
82+
83+
attrs = dir(gm)
84+
for attr in attrs:
85+
if "_tensor_constant" in attr:
86+
val = getattr(gm, attr)
87+
model_str += f" {attr} = {val!r}\n"
88+
89+
model_str += f"{_addindent(gm.code, 4)}\n"
90+
return model_str
91+
92+
2493
@functools.lru_cache(None) # subprocess is expensive
2594
def _cuda_system_info_comment():
2695
if not torch.cuda.is_available():
@@ -48,7 +117,7 @@ def _cuda_system_info_comment():
48117
return model_str
49118

50119

51-
def generate_repro_string(gm, args):
120+
def generate_post_aot_repro_string(gm, args):
52121
model_str = textwrap.dedent(
53122
"""
54123
import torch
@@ -65,14 +134,7 @@ def generate_repro_string(gm, args):
65134
model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
66135
model_str += _cuda_system_info_comment()
67136

68-
model_str += "class Repro(torch.nn.Module):\n"
69-
attrs = dir(gm)
70-
for attr in attrs:
71-
if "_tensor_constant" in attr:
72-
val = getattr(gm, attr)
73-
model_str += f" {attr} = {val!r}\n"
74-
model_str += textwrap.indent(gm.code, " ")
75-
model_str += "\n"
137+
model_str += NNModuleToString.convert(gm)
76138

77139
model_str += f"args = {[(tuple(arg.shape), tuple(arg.stride()), arg.dtype, arg.device.type) for arg in args]!r}\n"
78140
model_str += "args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]\n"
@@ -108,7 +170,7 @@ def dump_post_aot_graph_state(gm, args, compiler_name):
108170

109171

110172
def save_graph_repro(fd, gm, args, compiler_name):
111-
fd.write(generate_repro_string(gm, args))
173+
fd.write(generate_post_aot_repro_string(gm, args))
112174
fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0])
113175
fd.write(
114176
textwrap.dedent(
@@ -128,7 +190,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None):
128190
os.makedirs(subdir, exist_ok=True)
129191
file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py")
130192
with open(file_name, "w") as fd:
131-
fd.write(generate_repro_string(fx_g, args))
193+
fd.write(generate_post_aot_repro_string(fx_g, args))
132194
fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2]
133195
fd.write(
134196
textwrap.dedent(
@@ -208,7 +270,7 @@ def dump_to_minify(gm, args, compiler_name: str):
208270
with open(
209271
os.path.join(torchdynamo.config.base_dir, "minifier_launcher.py"), "w"
210272
) as fd:
211-
fd.write(generate_repro_string(gm, args))
273+
fd.write(generate_post_aot_repro_string(gm, args))
212274
fd.write("\n")
213275
fd.write(
214276
textwrap.dedent(
@@ -288,7 +350,7 @@ def run_fwd_maybe_bwd(gm, args):
288350
return out
289351

290352

291-
def generate_backend_repro_string(gm, args, compiler_name):
353+
def generate_dynamo_fx_repro_string(model_str, args, compiler_name):
292354
"""
293355
Generate a repro string for backend-agnostic minified version.
294356
"""
@@ -297,7 +359,10 @@ def generate_backend_repro_string(gm, args, compiler_name):
297359
"""
298360
import torch
299361
import torchdynamo
362+
from torch import tensor, device
363+
import torch.fx as fx
300364
from torchdynamo.testing import rand_strided
365+
from math import inf
301366
from torchdynamo.debug_utils import run_fwd_maybe_bwd
302367
303368
"""
@@ -313,8 +378,7 @@ def generate_backend_repro_string(gm, args, compiler_name):
313378

314379
setup_module = textwrap.dedent(
315380
f"""
316-
import module
317-
mod = module.ReproModule().cuda()
381+
mod = Repro().cuda()
318382
opt_mod = torchdynamo.optimize("{compiler_name}")(mod)
319383
320384
"""
@@ -329,16 +393,33 @@ def generate_backend_repro_string(gm, args, compiler_name):
329393
"""
330394
)
331395

332-
return imports + prep_inputs + setup_module + run_module
396+
return imports + model_str + setup_module + prep_inputs + run_module
333397

334398

335-
def dump_dynamo_gm_state(gm, args, compiler_name):
399+
def dump_dynamo_gm_to_file(gm, args, compiler_name):
336400
"""
337-
Saves the graph module and accompany it with a repro.py script. This is not
338-
as clean as the wrap_post_aot_debug, where the repro is limited to one file.
339-
This is because post_aot counterpart works on the functionalized graphs with
340-
params lifted as graph inputs. Here, we have TorchDynamo produced Fx graphs,
341-
which we save using to_folder utility.
401+
Saves the repro to a repro.py file
402+
"""
403+
subdir = f"{minifier_dir()}/checkpoints"
404+
if not os.path.exists(subdir):
405+
os.makedirs(subdir, exist_ok=True)
406+
file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py")
407+
print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
408+
409+
model_str = NNModuleToString.convert(gm)
410+
with open(file_name, "w") as fd:
411+
fd.write(generate_dynamo_fx_repro_string(model_str, args, compiler_name))
412+
repro_path = os.path.join(torchdynamo.config.base_dir, "repro.py")
413+
shutil.copyfile(file_name, repro_path)
414+
415+
416+
def dump_dynamo_gm_as_tarfile(gm, args, compiler_name):
417+
"""
418+
Saves the repro in repro.tar.gz, as opposed to a file. This is used for
419+
cases, where we can't convert a Fx GraphModule to a string, and therefore
420+
fallback to to_folder for serialization. We accompany this with a repro.py
421+
script that imports the saved module, sets it up and runs the model to repro
422+
the error.
342423
"""
343424
import tarfile
344425

@@ -355,7 +436,6 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
355436
gm_dir = os.path.join(tmp_dir, "module")
356437
if not os.path.exists(gm_dir):
357438
os.makedirs(gm_dir, exist_ok=True)
358-
print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
359439
for node in gm.graph.nodes:
360440
new_kwargs = {}
361441
for k, v in node.kwargs.items():
@@ -365,19 +445,39 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
365445
node.kwargs = new_kwargs
366446
gm.recompile()
367447

448+
print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
368449
with open(file_name, "w") as fd:
369450
# TODO - Add the readable version of to_folder when available
370-
gm.to_folder(gm_dir, "ReproModule")
371-
fd.write(generate_backend_repro_string(gm, args, compiler_name))
451+
gm.to_folder(gm_dir, "Repro")
452+
fd.write(
453+
generate_dynamo_fx_repro_string(
454+
"from module import Repro", args, compiler_name
455+
)
456+
)
457+
372458
local_dir = os.path.join(torchdynamo.config.base_dir, "repro")
373459
if os.path.exists(local_dir):
374460
shutil.rmtree(local_dir)
375461
shutil.copytree(tmp_dir, local_dir)
376462
local_tar_file = os.path.join(torchdynamo.config.base_dir, "repro.tar.gz")
463+
print(f"Writing checkpoint with {len(gm.graph.nodes)} locally to {local_tar_file}")
377464
with tarfile.open(local_tar_file, "w:gz") as tar:
378465
tar.add(local_dir, arcname=os.path.basename(local_dir))
379466

380467

468+
def dump_dynamo_gm_state(gm, args, compiler_name):
469+
"""
470+
Dumps the dynamo graph to repro the issue.
471+
1) It tries to convert Fx GraphModule to a string. If we can, it writes to a
472+
repro.py file.
473+
2) If we can't convert Fx GraphModule to a string, we use to_folder to save
474+
the module and save a tar file.
475+
"""
476+
if NNModuleToString.can_convert_to_string(gm):
477+
return dump_dynamo_gm_to_file(gm, args, compiler_name)
478+
return dump_dynamo_gm_as_tarfile(gm, args, compiler_name)
479+
480+
381481
def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
382482
"""
383483
Minifier uses this function to identify if the minified graph module fails

0 commit comments

Comments
 (0)