|
74 | 74 | import warnings
|
75 | 75 |
|
76 | 76 | import numpy as np
|
77 |
| -from numpy import array, conjugate, prod, sqrt, take |
78 | 77 |
|
79 |
| -from . import _float_utils |
80 | 78 | from . import _pydfti as mkl_fft # pylint: disable=no-name-in-module
|
| 79 | +from ._fft_utils import _check_norm, _compute_fwd_scale |
| 80 | +from ._float_utils import __downcast_float128_array |
81 | 81 |
|
82 | 82 |
|
83 |
| -def _compute_fwd_scale(norm, n, shape): |
84 |
| - _check_norm(norm) |
85 |
| - if norm in (None, "backward"): |
86 |
| - return 1.0 |
87 |
| - |
88 |
| - ss = n if n is not None else shape |
89 |
| - nn = prod(ss) |
90 |
| - fsc = 1 / nn if nn != 0 else 1 |
91 |
| - if norm == "forward": |
92 |
| - return fsc |
93 |
| - else: # norm == "ortho" |
94 |
| - return sqrt(fsc) |
95 |
| - |
96 |
| - |
97 |
| -def _check_norm(norm): |
98 |
| - if norm not in (None, "ortho", "forward", "backward"): |
99 |
| - raise ValueError( |
100 |
| - f"Invalid norm value {norm} should be None, 'ortho', 'forward', " |
101 |
| - "or 'backward'." |
| 83 | +# copied with modifications from: |
| 84 | +# https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py |
| 85 | +def _cook_nd_args(a, s=None, axes=None, invreal=False): |
| 86 | + if s is None: |
| 87 | + shapeless = True |
| 88 | + if axes is None: |
| 89 | + s = list(a.shape) |
| 90 | + else: |
| 91 | + s = np.take(a.shape, axes) |
| 92 | + else: |
| 93 | + shapeless = False |
| 94 | + s = list(s) |
| 95 | + if axes is None: |
| 96 | + if not shapeless and np.__version__ >= "2.0": |
| 97 | + msg = ( |
| 98 | + "`axes` should not be `None` if `s` is not `None` " |
| 99 | + "(Deprecated in NumPy 2.0). In a future version of NumPy, " |
| 100 | + "this will raise an error and `s[i]` will correspond to " |
| 101 | + "the size along the transformed axis specified by " |
| 102 | + "`axes[i]`. To retain current behaviour, pass a sequence " |
| 103 | + "[0, ..., k-1] to `axes` for an array of dimension k." |
| 104 | + ) |
| 105 | + warnings.warn(msg, DeprecationWarning, stacklevel=3) |
| 106 | + axes = list(range(-len(s), 0)) |
| 107 | + if len(s) != len(axes): |
| 108 | + raise ValueError("Shape and axes have different lengths.") |
| 109 | + if invreal and shapeless: |
| 110 | + s[-1] = (a.shape[axes[-1]] - 1) * 2 |
| 111 | + if None in s and np.__version__ >= "2.0": |
| 112 | + msg = ( |
| 113 | + "Passing an array containing `None` values to `s` is " |
| 114 | + "deprecated in NumPy 2.0 and will raise an error in " |
| 115 | + "a future version of NumPy. To use the default behaviour " |
| 116 | + "of the corresponding 1-D transform, pass the value matching " |
| 117 | + "the default for its `n` parameter. To use the default " |
| 118 | + "behaviour for every axis, the `s` argument can be omitted." |
102 | 119 | )
|
| 120 | + warnings.warn(msg, DeprecationWarning, stacklevel=3) |
| 121 | + # use the whole input array along axis `i` if `s[i] == -1 or None` |
| 122 | + s = [a.shape[_a] if _s in [-1, None] else _s for _s, _a in zip(s, axes)] |
| 123 | + |
| 124 | + return s, axes |
103 | 125 |
|
104 | 126 |
|
105 | 127 | def _swap_direction(norm):
|
@@ -218,7 +240,7 @@ def fft(a, n=None, axis=-1, norm=None):
|
218 | 240 |
|
219 | 241 | """
|
220 | 242 |
|
221 |
| - x = _float_utils.__downcast_float128_array(a) |
| 243 | + x = __downcast_float128_array(a) |
222 | 244 | fsc = _compute_fwd_scale(norm, n, x.shape[axis])
|
223 | 245 |
|
224 | 246 | return trycall(mkl_fft.fft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
@@ -311,7 +333,7 @@ def ifft(a, n=None, axis=-1, norm=None):
|
311 | 333 |
|
312 | 334 | """
|
313 | 335 |
|
314 |
| - x = _float_utils.__downcast_float128_array(a) |
| 336 | + x = __downcast_float128_array(a) |
315 | 337 | fsc = _compute_fwd_scale(norm, n, x.shape[axis])
|
316 | 338 |
|
317 | 339 | return trycall(mkl_fft.ifft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
@@ -402,7 +424,7 @@ def rfft(a, n=None, axis=-1, norm=None):
|
402 | 424 |
|
403 | 425 | """
|
404 | 426 |
|
405 |
| - x = _float_utils.__downcast_float128_array(a) |
| 427 | + x = __downcast_float128_array(a) |
406 | 428 | fsc = _compute_fwd_scale(norm, n, x.shape[axis])
|
407 | 429 |
|
408 | 430 | return trycall(mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
@@ -495,7 +517,7 @@ def irfft(a, n=None, axis=-1, norm=None):
|
495 | 517 |
|
496 | 518 | """
|
497 | 519 |
|
498 |
| - x = _float_utils.__downcast_float128_array(a) |
| 520 | + x = __downcast_float128_array(a) |
499 | 521 | fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
|
500 | 522 |
|
501 | 523 | return trycall(
|
@@ -581,9 +603,9 @@ def hfft(a, n=None, axis=-1, norm=None):
|
581 | 603 | """
|
582 | 604 |
|
583 | 605 | norm = _swap_direction(norm)
|
584 |
| - x = _float_utils.__downcast_float128_array(a) |
585 |
| - x = array(x, copy=True, dtype=complex) |
586 |
| - conjugate(x, out=x) |
| 606 | + x = __downcast_float128_array(a) |
| 607 | + x = np.array(x, copy=True, dtype=complex) |
| 608 | + np.conjugate(x, out=x) |
587 | 609 | fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
|
588 | 610 |
|
589 | 611 | return trycall(
|
@@ -651,61 +673,18 @@ def ihfft(a, n=None, axis=-1, norm=None):
|
651 | 673 |
|
652 | 674 | # The copy may be required for multithreading.
|
653 | 675 | norm = _swap_direction(norm)
|
654 |
| - x = _float_utils.__downcast_float128_array(a) |
655 |
| - x = array(x, copy=True, dtype=float) |
| 676 | + x = __downcast_float128_array(a) |
| 677 | + x = np.array(x, copy=True, dtype=float) |
656 | 678 | fsc = _compute_fwd_scale(norm, n, x.shape[axis])
|
657 | 679 |
|
658 | 680 | output = trycall(
|
659 | 681 | mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}
|
660 | 682 | )
|
661 | 683 |
|
662 |
| - conjugate(output, out=output) |
| 684 | + np.conjugate(output, out=output) |
663 | 685 | return output
|
664 | 686 |
|
665 | 687 |
|
666 |
| -# copied from: https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py |
667 |
| -def _cook_nd_args(a, s=None, axes=None, invreal=False): |
668 |
| - if s is None: |
669 |
| - shapeless = True |
670 |
| - if axes is None: |
671 |
| - s = list(a.shape) |
672 |
| - else: |
673 |
| - s = take(a.shape, axes) |
674 |
| - else: |
675 |
| - shapeless = False |
676 |
| - s = list(s) |
677 |
| - if axes is None: |
678 |
| - if not shapeless and np.__version__ >= "2.0": |
679 |
| - msg = ( |
680 |
| - "`axes` should not be `None` if `s` is not `None` " |
681 |
| - "(Deprecated in NumPy 2.0). In a future version of NumPy, " |
682 |
| - "this will raise an error and `s[i]` will correspond to " |
683 |
| - "the size along the transformed axis specified by " |
684 |
| - "`axes[i]`. To retain current behaviour, pass a sequence " |
685 |
| - "[0, ..., k-1] to `axes` for an array of dimension k." |
686 |
| - ) |
687 |
| - warnings.warn(msg, DeprecationWarning, stacklevel=3) |
688 |
| - axes = list(range(-len(s), 0)) |
689 |
| - if len(s) != len(axes): |
690 |
| - raise ValueError("Shape and axes have different lengths.") |
691 |
| - if invreal and shapeless: |
692 |
| - s[-1] = (a.shape[axes[-1]] - 1) * 2 |
693 |
| - if None in s and np.__version__ >= "2.0": |
694 |
| - msg = ( |
695 |
| - "Passing an array containing `None` values to `s` is " |
696 |
| - "deprecated in NumPy 2.0 and will raise an error in " |
697 |
| - "a future version of NumPy. To use the default behaviour " |
698 |
| - "of the corresponding 1-D transform, pass the value matching " |
699 |
| - "the default for its `n` parameter. To use the default " |
700 |
| - "behaviour for every axis, the `s` argument can be omitted." |
701 |
| - ) |
702 |
| - warnings.warn(msg, DeprecationWarning, stacklevel=3) |
703 |
| - # use the whole input array along axis `i` if `s[i] == -1 or None` |
704 |
| - s = [a.shape[_a] if _s in [-1, None] else _s for _s, _a in zip(s, axes)] |
705 |
| - |
706 |
| - return s, axes |
707 |
| - |
708 |
| - |
709 | 688 | def fftn(a, s=None, axes=None, norm=None):
|
710 | 689 | """
|
711 | 690 | Compute the N-dimensional discrete Fourier Transform.
|
@@ -806,7 +785,7 @@ def fftn(a, s=None, axes=None, norm=None):
|
806 | 785 |
|
807 | 786 | """
|
808 | 787 |
|
809 |
| - x = _float_utils.__downcast_float128_array(a) |
| 788 | + x = __downcast_float128_array(a) |
810 | 789 | s, axes = _cook_nd_args(x, s, axes)
|
811 | 790 | fsc = _compute_fwd_scale(norm, s, x.shape)
|
812 | 791 |
|
@@ -913,7 +892,7 @@ def ifftn(a, s=None, axes=None, norm=None):
|
913 | 892 |
|
914 | 893 | """
|
915 | 894 |
|
916 |
| - x = _float_utils.__downcast_float128_array(a) |
| 895 | + x = __downcast_float128_array(a) |
917 | 896 | s, axes = _cook_nd_args(x, s, axes)
|
918 | 897 | fsc = _compute_fwd_scale(norm, s, x.shape)
|
919 | 898 |
|
@@ -1201,7 +1180,7 @@ def rfftn(a, s=None, axes=None, norm=None):
|
1201 | 1180 |
|
1202 | 1181 | """
|
1203 | 1182 |
|
1204 |
| - x = _float_utils.__downcast_float128_array(a) |
| 1183 | + x = __downcast_float128_array(a) |
1205 | 1184 | s, axes = _cook_nd_args(x, s, axes)
|
1206 | 1185 | fsc = _compute_fwd_scale(norm, s, x.shape)
|
1207 | 1186 |
|
@@ -1345,7 +1324,7 @@ def irfftn(a, s=None, axes=None, norm=None):
|
1345 | 1324 |
|
1346 | 1325 | """
|
1347 | 1326 |
|
1348 |
| - x = _float_utils.__downcast_float128_array(a) |
| 1327 | + x = __downcast_float128_array(a) |
1349 | 1328 | s, axes = _cook_nd_args(x, s, axes, invreal=True)
|
1350 | 1329 | fsc = _compute_fwd_scale(norm, s, x.shape)
|
1351 | 1330 |
|
|
0 commit comments