4646from  .lists  import  SizeVariable 
4747
4848
49+ def  _run_node (node , args , kwargs , nnmodule ):
50+     op  =  node .op 
51+     if  op  ==  "call_function" :
52+         return  node .target (* args , ** kwargs )
53+     elif  op  ==  "call_method" :
54+         return  getattr (args [0 ], node .target )(* args [1 :], ** kwargs )
55+     elif  op  ==  "call_module" :
56+         assert  nnmodule  is  not None 
57+         return  nnmodule (* args , ** kwargs )
58+     assert  False , op 
59+ 
60+ 
61+ def  _get_real_value (node , output_graph ):
62+     """ 
63+     Run the actual computation represented by `node` and return the result. 
64+     This will execute any dependent nodes in the graph as well. 
65+     """ 
66+     cache  =  output_graph .real_value_cache 
67+     if  node  in  cache :
68+         return  cache [node ]
69+ 
70+     op  =  node .op 
71+     args , kwargs  =  torch .fx .node .map_arg (
72+         (node .args , node .kwargs ),
73+         lambda  n : _get_real_value (n , output_graph ),
74+     )
75+ 
76+     if  op  ==  "call_module" :
77+         nn_module  =  output_graph .nn_modules [node .target ]
78+         if  not  is_lazy_module (nn_module ):
79+             nn_module  =  copy .deepcopy (nn_module )
80+         else :
81+             # In the case of a lazy module, we want to run 
82+             # the pre-hooks which initialize it 
83+             nn_module (* args , ** kwargs )
84+     else :
85+         nn_module  =  None 
86+ 
87+     try :
88+         real_value  =  _run_node (node , args , kwargs , nn_module )
89+         cache [node ] =  real_value 
90+     except  RuntimeError  as  e :
91+         raise  TorchRuntimeError () from  e 
92+     return  real_value 
93+ 
94+ 
95+ def  _get_fake_value (node , tx ):
96+     """ 
97+     Run the computation represented by `node` using fake tensors and return the result. 
98+     """ 
99+     op  =  node .op 
100+     fake_wrapper  =  functools .partial (wrap_to_fake_tensor , fake_mode = tx .fake_mode )
101+     from  ..utils  import  wrap_fake_exception 
102+ 
103+     def  visit (n : torch .fx .Node ):
104+         return  n .meta ["example_value" ]
105+ 
106+     args , kwargs  =  torch .fx .node .map_arg ((node .args , node .kwargs ), visit )
107+     args  =  tree_map (fake_wrapper , args )
108+     kwargs  =  tree_map (fake_wrapper , kwargs )
109+ 
110+     nnmodule  =  None 
111+     if  op  ==  "call_module" :
112+         nnmodule  =  tx .output .nn_modules [node .target ]
113+ 
114+         if  not  is_lazy_module (nnmodule ):
115+             nnmodule  =  deepcopy_to_fake_tensor (nnmodule , tx .fake_mode )
116+ 
117+     def  context ():
118+         if  hasattr (py_dispatch , "enable_torch_dispatch_mode" ):
119+             return  py_dispatch .enable_torch_dispatch_mode (tx .fake_mode )
120+         else :
121+             return  tx .fake_mode 
122+ 
123+     if  op  ==  "call_module"  and  is_lazy_module (nnmodule ):
124+         assert  nnmodule  is  not None 
125+         # In the case of a lazy module, we want to run 
126+         # the pre-hooks which initialize it 
127+         nnmodule (* args , ** kwargs )
128+     try :
129+         with  context ():
130+             return  wrap_fake_exception (lambda : _run_node (node , args , kwargs , nnmodule ))
131+     except  Unsupported :
132+         raise 
133+     except  RuntimeError  as  e :
134+         if  isinstance (e , DataDependentOutputException ):
135+             if  config .capture_scalar_outputs  and  node .target  ==  "item" :
136+                 return  torch .zeros (size = (), dtype = args [0 ].dtype ).item ()
137+             else :
138+                 unimplemented (f"data dependent operator: { e .func }  )
139+         elif  isinstance (e , DynamicOutputShapeException ):
140+             unimplemented (f"dynamic shape operator: { e .func }  )
141+         else :
142+             raise  TorchRuntimeError () from  e 
143+ 
144+ 
145+ def  _clone_input (value ):
146+     if  isinstance (value , torch .Tensor ):
147+         use_fake_tensors  =  fake_tensors_available  and  config .fake_tensor_propagation 
148+         # tensor subclasses will not be converted to FakeTensors and need to be cloned 
149+         if  not  use_fake_tensors  or  not  isinstance (value , FakeTensor ):
150+             # NB: ensure strides are preserved 
151+             value  =  clone_input (value )
152+ 
153+     return  value 
154+ 
155+ 
49156class  TensorVariable (VariableTracker ):
50157    """A torch.Tensor input or an intermediate value in the FX graph""" 
51158
@@ -61,27 +168,18 @@ class TensorVariable(VariableTracker):
61168        "is_contiguous" ,
62169    ]
63170
64-     @staticmethod  
65-     def  propagate_args_kwargs (node ):
66-         def  visit (n : torch .fx .Node ):
67-             return  n .meta ["example_value" ]
68- 
69-         return  torch .fx .node .map_arg ((node .args , node .kwargs ), visit )
171+     def  get_real_value (self ):
172+         """ 
173+         Get the actual value represented by this variable if computation is run 
174+         using the user-provided inputs. 
70175
71-     @staticmethod  
72-     def  run_proxy (proxy , args , kwargs , nnmodule ):
73-         op  =  proxy .node .op 
74-         if  op  ==  "call_function" :
75-             return  proxy .node .target (* args , ** kwargs )
76-         elif  op  ==  "call_method" :
77-             return  getattr (args [0 ], proxy .node .target )(* args [1 :], ** kwargs )
78-         elif  op  ==  "call_module" :
79-             assert  nnmodule  is  not None 
80-             return  nnmodule (* args , ** kwargs )
81-         assert  False , op 
176+         NOTE: this runs actual tensor computation and may be 
177+         slow and memory-intensive. 
178+         """ 
179+         return  _get_real_value (self .proxy .node , self .proxy .tracer )
82180
83181    @classmethod  
84-     def  create (cls , tx , proxy , example_value = None , nnmodule = None ,  ** options ):
182+     def  create (cls , tx , proxy , example_value = None , ** options ):
85183        if  "guards"  in  options  and  options ["guards" ] is  not None :
86184            tx .output .guards .update (options ["guards" ])
87185
@@ -92,82 +190,29 @@ def create(cls, tx, proxy, example_value=None, nnmodule=None, **options):
92190            return  cls (proxy , ** options )
93191
94192        use_fake_tensors  =  fake_tensors_available  and  config .fake_tensor_propagation 
95-         if  use_fake_tensors :
96-             fake_wrapper  =  functools .partial (
97-                 wrap_to_fake_tensor , fake_mode = tx .fake_mode 
98-             )
99-             # python errors if the import isnt here 
100-             from  ..utils  import  wrap_fake_exception 
101-         else :
102193
103-             def  wrap_fake_exception (func ):
104-                 return  func ()
105- 
106-         args  =  kwargs  =  None 
107194        initial_example_value  =  example_value 
108195
109196        with  preserve_rng_state ():
110197            if  example_value  is  None :
111-                 op  =  proxy .node .op 
112-                 args , kwargs  =  cls .propagate_args_kwargs (proxy .node )
113198                if  use_fake_tensors :
114-                     args  =  tree_map (fake_wrapper , args )
115-                     kwargs  =  tree_map (fake_wrapper , kwargs )
116-                     if  op  ==  "call_module"  and  not  is_lazy_module (nnmodule ):
117-                         nnmodule  =  deepcopy_to_fake_tensor (nnmodule , tx .fake_mode )
118- 
119-                     def  context ():
120-                         if  hasattr (py_dispatch , "enable_torch_dispatch_mode" ):
121-                             return  py_dispatch .enable_torch_dispatch_mode (tx .fake_mode )
122-                         else :
123-                             return  tx .fake_mode 
124- 
199+                     example_value  =  _get_fake_value (proxy .node , tx )
125200                else :
126-                     context  =  contextlib .nullcontext 
127-                     if  op  ==  "call_module"  and  not  is_lazy_module (nnmodule ):
128-                         nnmodule  =  copy .deepcopy (nnmodule )
129- 
130-                 if  op  ==  "call_module"  and  is_lazy_module (nnmodule ):
131-                     assert  nnmodule  is  not None 
132-                     # In the case of a lazy module, we want to run 
133-                     # the pre-hooks which initialize it 
134-                     example_value  =  nnmodule (* args , ** kwargs )
135-                 try :
136-                     with  context ():
137-                         example_value  =  wrap_fake_exception (
138-                             lambda : cls .run_proxy (proxy , args , kwargs , nnmodule )
139-                         )
140-                 except  Unsupported :
141-                     raise 
142-                 except  RuntimeError  as  e :
143-                     if  use_fake_tensors  and  isinstance (e , DataDependentOutputException ):
144-                         if  (
145-                             config .capture_scalar_outputs 
146-                             and  proxy .node .target  ==  "item" 
147-                         ):
148-                             example_value  =  torch .zeros (
149-                                 size = (), dtype = args [0 ].dtype 
150-                             ).item ()
151-                         else :
152-                             unimplemented (f"data dependent operator: { e .func }  )
153-                     elif  use_fake_tensors  and  isinstance (
154-                         e , DynamicOutputShapeException 
155-                     ):
156-                         unimplemented (f"dynamic shape operator: { e .func }  )
157-                     else :
158-                         raise  TorchRuntimeError () from  e 
201+                     example_value  =  _get_real_value (proxy .node , tx .output )
202+ 
159203            else :
204+                 proxy .tracer .real_value_cache [proxy .node ] =  _clone_input (example_value )
160205                if  use_fake_tensors :
206+                     fake_wrapper  =  functools .partial (
207+                         wrap_to_fake_tensor , fake_mode = tx .fake_mode 
208+                     )
161209                    example_value  =  fake_wrapper (example_value )
162210
163211        if  isinstance (example_value , torch .Tensor ):
164212            is_parameter  =  isinstance (example_value , torch .nn .Parameter )
165213            parameter_value  =  initial_example_value  if  is_parameter  else  None 
166214
167-             # tensor subclasses will not be converted to FakeTensors and need to be cloned 
168-             if  not  use_fake_tensors  or  not  isinstance (example_value , FakeTensor ):
169-                 # NB: ensure strides are preserved 
170-                 example_value  =  clone_input (example_value )
215+             example_value  =  _clone_input (example_value )
171216            proxy .node .meta ["example_value" ] =  example_value 
172217            specialized_props  =  cls .specialize (example_value )
173218            if  use_fake_tensors  and  isinstance (example_value , FakeTensor ):
0 commit comments