Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 0264bca

Browse files
committed
[dynamo] Introduce get_real_value API to TensorVariable
Right now, example_value is doing two jobs: - We use it to propagate metadata (e.g. return type, shapes, etc.) throughout the graph - We use it to satisfy queries for the actual value (e.g. torch.cond, `assume_constant_result`) This is further complicated by the fact that we have two modes, one where `example_value` is a fake tensor, and one where it is a real tensor (this is the `fake_tensor_propagation` config flag). This leads to scenarios where we don't support every combination of job + mode, e.g. if `fake_tensor_propagation=False`, `assume_constant_result` is broken. This is made worse by the fact that "fake tensor mode" is the default and is required if you want dynamic shapes to work. So, this PR introduces a `get_real_value` API that just runs the graph up to `node` in order to get a concrete value. This API is orthogonal to `example_value`, so it doesn't care about `fake_tensor_propagation`. When `fake_tensor_propagation=True`: `example_value` is a fake tensor, you must use the `get_real_value` API to get a concrete value. This will be the only configuration in the future. When `fake_tensor_propagation=False`: `example_value` and `get_real_value` will produce the same value. This is redundant but we will be removing this config soon. To support this, I introduce a cache for computed real values, to memoize the work involved if we're asking for real values a lot. I attached this state to `OutputGraph` because it seems to be what historically managed `example_value` lifetimes, but idk.
1 parent cc0882d commit 0264bca

File tree

5 files changed

+132
-82
lines changed

5 files changed

+132
-82
lines changed

