Skip to content

Commit cb3059e

Browse files
authored
Revert "Update export api to be the latest version. (#216)" (#218)
This reverts commit 26909fd. stack-info: PR: #218, branch: xmfan/stack/15
1 parent 2f2664a commit cb3059e

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

autoparallel/api.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import warnings
99
from contextlib import ExitStack, contextmanager
1010
from types import MethodType
11-
from typing import Optional, Union
11+
from typing import Any, Optional, Union
1212

1313
import torch
14-
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
14+
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
1515
from torch._functorch.aot_autograd import (
1616
aot_compile_joint_with_descriptors,
1717
aot_export_joint_with_descriptors,
@@ -23,6 +23,7 @@
2323
from torch._subclasses import FakeTensorMode
2424
from torch.distributed.fsdp import MixedPrecisionPolicy
2525
from torch.distributed.tensor import DeviceMesh
26+
from torch.export._trace import _restore_state_dict
2627
from torch.export._unlift import _assign_attr
2728
from torch.export.unflatten import _AttrKind
2829
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@@ -165,6 +166,21 @@ def enable_local_map_wrapping():
165166
yield
166167

167168

169+
def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
170+
"""
171+
Thin wrapper around graph capture output that restores the
172+
original calling convention and attribute fqn. TODO:
173+
1) Use bytecode for calling convention instead of pytree for more
174+
seamless UX.
175+
2) Attach guards
176+
3) Be more careful about tensor constants names.
177+
"""
178+
with torch._dynamo.config.patch(install_free_tensors=True):
179+
gm = _dynamo_graph_capture_for_export(model)(*inputs)
180+
_restore_state_dict(model, gm)
181+
return gm
182+
183+
168184
class AutoParallel:
169185
"""
170186
Args:
@@ -289,7 +305,7 @@ def build_model_graph(self):
289305
with set_dtype_cast(
290306
True
291307
), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
292-
torch_ir_with_fqn = dynamo_graph_capture_for_export(self.model)(*inputs)
308+
torch_ir_with_fqn = _export(self.model, inputs)
293309
# TODO Cna't use fake mode here because it clashes with the user level
294310
# fake mode. Ideally dynamo should reuse the user level fake mode.
295311
self.joint_with_descriptors = aot_export_joint_with_descriptors(

0 commit comments

Comments
 (0)