Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing 6-step FFT algorithm #967

Open
NimaSarajpoor opened this issue May 13, 2024 · 15 comments
Open

Optimizing 6-step FFT algorithm #967

NimaSarajpoor opened this issue May 13, 2024 · 15 comments
Labels
enhancement New feature or request

Comments

@NimaSarajpoor
Copy link
Collaborator

This issue is to optimize the 6-step FFT algorithm, as discussed initially in #938. We will try to improve the performance of each block of the algorithm. Each result MUST come with an end-to-end code, and the code MUST have assertion to make sure the output is correct.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 13, 2024

RFFT (Real FFT)

We first start with RFFT algorithm in which we use 6-step FFT algorithm as provided here . We test it by comparing its output against scipy.fft.rfft, and we show its performance.

code

# saved in file: rfft_optimized_v0

import math
import numpy as np
from numba import njit


@njit(fastmath=True)
def _fft_block(n, s, eo, x, y):
    """
    A recursive function that is used as part of fft algorithm

    n : int
    s : int
    eo: bool
    x : numpy.array 1D
    y : numpy.array 1D
    """
    if n == 2:
        if eo:
            z = y
        else:
            z = x

        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b

    elif n >= 4:
        m = n // 2
        sm = s * m

        theta = math.pi / m
        for p in range(m):
            sp = s * p
            w = math.cos(p * theta) - 1j * math.sin(p * theta)
            for q in range(s):
                idx = sp + q
                a = x[idx]
                b = x[idx + sm]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * w

        _fft_block(m, 2*s, not eo, y, x)

    else:
        pass


@njit(fastmath=True)
def _eightstep_fft(x, y):
    """
    Apply 8-step FFT algorithm and update x in-place.
    """
    n = len(x)
    m = n // 2

    theta = math.pi / m 
    for i in range(m):
        w = math.cos(i * theta) - 1j * math.sin(i * theta)
        j = i + m
        y[i] = x[i] + x[j]
        y[j] = (x[i] - x[j]) * w

    _sixstep_fft(y[:m], x[:m])
    _sixstep_fft(y[m:], x[m:])

    for p in range(m):
        x[2 * p] = y[p]
        x[2 * p + 1] = y[p + m]

    return


@njit(fastmath=True)
def _sixstep_fft(x, y):
    """
    Apply 6-step FFT algorithm and update x in-place.
    """
    n = len(x)
    n_sqrt = int(np.sqrt(n))

    theta = 2 * math.pi / n_sqrt
    c_theta = math.cos(theta) - 1j * math.sin(theta)


    # step 1: matrix transpose
    for k in range(n_sqrt):
        kn = k * n_sqrt
        for p in range(k + 1, n_sqrt):
            i = k + p * n_sqrt
            j = p + kn
            x[i], x[j] = x[j], x[i]

    # step 2
    for start in range(0, n, n_sqrt):
        _fft_block(n_sqrt, 1, False, x[start:], y[start:])

    # step 3 and 4: tranpose with twiddle_factor
    for p in range(n_sqrt):
        theta0 = 2 * p * math.pi / n
        for k in range(p, n_sqrt):
            theta = k * theta0
            w = math.cos(theta) - 1j * math.sin(theta)
    
            if k == p:
                i = p * n_sqrt + p
                x[i] = x[i] * w
            else:
                i = p * n_sqrt + k
                j = k * n_sqrt + p
                x[j], x[i] = x[i] * w, x[j] * w
                
    # step 5
    for start in range(0, n, n_sqrt):
        _fft_block(n_sqrt, 1, False, x[start:], y[start:])

    # step 6: matrix transpose
    for k in range(n_sqrt):
        kn = k * n_sqrt
        for p in range(k + 1, n_sqrt):
            i = k + p * n_sqrt
            j = p + kn
            x[i], x[j] = x[j], x[i]

    return


@njit(fastmath=True)
def _compute_fft(x, y):
    n = len(x)
    n_logtwo = int(np.log2(n))

    if n_logtwo == 1:
        a = x[0]
        b = x[1]
        x[0] = a + b
        x[1] = a - b
    elif n_logtwo % 2 == 0:
        _sixstep_fft(x, y)
    else:
        _eightstep_fft(x, y)

    return


