From 3b1894ca9fc8285f3ccaf1b52a6e8fdc71c1cb4a Mon Sep 17 00:00:00 2001 From: SigureMo Date: Sun, 1 Oct 2023 01:34:53 +0800 Subject: [PATCH] move all patch code to base/__init__.py --- python/paddle/__init__.py | 10 ---------- python/paddle/base/__init__.py | 6 ++++++ python/paddle/framework/__init__.py | 3 --- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5810f6df3ef24..b6a4278e4d3ef 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -28,16 +28,6 @@ from .base import core # noqa: F401 from .batch import batch -from .framework import ( - monkey_patch_variable, - monkey_patch_tensor, - monkey_patch_math_tensor, -) - -monkey_patch_variable() -monkey_patch_tensor() -monkey_patch_math_tensor() - from .framework import ( disable_signal_handler, get_flags, diff --git a/python/paddle/base/__init__.py b/python/paddle/base/__init__.py index 24117377dab51..1a509b9bb4b80 100644 --- a/python/paddle/base/__init__.py +++ b/python/paddle/base/__init__.py @@ -34,6 +34,9 @@ raise e from . import core +from .layers.math_op_patch import monkey_patch_variable +from .dygraph.math_op_patch import monkey_patch_math_tensor +from .dygraph.tensor_patch_methods import monkey_patch_tensor # import all class inside framework into base module from . import framework @@ -202,6 +205,9 @@ def remove_flag_if_exists(name): # TODO(panyx0718): Avoid doing complex initialization logic in __init__.py. # Consider paddle.init(args) or paddle.main(args) __bootstrap__() +monkey_patch_variable() +monkey_patch_tensor() +monkey_patch_math_tensor() # NOTE(Aurelius84): clean up ExecutorCacheInfo in advance manually. atexit.register(core.clear_executor_cache) diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index fb151b5005d7e..e19f62dec222e 100755 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -44,9 +44,6 @@ from .io_utils import _unpack_saved_dict from .io_utils import _load_program_scope -from ..base.layers.math_op_patch import monkey_patch_variable -from ..base.dygraph.math_op_patch import monkey_patch_math_tensor -from ..base.dygraph.tensor_patch_methods import monkey_patch_tensor from ..base.framework import disable_signal_handler # noqa: F401 from ..base.framework import get_flags # noqa: F401 from ..base.framework import set_flags # noqa: F401