Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 6bd7193

Browse files
committed
naming suggestions from Horace
1 parent 06b9e42 commit 6bd7193

File tree

5 files changed

+30
-30
lines changed

5 files changed

+30
-30
lines changed

torchdynamo/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@
110110
# 0: Nothing printed out when compilation fails
111111
# 1: Dump the initial graph to repro.tar.gz
112112
# 2/3: Minifies and Dumps the minified repro to repro.tar.gz
113-
dynamo_repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 0))
114-
113+
backend_repro_level = int(os.environ.get("BACKEND_REPRO_LEVEL", 0))
115114

116115
# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
117116
# When this flag is set to False, we introduce a graph break instead of capturing.

torchdynamo/debug_utils.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ class NNModuleToString:
3333
torch.nn.LayerNorm,
3434
torch.nn.Dropout,
3535
torch.nn.Softmax,
36+
torch.nn.ReLU,
37+
torch.nn.MaxPool2d,
3638
]
3739

3840
@staticmethod
3941
def can_convert_to_string(gm):
4042
cant_convert = set()
4143
for _, module in gm.named_children():
4244
if type(module) not in NNModuleToString.safe_reprs:
43-
cant_convert.update(type(module))
45+
cant_convert.add(module)
4446

4547
if len(cant_convert) > 0:
4648
logging.warn(
@@ -117,7 +119,7 @@ def _cuda_system_info_comment():
117119
return model_str
118120

119121

120-
def generate_post_aot_repro_string(gm, args):
122+
def generate_compiler_repro_string(gm, args):
121123
model_str = textwrap.dedent(
122124
"""
123125
import torch
@@ -157,7 +159,7 @@ def generate_post_aot_repro_string(gm, args):
157159
}
158160

159161

160-
def dump_post_aot_graph_state(gm, args, compiler_name):
162+
def dump_compiler_graph_state(gm, args, compiler_name):
161163
subdir = f"{minifier_dir()}/checkpoints"
162164
if not os.path.exists(subdir):
163165
os.makedirs(subdir, exist_ok=True)
@@ -170,7 +172,7 @@ def dump_post_aot_graph_state(gm, args, compiler_name):
170172

171173

172174
def save_graph_repro(fd, gm, args, compiler_name):
173-
fd.write(generate_post_aot_repro_string(gm, args))
175+
fd.write(generate_compiler_repro_string(gm, args))
174176
fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0])
175177
fd.write(
176178
textwrap.dedent(
@@ -190,7 +192,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None):
190192
os.makedirs(subdir, exist_ok=True)
191193
file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py")
192194
with open(file_name, "w") as fd:
193-
fd.write(generate_post_aot_repro_string(fx_g, args))
195+
fd.write(generate_compiler_repro_string(fx_g, args))
194196
fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2]
195197
fd.write(
196198
textwrap.dedent(
@@ -270,15 +272,15 @@ def dump_to_minify(gm, args, compiler_name: str):
270272
with open(
271273
os.path.join(torchdynamo.config.base_dir, "minifier_launcher.py"), "w"
272274
) as fd:
273-
fd.write(generate_post_aot_repro_string(gm, args))
275+
fd.write(generate_compiler_repro_string(gm, args))
274276
fd.write("\n")
275277
fd.write(
276278
textwrap.dedent(
277279
f"""
278280
from functools import partial
279281
from torchdynamo.debug_utils import (
280282
isolate_fails,
281-
dump_post_aot_graph_state,
283+
dump_compiler_graph_state,
282284
)
283285
from functorch.compile import minifier
284286
@@ -288,15 +290,15 @@ def dump_to_minify(gm, args, compiler_name: str):
288290
mod,
289291
args,
290292
module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"),
291-
dump_state=partial(dump_post_aot_graph_state, compiler_name="{compiler_name}"),
293+
dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"),
292294
)
293295
"""
294296
)
295297
)
296298
print("wrote out to minifier_launcher.py")
297299

298300

299-
def wrap_post_aot_debug(compiler, compiler_name: str):
301+
def wrap_compiler_debug(compiler, compiler_name: str):
300302
"""
301303
Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both
302304
forward and backward call separately with the backend compiler - like
@@ -319,7 +321,7 @@ def debug_wrapper(gm, example_inputs, **kwargs):
319321
compiled_fn(*example_inputs)
320322
except Exception as e:
321323
if config.repro_level == 1:
322-
dump_post_aot_graph_state(
324+
dump_compiler_graph_state(
323325
fx.GraphModule(gm, orig_graph), example_inputs, compiler_name
324326
)
325327
elif config.repro_level == 2:
@@ -384,7 +386,6 @@ def generate_dynamo_fx_repro_string(model_str, args, compiler_name):
384386
"""
385387
)
386388

387-
# TODO - Figure out the amp state
388389
run_module = textwrap.dedent(
389390
f"""
390391
with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
@@ -396,7 +397,7 @@ def generate_dynamo_fx_repro_string(model_str, args, compiler_name):
396397
return imports + model_str + setup_module + prep_inputs + run_module
397398

398399

399-
def dump_dynamo_gm_to_file(gm, args, compiler_name):
400+
def dump_backend_repro_as_file(gm, args, compiler_name):
400401
"""
401402
Saves the repro to a repro.py file
402403
"""
@@ -413,7 +414,7 @@ def dump_dynamo_gm_to_file(gm, args, compiler_name):
413414
shutil.copyfile(file_name, repro_path)
414415

415416

416-
def dump_dynamo_gm_as_tarfile(gm, args, compiler_name):
417+
def dump_backend_repro_as_tarfile(gm, args, compiler_name):
417418
"""
418419
Saves the repro in repro.tar.gz, as opposed to a file. This is used for
419420
cases, where we can't convert a Fx GraphModule to a string, and therefore
@@ -465,7 +466,7 @@ def dump_dynamo_gm_as_tarfile(gm, args, compiler_name):
465466
tar.add(local_dir, arcname=os.path.basename(local_dir))
466467

467468

468-
def dump_dynamo_gm_state(gm, args, compiler_name):
469+
def dump_backend_state(gm, args, compiler_name):
469470
"""
470471
Dumps the dynamo graph to repro the issue.
471472
1) It tries to convert Fx GraphModule to a string. If we can, it writes to a
@@ -474,8 +475,8 @@ def dump_dynamo_gm_state(gm, args, compiler_name):
474475
the module and save a tar file.
475476
"""
476477
if NNModuleToString.can_convert_to_string(gm):
477-
return dump_dynamo_gm_to_file(gm, args, compiler_name)
478-
return dump_dynamo_gm_as_tarfile(gm, args, compiler_name)
478+
return dump_backend_repro_as_file(gm, args, compiler_name)
479+
return dump_backend_repro_as_tarfile(gm, args, compiler_name)
479480

480481

481482
def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
@@ -503,10 +504,10 @@ def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
503504
return False
504505

505506

506-
def wrap_dynamo_gm_debug(compiler_fn, compiler_name: str):
507+
def wrap_backend_debug(compiler_fn, compiler_name: str):
507508
"""
508509
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
509-
As opposed to wrap_post_aot_debug, this wrapper intercepts at the
510+
As opposed to wrap_compiler_debug, this wrapper intercepts at the
510511
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
511512
level, e.g., it is useful for minifying issues related to Aot Autograd
512513
tracing. If an error is found, we minify and save the minified repro in
@@ -517,7 +518,7 @@ def wrap_dynamo_gm_debug(compiler_fn, compiler_name: str):
517518
@functools.wraps(compiler_fn)
518519
def debug_wrapper(gm, example_inputs, **kwargs):
519520
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
520-
if config.dynamo_repro_level > 0:
521+
if config.backend_repro_level > 0:
521522
# Ensure that we fail when backend fails
522523
config.raise_on_backend_error = True
523524
try:
@@ -528,15 +529,15 @@ def debug_wrapper(gm, example_inputs, **kwargs):
528529
f"Compiled Fx GraphModule failed with {orig_failure}. Starting minifier."
529530
)
530531
dump_state_fn = functools.partial(
531-
dump_dynamo_gm_state, compiler_name=compiler_name
532+
dump_backend_state, compiler_name=compiler_name
532533
)
533-
if config.dynamo_repro_level == 1:
534+
if config.backend_repro_level == 1:
534535
dump_state_fn(
535536
fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
536537
)
537538
else:
538539
# As opposed to using dump_to_minify, like we do in
539-
# wrap_post_aot_debug, we directly run minifier here. This
540+
# wrap_compiler_debug, we directly run minifier here. This
540541
# is because we can't serialize compiler_fn here.
541542

