Skip to content

Commit

Permalink
Correlation via fft implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Nov 29, 2024
1 parent 6fc7166 commit 5f429be
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 32 deletions.
125 changes: 101 additions & 24 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
"""

import math

import dpctl.tensor as dpt
import dpctl.utils as dpu
import numpy
Expand All @@ -59,6 +61,8 @@
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov

min_ = min # pylint: disable=used-before-assignment

__all__ = [
"amax",
"amin",
Expand Down Expand Up @@ -478,17 +482,57 @@ def _get_padding(a_size, v_size, mode):
return l_pad, r_pad


def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
def _choose_conv_method(a, v, rdtype):
assert a.size >= v.size
if rdtype == dpnp.bool:
return "direct"

if v.size < 10**4 or a.size < 10**4:
return "direct"

if dpnp.issubdtype(rdtype, dpnp.integer):
max_a = int(dpnp.max(dpnp.abs(a)))
sum_v = int(dpnp.sum(dpnp.abs(v)))
max_value = int(max_a * sum_v)

default_float = dpnp.default_float_type(a.sycl_device)
if max_value > 2 ** numpy.finfo(default_float).nmant - 1:
return "direct"

if dpnp.issubdtype(rdtype, dpnp.number):
return "fft"

raise ValueError(f"Unsupported dtype: {rdtype}")


def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype):
queue = a.sycl_queue
device = a.sycl_device

supported_types = statistics_ext.sliding_dot_product1d_dtypes()
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)

usm_type = dpu.get_coerced_usm_type([a.usm_type, v.usm_type])
out_size = l_pad + r_pad + a.size - v.size + 1
if supported_dtype is None:
raise ValueError(
f"Unsupported input types ({a.dtype}, {v.dtype}), "
"and the inputs could not be coerced to any "
f"supported types. List of supported types: {supported_types}"
)

a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")

usm_type = dpu.get_coerced_usm_type([a_casted.usm_type, v_casted.usm_type])
out_size = l_pad + r_pad + a_casted.size - v_casted.size + 1
out = dpnp.empty(
shape=out_size, sycl_queue=queue, dtype=a.dtype, usm_type=usm_type
shape=out_size,
sycl_queue=queue,
dtype=supported_dtype,
usm_type=usm_type,
)

a_usm = dpnp.get_usm_ndarray(a)
v_usm = dpnp.get_usm_ndarray(v)
a_usm = dpnp.get_usm_ndarray(a_casted)
v_usm = dpnp.get_usm_ndarray(v_casted)
out_usm = dpnp.get_usm_ndarray(out)

_manager = dpu.SequentialOrderManager[queue]
Expand All @@ -506,7 +550,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
return out


def correlate(a, v, mode="valid"):
def _convolve_fft(a, v, l_pad, r_pad, rtype):
assert a.size >= v.size
assert l_pad < v.size

# +1 is needed to avoid circular convolution
padded_size = a.size + r_pad + 1
fft_size = 2 ** math.ceil(math.log2(padded_size))

af = dpnp.fft.fft(a, fft_size) # pylint: disable=no-member
vf = dpnp.fft.fft(v, fft_size) # pylint: disable=no-member

r = dpnp.fft.ifft(af * vf) # pylint: disable=no-member
if dpnp.issubdtype(rtype, dpnp.floating):
r = r.real
elif dpnp.issubdtype(rtype, dpnp.integer) or rtype == dpnp.bool:
r = r.real.round()

start = v.size - 1 - l_pad
end = padded_size - 1

return r[start:end]


def correlate(a, v, mode="valid", method="auto"):
r"""
Cross-correlation of two 1-dimensional sequences.
Expand All @@ -531,6 +598,20 @@ def correlate(a, v, mode="valid"):
is ``'valid'``, unlike :obj:`dpnp.convolve`, which uses ``'full'``.
Default: ``'valid'``.
method : {'auto', 'direct', 'fft'}, optional
`'direct'`: The correlation is determined directly from sums.
`'fft'`: The Fourier Transform is used to perform the calculations.
This method is faster for long sequences but can have accuracy issues.
`'auto'`: Automatically chooses direct or Fourier method based on
an estimate of which is faster.
Note: Use of the FFT convolution on input containing NAN or INF
will lead to the entire output being NAN or INF.
Use method='direct' when your input contains NAN or INF values.
Default: ``'auto'``.
Notes
-----
Expand All @@ -556,7 +637,6 @@ def correlate(a, v, mode="valid"):
:obj:`dpnp.convolve` : Discrete, linear convolution of two
one-dimensional sequences.
Examples
--------
>>> import dpnp as np
Expand Down Expand Up @@ -598,19 +678,14 @@ def correlate(a, v, mode="valid"):
f"Received shapes: a.shape={a.shape}, v.shape={v.shape}"
)

supported_types = statistics_ext.sliding_dot_product1d_dtypes()
supported_methods = ["auto", "direct", "fft"]
if method not in supported_methods:
raise ValueError(
f"Unknown method: {method}. Supported methods: {supported_methods}"
)

device = a.sycl_device
rdtype = result_type_for_device([a.dtype, v.dtype], device)
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)

if supported_dtype is None:
raise ValueError(
f"function '{correlate}' does not support input types "
f"({a.dtype}, {v.dtype}), "
"and the inputs could not be coerced to any "
f"supported types. List of supported types: {supported_types}"
)

if dpnp.issubdtype(v.dtype, dpnp.complexfloating):
v = dpnp.conj(v)
Expand All @@ -622,13 +697,15 @@ def correlate(a, v, mode="valid"):

