Skip to content

Commit 6c4d1b5

Browse files
committed
[reland] Update export api to be the latest version.
Summary: the issue from torch nightly has been fixed for the new export API. relanding. Test Plan: Also tested on #227 ``` =================================================================================== test session starts =================================================================================== platform linux -- Python 3.12.11, pytest-7.3.2, pluggy-1.6.0 rootdir: /data/users/zhxchen17/autoparallel plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, subtests-0.13.1, rerunfailures-14.0, flakefinder-1.1.0, cpp-2.3.0, anyio-4.10.0 collected 21 items tests/test_aot_eager.py ..x [ 14%] tests/test_api.py .... [ 33%] tests/test_dtensor.py .... [ 52%] tests/test_optimize_placement.py ........ [ 90%] tests/test_ordered_sharding.py .. [100%] ======================================================================== 20 passed, 1 xfailed in 86.30s (0:01:26) ========================================================================= ```
1 parent c379849 commit 6c4d1b5

File tree

1 file changed

+3
-19
lines changed

1 file changed

+3
-19
lines changed

autoparallel/api.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import warnings
99
from contextlib import ExitStack, contextmanager
1010
from types import MethodType
11-
from typing import Any, Optional, Union
11+
from typing import Optional, Union
1212

1313
import torch
14-
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
14+
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
1515
from torch._functorch.aot_autograd import (
1616
aot_compile_joint_with_descriptors,
1717
aot_export_joint_with_descriptors,
@@ -23,7 +23,6 @@
2323
from torch._subclasses import FakeTensorMode
2424
from torch.distributed.fsdp import MixedPrecisionPolicy
2525
from torch.distributed.tensor import DeviceMesh
26-
from torch.export._trace import _restore_state_dict
2726
from torch.export._unlift import _assign_attr
2827
from torch.export.unflatten import _AttrKind
2928
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@@ -159,21 +158,6 @@ def enable_local_map_wrapping():
159158
yield
160159

161160

162-
def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
163-
"""
164-
Thin wrapper around graph capture output that restores the
165-
original calling convention and attribute fqn. TODO:
166-
1) Use bytecode for calling convention instead of pytree for more
167-
seamless UX.
168-
2) Attach guards
169-
3) Be more careful about tensor constants names.
170-
"""
171-
with torch._dynamo.config.patch(install_free_tensors=True):
172-
gm = _dynamo_graph_capture_for_export(model)(*inputs)
173-
_restore_state_dict(model, gm)
174-
return gm
175-
176-
177161
class AutoParallel:
178162
"""
179163
Args:
@@ -298,7 +282,7 @@ def build_model_graph(self):
298282
with set_dtype_cast(
299283
True
300284
), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
301-
torch_ir_with_fqn = _export(self.model, inputs)
285+
torch_ir_with_fqn = dynamo_graph_capture_for_export(self.model)(*inputs)
302286
# TODO Cna't use fake mode here because it clashes with the user level
303287
# fake mode. Ideally dynamo should reuse the user level fake mode.
304288
self.joint_with_descriptors = aot_export_joint_with_descriptors(

0 commit comments

Comments
 (0)