Skip to content

Commit

Permalink
[Dy2St] replace deprecated load_module with exec_module (PaddlePa…
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Dec 7, 2022
1 parent dbe0595 commit fb12bde
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
10 changes: 6 additions & 4 deletions python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
from paddle.utils import gast
import inspect
import importlib.util
import os
import sys
import shutil
Expand All @@ -32,6 +33,7 @@
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import assign
from functools import reduce
from importlib.machinery import SourceFileLoader
import warnings


Expand Down Expand Up @@ -71,9 +73,6 @@ def visit(self, node):
return ret


# imp is deprecated in python3
from importlib.machinery import SourceFileLoader

dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay",
"ExponentialDecay": "exponential_decay",
Expand Down Expand Up @@ -586,7 +585,10 @@ def func_prefix(func):
DEL_TEMP_DIR = False

func_name = dyfunc.__name__
module = SourceFileLoader(module_name, f.name).load_module()
loader = SourceFileLoader(module_name, f.name)
spec = importlib.util.spec_from_loader(loader.name, loader)
module = importlib.util.module_from_spec(spec)
loader.exec_module(module)
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'.
if hasattr(module, '__i_m_p_l__'):
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/utils/cpp_extension/extension_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import collections
import glob
import hashlib
import importlib.util
import json
import logging
import os
Expand Down Expand Up @@ -1070,7 +1071,9 @@ def _load_module_from_file(api_file_path, module_name, verbose=False):

# load module with RWLock
loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = loader.load_module()
spec = importlib.util.spec_from_loader(loader.name, loader)
module = importlib.util.module_from_spec(spec)
loader.exec_module(module)

return module

Expand Down

0 comments on commit fb12bde

Please sign in to comment.