@njit(fastmath=True)
def _compute_rfft(T):
    n = len(T)
    half_n = n // 2

    x = np.empty(half_n, dtype=np.complex_)
    for i in range(half_n):
        x[i] = T[2 * i] + 1j * T[2 * i + 1]

    y = np.empty(half_n + 1, dtype=np.complex_)
    _compute_fft(x, y[:half_n])

    y[0] = x[0].real + x[0].imag
    y[n // 4] = x[n // 4].conjugate()
    y[half_n] = x[0].real - x[0].imag

    w = 0.5j
    factor = math.cos(math.pi / half_n) - 1j * math.sin(math.pi / half_n)
    for k in range(1, n//4):
        w = w * factor
        v = (x[k] - x[half_n - k].conjugate()) * (0.5 + w)

        y[k] = x[k] - v
        y[half_n - k] = x[half_n - k] + v.conjugate()

    return y

Performance (+ Assertion)

import time

import numpy as np
import scipy

from matplotlib import pyplot as plt
from rfft_optimized_v0 import _compute_rfft

funcs_dict = {
    'rfft_v0': _compute_rfft,
    'numpy_rfft': np.fft.rfft,
    'scipy_rfft': scipy.fft.rfft
}

ref_func_key = 'numpy_rfft'
ref_func = funcs_dict[ref_func_key]

seed = 0
np.random.seed(seed)
data = np.random.rand(2 ** 23)

p_vals = np.arange(2, 24)

n_iter = 2
performance = {}
for func_name, func_obj in funcs_dict.items():
    print(f'============ {func_name} ============')
    running_time = []
    for p in p_vals:
        print(f'{p}', end='-->')
        T = data[:2 ** p]

        F_ref = ref_func(T)
        F_comp = func_obj(T)
        np.testing.assert_allclose(F_ref, F_comp, atol=1e-7)

        lst = []
        for _ in range(n_iter):
            t1 = time.time()
            func_obj(T)
            t2 = time.time()
            lst.append(t2 - t1)

        running_time.append(lst)
    
    performance[func_name] = running_time
    print('Done!')


### plotting
plt.figure(figsize=(20, 5))
plt.title('Comparing performances of different functions for computing rfft')

colors = ['cyan', 'r', 'b', 'orange', 'k', 'm', 'g', 'yellow', 'brown']

baseline_key = 'scipy_rfft'
baseline = np.array([np.median(lst) for lst in performance[baseline_key]])

for i, key in enumerate(list(performance.keys())):
    if key == baseline_key:
      continue
    
    arr = np.array([np.median(lst) for lst in performance[key]])
    r = arr / baseline
    
    plt.plot(r, color=colors[i], marker='o', label=key)

plt.axhline(y=1, color='k', linestyle='--', label=f'y=1 (baseline: {baseline_key})')
plt.xlabel('The log2 of input array length', fontsize=13)
plt.ylabel(f"Running time ratio w.r.t baseline", fontsize=13)
plt.xticks(ticks=np.arange(len(p_vals)), labels=p_vals, fontsize=13)
plt.yticks(fontsize=13)
plt.grid()
plt.legend(fontsize=13)
plt.show()

In my MacOS, the result is:
image

This Colab notebook contains the code.


[Update]
I ran it again with 100 iterations, and got this:

image

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 13, 2024

Matrix Transposition

In the RFFT code provided in previous comment, the function _compute_fft does the majority of the computing load. This function calls 6-step or 8-step depending on the log2 of the length of its input. We start with the blocks of the code in 6-step algorithm. The first step in this algorithm is to transform 1D array x into x.reshape(n, n).T.ravel(), where $n=\sqrt{len(x)}$.

Code

Currently, this is happening via the following code:

@njit(fastmath=True)
def _tranpose_v0(x, n_sqrt):
    for k in range(n_sqrt):
        kn = k * n_sqrt
        for p in range(k + 1, n_sqrt):
            i = k + p * n_sqrt
            j = p + kn
            x[i], x[j] = x[j], x[i]
    
    return

However, as discussed in #938, we can use cache-oblivious algorithm (see #965) for matrix transposition as follows:

@njit(fastmath=True)
def _tranpose_v2(x, n_sqrt, x_transpose):
    blocksize = 32
    blocksize = min(blocksize, n_sqrt)

    x = x.reshape(n_sqrt, n_sqrt)
    x_transpose = x_transpose.reshape(n_sqrt, n_sqrt)
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            x_transpose[i:i + blocksize, j:j + blocksize] = np.transpose(x[j:j + blocksize, i:i + blocksize])

    return

Performance( + Assertion)

import time
import numpy as np

from matplotlib import pyplot as plt
from numba import njit

@njit(fastmath=True)
def _tranpose_v0(x, n_sqrt):
    for k in range(n_sqrt):
        kn = k * n_sqrt
        for p in range(k + 1, n_sqrt):
            i = k + p * n_sqrt
            j = p + kn
            x[i], x[j] = x[j], x[i]
    
    return


@njit(fastmath=True)
def _tranpose_v1(x, n_sqrt, x_transpose):
    blocksize = 32
    blocksize = min(blocksize, n_sqrt)

    x = x.reshape(n_sqrt, n_sqrt)
    x_transpose = x_transpose.reshape(n_sqrt, n_sqrt)
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            x_transpose[i:i + blocksize, j:j + blocksize] = np.transpose(x[j:j + blocksize, i:i + blocksize])

    return

funcs_dict = {
    'tranpose_v0': _tranpose_v0,
    'tranpose_v1': _tranpose_v1,
}

ref_func_key = 'tranpose_v0'
ref_func = funcs_dict[ref_func_key]

seed = 0
np.random.seed(seed)
data = np.random.rand(2 ** 23)
data = data[::2] + 1j * data[1::2]

p_vals = np.arange(2, 22 + 1, 2)

n_iter = 500
performance = {}
for func_name, func_obj in funcs_dict.items():
    print(f'============ {func_name} ============')
    running_time = []
    for p in p_vals:
        print(f'{p}', end='-->')
        T = data[:2 ** p]
        n = len(T)
        n_sqrt = int(np.sqrt(n))

        y = np.empty(n, dtype=np.complex_)
    
        ref = T.copy()
        ref_func(ref, n_sqrt)

        if func_name == 'tranpose_v0':
            comp = T.copy()
            func_obj(comp, n_sqrt)
        else:
            comp = y.copy()
            func_obj(T, n_sqrt, comp)

        np.testing.assert_allclose(ref, comp, atol=1e-7)

        lst = []
        for _ in range(n_iter):
            x = T.copy()
            
            if func_name == 'tranpose_v0':
                t1 = time.time()
                func_obj(x, n_sqrt)
                t2 = time.time()
            else:
                t1 = time.time()
                func_obj(x, n_sqrt, y)
                t2 = time.time()
            lst.append(t2 - t1)

        running_time.append(lst)
    
    performance[func_name] = running_time
    print('Done!')


### plotting
plt.figure(figsize=(20, 5))
plt.title('Comparing performances of different functions\n for performing matrix transpose')

colors = ['cyan', 'r', 'b', 'orange', 'k', 'm', 'g', 'yellow', 'brown']

baseline_key = 'tranpose_v0'
baseline = np.array([np.median(lst) for lst in performance[baseline_key]])

for i, key in enumerate(list(performance.keys())):
    if key == baseline_key:
      continue
    
    arr = np.array([np.median(lst) for lst in performance[key]])
    r = arr / baseline
    
    plt.plot(r, color=colors[i], marker='o', label=key)

plt.axhline(y=1, color='k', linestyle='--', label=f'y=1 (baseline: {baseline_key})')
plt.xlabel('The log2 of input array length', fontsize=13)
plt.ylabel(f"Running time ratio w.r.t baseline", fontsize=13)
plt.xticks(ticks=np.arange(len(p_vals)), labels=p_vals, fontsize=13)
plt.yticks(fontsize=13)
plt.grid()
plt.legend(fontsize=13)
plt.show()
image

This Colab notebook contains the code.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 13, 2024

fft_block

In the step 2 and 5 of the 6-step FFT algorithm, we use the following recursive function:

# fft_block_v0

@njit(fastmath=True)
def _fft_block(n, s, eo, x, y):
    """
    A recursive function that is used as part of fft algorithm

    n : int
    s : int
    eo: bool
    x : numpy.array 1D
    y : numpy.array 1D
    """
    if n == 2:
        if eo:
            z = y
        else:
            z = x

        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b

    elif n >= 4:
        m = n // 2
        sm = s * m

        theta = math.pi / m
        for p in range(m):
            sp = s * p
            c = math.cos(p * theta) - 1j * math.sin(p * theta)
            for q in range(s):
                idx = sp + q
                a = x[idx]
                b = x[idx + sm]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * c

        _fft_block(m, 2*s, not eo, y, x)

    else:
        pass

We can think of the following ways to speed this up:

(1) Avoid calling math.cos and math.sin function in each iteration of the outer for-loop. We can move it to the outside of the outer for-loop. Then, we can update the factor as follows:

# will be used in new version: fft_block_v1

theta = math.pi / m
c = math.cos(theta) - 1j * math.sin(theta)
w = 1.0
for p in range(m):
    sp = s * p
    for q in range(s):
        idx = sp + q
        a = x[idx]
        b = x[idx + sm]

        y[idx + sp] = a + b
        y[idx + sp + s] = (a - b) * w
    
    w = w * c

(2) In (1), we have: c = math.cos(p * theta) - 1j * math.sin(p * theta). So, we still call math.cos and math.sin in each call of the fft_block recursive function. We can avoid this by adding c this as parameter to the function's signature.

# will be used in new version: fft_block_v2

@njit(fastmath=True)
_fft_block(n, s, eo, x, y, c):
    # Do something
    _fft_block(n, s, eo, x, y, c * c):

An initial c needs to be computed and passed as argument to the _fft_block.

(3) The factor w in y[idx + sp + s] = (a - b) * w needs to be updated m times in each call of the recursive function. Now we try to reduce the number of times this parameter needs to be updated within the for-loop for p in range(m):.

For a given m, the factor w in y[idx + sp + s] = (a - b) * w has the following relationship with p (see # fft_block_v0 provided at the top of this comment):

$w_{p} = cos(\theta_p) - 1j * sin(\theta_p)$, where $\theta_p = \frac{p * \pi}{m}$

$w_{m - p} = cos(\theta_{m-p}) - 1j * sin(\theta_{m-p})$, where $\theta_{m-p} = \frac{(m - p) * \pi}{m}$

Note that $\theta_p + \theta_{m-p} = \pi$. Therefore:

$w_{m - p} = cos(\pi - \theta_{p}) - 1j * sin(\pi - \theta_{p})$
$w_{m - p} = - cos(\theta_p) - 1j * sin(\theta_p)$
$w_{m - p} = - w_{p}^{*}$

Therefore, in this version, in addition to (2), we can replace the following for-loop

w = 1.0
for p in range(m):
    sp = s * p
    for q in range(s):
        idx = sp + q
        a = x[idx]
        b = x[idx + sm]

        y[idx + sp] = a + b
        y[idx + sp + s] = (a - b) * w
    
    w = w * c

with this code:

# will be used in new version: fft_block_v3

        # p = 0
        for i in range(s):
            j = i + sm
            y[i] = x[i] + x[j]
            y[i + s] = x[i] - x[j]

        w = 1.0
        for p in range(1, m // 2):
            # p  --> 1, 2, 3, ..., m//2 - 1
            w = w * c
            sp = s * p
            for i in range(sp, sp + s):
                b = x[i + sm]
                k = i + sp
                y[k] = x[i] + b
                y[k + s] = (x[i] - b) * w

            # p = m - p --> m - 1, m - 2, m - 3, ..., m - m // 2 + 1
            sp = sm - sp 
            w_conj = w.conjugate()
            for i in range(sp, sp + s):
                b = x[i + sm]
                k = i + sp
                y[k] = x[i] + b
                y[k + s] = (b - x[i]) * w_conj

        # p = m // 2
        w = w * c
        sp = sm // 2
        for i in range(sp, sp + s):
            b = x[i + sm]
            k = i + sp
            y[k] = x[i] + b
            y[k + s] = (x[i] - b) * w

(4) we can precompute all the factors and pass the precomputed array as argument to the recursive function. We can use the following code to compute the array:

# will be used in new version: fft_block_v4

@njit(fastmath=True)
def _fill_c_array(c_arr, n, c):
    m = n // 2
    w = 1.0
    for i in range(m):
        c_arr[i] = w
        w = w * c

    if m > 2:
        _fill_c_array(c_arr[m:], m, c * c)
    else:
        return

@njit(fastmath=True)
def fill_c_array(n):
    """
    n is square root of length of input array in 6-step function.
    """
    theta = 2 * math.pi / n
    c = math.cos(theta) - 1j * math.sin(theta)
    c_arr = np.empty(n, dtype=np.complex_)

    _fill_c_array(c_arr, n, c)

    return c_arr

Performance (+ Assertion)

In this Colab notebook, the performance of these four new versions are checked and compared with the baseline. In the 6-step FFT, the recursive function _fft_block(n, s, eo, x, y) is being called n times, twice! So, to better reflect the impact of different functions above on the 6-step fft algorithm, we will consider the same number of calls. The fft_block of version (2) seems to work better than the others (overall?)

@NimaSarajpoor
Copy link
Collaborator Author

transpose + twiddle factor

This is regarding step 3 and 4 of the 6-step algorithm. Originally, we have the following code:

# version 0

    for p in range(n_sqrt):
        theta0 = 2 * p * math.pi / n
        for k in range(p, n_sqrt):
            theta = k * theta0
            w = math.cos(theta) - 1j * math.sin(theta)
    
            if k == p:
                i = p * n_sqrt + p
                x[i] = x[i] * w
            else:
                i = p * n_sqrt + k
                j = k * n_sqrt + p
                x[j], x[i] = x[i] * w, x[j] * w

which tranpose the matrix and multiple its element by twiddle factor. The twiddle factor $w$ is $exp(-j\theta)$, where $\theta = p * k* \frac{2\pi}{n}$. This factor will be multiplied to the element (p, k) and element (k, p) in x.reshape(n_sqrt, n_sqrt), where n_sqrt is $\sqrt{len(x)}$.

We can avoid calling math.cos and math.sin in the inner for-loop by creating an initial value for w, and keep updating it within the inner for-loop as shown below:

# version 1

    theta_init = 2 * math.pi / n
    for p in range(n_sqrt):
        theta0 = theta_init * p

        c = math.cos(theta0) - 1j * math.sin(theta0)
        w = math.cos(theta0 * p) - 1j * math.sin(theta0 * p)
        for k in range(p, n_sqrt):
            i = p * n_sqrt + k

            if p == k:
                x[i] = x[i] * w
            else:
                j = k * n_sqrt + p
                x[j], x[i] = x[i] * w, x[j] * w

            w = w * c

We still need to call math.cos and math.sin twice in each iteration of the outer for-loop. We can again move them to the outside of the outer for-loop. We can start with an initial value, and keep updating them as follows:

# version 2

    wp = 1.0
    cp = 1.0

    theta = 2 * math.pi / n
    factor = math.cos(theta) - 1j * math.sin(theta)
    for p in range(n_sqrt):
        pns = p * n_sqrt
        c = cp
        w = wp
        x[pns + p] = x[pns + p] * w
        for q in range(p + 1, n_sqrt):
            w = w * c
            i = pns + q
            j = q * n_sqrt + p
            x[j], x[i] = x[i] * w, x[j] * w

        cp_new = factor * cp
        wp = wp * cp_new * cp
        cp = cp_new

Performance ( + Assertion)

This Colab Notebook contains the code that shows the performance of these three versions.

image

As observed, the last version mentioned above outperforms the others.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 15, 2024

Function _sixstep_fft (Let's put the optimized blocks together)

Now, let's compare the original function _sixstep_fft (used in our fft algorithm as provided in this comment) with the new version where each block of code is replaced with its optimized version based on the results provided in previous comments. The new version of this function is provided below:

%%writefile sixstep_fft_v1.py
import math
import numpy as np
from numba import njit


@njit(fastmath=True)
def _tranpose(x, n_sqrt, x_transpose):
    blocksize = 32
    blocksize = min(blocksize, n_sqrt)

    x = x.reshape(n_sqrt, n_sqrt)
    x_transpose = x_transpose.reshape(n_sqrt, n_sqrt)
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            x_transpose[i:i + blocksize, j:j + blocksize] = np.transpose(x[j:j + blocksize, i:i + blocksize])

    return


@njit(fastmath=True)
def _fft_block(n, s, eo, x, y, c):
    """
    A recursive function that is used as part of fft algorithm

    n : int
    s : int
    eo: bool
    x : numpy.array 1D
    y : numpy.array 1D
    """
    if n == 2:
        if eo:
            z = y
        else:
            z = x

        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b

    elif n >= 4:
        m = n // 2
        sm = s * m

        w = 1.0
        for p in range(m):
            sp = s * p
            for q in range(s):
                idx = sp + q
                a = x[idx]
                b = x[idx + sm]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * w

            w = w * c

        _fft_block(m, 2*s, not eo, y, x, c * c)

    else:
        pass


@njit(fastmath=True)
def _sixstep_fft(x, y):
    """
    Apply 6-step FFT algorithm and update x in-place.
    """
    n = len(x)
    n_sqrt = int(np.sqrt(n))


    # step 1: matrix transpose
    _tranpose(x, n_sqrt, y)

    # step 2
    theta = 2 * math.pi / n_sqrt
    c_theta = math.cos(theta) - 1j * math.sin(theta)
    for start in range(0, n, n_sqrt):
      _fft_block(n_sqrt, 1, False, y[start:], x[start:], c_theta)

    # step 3 and 4: tranpose with twiddle_factor
    wp = 1.0
    cp = 1.0

    theta_twiddle = 2 * math.pi / n
    factor = math.cos(theta_twiddle) - 1j * math.sin(theta_twiddle)
    for p in range(n_sqrt):
        pns = p * n_sqrt
        c = cp
        w = wp
        y[pns + p] = y[pns + p] * w
        for q in range(p + 1, n_sqrt):
            w = w * c
            i = pns + q
            j = q * n_sqrt + p
            y[j], y[i] = y[i] * w, y[j] * w

        cp_new = factor * cp
        wp = wp * cp_new * cp
        cp = cp_new

    # step 5
    for start in range(0, n, n_sqrt):
      _fft_block(n_sqrt, 1, False, y[start:], x[start:], c_theta)

    # step 6: matrix transpose
    _tranpose(y, n_sqrt, x)

    return

Performance (+ Assertion)

The code is provided in this Colab notebook. The result is provided below:

image

As observed, we see 50-80% improvement (except input with length 2^2)

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 15, 2024

Eight-step function

We now work on the eight-step function as provided in this comment

@njit(fastmath=True)
def _eightstep_fft(x, y):
    """
    Apply 8-step FFT algorithm and update x in-place.
    """
    n = len(x)
    m = n // 2

    theta = math.pi / m 
    for i in range(m):
        w = math.cos(i * theta) - 1j * math.sin(i * theta)
        j = i + m
        y[i] = x[i] + x[j]
        y[j] = (x[i] - x[j]) * w

    _sixstep_fft(y[:m], x[:m])
    _sixstep_fft(y[m:], x[m:])

    for p in range(m):
        x[2 * p] = y[p]
        x[2 * p + 1] = y[p + m]

    return

We focus on the following part of the code:

# version 0

   m = len(x) // 2

  theta = math.pi / m 
  for i in range(m):
      w = math.cos(i * theta) - 1j * math.sin(i * theta)
      j = i + m
      y[i] = x[i] + x[j]
      y[j] = (x[i] - x[j]) * w

Note 1: We can move w = math.cos(i * theta) - 1j * math.sin(i * theta) to the outside of the for-loop. (# version 1)

Note 2: The factor, w, of the i-th iteration. is negative-conjugate of the w of m-i-th iteration (# version 2)

Performance (+ Assertion)

This Colab Notebook contains the code. The following result is obtained:

image

As observed, the version 1 and 2 show close performance. The code in version 2 is more complicated though. So, we go with version 1.

m = len(x) // 2
  
  theta = math.pi / m 
  factor = math.cos(theta) - 1j * math.sin(theta)
  w = 1.0
  for i in range(m):
      j = i + m
      y[i] = x[i] + x[j]
      y[j] = (x[i] - x[j]) * w

      w = w * factor

  return

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 19, 2024

Put pieces together...

It is now the time to put the pieces together, and compare our so-far-optimized RFFT with the one provided in this comment. Code is available in this Colab notebook. The running time of RFFT is recorded for input with length 2^2...2^20. For each length, the function is called 5000 times. Since the running time is small, small deviation may result in considerable difference when calculating the speed-up ratio (w.r.t the running time of scipy.fft.rfft). To this end, I tried to calculate a range. Out of 5000 samples (of running time), I removed the ones that are outside of the range $[\mu - 2\sigma, \mu + 2\sigma]$. I then got the min, max, and mean. Therefore, the max and min speed-up for RFFT w.r.t scipy, can be computed as follows:

max_speed_up = max_running_time_of_RFFT / min_running_time_of_Scipy
min_speed_up = min_running_time_of_RFFT / max_running_time_of_Scipy

and, we can calculate the mean-based speed-up as follows:

mean-based speed_up = mean_running_time_of_RFFT  / mean_running_time_of_Scioy

Not sure if this approach is science-backed. Still, it can give us some idea about the range. In Colab, I got the following result (lower is better):

image

And, in my MacOS, I got this:
image

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 19, 2024

As a follow up to my previous comment, I would like to check the performance of numpy.fft.rfft as well. The result (from running the code in my MacOS) shown below. Lower is better.
image

@seanlaw seanlaw added the enhancement New feature or request label May 22, 2024
@NimaSarajpoor
Copy link
Collaborator Author

[Recap]
In the previous comment, I showed that the optimized version of our FFT implementation (shown in red) is significantly better than the initial implementation. However, it is still outperformed by Numpy and Scipy when the size of input is large. Our FFT implementation calls a recursive function which was optimized according to the study described in this comment.

[Now]
We replace the recursive function with a for-loop. Furthermore, I call different functions for input with different sizes. We call it rfft_v2. I noticed a considerable performance gain. It is NOT a clean approach and I need to work on it more. But.... it gives us some hope! :)

We are now getting closer to Scipy's performance!

image

The code is available in this colab notebook and it has assertion to make sure the output is correct.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Jun 2, 2024

@seanlaw
I have been working on cleaning the _sixstep_fft function (a part of our RFFT implementation). I decided to get a 2D view of x and y right in the beginning of the function. This helped me to remove some arithmetic operations (no improvement in the performance though)

@njit(fastmath=True)
def _sixstep_fft(x, y):
    """
    Apply 6-step FFT algorithm and update x in-place.
    """
    n = len(x)
    n_sqrt = int(np.sqrt(n))
    x_2D = x.reshape(n_sqrt, n_sqrt)
    y_2D = y.reshape(n_sqrt, n_sqrt)

    # step 1: matrix transpose
    blocksize = 32
    blocksize = min(blocksize, n_sqrt)
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            y_2D[i:i + blocksize, j:j + blocksize] = np.transpose(x_2D[j:j + blocksize, i:i + blocksize])

    
    # step 2: call fft_block on each row of y_2D. 
    theta = 2 * math.pi / n_sqrt
    c_theta = math.cos(theta) - 1j * math.sin(theta)
    for i in range(n_sqrt):
        _fft_block(y_2D[i], x_2D[i], c_theta, ...)  
        # see note below for this function
        # x_2D[i] is helper array

    # step 3 and 4: tranpose and multiply by twiddle_factor
    wp = 1.0
    cp = 1.0

    theta = 2 * math.pi / n
    factor = math.cos(theta) - 1j * math.sin(theta)
    for p in range(n_sqrt):
        c = cp
        w = wp
        y_2D[p, p] = y_2D[p, p] * w
        for q in range(p + 1, n_sqrt):
            w = w * c
            y_2D[q, p], y_2D[p, q] = y_2D[p, q] * w, y_2D[q, p] * w

        cp_new = factor * cp
        wp = wp * cp_new * cp
        cp = cp_new

    # step 5
    for i in range(n_sqrt):
        _fft_block(y_2D[i], x_2D[i], c_theta, ...)

    # step 6: matrix transpose
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            x_2D[i:i + blocksize, j:j + blocksize] = np.transpose(y_2D[j:j + blocksize, i:i + blocksize])

    return

Notes:
(1) In step 1 and step 6, we do matrix transpose using cache-oblivious algorithm (COA). So, I assume the performance of these two steps should be good.

(2) In the combined steps 3 & 4, we do matrix transpose but we do not use COA. So, IMO, there might be some opportunity for performance boost here. I am going to provide the original version below to better understand this combined step:

theta = 2 * math.pi / n
for p in range(n_sqrt):
    for q in range(p, n_sqrt): 
        twiddle_factor = math.cos(theta * p * q) - 1j * math.sin(theta * p * q)   
        if p == q:
            y_2D[p, p] = y_2D[p, p] * twiddle_factor
        else:
            y_2D[q, p], y_2D[p, q] = y_2D[p, q] * twiddle_factor, y_2D[q, p] * twiddle_factor

(3) In each of the steps 2 and 4, the function _fft_block is applied to each row of y_2D. Originally, this function was recursive as follows:

@njit(fastmath=True)
def _fft_block(n, s, eo, x, y, c):
    """
    A recursive function that is used as part of fft algorithm

    n : int, where n=len(x), len(x)//2, ...
    s : int, where s=1, 2, 4, 
    eo: bool
    x : chunk of interest
    y : helper array
    """
    if n == 2:
        if eo:
            z = y
        else:
            z = x

        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b

    elif n >= 4:
        m = n // 2
        sm = s * m

        w = 1.0
        for p in range(m):
            sp = s * p
            for q in range(s):
                idx = sp + q
                a = x[idx]
                b = x[idx + sm]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * w

            w = w * c

        _fft_block(m, 2*s, not eo, y, x, c * c)   # note that we are swapping the pointers to x and y

    else:
        pass

and, in the _sixstep_fft function, it is called like _fft_block(n_sqrt, 1, False, y_2D[i], x_2D[i], c_theta)

We can make this function cleaner by removing the need for the parameter n:

@njit(fastmath=True)
def _fft_block(s, eo, x, y, c, x_length_half):
    if s == x_length_half:
        if eo:
            z = y
        else:
            z = x

        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b

    elif s == 1:
        w = 1.0
        for i in range(x_length_half):            
            a = x[i]
            b = x[i + x_length_half]

            y[2 * i] = a + b
            y[2 * i + 1] = (a - b) * w
            w = w * c

        _fft_block(2, True, y, x, c * c, x_length_half)

    elif s < x_length_half:
        w = 1.0
        for i in range(0, x_length_half, s):
            for j in range(i, i + s):
                a = x[j]
                b = x[j + x_length_half]

                y[j + i] = a + b
                y[j + i + s] = (a - b) * w

            w = w * c

        _fft_block(2 * s, not eo, y, x, c * c, x_length_half)
    
    else:
        pass

Still, this does not improve the performance. We can also convert the recursive function to a for-loop:

# when len(x) is 2
@njit(fastmath=True)
def _fft_block_2(x):
    a = x[0]
    b = x[1]
    x[0] = a + b
    x[1] = a - b

    return

# x_2D = x.reshape(n_sqrt, n_sqrt)
R = int(np.log2(n_sqrt))
if R % 2 == 0:
    R_iter = R // 2 - 1
else:
    R_iter = (R - 3) // 2

# when len(x) is > 2, and log2(len(x)) is even
@njit(fastmath=True)
def _fft_block_forloop_log2nsqrtEVEN(x, y, c_theta, n_sqrt, x_len_half, R_iter):  
    w = 1.0
    for p in range(x_len_half):
        a = x[p]
        b = x[p + x_len_half]

        y[2 * p] = a + b
        y[2 * p + 1] = (a - b) * w
        w = w * c

    s = 1
    for _ in range(R_iter):
        s = 2 * s
        for idx in range(s):
              a = y[idx]
              b = y[idx + x_len_half]

              x[idx] = a + b
              x[idx + s] = a - b

        c = c * c
        w = c
        for sp in range(s, x_len_half, s):
            for idx in range(sp, sp + s):
                a = y[idx]
                b = y[idx + x_len_half]

                x[idx + sp] = a + b
                x[idx + sp + s] = (a - b) * w

            w = w * c

        s = s * 2
        for idx in range(s):
              a = x[idx]
              b = x[idx + x_len_half]

              y[idx] = a + b
              y[idx + s] = a - b

        c = c * c
        w = c
        for sp in range(s, x_len_half, s):
            for idx in range(sp, sp + s):
                a = x[idx]
                b = x[idx + x_len_half]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * w

            w = w * c

    for i in range(x_len_half):
        j = i + x_len_half
        a = y[i]
        b = y[j]
        x[i] = a + b
        x[j] = a - b


    return


# when len(x) is > 2, and log2(len(x)) is odd
@njit(fastmath=True)
def _fft_block_forloop_log2nsqrtODD(x, y, c, n_sqrt, sm, R_iter):  
    w = 1.0

    for p in range(x_len_half):
        a = x[p]
        b = x[p + x_len_half]

        y[2 * p] = a + b
        y[2 * p + 1] = (a - b) * w

        w = w * c

    s = 1
    for _ in range(R_iter):
        s = s * 2
        for idx in range(s):
              a = y[idx]
              b = y[idx + x_len_half]

              x[idx] = a + b
              x[idx + s] = a - b

        c = c * c
        w = c
        for sp in range(s, x_len_half, s):
            for idx in range(sp, sp + s):
                a = y[idx]
                b = y[idx + x_len_half]

                x[idx + sp] = a + b
                x[idx + sp + s] = (a - b) * w

            w = w * c

        s = s * 2
        for idx in range(s):
            a = x[idx]
            b = x[idx + x_len_half]

            y[idx] = a + b
            y[idx + s] = a - b

        c = c * c
        w = c
        for sp in range(s, x_len_half, s):
            for idx in range(sp, sp + s):
                a = x[idx]
                b = x[idx + x_len_half]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * w

            w = w * c

    s = 2 * s

    for idx in range(s):
          a = y[idx]
          b = y[idx + x_len_half]

          x[idx] = a + b
          x[idx + s] = a - b

    c = c * c
    w = c
    for sp in range(s, x_len_half, s):
        for idx in range(sp, sp + s):
            a = y[idx]
            b = y[idx + x_len_half]

            x[idx + sp] = a + b
            x[idx + sp + s] = (a - b) * w

        w = w * c

    for i in range(x_len_half):
        j = i + x_len_half
        a = x[i]
        b = x[j]
        x[i] = a + b
        x[j] = a - b

    return

And still, there is no boost in performance. However, if we remove the outer for-loop for _ in range(R_iter):, and, instead, just write the inner for-loop R_iter times, then we can see a significant improvement in the performance as shown in my previous comment. The problem is that we need to create different functions, each for a different value of n_sqrt.


I will update this post by adding a Colab notebook that contains the code.

@seanlaw
Copy link
Contributor

seanlaw commented Jun 3, 2024

However, if we remove the outer for-loop for _ in range(R_iter):, and, instead, just write the inner for-loop R_iter times, then we can see a significant improvement

Is this the same as loop unrolling?

@NimaSarajpoor
Copy link
Collaborator Author

However, if we remove the outer for-loop for _ in range(R_iter):, and, instead, just write the inner for-loop R_iter times, then we can see a significant improvement
Is this the same as loop unrolling?

Yes! Thank you!! I am going to go through that wikipedia page and see if I can find more resources to get more insight (I recently read in Itamar's book that one shouldn't rely TOO much on the compiler's capabilities in optimizing the performance 😅)

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Jul 30, 2024

[Update]

However, if we remove the outer for-loop for _ in range(R_iter):, and, instead, just write the inner for-loop R_iter times, then we can see a significant improvement
Is this the same as loop unrolling?

Yes! Thank you!! I am going to go through that wikipedia page and see if I can find more resources to get more insight (I recently read in Itamar's book that one shouldn't rely TOO much on the compiler's capabilities in optimizing the performance 😅)

Couldn't find a way to tell compiler to unroll a loop! Still looking for a clean way to get that performance!

I will update this post by adding a Colab notebook that contains the code.

I have been experimenting different versions of my code. Need to prune them and just keep two or three versions. We can move forward from there.


[NOTE] Please ignore the following as the performance gain is different when we add njit decorator!

Recall that steps 3 & 4 of six-fft algorithm contains transpose & twiddle factor multiplication. Originally, it was:

  # step 3 and 4: tranpose with twiddle_factor
  for p in range(n_sqrt):
      theta_p = (2 * math.pi / n) * p
      for k in range(p, n_sqrt):
          theta_pk =  theta_p * k
          w = math.cos(theta_pk) - 1j * math.sin(theta_pk)

          if k == p:
              i = p * n_sqrt + p
              x[i] = x[i] * w
          else:
              i = p * n_sqrt + k
              j = k * n_sqrt + p
              x[j], x[i] = x[i] * w, x[j] * w

However, it was slow. I then tried to use trigonometry trick to avoid calling math.cos and math.sin in the inner-loop, and just do it once outside of the for-loop.

    # step 3 and 4: tranpose with twiddle_factor
    wp = 1.0
    cp = 1.0

    theta = 2 * math.pi / n
    factor = math.cos(theta) - 1j * math.sin(theta)
    for p in range(n_sqrt):
        pns = p * n_sqrt
        c = cp
        w = wp
        y[pns + p] = y[pns + p] * w
        for q in range(p + 1, n_sqrt):
            w = w * c
            i = pns + q
            j = q * n_sqrt + p
            y[j], y[i] = y[i] * w, y[j] * w

        cp_new = factor * cp
        wp = wp * cp_new * cp
        cp = cp_new

I think it is still slow! So, I tried to leverage cache-oblivious algorithm of matrix transpose AND pre-computed twiddle factor. I compared the performance using the following script:

import math
import numpy as np
from numba import njit
import time

def func_ref(x, y):
    n = len(x)
    n_sqrt = int(n ** 0.5)
    x = x.reshape(n_sqrt, n_sqrt)
    y = y.reshape(n_sqrt, n_sqrt)

    wp = 1.0
    cp = 1.0

    theta = 2 * math.pi / n
    factor = math.cos(theta) - 1j * math.sin(theta)
    for p in range(n_sqrt):
        c = cp
        w = wp
        y[p, p] = y[p, p] * w
        for q in range(p + 1, n_sqrt):
            w = w * c
            y[q, p], y[p, q] = y[p, q] * w, y[q, p] * w

        cp_new = factor * cp
        wp = wp * cp_new * cp
        cp = cp_new

    return


def get_w(n):
    n_sqrt = int(n ** 0.5)
    w = np.empty((n_sqrt, n_sqrt), dtype=np.complex128)
    theta = 2 * math.pi / n
    for i in range(n_sqrt):
        for j in range(n_sqrt):
            w[i, j] = math.cos(i * j * theta) - 1j * math.sin(i * j * theta)

    return w


def func_comp(x, y, w):
    n = len(x)
    n_sqrt = int(n ** 0.5)
    x = x.reshape(n_sqrt, n_sqrt)
    y = y.reshape(n_sqrt, n_sqrt)

    blocksize = 32
    blocksize = min(blocksize, n_sqrt)
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            x[i:i + blocksize, j:j + blocksize] = (
                np.transpose(y[j:j + blocksize, i:i + blocksize] * w[j:j + blocksize, i:i + blocksize])
            )
    
    return


if __name__ == '__main__':
        n_iter = 100

        for p in range(5, 21 + 1, 2):
            x = np.random.rand(2 ** p)
            x = x[::2] + 1j * x[1::2]
            print(f'============ log2_len(x): {p - 1} ============')

            y = x.copy()

            lst_ref = []
            for _ in range(n_iter):
                x_ref = x.copy()
                y_ref = y.copy()
                t1 = time.time()
                func_ref(x_ref, y_ref)
                t2 = time.time()
                lst_ref.append(t2 - t1)
            
            mu_ref = np.mean(lst_ref[1:])
            std_ref = np.std(lst_ref[1:])


            x_comp = x.copy()
            y_comp = y.copy()
            w = get_w(len(x_comp))

            lst_comp = []
            for _ in range(n_iter):
                t1 = time.time()
                func_comp(x_comp, y_comp, w)
                t2 = time.time()
                lst_comp.append(t2 - t1)
            mu_comp = np.mean(lst_comp[1:])
            std_comp = np.std(lst_comp[1:])



            np.testing.assert_allclose(y_ref, x_comp, atol=1e-7)


            print(f'ref --> mean: {mu_ref:4f}, std: {std_ref:4f}')
            print(f'comp --> mean: {mu_comp:4f}, std: {std_comp:4f}')
            print(f'Speedup percentage: {100 * (mu_ref - mu_comp) / mu_ref:.2f}%')

And the result shows:

============ log2_len(x): 2 ============
ref --> mean: 0.000002, std: 0.000000
comp --> mean: 0.000002, std: 0.000000
Speedup percentage: -6.83%
============ log2_len(x): 4 ============
ref --> mean: 0.000005, std: 0.000001
comp --> mean: 0.000002, std: 0.000000
Speedup percentage: 55.48%
============ log2_len(x): 6 ============
ref --> mean: 0.000014, std: 0.000001
comp --> mean: 0.000002, std: 0.000000
Speedup percentage: 84.63%
============ log2_len(x): 8 ============
ref --> mean: 0.000052, std: 0.000003
comp --> mean: 0.000002, std: 0.000001
Speedup percentage: 95.30%
============ log2_len(x): 10 ============
ref --> mean: 0.000198, std: 0.000052
comp --> mean: 0.000003, std: 0.000001
Speedup percentage: 98.31%
============ log2_len(x): 12 ============
ref --> mean: 0.000755, std: 0.000017
comp --> mean: 0.000018, std: 0.000000
Speedup percentage: 97.59%
============ log2_len(x): 14 ============
ref --> mean: 0.002827, std: 0.000096
comp --> mean: 0.000066, std: 0.000006
Speedup percentage: 97.65%
============ log2_len(x): 16 ============
ref --> mean: 0.011109, std: 0.000128
comp --> mean: 0.000273, std: 0.000003
Speedup percentage: 97.54%
============ log2_len(x): 18 ============
ref --> mean: 0.045712, std: 0.000397
comp --> mean: 0.001443, std: 0.000032
Speedup percentage: 96.84%
============ log2_len(x): 20 ============
ref --> mean: 0.190583, std: 0.002836
comp --> mean: 0.008274, std: 0.000205
Speedup percentage: 95.66%

@seanlaw
Copy link
Contributor

seanlaw commented Jul 30, 2024

@NimaSarajpoor Thanks for the update. Have you tried using profila? Based on what you are seeing, do you feel like we can actually get close to/beat the performance of FFTW?

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Dec 22, 2024

According to this open issue in Numba, providing optional argument(s) in a NJIT-decorated function can result in a considerable performance gain.

@NimaSarajpoor
Before reporting the performance, it is better to revisit the implemented FFT code and see if such opportunity exists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants