From 3e7792ead9a1e85c82befc34937581958abc2412 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Fri, 10 Sep 2021 00:08:37 +0800 Subject: [PATCH] fix argument checking (#28) * fix argument checking --- python/paddle/tensor/fft.py | 136 +++++++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/fft.py b/python/paddle/tensor/fft.py index 4a541c7ff6eeb0..1b3691d7fd6476 100644 --- a/python/paddle/tensor/fft.py +++ b/python/paddle/tensor/fft.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Sequence import numpy as np import paddle from .attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype @@ -28,28 +29,57 @@ def _check_normalization(norm): format(norm)) +def _check_fft_n(n): + if not isinstance(n, int): + raise ValueError( + "Invalid FFT argument n({}), it shoule be an integer.".format(n)) + if n <= 0: + raise ValueError( + "Invalid FFT argument n({}), it should be positive.".format(n)) + + def _check_fft_shape(x, s): ndim = x.ndim + if not isinstance(s, Sequence): + raise ValueError( + "Invaid FFT argument s({}), it should be a sequence of integers.") + if len(s) > ndim: raise ValueError( - "Length of fft sizes should not be larger than the rank of input. " - "Received, len of s: {}, rank of x: {}".format(len(s), ndim)) + "Length of FFT argument s should not be larger than the rank of input. " + "Received s: {}, rank of x: {}".format(s, ndim)) for size in s: if not isinstance(size, int) or size <= 0: raise ValueError("FFT sizes {} contains invalid value ({})".format( s, size)) +def _check_fft_axis(x, axis): + ndim = x.ndim + if not isinstance(axis, int): + raise ValueError( + "Invalid FFT axis ({}), it shoule be an integer.".format(axis)) + if axis < -ndim or axis >= ndim: + raise ValueError( + "Invalid FFT axis ({}), it should be in range [-{}, {})".format( + axis, ndim, ndim)) + + def _check_fft_axes(x, axes): ndim = x.ndim + if not isinstance(axes, Sequence): + raise ValueError( + "Invalid FFT axes ({}), it should be a sequence of integers.". + format(axes)) if len(axes) > ndim: raise ValueError( "Length of fft axes should not be larger than the rank of input. " "Received, len of axes: {}, rank of x: {}".format(len(axes), ndim)) for axis in axes: if not isinstance(axis, int) or axis < -ndim or axis >= ndim: - raise ValueError("FFT axes {} contains invalid value ({})".format( - axes, axis)) + raise ValueError( + "FFT axes {} contains invalid value ({}), it should be in range [-{}, {})". + format(axes, axis, ndim, ndim)) def _resize_fft_input(x, s, axes): @@ -89,6 +119,12 @@ def _normalize_axes(x, axes): return [item if item >= 0 else (item + ndim) for item in axes] +def _check_at_least_ndim(x, rank): + if x.ndim < rank: + raise ValueError("The rank of the input ({}) should >= {}".format( + x.ndim, rank)) + + # public APIs 1d def fft(x, n=None, axis=-1, norm="backward", name=None): if not is_complex(x): @@ -157,26 +193,92 @@ def ihfftn(x, s=None, axes=None, norm="backward", name=None): ## public APIs 2d def fft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) return fftn(x, s, axes, norm, name) def ifft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) return ifftn(x, s, axes, norm, name) def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) return rfftn(x, s, axes, norm, name) def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) return irfftn(x, s, axes, norm, name) def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) return hfftn(x, s, axes, norm, name) def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) return ihfftn(x, s, axes, norm, name) @@ -232,12 +334,14 @@ def fft_c2c(x, n, axis, norm, forward, name): if is_interger(x): x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) _check_normalization(norm) + axis = axis or -1 + _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) if n is not None: + _check_fft_n(n) s = [n] - _check_fft_shape(x, s) x = _resize_fft_input(x, s, axes) op_type = 'fft_c2c' @@ -262,11 +366,12 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name): x = paddle.cast(x, paddle.get_default_dtype()) _check_normalization(norm) axis = axis or -1 + _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) if n is not None: + _check_fft_n(n) s = [n] - _check_fft_shape(x, s) x = _resize_fft_input(x, s, axes) op_type = 'fft_r2c' check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type) @@ -298,11 +403,12 @@ def fft_c2r(x, n, axis, norm, forward, name): x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) _check_normalization(norm) axis = axis or -1 + _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) if n is not None: + _check_fft_n(n) s = [n // 2 + 1] - _check_fft_shape(x, s) x = _resize_fft_input(x, s, axes) op_type = 'fft_c2r' check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) @@ -349,6 +455,10 @@ def fftn_c2c(x, s, axes, norm, forward, name): axes_argsoft = np.argsort(axes).tolist() axes = [axes[i] for i in axes_argsoft] if s is not None: + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) s = [s[i] for i in axes_argsoft] if s is not None: @@ -391,7 +501,11 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name): axes_argsoft = np.argsort(axes[:-1]).tolist() axes = [axes[i] for i in axes_argsoft] + [axes[-1]] if s is not None: - s = [s[i] for i in axes_argsoft] + s[-1] + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) + s = [s[i] for i in axes_argsoft] + [s[-1]] if s is not None: x = _resize_fft_input(x, s, axes) @@ -442,7 +556,11 @@ def fftn_c2r(x, s, axes, norm, forward, name): axes_argsoft = np.argsort(axes[:-1]).tolist() axes = [axes[i] for i in axes_argsoft] + [axes[-1]] if s is not None: - s = [s[i] for i in axes_argsoft] + s[-1] + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) + s = [s[i] for i in axes_argsoft] + [s[-1]] if s is not None: fft_input_shape = list(s)