Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit e35c50b

Browse files
committed
Write to a file instead of tar.gz
1 parent cbfa91e commit e35c50b

File tree

1 file changed

+123
-25
lines changed

1 file changed

+123
-25
lines changed

torchdynamo/debug_utils.py

Lines changed: 123 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,73 @@ def minifier_dir():
2121
return f"/tmp/minifier_{getpass.getuser()}"
2222

2323

24+
class NNModuleToString:
25+
safe_reprs = [
26+
torch.nn.Linear,
27+
torch.nn.Conv1d,
28+
torch.nn.Conv2d,
29+
torch.nn.Conv3d,
30+
torch.nn.BatchNorm1d,
31+
torch.nn.BatchNorm2d,
32+
torch.nn.BatchNorm3d,
33+
torch.nn.LayerNorm,
34+
torch.nn.Dropout,
35+
torch.nn.Softmax,
36+
]
37+
38+
@staticmethod
39+
def can_convert_to_string(gm):
40+
cant_convert = set()
41+
for _, module in gm.named_children():
42+
if type(module) not in NNModuleToString.safe_reprs:
43+
cant_convert.update(type(module))
44+
45+
if len(cant_convert) > 0:
46+
logging.warn(
47+
f"Was not able to save the following children modules as reprs {cant_convert}"
48+
)
49+
return False
50+
return True
51+
52+
@staticmethod
53+
def convert(gm):
54+
from torch.nn.modules.module import _addindent
55+
56+
tab = " " * 4
57+
58+
model_str = textwrap.dedent(
59+
"""
60+
from torch.nn import *
61+
class Repro(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
"""
65+
)
66+
67+
for module_name, module in gm.named_children():
68+
module_str = f"{module.__repr__()}"
69+
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
70+
71+
for buffer_name, buffer in gm._buffers.items():
72+
if buffer is None:
73+
continue
74+
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.randn({list(buffer.shape)}, dtype={buffer.dtype}))\n"
75+
76+
for param_name, param in gm._parameters.items():
77+
if param is None:
78+
continue
79+
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))\n"
80+
81+
attrs = dir(gm)
82+
for attr in attrs:
83+
if "_tensor_constant" in attr:
84+
val = getattr(gm, attr)
85+
model_str += f" {attr} = {val!r}\n"
86+
87+
model_str += f"{_addindent(gm.code, 4)}\n"
88+
return model_str
89+
90+
2491
@functools.lru_cache(None) # subprocess is expensive
2592
def _cuda_system_info_comment():
2693
if not torch.cuda.is_available():
@@ -48,7 +115,7 @@ def _cuda_system_info_comment():
48115
return model_str
49116

50117

51-
def generate_repro_string(gm, args):
118+
def generate_post_aot_repro_string(gm, args):
52119
model_str = textwrap.dedent(
53120
"""
54121
import torch
@@ -65,14 +132,7 @@ def generate_repro_string(gm, args):
65132
model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
66133
model_str += _cuda_system_info_comment()
67134

68-
model_str += "class Repro(torch.nn.Module):\n"
69-
attrs = dir(gm)
70-
for attr in attrs:
71-
if "_tensor_constant" in attr:
72-
val = getattr(gm, attr)
73-
model_str += f" {attr} = {val!r}\n"
74-
model_str += textwrap.indent(gm.code, " ")
75-
model_str += "\n"
135+
model_str += NNModuleToString.convert(gm)
76136

77137
model_str += f"args = {[(tuple(arg.shape), tuple(arg.stride()), arg.dtype, arg.device.type) for arg in args]!r}\n"
78138
model_str += "args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]\n"
@@ -108,7 +168,7 @@ def dump_post_aot_graph_state(gm, args, compiler_name):
108168

109169

110170
def save_graph_repro(fd, gm, args, compiler_name):
111-
fd.write(generate_repro_string(gm, args))
171+
fd.write(generate_post_aot_repro_string(gm, args))
112172
fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0])
113173
fd.write(
114174
textwrap.dedent(
@@ -128,7 +188,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None):
128188
os.makedirs(subdir, exist_ok=True)
129189
file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py")
130190
with open(file_name, "w") as fd:
131-
fd.write(generate_repro_string(fx_g, args))
191+
fd.write(generate_post_aot_repro_string(fx_g, args))
132192
fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2]
133193
fd.write(
134194
textwrap.dedent(
@@ -208,7 +268,7 @@ def dump_to_minify(gm, args, compiler_name: str):
208268
with open(
209269
os.path.join(torchdynamo.config.base_dir, "minifier_launcher.py"), "w"
210270
) as fd:
211-
fd.write(generate_repro_string(gm, args))
271+
fd.write(generate_post_aot_repro_string(gm, args))
212272
fd.write("\n")
213273
fd.write(
214274
textwrap.dedent(
@@ -288,7 +348,7 @@ def run_fwd_maybe_bwd(gm, args):
288348
return out
289349

290350

291-
def generate_backend_repro_string(gm, args, compiler_name):
351+
def generate_dynamo_fx_repro_string(model_str, args, compiler_name):
292352
"""
293353
Generate a repro string for backend-agnostic minified version.
294354
"""
@@ -297,7 +357,10 @@ def generate_backend_repro_string(gm, args, compiler_name):
297357
"""
298358
import torch
299359
import torchdynamo
360+
from torch import tensor, device
361+
import torch.fx as fx
300362
from torchdynamo.testing import rand_strided
363+
from math import inf
301364
from torchdynamo.debug_utils import run_fwd_maybe_bwd
302365
303366
"""
@@ -313,8 +376,7 @@ def generate_backend_repro_string(gm, args, compiler_name):
313376

314377
setup_module = textwrap.dedent(
315378
f"""
316-
import module
317-
mod = module.ReproModule().cuda()
379+
mod = Repro().cuda()
318380
opt_mod = torchdynamo.optimize("{compiler_name}")(mod)
319381
320382
"""
@@ -329,16 +391,33 @@ def generate_backend_repro_string(gm, args, compiler_name):
329391
"""
330392
)
331393