542543
# The minified version uses

torchdynamo/eval_frame.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.utils._pytree as pytree
1414

1515
import torchdynamo
16-
from torchdynamo.debug_utils import wrap_dynamo_gm_debug
16+
from torchdynamo.debug_utils import wrap_backend_debug
1717
from torchdynamo.utils import checkpoint_params
1818
from torchdynamo.utils import clone_inputs
1919
from torchdynamo.utils import compile_times
@@ -290,7 +290,7 @@ def get_compiler_fn(compiler_fn):
290290

291291
compiler_fn = BACKENDS[compiler_fn]
292292

293-
return wrap_dynamo_gm_debug(compiler_fn, compiler_str)
293+
return wrap_backend_debug(compiler_fn, compiler_str)
294294

295295

296296
class _NullDecorator(contextlib.nullcontext):

torchdynamo/optimizations/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torchdynamo
1414
from torchdynamo import config
15-
from torchdynamo.debug_utils import wrap_post_aot_debug
15+
from torchdynamo.debug_utils import wrap_compiler_debug
1616
from torchdynamo.utils import clone_inputs
1717
from torchdynamo.utils import count_calls
1818
from torchdynamo.utils import counters
@@ -249,7 +249,7 @@ def candidate(self):
249249
return BACKENDS["aot_autograd"](
250250
self.gm,
251251
self.example_inputs,
252-
fw_compiler=wrap_post_aot_debug(self.nvfuser, "nvfuser"),
252+
fw_compiler=wrap_compiler_debug(self.nvfuser, "nvfuser"),
253253
partition_fn=self.min_cut_rematerialization_partition,
254254
hasher_type="StaticShapeHasher",
255255
decompositions=self.aten2aten_decompositions,

torchinductor/debug.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import torchinductor
2323
from torchdynamo.debug_utils import save_graph_repro
24-
from torchdynamo.debug_utils import wrap_post_aot_debug
24+
from torchdynamo.debug_utils import wrap_compiler_debug
2525
from torchdynamo.utils import init_logging
2626

2727
from . import config
@@ -181,7 +181,7 @@ def inner(*args, **kwargs):
181181
with DebugContext():
182182
return fn(*args, **kwargs)
183183

184-
return wrap_post_aot_debug(inner, compiler_name="inductor")
184+
return wrap_compiler_debug(inner, compiler_name="inductor")
185185

186186
@staticmethod
187187
def create_debug_dir():

0 commit comments

Comments
 (0)