11import logging
22from typing import Sequence
33import torch
4- from functools import partial
4+ from torch ._dynamo .utils import detect_fake_mode
5+ import unittest
56import torch ._dynamo as td
67from torch ._guards import TracingContext
78
1617 partition ,
1718 get_submod_inputs ,
1819)
19- from torch_tensorrt .dynamo .lowering ._freeze_aot_graph import freeze_autograd_gm
2020from torch_tensorrt .dynamo .utils import parse_dynamo_kwargs
2121from torch_tensorrt .dynamo .conversion import (
2222 convert_module ,
2323 repair_long_or_double_inputs ,
2424)
2525
26- from torch ._functorch .aot_autograd import make_boxed_compiler
27- from .aot_module import aot_module
26+ from torch ._functorch .aot_autograd import aot_export_joint_simple
2827
2928
3029logger = logging .getLogger (__name__ )
@@ -36,8 +35,6 @@ def torch_tensorrt_backend(
3635):
3736 DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3837
39- TracingContext .get ().fake_mode .allow_non_fake_inputs = True
40-
4138 return DEFAULT_BACKEND (gm , sample_inputs , ** kwargs )
4239
4340
@@ -47,21 +44,25 @@ def aot_torch_tensorrt_aten_backend(
4744):
4845 settings = parse_dynamo_kwargs (kwargs )
4946
50- custom_backend = partial (
51- _pretraced_backend ,
52- settings = settings ,
53- )
54-
5547 # Perform Pre-AOT Lowering for Module-Level Replacement
5648 gm = pre_aot_substitutions (gm )
5749
58- # Invoke AOTAutograd to translate operators to aten
59- return aot_module (
60- gm ,
61- sample_inputs ,
62- fw_compiler = make_boxed_compiler (custom_backend ),
63- decompositions = get_decompositions (),
64- )
50+ fake_mode = detect_fake_mode (sample_inputs )
51+
52+ # Place backend tracing within FakeTensor context allowing nonfake Tensors
53+ with unittest .mock .patch .object (
54+ fake_mode , "allow_non_fake_inputs" , True
55+ ), fake_mode :
56+
57+ # Invoke AOTAutograd to translate operators to aten
58+ graph_module = aot_export_joint_simple (
59+ gm ,
60+ sample_inputs ,
61+ trace_joint = False ,
62+ decompositions = get_decompositions (),
63+ )
64+
65+ return _pretraced_backend (graph_module , sample_inputs , settings )
6566
6667
6768def _pretraced_backend (
@@ -81,16 +82,9 @@ def _pretraced_backend(
8182 try :
8283 logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
8384
84- frozen_gm , unfrozen_indices = freeze_autograd_gm (gm , sample_inputs )
85- nonfrozen_inputs = [sample_inputs [idx ] for idx in unfrozen_indices ]
86-
87- frozen_gm .graph .eliminate_dead_code ()
88- frozen_gm .graph .lint ()
89- frozen_gm .recompile ()
90-
9185 trt_compiled = _compile_module (
92- frozen_gm ,
93- nonfrozen_inputs ,
86+ gm ,
87+ sample_inputs ,
9488 settings = settings ,
9589 )
9690 return trt_compiled
0 commit comments