332-
return imports + prep_inputs + setup_module + run_module
394+
return imports + model_str + setup_module + prep_inputs + run_module
333395

334396

335-
def dump_dynamo_gm_state(gm, args, compiler_name):
397+
def dump_dynamo_gm_to_file(gm, args, compiler_name):
336398
"""
337-
Saves the graph module and accompany it with a repro.py script. This is not
338-
as clean as the wrap_post_aot_debug, where the repro is limited to one file.
339-
This is because post_aot counterpart works on the functionalized graphs with
340-
params lifted as graph inputs. Here, we have TorchDynamo produced Fx graphs,
341-
which we save using to_folder utility.
399+
Saves the repro to a repro.py file
400+
"""
401+
subdir = f"{minifier_dir()}/checkpoints"
402+
if not os.path.exists(subdir):
403+
os.makedirs(subdir, exist_ok=True)
404+
file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py")
405+
print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
406+
407+
model_str = NNModuleToString.convert(gm)
408+
with open(file_name, "w") as fd:
409+
fd.write(generate_dynamo_fx_repro_string(model_str, args, compiler_name))
410+
repro_path = os.path.join(torchdynamo.config.base_dir, "repro.py")
411+
shutil.copyfile(file_name, repro_path)
412+
413+
414+
def dump_dynamo_gm_as_tarfile(gm, args, compiler_name):
415+
"""
416+
Saves the repro in repro.tar.gz, as opposed to a file. This is used for
417+
cases, where we can't convert a Fx GraphModule to a string, and therefore
418+
fallback to to_folder for serialization. We accompany this with a repro.py
419+
script that imports the saved module, sets it up and runs the model to repro
420+
the error.
342421
"""
343422
import tarfile
344423

@@ -355,7 +434,6 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
355434
gm_dir = os.path.join(tmp_dir, "module")
356435
if not os.path.exists(gm_dir):
357436
os.makedirs(gm_dir, exist_ok=True)
358-
print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
359437
for node in gm.graph.nodes:
360438
new_kwargs = {}
361439
for k, v in node.kwargs.items():
@@ -365,19 +443,39 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
365443
node.kwargs = new_kwargs
366444
gm.recompile()
367445

446+
print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
368447
with open(file_name, "w") as fd:
369448
# TODO - Add the readable version of to_folder when available
370-
gm.to_folder(gm_dir, "ReproModule")
371-
fd.write(generate_backend_repro_string(gm, args, compiler_name))
449+
gm.to_folder(gm_dir, "Repro")
450+
fd.write(
451+
generate_dynamo_fx_repro_string(
452+
"from module import Repro", args, compiler_name
453+
)
454+
)
455+
372456
local_dir = os.path.join(torchdynamo.config.base_dir, "repro")
373457
if os.path.exists(local_dir):
374458
shutil.rmtree(local_dir)
375459
shutil.copytree(tmp_dir, local_dir)
376460
local_tar_file = os.path.join(torchdynamo.config.base_dir, "repro.tar.gz")
461+
print(f"Writing checkpoint with {len(gm.graph.nodes)} locally to {local_tar_file}")
377462
with tarfile.open(local_tar_file, "w:gz") as tar:
378463
tar.add(local_dir, arcname=os.path.basename(local_dir))
379464

380465

466+
def dump_dynamo_gm_state(gm, args, compiler_name):
467+
"""
468+
Dumps the dynamo graph to repro the issue.
469+
1) It tries to convert Fx GraphModule to a string. If we can, it writes to a
470+
repro.py file.
471+
2) If we can't convert Fx GraphModule to a string, we use to_folder to save
472+
the module and save a tar file.
473+
"""
474+
if NNModuleToString.can_convert_to_string(gm):
475+
return dump_dynamo_gm_to_file(gm, args, compiler_name)
476+
return dump_dynamo_gm_as_tarfile(gm, args, compiler_name)
477+
478+
381479
def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
382480
"""
383481
Minifier uses this function to identify if the minified graph module fails

0 commit comments

Comments
 (0)