torchdynamo/output_graph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def __init__(
9999
self.side_effects = SideEffects()
100100
self.code_options = dict(code_options)
101101
self.output_instructions = []
102+
# Node => computed real value (see TensorVariable.get_real_value)
103+
self.real_value_cache = {}
102104

103105
# Not checkpointed
104106
self.compiler_fn = compiler_fn
@@ -146,6 +148,7 @@ def restore_graphstate(self, state):
146148
if "example_value" in node.meta:
147149
del node.meta["example_value"]
148150
self.graph.erase_node(node)
151+
self.real_value_cache.pop(node, None)
149152

150153
def count_calls(self):
151154
return count_calls(self.graph)
@@ -387,6 +390,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
387390
for node in self.graph.nodes:
388391
if "example_value" in node.meta:
389392
del node.meta["example_value"]
393+
self.real_value_cache.clear()
390394

391395
gm = fx.GraphModule(root, self.graph)
392396
gm.recompile()
@@ -459,6 +463,7 @@ def remove_unused_graphargs(self):
459463
if "example_value" in node.meta:
460464
del node.meta["example_value"]
461465
self.graph.erase_node(node)
466+
self.real_value_cache.pop(node, None)
462467

463468
self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
464469

@@ -493,6 +498,7 @@ def cleanup(self):
493498
for node in self.graph.nodes:
494499
if "example_value" in node.meta:
495500
del node.meta["example_value"]
501+
self.real_value_cache.clear()
496502

497503
def create_proxy(
498504
self,

torchdynamo/variables/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def call_function(
269269
def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
270270
def convert(x):
271271
if isinstance(x, variables.TensorVariable):
272-
return x.proxy.node.meta["example_value"]
272+
return x.get_real_value()
273273
return x.as_python_constant()
274274

275275
args = [convert(x) for x in args]

torchdynamo/variables/nn_module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def record_nn_module_stack():
204204
*proxy_args_kwargs(args, kwargs),
205205
current_tx=tx,
206206
),
207-
nnmodule=mod,
208207
**options,
209208
)
210209
else:

torchdynamo/variables/tensor.py

Lines changed: 124 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,113 @@
4646
from .lists import SizeVariable
4747

4848

49+
def _run_node(node, args, kwargs, nnmodule):
50+
op = node.op
51+
if op == "call_function":
52+
return node.target(*args, **kwargs)
53+
elif op == "call_method":
54+
return getattr(args[0], node.target)(*args[1:], **kwargs)
55+
elif op == "call_module":
56+
assert nnmodule is not None
57+
return nnmodule(*args, **kwargs)
58+
assert False, op
59+
60+
61+
def _get_real_value(node, output_graph):
62+
"""
63+
Run the actual computation represented by `node` and return the result.
64+
This will execute any dependent nodes in the graph as well.
65+
"""
66+
cache = output_graph.real_value_cache
67+
if node in cache:
68+
return cache[node]
69+
70+
op = node.op
71+
args, kwargs = torch.fx.node.map_arg(
72+
(node.args, node.kwargs),
73+
lambda n: _get_real_value(n, output_graph),
74+
)
75+
76+
if op == "call_module":
77+
nn_module = output_graph.nn_modules[node.target]
78+
if not is_lazy_module(nn_module):
79+
nn_module = copy.deepcopy(nn_module)
80+
else:
81+
# In the case of a lazy module, we want to run
82+
# the pre-hooks which initialize it
83+
nn_module(*args, **kwargs)
84+
else:
85+
nn_module = None
86+
87+
try:
88+
real_value = _run_node(node, args, kwargs, nn_module)
89+
cache[node] = real_value
90+
except RuntimeError as e:
91+
raise TorchRuntimeError() from e
92+
return real_value
93+
94+
95+
def _get_fake_value(node, tx):
96+
"""
97+
Run the computation represented by `node` using fake tensors and return the result.
98+
"""
99+
op = node.op
100+
fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=tx.fake_mode)
101+
from ..utils import wrap_fake_exception
102+
103+
def visit(n: torch.fx.Node):
104+
return n.meta["example_value"]
105+
106+
args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit)
107+
args = tree_map(fake_wrapper, args)
108+
kwargs = tree_map(fake_wrapper, kwargs)
109+
110+
nnmodule = None
111+
if op == "call_module":
112+
nnmodule = tx.output.nn_modules[node.target]
113+
114+
if not is_lazy_module(nnmodule):
115+
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
116+
117+
def context():
118+
if hasattr(py_dispatch, "enable_torch_dispatch_mode"):
119+
return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode)
120+
else:
121+
return tx.fake_mode
122+
123+
if op == "call_module" and is_lazy_module(nnmodule):
124+
assert nnmodule is not None
125+
# In the case of a lazy module, we want to run
126+
# the pre-hooks which initialize it
127+
nnmodule(*args, **kwargs)
128+
try:
129+
with context():
130+
return wrap_fake_exception(lambda: _run_node(node, args, kwargs, nnmodule))
131+
except Unsupported:
132+
raise
133+
except RuntimeError as e:
134+
if isinstance(e, DataDependentOutputException):
135+
if config.capture_scalar_outputs and node.target == "item":
136+
return torch.zeros(size=(), dtype=args[0].dtype).item()
137+
else:
138+
unimplemented(f"data dependent operator: {e.func}")
139+
elif isinstance(e, DynamicOutputShapeException):
140+
unimplemented(f"dynamic shape operator: {e.func}")
141+
else:
142+
raise TorchRuntimeError() from e
143+
144+
145+
def _clone_input(value):
146+
if isinstance(value, torch.Tensor):
147+
use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation
148+
# tensor subclasses will not be converted to FakeTensors and need to be cloned
149+
if not use_fake_tensors or not isinstance(value, FakeTensor):
150+
# NB: ensure strides are preserved
151+
value = clone_input(value)
152+
153+
return value
154+
155+
49156
class TensorVariable(VariableTracker):
50157
"""A torch.Tensor input or an intermediate value in the FX graph"""
51158

@@ -61,27 +168,18 @@ class TensorVariable(VariableTracker):
61168
"is_contiguous",
62169
]
63170

64-
@staticmethod
65-
def propagate_args_kwargs(node):
66-
def visit(n: torch.fx.Node):
67-
return n.meta["example_value"]
68-
69-
return torch.fx.node.map_arg((node.args, node.kwargs), visit)
171+
def get_real_value(self):
172+
"""
173+
Get the actual value represented by this variable if computation is run
174+
using the user-provided inputs.
70175
71-
@staticmethod
72-
def run_proxy(proxy, args, kwargs, nnmodule):
73-
op = proxy.node.op
74-
if op == "call_function":
75-
return proxy.node.target(*args, **kwargs)
76-
elif op == "call_method":
77-
return getattr(args[0], proxy.node.target)(*args[1:], **kwargs)
78-
elif op == "call_module":
79-
assert nnmodule is not None
80-
return nnmodule(*args, **kwargs)
81-
assert False, op
176+
NOTE: this runs actual tensor computation and may be
177+
slow and memory-intensive.
178+
"""
179+
return _get_real_value(self.proxy.node, self.proxy.tracer)
82180

