@@ -310,25 +310,23 @@ def wrap_compiler_debug(compiler, compiler_name: str):
310310 @functools .wraps (compiler )
311311 def debug_wrapper (gm , example_inputs , ** kwargs ):
312312 orig_graph = copy .deepcopy (gm .graph )
313- if config .repro_level == 3 :
314- dump_to_minify (
315- fx .GraphModule (gm , orig_graph ), example_inputs , compiler_name
316- )
317-
318- try :
319- compiled_fn = compiler (gm , example_inputs , ** kwargs )
320- if config .repro_level > 0 :
313+ assert config .repro_after in ("dynamo" , "aot" , None )
314+ if config .repro_after == "aot" :
315+ try :
316+ compiled_fn = compiler (gm , example_inputs , ** kwargs )
321317 compiled_fn (* example_inputs )
322- except Exception as e :
323- if config .repro_level == 1 :
324- dump_compiler_graph_state (
325- fx .GraphModule (gm , orig_graph ), example_inputs , compiler_name
326- )
327- elif config .repro_level == 2 :
328- dump_to_minify (
329- fx .GraphModule (gm , orig_graph ), example_inputs , compiler_name
330- )
331- raise e
318+ except Exception as e :
319+ if config .repro_level == 1 :
320+ dump_compiler_graph_state (
321+ fx .GraphModule (gm , orig_graph ), example_inputs , compiler_name
322+ )
323+ elif config .repro_level == 2 :
324+ dump_to_minify (
325+ fx .GraphModule (gm , orig_graph ), example_inputs , compiler_name
326+ )
327+ raise e
328+ else :
329+ compiled_fn = compiler (gm , example_inputs , ** kwargs )
332330
333331 return compiled_fn
334332
@@ -517,11 +515,12 @@ def wrap_backend_debug(compiler_fn, compiler_name: str):
517515
518516 @functools .wraps (compiler_fn )
519517 def debug_wrapper (gm , example_inputs , ** kwargs ):
520- compiled_gm = compiler_fn ( gm , example_inputs , ** kwargs )
521- if config .backend_repro_level > 0 :
518+ assert config . repro_after in ( "dynamo" , "aot" , None )
519+ if config .repro_after == "dynamo" :
522520 # Ensure that we fail when backend fails
523521 config .raise_on_backend_error = True
524522 try :
523+ compiled_gm = compiler_fn (gm , example_inputs , ** kwargs )
525524 run_fwd_maybe_bwd (compiled_gm , clone_inputs (example_inputs ))
526525 except Exception as exc :
527526 orig_failure = str (exc )
@@ -531,7 +530,7 @@ def debug_wrapper(gm, example_inputs, **kwargs):
531530 dump_state_fn = functools .partial (
532531 dump_backend_state , compiler_name = compiler_name
533532 )
534- if config .backend_repro_level == 1 :
533+ if config .repro_level == 1 :
535534 dump_state_fn (
536535 fx .GraphModule (gm , copy .deepcopy (gm .graph )), example_inputs
537536 )
@@ -558,6 +557,8 @@ def debug_wrapper(gm, example_inputs, **kwargs):
558557 dump_state = dump_state_fn ,
559558 )
560559 raise exc
560+ else :
561+ compiled_gm = compiler_fn (gm , example_inputs , ** kwargs )
561562
562563 return compiled_gm
563564
0 commit comments