|
8 | 8 | import warnings |
9 | 9 | from contextlib import ExitStack, contextmanager |
10 | 10 | from types import MethodType |
11 | | -from typing import Optional, Union |
| 11 | +from typing import Any, Optional, Union |
12 | 12 |
|
13 | 13 | import torch |
14 | | -from torch._dynamo.functional_export import dynamo_graph_capture_for_export |
| 14 | +from torch._dynamo.functional_export import _dynamo_graph_capture_for_export |
15 | 15 | from torch._functorch.aot_autograd import ( |
16 | 16 | aot_compile_joint_with_descriptors, |
17 | 17 | aot_export_joint_with_descriptors, |
|
23 | 23 | from torch._subclasses import FakeTensorMode |
24 | 24 | from torch.distributed.fsdp import MixedPrecisionPolicy |
25 | 25 | from torch.distributed.tensor import DeviceMesh |
| 26 | +from torch.export._trace import _restore_state_dict |
26 | 27 | from torch.export._unlift import _assign_attr |
27 | 28 | from torch.export.unflatten import _AttrKind |
28 | 29 | from torch.fx.experimental.symbolic_shapes import ShapeEnv |
@@ -165,6 +166,21 @@ def enable_local_map_wrapping(): |
165 | 166 | yield |
166 | 167 |
|
167 | 168 |
|
| 169 | +def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module: |
| 170 | + """ |
| 171 | + Thin wrapper around graph capture output that restores the |
| 172 | + original calling convention and attribute fqn. TODO: |
| 173 | + 1) Use bytecode for calling convention instead of pytree for more |
| 174 | + seamless UX. |
| 175 | + 2) Attach guards |
| 176 | + 3) Be more careful about tensor constants names. |
| 177 | + """ |
| 178 | + with torch._dynamo.config.patch(install_free_tensors=True): |
| 179 | + gm = _dynamo_graph_capture_for_export(model)(*inputs) |
| 180 | + _restore_state_dict(model, gm) |
| 181 | + return gm |
| 182 | + |
| 183 | + |
168 | 184 | class AutoParallel: |
169 | 185 | """ |
170 | 186 | Args: |
@@ -289,7 +305,7 @@ def build_model_graph(self): |
289 | 305 | with set_dtype_cast( |
290 | 306 | True |
291 | 307 | ), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): |
292 | | - torch_ir_with_fqn = dynamo_graph_capture_for_export(self.model)(*inputs) |
| 308 | + torch_ir_with_fqn = _export(self.model, inputs) |
293 | 309 | # TODO Cna't use fake mode here because it clashes with the user level |
294 | 310 | # fake mode. Ideally dynamo should reuse the user level fake mode. |
295 | 311 | self.joint_with_descriptors = aot_export_joint_with_descriptors( |
|
0 commit comments