Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][CodeStyle] replace self-defined parameterized function with third-party library #48701

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 3 additions & 28 deletions python/paddle/fluid/tests/unittests/autograd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import typing

import numpy as np
from parameterized import parameterized_class

import paddle
from paddle.incubate.autograd.utils import as_tensors
Expand Down Expand Up @@ -281,7 +282,7 @@ def square(x):
# Parameterized Test Utils.
##########################################################

TEST_CASE_NAME = 'suffix'
TEST_CASE_NAME = 'name'


def place(devices, key='place'):
Expand Down Expand Up @@ -320,33 +321,7 @@ def parameterize(fields, values=None):
values (Sequence, optional): The test cases sequence. Defaults to None.

"""
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]

def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for i, values in enumerate(params):
test_cls = dict(cls.__dict__)
values = {
k: staticmethod(v) if callable(v) else v
for k, v in values.items()
}
test_cls.update(values)
name = cls.__name__ + str(i)
name = (
name + '.' + values.get('suffix')
if values.get('suffix')
else name
)

test_cls_module[name] = type(name, (cls,), test_cls)

for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls

return decorate
return parameterized_class(fields, values)


##########################################################
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/distribution/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

DEFAULT_DTYPE = 'float64'

TEST_CASE_NAME = 'suffix'
TEST_CASE_NAME = 'name'
# All test case will use float64 for compare percision, refs:
# https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64
RTOL = {
Expand Down
199 changes: 7 additions & 192 deletions python/paddle/fluid/tests/unittests/distribution/parameterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
import inspect
import re

import sys
from unittest import SkipTest

import config
import numpy as np
from parameterized import parameterized, parameterized_class

TEST_CASE_NAME = 'suffix'
TEST_CASE_NAME = 'name'


def xrand(shape=(10, 10, 10), dtype=config.DEFAULT_DTYPE, min=1.0, max=10.0):
Expand All @@ -47,190 +44,8 @@ def decorate(cls):
return decorate


def parameterize_cls(fields, values=None):
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]

def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for k, v in enumerate(params):
test_cls = dict(cls.__dict__)
test_cls.update(v)
name = cls.__name__ + str(k)
name = name + '.' + v.get('suffix') if v.get('suffix') else name

test_cls_module[name] = type(name, (cls,), test_cls)

for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls

return decorate


def parameterize_func(
input, name_func=None, doc_func=None, skip_on_empty=False
):
name_func = name_func or default_name_func

def wrapper(f, instance=None):
frame_locals = inspect.currentframe().f_back.f_locals

parameters = input_as_callable(input)()

if not parameters:
if not skip_on_empty:
raise ValueError(
"Parameters iterable is empty (hint: use "
"`parameterized.expand([], skip_on_empty=True)` to skip "
"this test when the input is empty)"
)
return functools.wraps(f)(skip_on_empty_helper)

digits = len(str(len(parameters) - 1))
for num, p in enumerate(parameters):
name = name_func(
f, "{num:0>{digits}}".format(digits=digits, num=num), p
)
# If the original function has patches applied by 'mock.patch',
# re-construct all patches on the just former decoration layer
# of param_as_standalone_func so as not to share
# patch objects between new functions
nf = reapply_patches_if_need(f)
frame_locals[name] = param_as_standalone_func(p, nf, name)
frame_locals[name].__doc__ = f.__doc__

# Delete original patches to prevent new function from evaluating
# original patching object as well as re-constrfucted patches.
delete_patches_if_need(f)

f.__test__ = False

return wrapper


def reapply_patches_if_need(func):
def dummy_wrapper(orgfunc):
@functools.wraps(orgfunc)
def dummy_func(*args, **kwargs):
return orgfunc(*args, **kwargs)

return dummy_func

if hasattr(func, 'patchings'):
func = dummy_wrapper(func)
tmp_patchings = func.patchings
delattr(func, 'patchings')
for patch_obj in tmp_patchings:
func = patch_obj.decorate_callable(func)
return func


def delete_patches_if_need(func):
if hasattr(func, 'patchings'):
func.patchings[:] = []


def default_name_func(func, num, p):
base_name = func.__name__
name_suffix = "_%s" % (num,)

if len(p.args) > 0 and isinstance(p.args[0], str):
name_suffix += "_" + to_safe_name(p.args[0])
return base_name + name_suffix


def param_as_standalone_func(p, func, name):
@functools.wraps(func)
def standalone_func(*a):
return func(*(a + p.args), **p.kwargs)

standalone_func.__name__ = name

# place_as is used by py.test to determine what source file should be
# used for this test.
standalone_func.place_as = func

# Remove __wrapped__ because py.test will try to look at __wrapped__
# to determine which parameters should be used with this test case,
# and obviously we don't need it to do any parameterization.
try:
del standalone_func.__wrapped__
except AttributeError:
pass
return standalone_func


def input_as_callable(input):
if callable(input):
return lambda: check_input_values(input())
input_values = check_input_values(input)
return lambda: input_values


def check_input_values(input_values):
if not isinstance(input_values, list):
input_values = list(input_values)
return [param.from_decorator(p) for p in input_values]


def skip_on_empty_helper(*a, **kw):
raise SkipTest("parameterized input is empty")


_param = collections.namedtuple("param", "args kwargs")


class param(_param):
def __new__(cls, *args, **kwargs):
return _param.__new__(cls, args, kwargs)

@classmethod
def explicit(cls, args=None, kwargs=None):
"""Creates a ``param`` by explicitly specifying ``args`` and
``kwargs``::
>>> param.explicit([1,2,3])
param(*(1, 2, 3))
>>> param.explicit(kwargs={"foo": 42})
param(*(), **{"foo": "42"})
"""
args = args or ()
kwargs = kwargs or {}
return cls(*args, **kwargs)

@classmethod
def from_decorator(cls, args):
"""Returns an instance of ``param()`` for ``@parameterized`` argument
``args``::
>>> param.from_decorator((42, ))
param(args=(42, ), kwargs={})
>>> param.from_decorator("foo")
param(args=("foo", ), kwargs={})
"""
if isinstance(args, param):
return args
elif isinstance(args, str):
args = (args,)
try:
return cls(*args)
except TypeError as e:
if "after * must be" not in str(e):
raise
raise TypeError(
"Parameters must be tuples, but %r is not (hint: use '(%r, )')"
% (args, args),
)

def __repr__(self):
return "param(*%r, **%r)" % self


def to_safe_name(s):
return str(re.sub("[^a-zA-Z0-9_]+", "_", s))


# alias
parameterize = parameterize_func
param_cls = parameterize_cls
param_func = parameterize_func
parameterize = parameterized.expand
param_func = parameterized.expand
parameterize_cls = parameterized_class
param_cls = parameterized_class
Loading