83181
@classmethod
84-
def create(cls, tx, proxy, example_value=None, nnmodule=None, **options):
182+
def create(cls, tx, proxy, example_value=None, **options):
85183
if "guards" in options and options["guards"] is not None:
86184
tx.output.guards.update(options["guards"])
87185

@@ -92,82 +190,29 @@ def create(cls, tx, proxy, example_value=None, nnmodule=None, **options):
92190
return cls(proxy, **options)
93191

94192
use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation
95-
if use_fake_tensors:
96-
fake_wrapper = functools.partial(
97-
wrap_to_fake_tensor, fake_mode=tx.fake_mode
98-
)
99-
# python errors if the import isnt here
100-
from ..utils import wrap_fake_exception
101-
else:
102193

103-
def wrap_fake_exception(func):
104-
return func()
105-
106-
args = kwargs = None
107194
initial_example_value = example_value
108195

109196
with preserve_rng_state():
110197
if example_value is None:
111-
op = proxy.node.op
112-
args, kwargs = cls.propagate_args_kwargs(proxy.node)
113198
if use_fake_tensors:
114-
args = tree_map(fake_wrapper, args)
115-
kwargs = tree_map(fake_wrapper, kwargs)
116-
if op == "call_module" and not is_lazy_module(nnmodule):
117-
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
118-
119-
def context():
120-
if hasattr(py_dispatch, "enable_torch_dispatch_mode"):
121-
return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode)
122-
else:
123-
return tx.fake_mode
124-
199+
example_value = _get_fake_value(proxy.node, tx)
125200
else:
126-
context = contextlib.nullcontext
127-
if op == "call_module" and not is_lazy_module(nnmodule):
128-
nnmodule = copy.deepcopy(nnmodule)
129-
130-
if op == "call_module" and is_lazy_module(nnmodule):
131-
assert nnmodule is not None
132-
# In the case of a lazy module, we want to run
133-
# the pre-hooks which initialize it
134-
example_value = nnmodule(*args, **kwargs)
135-
try:
136-
with context():
137-
example_value = wrap_fake_exception(
138-
lambda: cls.run_proxy(proxy, args, kwargs, nnmodule)
139-
)
140-
except Unsupported:
141-
raise
142-
except RuntimeError as e:
143-
if use_fake_tensors and isinstance(e, DataDependentOutputException):
144-
if (
145-
config.capture_scalar_outputs
146-
and proxy.node.target == "item"
147-
):
148-
example_value = torch.zeros(
149-
size=(), dtype=args[0].dtype
150-
).item()
151-
else:
152-
unimplemented(f"data dependent operator: {e.func}")
153-
elif use_fake_tensors and isinstance(
154-
e, DynamicOutputShapeException
155-
):
156-
unimplemented(f"dynamic shape operator: {e.func}")
157-
else:
158-
raise TorchRuntimeError() from e
201+
example_value = _get_real_value(proxy.node, tx.output)
202+
159203
else:
204+
proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value)
160205
if use_fake_tensors:
206+
fake_wrapper = functools.partial(
207+
wrap_to_fake_tensor, fake_mode=tx.fake_mode
208+
)
161209
example_value = fake_wrapper(example_value)
162210

163211
if isinstance(example_value, torch.Tensor):
164212
is_parameter = isinstance(example_value, torch.nn.Parameter)
165213
parameter_value = initial_example_value if is_parameter else None
166214

167-
# tensor subclasses will not be converted to FakeTensors and need to be cloned
168-
if not use_fake_tensors or not isinstance(example_value, FakeTensor):
169-
# NB: ensure strides are preserved
170-
example_value = clone_input(example_value)
215+
example_value = _clone_input(example_value)
171216
proxy.node.meta["example_value"] = example_value
172217
specialized_props = cls.specialize(example_value)
173218
if use_fake_tensors and isinstance(example_value, FakeTensor):

torchdynamo/variables/torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def call_function(
551551

552552
def unwrap_real(arg):
553553
if isinstance(arg, TensorVariable):
554-
return arg.as_proxy().node.meta["example_value"]
554+
return arg.get_real_value()
555555
if isinstance(arg, UserFunctionVariable):
556556
return arg.fn
557557
if arg.has_unpack_var_sequence(tx):

0 commit comments

Comments
 (0)