l_pad, r_pad = _get_padding(a.size, v.size, mode)

a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")

if v.size > a.size:
a_casted, v_casted = v_casted, a_casted
if method == "auto":
method = _choose_conv_method(a, v, rdtype)

r = _run_native_sliding_dot_product1d(a_casted, v_casted, l_pad, r_pad)
if method == "direct":
r = _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype)
elif method == "fft":
r = _convolve_fft(a, v[::-1], l_pad, r_pad, rdtype)
else:
raise ValueError(f"Unknown method: {method}")

if revert:
r = r[::-1]
Expand Down
101 changes: 93 additions & 8 deletions dpnp/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,26 +629,104 @@ def test_corrcoef_scalar(self):


class TestCorrelate:
def setup_method(self):
numpy.random.seed(0)

@pytest.mark.parametrize(
"a, v", [([1], [1, 2, 3]), ([1, 2, 3], [1]), ([1, 2, 3], [1, 2])]
)
@pytest.mark.parametrize("mode", [None, "full", "valid", "same"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
def test_correlate(self, a, v, mode, dtype):
@pytest.mark.parametrize("method", [None, "auto", "direct", "fft"])
def test_correlate(self, a, v, mode, dtype, method):
an = numpy.array(a, dtype=dtype)
vn = numpy.array(v, dtype=dtype)
ad = dpnp.array(an)
vd = dpnp.array(vn)

if mode is None:
expected = numpy.correlate(an, vn)
result = dpnp.correlate(ad, vd)
else:
expected = numpy.correlate(an, vn, mode=mode)
result = dpnp.correlate(ad, vd, mode=mode)
dpnp_kwargs = {}
numpy_kwargs = {}
if mode is not None:
dpnp_kwargs["mode"] = mode
numpy_kwargs["mode"] = mode
if method is not None:
dpnp_kwargs["method"] = method

expected = numpy.correlate(an, vn, **numpy_kwargs)
result = dpnp.correlate(ad, vd, **dpnp_kwargs)

assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("a_size", [1, 100, 10000])
@pytest.mark.parametrize("v_size", [1, 100, 10000])
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("method", ["auto", "direct", "fft"])
def test_correlate_random(self, a_size, v_size, mode, dtype, method):
if dtype == dpnp.bool:
an = numpy.random.rand(a_size) > 0.9
vn = numpy.random.rand(v_size) > 0.9
else:
an = (100 * numpy.random.rand(a_size)).astype(dtype)
vn = (100 * numpy.random.rand(v_size)).astype(dtype)

if dpnp.issubdtype(dtype, dpnp.complexfloating):
an = an + 1j * (100 * numpy.random.rand(a_size)).astype(dtype)
vn = vn + 1j * (100 * numpy.random.rand(v_size)).astype(dtype)

ad = dpnp.array(an)
vd = dpnp.array(vn)

dpnp_kwargs = {}
numpy_kwargs = {}
if mode is not None:
dpnp_kwargs["mode"] = mode
numpy_kwargs["mode"] = mode
if method is not None:
dpnp_kwargs["method"] = method

result = dpnp.correlate(ad, vd, **dpnp_kwargs)
expected = numpy.correlate(an, vn, **numpy_kwargs)

rdtype = result.dtype
if dpnp.issubdtype(rdtype, dpnp.integer):
rdtype = dpnp.default_float_type(ad.device)

if method != "fft" and (
dpnp.issubdtype(dtype, dpnp.integer) or dtype == dpnp.bool
):
# For 'direct' and 'auto' methods, we expect exact results for integer types
assert_array_equal(result, expected)
else:
result = result.astype(rdtype)
if method == "direct":
expected = numpy.correlate(an, vn, **numpy_kwargs)
# For 'direct' method we can use standard validation
assert_dtype_allclose(result, expected, factor=30)
else:
rtol = 1e-3
atol = 1e-10

if rdtype == dpnp.float64 or rdtype == dpnp.complex128:
rtol = 1e-6
atol = 1e-12
elif rdtype == dpnp.bool:
result = result.astype(dpnp.int32)
rdtype = result.dtype

expected = expected.astype(rdtype)

diff = numpy.abs(result.asnumpy() - expected)
invalid = diff > atol + rtol * numpy.abs(expected)

# When using the 'fft' method, we might encounter outliers.
# This usually happens when the resulting array contains values close to zero.
# For these outliers, the relative error can be significant.
# We can tolerate a few such outliers.
max_outliers = 8 if expected.size > 1 else 0
if invalid.sum() > max_outliers:
assert_dtype_allclose(result, expected, factor=1000)

def test_correlate_mode_error(self):
a = dpnp.arange(5)
v = dpnp.arange(3)
Expand Down Expand Up @@ -689,7 +767,7 @@ def test_correlate_different_sizes(self, size):
vd = dpnp.array(v)

expected = numpy.correlate(a, v)
result = dpnp.correlate(ad, vd)
result = dpnp.correlate(ad, vd, method="direct")

assert_dtype_allclose(result, expected, factor=20)

Expand All @@ -700,6 +778,13 @@ def test_correlate_another_sycl_queue(self):
with pytest.raises(ValueError):
dpnp.correlate(a, v)

def test_correlate_unkown_method(self):
a = dpnp.arange(5)
v = dpnp.arange(3)

with pytest.raises(ValueError):
dpnp.correlate(a, v, method="unknown")


@pytest.mark.parametrize(
"dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True)
Expand Down

0 comments on commit 5f429be

Please sign in to comment.