diff --git a/python/paddle/fluid/tests/unittests/autograd/utils.py b/python/paddle/fluid/tests/unittests/autograd/utils.py index f3f79822537ad..c7a1a3a4185bd 100644 --- a/python/paddle/fluid/tests/unittests/autograd/utils.py +++ b/python/paddle/fluid/tests/unittests/autograd/utils.py @@ -17,6 +17,7 @@ import typing import numpy as np +from parameterized import parameterized_class import paddle from paddle.incubate.autograd.utils import as_tensors @@ -281,7 +282,7 @@ def square(x): # Parameterized Test Utils. ########################################################## -TEST_CASE_NAME = 'suffix' +TEST_CASE_NAME = 'name' def place(devices, key='place'): @@ -320,33 +321,7 @@ def parameterize(fields, values=None): values (Sequence, optional): The test cases sequence. Defaults to None. """ - fields = [fields] if isinstance(fields, str) else fields - params = [dict(zip(fields, vals)) for vals in values] - - def decorate(cls): - test_cls_module = sys.modules[cls.__module__].__dict__ - for i, values in enumerate(params): - test_cls = dict(cls.__dict__) - values = { - k: staticmethod(v) if callable(v) else v - for k, v in values.items() - } - test_cls.update(values) - name = cls.__name__ + str(i) - name = ( - name + '.' + values.get('suffix') - if values.get('suffix') - else name - ) - - test_cls_module[name] = type(name, (cls,), test_cls) - - for m in list(cls.__dict__): - if m.startswith("test"): - delattr(cls, m) - return cls - - return decorate + return parameterized_class(fields, values) ########################################################## diff --git a/python/paddle/fluid/tests/unittests/distribution/config.py b/python/paddle/fluid/tests/unittests/distribution/config.py index 29a27890bad2c..4aa3cd907f74d 100644 --- a/python/paddle/fluid/tests/unittests/distribution/config.py +++ b/python/paddle/fluid/tests/unittests/distribution/config.py @@ -19,7 +19,7 @@ DEFAULT_DTYPE = 'float64' -TEST_CASE_NAME = 'suffix' +TEST_CASE_NAME = 'name' # All test case will use float64 for compare percision, refs: # https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64 RTOL = { diff --git a/python/paddle/fluid/tests/unittests/distribution/parameterize.py b/python/paddle/fluid/tests/unittests/distribution/parameterize.py index 324a5e4c6a0cd..182fd5b86ab6d 100644 --- a/python/paddle/fluid/tests/unittests/distribution/parameterize.py +++ b/python/paddle/fluid/tests/unittests/distribution/parameterize.py @@ -11,17 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections -import functools -import inspect -import re + import sys -from unittest import SkipTest import config import numpy as np +from parameterized import parameterized, parameterized_class -TEST_CASE_NAME = 'suffix' +TEST_CASE_NAME = 'name' def xrand(shape=(10, 10, 10), dtype=config.DEFAULT_DTYPE, min=1.0, max=10.0): @@ -47,190 +44,8 @@ def decorate(cls): return decorate -def parameterize_cls(fields, values=None): - fields = [fields] if isinstance(fields, str) else fields - params = [dict(zip(fields, vals)) for vals in values] - - def decorate(cls): - test_cls_module = sys.modules[cls.__module__].__dict__ - for k, v in enumerate(params): - test_cls = dict(cls.__dict__) - test_cls.update(v) - name = cls.__name__ + str(k) - name = name + '.' + v.get('suffix') if v.get('suffix') else name - - test_cls_module[name] = type(name, (cls,), test_cls) - - for m in list(cls.__dict__): - if m.startswith("test"): - delattr(cls, m) - return cls - - return decorate - - -def parameterize_func( - input, name_func=None, doc_func=None, skip_on_empty=False -): - name_func = name_func or default_name_func - - def wrapper(f, instance=None): - frame_locals = inspect.currentframe().f_back.f_locals - - parameters = input_as_callable(input)() - - if not parameters: - if not skip_on_empty: - raise ValueError( - "Parameters iterable is empty (hint: use " - "`parameterized.expand([], skip_on_empty=True)` to skip " - "this test when the input is empty)" - ) - return functools.wraps(f)(skip_on_empty_helper) - - digits = len(str(len(parameters) - 1)) - for num, p in enumerate(parameters): - name = name_func( - f, "{num:0>{digits}}".format(digits=digits, num=num), p - ) - # If the original function has patches applied by 'mock.patch', - # re-construct all patches on the just former decoration layer - # of param_as_standalone_func so as not to share - # patch objects between new functions - nf = reapply_patches_if_need(f) - frame_locals[name] = param_as_standalone_func(p, nf, name) - frame_locals[name].__doc__ = f.__doc__ - - # Delete original patches to prevent new function from evaluating - # original patching object as well as re-constrfucted patches. - delete_patches_if_need(f) - - f.__test__ = False - - return wrapper - - -def reapply_patches_if_need(func): - def dummy_wrapper(orgfunc): - @functools.wraps(orgfunc) - def dummy_func(*args, **kwargs): - return orgfunc(*args, **kwargs) - - return dummy_func - - if hasattr(func, 'patchings'): - func = dummy_wrapper(func) - tmp_patchings = func.patchings - delattr(func, 'patchings') - for patch_obj in tmp_patchings: - func = patch_obj.decorate_callable(func) - return func - - -def delete_patches_if_need(func): - if hasattr(func, 'patchings'): - func.patchings[:] = [] - - -def default_name_func(func, num, p): - base_name = func.__name__ - name_suffix = "_%s" % (num,) - - if len(p.args) > 0 and isinstance(p.args[0], str): - name_suffix += "_" + to_safe_name(p.args[0]) - return base_name + name_suffix - - -def param_as_standalone_func(p, func, name): - @functools.wraps(func) - def standalone_func(*a): - return func(*(a + p.args), **p.kwargs) - - standalone_func.__name__ = name - - # place_as is used by py.test to determine what source file should be - # used for this test. - standalone_func.place_as = func - - # Remove __wrapped__ because py.test will try to look at __wrapped__ - # to determine which parameters should be used with this test case, - # and obviously we don't need it to do any parameterization. - try: - del standalone_func.__wrapped__ - except AttributeError: - pass - return standalone_func - - -def input_as_callable(input): - if callable(input): - return lambda: check_input_values(input()) - input_values = check_input_values(input) - return lambda: input_values - - -def check_input_values(input_values): - if not isinstance(input_values, list): - input_values = list(input_values) - return [param.from_decorator(p) for p in input_values] - - -def skip_on_empty_helper(*a, **kw): - raise SkipTest("parameterized input is empty") - - -_param = collections.namedtuple("param", "args kwargs") - - -class param(_param): - def __new__(cls, *args, **kwargs): - return _param.__new__(cls, args, kwargs) - - @classmethod - def explicit(cls, args=None, kwargs=None): - """Creates a ``param`` by explicitly specifying ``args`` and - ``kwargs``:: - >>> param.explicit([1,2,3]) - param(*(1, 2, 3)) - >>> param.explicit(kwargs={"foo": 42}) - param(*(), **{"foo": "42"}) - """ - args = args or () - kwargs = kwargs or {} - return cls(*args, **kwargs) - - @classmethod - def from_decorator(cls, args): - """Returns an instance of ``param()`` for ``@parameterized`` argument - ``args``:: - >>> param.from_decorator((42, )) - param(args=(42, ), kwargs={}) - >>> param.from_decorator("foo") - param(args=("foo", ), kwargs={}) - """ - if isinstance(args, param): - return args - elif isinstance(args, str): - args = (args,) - try: - return cls(*args) - except TypeError as e: - if "after * must be" not in str(e): - raise - raise TypeError( - "Parameters must be tuples, but %r is not (hint: use '(%r, )')" - % (args, args), - ) - - def __repr__(self): - return "param(*%r, **%r)" % self - - -def to_safe_name(s): - return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) - - # alias -parameterize = parameterize_func -param_cls = parameterize_cls -param_func = parameterize_func +parameterize = parameterized.expand +param_func = parameterized.expand +parameterize_cls = parameterized_class +param_cls = parameterized_class diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index 1b42badd1481a..e676ada6bdc0d 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -17,6 +17,7 @@ import numpy as np import scipy.fft +from parameterized import parameterized_class import paddle @@ -24,7 +25,7 @@ if paddle.is_compiled_with_cuda(): DEVICES.append(paddle.CUDAPlace(0)) -TEST_CASE_NAME = 'suffix' +TEST_CASE_NAME = 'name' # All test case will use float64 for compare percision, refs: # https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64 RTOL = { @@ -67,31 +68,8 @@ def decorate(cls): return decorate -def parameterize(fields, values=None): - - fields = [fields] if isinstance(fields, str) else fields - params = [dict(zip(fields, vals)) for vals in values] - - def decorate(cls): - test_cls_module = sys.modules[cls.__module__].__dict__ - for k, v in enumerate(params): - test_cls = dict(cls.__dict__) - test_cls.update(v) - name = cls.__name__ + str(k) - name = name + '.' + v.get('suffix') if v.get('suffix') else name - - test_cls_module[name] = type(name, (cls,), test_cls) - - for m in list(cls.__dict__): - if m.startswith("test"): - delattr(cls, m) - return cls - - return decorate - - @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), @@ -130,7 +108,7 @@ def test_fft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), @@ -169,7 +147,7 @@ def test_fft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), @@ -209,7 +187,7 @@ def test_fft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), @@ -255,7 +233,7 @@ def test_fft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -323,7 +301,7 @@ def test_fft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), @@ -362,7 +340,7 @@ def test_fftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), @@ -401,7 +379,7 @@ def test_ifftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ( @@ -465,7 +443,7 @@ def test_hfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ( @@ -529,7 +507,7 @@ def test_irfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ( @@ -593,7 +571,7 @@ def test_irfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ( @@ -657,7 +635,7 @@ def test_hfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ ( @@ -715,7 +693,7 @@ def test_hfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ ( @@ -772,7 +750,7 @@ def test_irfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -854,7 +832,7 @@ def test_hfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -926,7 +904,7 @@ def test_irfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1017,7 +995,7 @@ def test_hfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1106,7 +1084,7 @@ def test_irfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1196,7 +1174,7 @@ def test_hfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1275,7 +1253,7 @@ def test_irfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), @@ -1313,7 +1291,7 @@ def test_rfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), @@ -1354,7 +1332,7 @@ def test_rfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), @@ -1393,7 +1371,7 @@ def test_rfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1453,7 +1431,7 @@ def test_rfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), @@ -1491,7 +1469,7 @@ def test_rfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1540,7 +1518,7 @@ def test_rfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), @@ -1578,7 +1556,7 @@ def test_ihfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), @@ -1618,7 +1596,7 @@ def test_ihfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), @@ -1657,7 +1635,7 @@ def test_ihfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1725,7 +1703,7 @@ def test_ihfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), @@ -1763,7 +1741,7 @@ def test_ihfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1804,7 +1782,7 @@ def test_ihfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'n', 'd', 'dtype'), [ ('test_without_d', 20, 1, 'float32'), @@ -1824,7 +1802,7 @@ def test_fftfreq(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'n', 'd', 'dtype'), [ ('test_without_d', 20, 1, 'float32'), @@ -1844,7 +1822,7 @@ def test_rfftfreq(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ ('test_1d', np.random.randn(10), (0,), 'float64'), @@ -1873,7 +1851,7 @@ def test_fftshift(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'axes'), [ ('test_1d', np.random.randn(10), (0,), 'float64'), diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py index 79b8fb2798252..652eff19dbc80 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py @@ -17,15 +17,8 @@ import numpy as np import scipy.fft -from test_fft import ( - ATOL, - DEVICES, - RTOL, - TEST_CASE_NAME, - parameterize, - place, - rand_x, -) +from parameterized import parameterized_class +from test_fft import ATOL, DEVICES, RTOL, TEST_CASE_NAME, place, rand_x import paddle @@ -47,7 +40,7 @@ def stgraph(func, place, x, n, axes, norm): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), @@ -91,7 +84,7 @@ def test_static_rfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), @@ -125,7 +118,7 @@ def test_fft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), @@ -170,7 +163,7 @@ def test_static_fft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ # ('test_x_not_tensor', [0, 1], None, (0, 1), 'backward', ValueError), @@ -219,7 +212,7 @@ def test_static_fft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), @@ -263,7 +256,7 @@ def test_static_fftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -310,7 +303,7 @@ def test_static_rfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ( @@ -375,7 +368,7 @@ def test_hfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ( @@ -440,7 +433,7 @@ def test_irfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ( @@ -505,7 +498,7 @@ def test_static_irfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ( @@ -570,7 +563,7 @@ def test_static_hfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ ( @@ -635,7 +628,7 @@ def test_static_hfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ ( @@ -693,7 +686,7 @@ def test_static_irfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -787,7 +780,7 @@ def test_static_hfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -882,7 +875,7 @@ def test_static_irfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -985,7 +978,7 @@ def test_static_hfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1088,7 +1081,7 @@ def test_static_irfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1191,7 +1184,7 @@ def test_static_hfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1287,7 +1280,7 @@ def test_static_irfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), @@ -1324,7 +1317,7 @@ def test_static_rfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), @@ -1363,7 +1356,7 @@ def test_rfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), @@ -1401,7 +1394,7 @@ def test_static_rfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1458,7 +1451,7 @@ def test_static_rfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), @@ -1495,7 +1488,7 @@ def test_static_rfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1542,7 +1535,7 @@ def test_static_rfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), @@ -1579,7 +1572,7 @@ def test_static_ihfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), @@ -1618,7 +1611,7 @@ def test_static_ihfft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), @@ -1656,7 +1649,7 @@ def test_static_ihfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1721,7 +1714,7 @@ def test_static_ihfft2(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), @@ -1758,7 +1751,7 @@ def test_static_ihfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ ( @@ -1797,7 +1790,7 @@ def test_static_ihfftn(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ ('test_1d', np.random.randn(10), (0,), 'float64'), @@ -1829,7 +1822,7 @@ def test_fftshift(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'axes'), [ ('test_1d', np.random.randn(10), (0,), 'float64'), diff --git a/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py b/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py index 783baf6527ec9..eab29210d31af 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py +++ b/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import sys import numpy as np @@ -30,51 +29,11 @@ sys.path.append("../") from op_test import OpTest +from parameterized import parameterized_class paddle.enable_static() -TEST_CASE_NAME = 'test_case' - - -def parameterize(attrs, input_values=None): - - if isinstance(attrs, str): - attrs = [attrs] - input_dicts = ( - attrs - if input_values is None - else [dict(zip(attrs, vals)) for vals in input_values] - ) - - def decorator(base_class): - test_class_module = sys.modules[base_class.__module__].__dict__ - for idx, input_dict in enumerate(input_dicts): - test_class_dict = dict(base_class.__dict__) - test_class_dict.update(input_dict) - - name = class_name(base_class, idx, input_dict) - - test_class_module[name] = type(name, (base_class,), test_class_dict) - - for method_name in list(base_class.__dict__): - if method_name.startswith("test"): - delattr(base_class, method_name) - return base_class - - return decorator - - -def to_safe_name(s): - return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) - - -def class_name(cls, num, params_dict): - suffix = to_safe_name( - next((v for v in params_dict.values() if isinstance(v, str)), "") - ) - if TEST_CASE_NAME in params_dict: - suffix = to_safe_name(params_dict["test_case"]) - return "{}_{}{}".format(cls.__name__, num, suffix and "_" + suffix) +TEST_CASE_NAME = 'name' def fft_c2c_python_api(x, axes, norm, forward): @@ -89,7 +48,7 @@ def fft_c2r_python_api(x, axes, norm, forward, last_dim_size=0): return _C_ops.fft_c2r(x, axes, norm, forward, last_dim_size) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward'), [ ( @@ -175,7 +134,7 @@ def test_check_grad(self): ) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward', 'last_dim_size'), [ ( @@ -272,7 +231,7 @@ def test_check_grad(self): ) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward', 'onesided'), [ ( diff --git a/python/paddle/fluid/tests/unittests/test_signal.py b/python/paddle/fluid/tests/unittests/test_signal.py index 19a0dd433ce2c..bac292b493f0e 100644 --- a/python/paddle/fluid/tests/unittests/test_signal.py +++ b/python/paddle/fluid/tests/unittests/test_signal.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import sys import unittest @@ -20,6 +19,7 @@ import scipy.signal from numpy import fft from numpy.lib.stride_tricks import as_strided +from parameterized import parameterized_class import paddle @@ -28,7 +28,7 @@ DEVICES = [paddle.CPUPlace()] if paddle.is_compiled_with_cuda(): DEVICES.append(paddle.CUDAPlace(0)) -TEST_CASE_NAME = 'test_case' +TEST_CASE_NAME = 'name' # Constrain STFT block sizes to 256 KB MAX_MEM_BLOCK = 2**8 * 2**10 @@ -631,49 +631,8 @@ def rand_x( return np.random.randn(*shape).astype(dtype) -def parameterize(attrs, input_values=None): - - if isinstance(attrs, str): - attrs = [attrs] - input_dicts = ( - attrs - if input_values is None - else [dict(zip(attrs, vals)) for vals in input_values] - ) - - def decorator(base_class): - test_class_module = sys.modules[base_class.__module__].__dict__ - for idx, input_dict in enumerate(input_dicts): - test_class_dict = dict(base_class.__dict__) - test_class_dict.update(input_dict) - - name = class_name(base_class, idx, input_dict) - - test_class_module[name] = type(name, (base_class,), test_class_dict) - - for method_name in list(base_class.__dict__): - if method_name.startswith("test"): - delattr(base_class, method_name) - return base_class - - return decorator - - -def class_name(cls, num, params_dict): - suffix = to_safe_name( - next((v for v in params_dict.values() if isinstance(v, str)), "") - ) - if TEST_CASE_NAME in params_dict: - suffix = to_safe_name(params_dict["test_case"]) - return "{}_{}{}".format(cls.__name__, num, suffix and "_" + suffix) - - -def to_safe_name(s): - return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) - - @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis'), [ ('test_1d_input1', rand_x(1, np.float64, shape=[150]), 50, 15, 0), @@ -701,7 +660,7 @@ def test_frame(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis'), [ ('test_1d_input1', rand_x(1, np.float64, shape=[150]), 50, 15, 0), @@ -740,7 +699,7 @@ def test_frame_static(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis', 'expect_exception'), [ ('test_axis', rand_x(1, np.float64, shape=[150]), 50, 15, 2, ValueError), @@ -760,7 +719,7 @@ def test_frame(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'hop_length', 'axis'), [ ('test_2d_input1', rand_x(2, np.float64, shape=[3, 50]), 4, 0), @@ -783,7 +742,7 @@ def test_overlap_add(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'hop_length', 'axis'), [ ('test_2d_input1', rand_x(2, np.float64, shape=[3, 50]), 4, 0), @@ -818,7 +777,7 @@ def test_overlap_add_static(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'hop_length', 'axis', 'expect_exception'), [ ('test_axis', rand_x(2, np.float64, shape=[3, 50]), 4, 2, ValueError), @@ -863,7 +822,7 @@ def test_overlap_add(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided'), [ ('test_1d_input', rand_x(1, np.float64, shape=[160000]), 512, @@ -915,7 +874,7 @@ def test_stft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided', 'expect_exception'), [ ('test_dims', rand_x(1, np.float64, shape=[1, 2, 3]), 512, @@ -957,7 +916,7 @@ def test_stft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex'), [ ('test_2d_input', rand_x(2, np.float64, shape=[257, 471], complex=True), 512, @@ -1011,7 +970,7 @@ def test_istft(self): @place(DEVICES) -@parameterize( +@parameterized_class( (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex', 'expect_exception'), [ ('test_dims', rand_x(4, np.float64, shape=[1, 2, 3, 4], complex=True), 512,