4646from .lists import SizeVariable
4747
4848
49+ def _run_node (node , args , kwargs , nnmodule ):
50+ op = node .op
51+ if op == "call_function" :
52+ return node .target (* args , ** kwargs )
53+ elif op == "call_method" :
54+ return getattr (args [0 ], node .target )(* args [1 :], ** kwargs )
55+ elif op == "call_module" :
56+ assert nnmodule is not None
57+ return nnmodule (* args , ** kwargs )
58+ assert False , op
59+
60+
61+ def _get_real_value (node , output_graph ):
62+ """
63+ Run the actual computation represented by `node` and return the result.
64+ This will execute any dependent nodes in the graph as well.
65+ """
66+ cache = output_graph .real_value_cache
67+ if node in cache :
68+ return cache [node ]
69+
70+ op = node .op
71+ args , kwargs = torch .fx .node .map_arg (
72+ (node .args , node .kwargs ),
73+ lambda n : _get_real_value (n , output_graph ),
74+ )
75+
76+ if op == "call_module" :
77+ nn_module = output_graph .nn_modules [node .target ]
78+ if not is_lazy_module (nn_module ):
79+ nn_module = copy .deepcopy (nn_module )
80+ else :
81+ # In the case of a lazy module, we want to run
82+ # the pre-hooks which initialize it
83+ nn_module (* args , ** kwargs )
84+ else :
85+ nn_module = None
86+
87+ try :
88+ real_value = _run_node (node , args , kwargs , nn_module )
89+ cache [node ] = real_value
90+ except RuntimeError as e :
91+ raise TorchRuntimeError () from e
92+ return real_value
93+
94+
95+ def _get_fake_value (node , tx ):
96+ """
97+ Run the computation represented by `node` using fake tensors and return the result.
98+ """
99+ op = node .op
100+ fake_wrapper = functools .partial (wrap_to_fake_tensor , fake_mode = tx .fake_mode )
101+ from ..utils import wrap_fake_exception
102+
103+ def visit (n : torch .fx .Node ):
104+ return n .meta ["example_value" ]
105+
106+ args , kwargs = torch .fx .node .map_arg ((node .args , node .kwargs ), visit )
107+ args = tree_map (fake_wrapper , args )
108+ kwargs = tree_map (fake_wrapper , kwargs )
109+
110+ nnmodule = None
111+ if op == "call_module" :
112+ nnmodule = tx .output .nn_modules [node .target ]
113+
114+ if not is_lazy_module (nnmodule ):
115+ nnmodule = deepcopy_to_fake_tensor (nnmodule , tx .fake_mode )
116+
117+ def context ():
118+ if hasattr (py_dispatch , "enable_torch_dispatch_mode" ):
119+ return py_dispatch .enable_torch_dispatch_mode (tx .fake_mode )
120+ else :
121+ return tx .fake_mode
122+
123+ if op == "call_module" and is_lazy_module (nnmodule ):
124+ assert nnmodule is not None
125+ # In the case of a lazy module, we want to run
126+ # the pre-hooks which initialize it
127+ nnmodule (* args , ** kwargs )
128+ try :
129+ with context ():
130+ return wrap_fake_exception (lambda : _run_node (node , args , kwargs , nnmodule ))
131+ except Unsupported :
132+ raise
133+ except RuntimeError as e :
134+ if isinstance (e , DataDependentOutputException ):
135+ if config .capture_scalar_outputs and node .target == "item" :
136+ return torch .zeros (size = (), dtype = args [0 ].dtype ).item ()
137+ else :
138+ unimplemented (f"data dependent operator: { e .func } " )
139+ elif isinstance (e , DynamicOutputShapeException ):
140+ unimplemented (f"dynamic shape operator: { e .func } " )
141+ else :
142+ raise TorchRuntimeError () from e
143+
144+
145+ def _clone_input (value ):
146+ if isinstance (value , torch .Tensor ):
147+ use_fake_tensors = fake_tensors_available and config .fake_tensor_propagation
148+ # tensor subclasses will not be converted to FakeTensors and need to be cloned
149+ if not use_fake_tensors or not isinstance (value , FakeTensor ):
150+ # NB: ensure strides are preserved
151+ value = clone_input (value )
152+
153+ return value
154+
155+
49156class TensorVariable (VariableTracker ):
50157 """A torch.Tensor input or an intermediate value in the FX graph"""
51158
@@ -61,27 +168,18 @@ class TensorVariable(VariableTracker):
61168 "is_contiguous" ,
62169 ]
63170
64- @staticmethod
65- def propagate_args_kwargs (node ):
66- def visit (n : torch .fx .Node ):
67- return n .meta ["example_value" ]
68-
69- return torch .fx .node .map_arg ((node .args , node .kwargs ), visit )
171+ def get_real_value (self ):
172+ """
173+ Get the actual value represented by this variable if computation is run
174+ using the user-provided inputs.
70175
71- @staticmethod
72- def run_proxy (proxy , args , kwargs , nnmodule ):
73- op = proxy .node .op
74- if op == "call_function" :
75- return proxy .node .target (* args , ** kwargs )
76- elif op == "call_method" :
77- return getattr (args [0 ], proxy .node .target )(* args [1 :], ** kwargs )
78- elif op == "call_module" :
79- assert nnmodule is not None
80- return nnmodule (* args , ** kwargs )
81- assert False , op
176+ NOTE: this runs actual tensor computation and may be
177+ slow and memory-intensive.
178+ """
179+ return _get_real_value (self .proxy .node , self .proxy .tracer )
82180
83181 @classmethod
84- def create (cls , tx , proxy , example_value = None , nnmodule = None , ** options ):
182+ def create (cls , tx , proxy , example_value = None , ** options ):
85183 if "guards" in options and options ["guards" ] is not None :
86184 tx .output .guards .update (options ["guards" ])
87185
@@ -92,82 +190,29 @@ def create(cls, tx, proxy, example_value=None, nnmodule=None, **options):
92190 return cls (proxy , ** options )
93191
94192 use_fake_tensors = fake_tensors_available and config .fake_tensor_propagation
95- if use_fake_tensors :
96- fake_wrapper = functools .partial (
97- wrap_to_fake_tensor , fake_mode = tx .fake_mode
98- )
99- # python errors if the import isnt here
100- from ..utils import wrap_fake_exception
101- else :
102193
103- def wrap_fake_exception (func ):
104- return func ()
105-
106- args = kwargs = None
107194 initial_example_value = example_value
108195
109196 with preserve_rng_state ():
110197 if example_value is None :
111- op = proxy .node .op
112- args , kwargs = cls .propagate_args_kwargs (proxy .node )
113198 if use_fake_tensors :
114- args = tree_map (fake_wrapper , args )
115- kwargs = tree_map (fake_wrapper , kwargs )
116- if op == "call_module" and not is_lazy_module (nnmodule ):
117- nnmodule = deepcopy_to_fake_tensor (nnmodule , tx .fake_mode )
118-
119- def context ():
120- if hasattr (py_dispatch , "enable_torch_dispatch_mode" ):
121- return py_dispatch .enable_torch_dispatch_mode (tx .fake_mode )
122- else :
123- return tx .fake_mode
124-
199+ example_value = _get_fake_value (proxy .node , tx )
125200 else :
126- context = contextlib .nullcontext
127- if op == "call_module" and not is_lazy_module (nnmodule ):
128- nnmodule = copy .deepcopy (nnmodule )
129-
130- if op == "call_module" and is_lazy_module (nnmodule ):
131- assert nnmodule is not None
132- # In the case of a lazy module, we want to run
133- # the pre-hooks which initialize it
134- example_value = nnmodule (* args , ** kwargs )
135- try :
136- with context ():
137- example_value = wrap_fake_exception (
138- lambda : cls .run_proxy (proxy , args , kwargs , nnmodule )
139- )
140- except Unsupported :
141- raise
142- except RuntimeError as e :
143- if use_fake_tensors and isinstance (e , DataDependentOutputException ):
144- if (
145- config .capture_scalar_outputs
146- and proxy .node .target == "item"
147- ):
148- example_value = torch .zeros (
149- size = (), dtype = args [0 ].dtype
150- ).item ()
151- else :
152- unimplemented (f"data dependent operator: { e .func } " )
153- elif use_fake_tensors and isinstance (
154- e , DynamicOutputShapeException
155- ):
156- unimplemented (f"dynamic shape operator: { e .func } " )
157- else :
158- raise TorchRuntimeError () from e
201+ example_value = _get_real_value (proxy .node , tx .output )
202+
159203 else :
204+ proxy .tracer .real_value_cache [proxy .node ] = _clone_input (example_value )
160205 if use_fake_tensors :
206+ fake_wrapper = functools .partial (
207+ wrap_to_fake_tensor , fake_mode = tx .fake_mode
208+ )
161209 example_value = fake_wrapper (example_value )
162210
163211 if isinstance (example_value , torch .Tensor ):
164212 is_parameter = isinstance (example_value , torch .nn .Parameter )
165213 parameter_value = initial_example_value if is_parameter else None
166214
167- # tensor subclasses will not be converted to FakeTensors and need to be cloned
168- if not use_fake_tensors or not isinstance (example_value , FakeTensor ):
169- # NB: ensure strides are preserved
170- example_value = clone_input (example_value )
215+ example_value = _clone_input (example_value )
171216 proxy .node .meta ["example_value" ] = example_value
172217 specialized_props = cls .specialize (example_value )
173218 if use_fake_tensors and isinstance (example_value , FakeTensor ):
0 commit comments