Skip to content

Commit

Permalink
move all patch code to base/__init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Sep 30, 2023
1 parent 574d9ca commit 3b1894c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 13 deletions.
10 changes: 0 additions & 10 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@
from .base import core # noqa: F401
from .batch import batch

from .framework import (
monkey_patch_variable,
monkey_patch_tensor,
monkey_patch_math_tensor,
)

monkey_patch_variable()
monkey_patch_tensor()
monkey_patch_math_tensor()

from .framework import (
disable_signal_handler,
get_flags,
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
raise e

from . import core
from .layers.math_op_patch import monkey_patch_variable
from .dygraph.math_op_patch import monkey_patch_math_tensor
from .dygraph.tensor_patch_methods import monkey_patch_tensor

# import all class inside framework into base module
from . import framework
Expand Down Expand Up @@ -202,6 +205,9 @@ def remove_flag_if_exists(name):
# TODO(panyx0718): Avoid doing complex initialization logic in __init__.py.
# Consider paddle.init(args) or paddle.main(args)
__bootstrap__()
monkey_patch_variable()
monkey_patch_tensor()
monkey_patch_math_tensor()

# NOTE(Aurelius84): clean up ExecutorCacheInfo in advance manually.
atexit.register(core.clear_executor_cache)
Expand Down
3 changes: 0 additions & 3 deletions python/paddle/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@
from .io_utils import _unpack_saved_dict
from .io_utils import _load_program_scope

from ..base.layers.math_op_patch import monkey_patch_variable
from ..base.dygraph.math_op_patch import monkey_patch_math_tensor
from ..base.dygraph.tensor_patch_methods import monkey_patch_tensor
from ..base.framework import disable_signal_handler # noqa: F401
from ..base.framework import get_flags # noqa: F401
from ..base.framework import set_flags # noqa: F401
Expand Down

0 comments on commit 3b1894c

Please sign in to comment.