Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pb.time_shift supports multiple time shifts #58

Merged
merged 3 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .zenodo.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@
"name": "Rebecca Lin",
"orcid": "0000-0003-4530-4254"
}
]
],

"license": "GPL-3.0-or-later",
"title": "pulsarbat: PULSAR Baseband Analysis Tools"
}
14 changes: 13 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,19 @@ Citing

``pulsarbat`` has a DOI via Zenodo: https://doi.org/10.5281/zenodo.6934355

This DOI represents all versions, and will always resolve to the latest one. To cite a specific version, follow the link and find the version you want to cite on Zenodo.
This DOI link represents all versions, and will always resolve to the latest one.
Use the following Bibtex entry to cite this work:

.. code-block::

@software{pulsarbat,
author = {Nikhil Mahajan and Rebecca Lin},
title = {pulsarbat: PULSAR Baseband Analysis Tools},
year = {2023},
publisher = {Zenodo},
doi = {10.5281/zenodo.6934355},
url = {https://doi.org/10.5281/zenodo.6934355}
}


License
Expand Down
2 changes: 1 addition & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ Release Notes

- Added a frequency-shift function ``pulsarbat.freq_shift`` (:pr:`52`)
- Added a snippet function ``pulsarbat.snippet`` (:pr:`54`)

- ``pulsarbat.time_shift`` can now support multiple shifts (:pr:`58`)
110 changes: 67 additions & 43 deletions pulsarbat/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def concatenate(signals, /, axis=0):


def snippet(z, /, t, n):
"""Extracts a snippet of a signal.
"""Extracts a snippet of a signal in time.

If ``t`` corresponds to non-integer number of samples from the
start of ``z``, time-shifting via FFT (by applying a phase gradient
Expand Down Expand Up @@ -197,79 +197,98 @@ def snippet(z, /, t, n):
if (i := int(t)) < t:
shift = i - t

if isinstance(z.data, da.Array):
f = da.fft.fftfreq(len(z), 1, chunks=(-1,))
else:
f = np.fft.fftfreq(len(z), 1)

ph = np.exp(-2j * np.pi * shift * f).astype(np.complex64)

ix = (slice(None),) + (None,) * (z.ndim - 1)
shifted = pb.fft.ifft(pb.fft.fft(z.data, axis=0) * ph[ix], axis=0)
shifted = shifted if np.iscomplexobj(z.data) else shifted.real
shifted = shifted.astype(z.dtype)

if z.start_time is None:
new_start = None
else:
new_start = z.start_time - shift * z.dt

shifted = pb.time_shift(z, shift, crop=True).data
z = type(z).like(z, shifted, start_time=new_start)

return z[i : i + n]


def time_shift(z, /, t):
"""Shift signal by given number of samples or time.
def time_shift(z, /, shift, crop=False):
"""Shift signal data by given number of samples or time.

This function shifts signals in time via FFT by multiplying by
a phase gradient in frequency domain.
This function shifts the signal data in time via FFT by multiplying by
a phase gradient in frequency domain. This usually only makes sense if
``z`` is a :py:class:`.BasebandSignal`. For non-baseband signals,
the output might not be meaningful.

Parameters
----------
z : Signal
Input signal.
t : int, float or Quantity
shift : int, float, array-like or Quantity
Shift amount. If a number (int or float), the signal is shifted
by that number of samples. An astropy Quantity with units of
time can also be passed, in which case the signal will be
shifted by `t * z.sample_rate` samples.
shifted by `dt * z.sample_rate` samples. If an array, must have
shape such that axes with length more than 1 match ``z.sample_shape``.
crop : bool, optional
Whether the returned signal is cropped to eliminate out-of-bounds
data. Default is False.

Returns
-------
out : Signal
Shifted signal.
Shifted signal. If the ``crop`` parameter is ``False``, will have
the same shape and ``start_time`` as input signal. If ``crop`` is
``True``, ``start_time`` will change by ``max(0, shift.max())``.

Notes
-----
Since an FFT is used, it is efficient to provide a signal with a
fast FFT length via :py:func:`pulsarbat.fast_len`.
"""
if t == 0:
if isinstance(shift, u.Quantity):
shift = (shift * z.sample_rate).to_value(u.one)

shift = np.array(shift)

if shift.ndim >= z.ndim:
raise ValueError(
f"shift has too many dimensions. Expected <= {z.ndim - 1} dimensions, "
f"got {shift.ndim} dimensions!"
)

# If shifts are zero, do nothing
if np.allclose(shift, 0):
return z

if isinstance(t, u.Quantity):
n = np.float64((t * z.sample_rate).to_value(u.one))
else:
n = np.float64(t)
if shift.ndim > 0:
ix = (slice(None),) * shift.ndim + (None,) * (z.ndim - shift.ndim - 1)
shift = shift[ix]

f_ix = tuple(slice(None) if j == 0 else None for j in range(z.ndim))
if isinstance(z.data, da.Array):
f = da.fft.fftfreq(len(z), 1)
f = da.fft.fftfreq(len(z), 1, chunks=(-1,))[f_ix]
else:
f = np.fft.fftfreq(len(z), 1)
f = np.fft.fftfreq(len(z), 1)[f_ix]

ix = tuple(slice(None) if i == 0 else None for i in range(z.ndim))
ph = np.exp(-2j * np.pi * n * f).astype(np.complex64)[ix]
ph = np.exp(-2j * np.pi * shift * f).astype(np.complex64)
shifted = pb.fft.ifft(pb.fft.fft(z.data, axis=0) * ph, axis=0)
shifted = shifted if np.iscomplexobj(z.data) else shifted.real

if np.iscomplexobj(z.data):
shifted = shifted.astype(z.dtype)
else:
shifted = shifted.real.astype(z.dtype)
start, stop = 0, 0
it = np.nditer(shift, flags=["multi_index"])
for a in it:
if a < 0:
a = int(np.floor(a))
ix = (np.s_[a:],) + it.multi_index
stop = min(stop, a)
else:
a = int(np.ceil(a))
ix = (np.s_[:a],) + it.multi_index
start = max(start, a)

x = type(z).like(z, shifted, start_time=z.start_time - n * z.dt)
shifted[ix] = 0

if n >= 0:
i = np.int64(np.ceil(n))
x = x[i:]
else:
i = np.int64(np.floor(n))
x = x[:i]
x = type(z).like(z, shifted)

if crop:
x = x[start:len(x) + stop]

return x

Expand All @@ -290,8 +309,8 @@ def freq_shift(z, /, shift):
z : BasebandSignal
Input signal.
shift : Quantity
Shift amount in units of frequency. Should be a scalar or
have shape ``z.sample_shape[:n]`` for ``0 <= n``.
Shift amount in units of frequency. Should be a scalar or have
shape that such that axes with length more than 1 match ``z.sample_shape``.

Returns
-------
Expand All @@ -318,8 +337,13 @@ def freq_shift(z, /, shift):
ix = (slice(None),) * shift.ndim + (None,) * (z.ndim - shift.ndim - 1)
ft = (shift[ix] * z.dt).to_value(u.one)

if isinstance(z.data, da.Array):
n = da.arange(len(z), chunks=(-1,))
else:
n = np.arange(len(z))

ix = tuple(slice(None) if j == 0 else None for j in range(z.ndim))
ph = np.exp(2j * np.pi * ft * np.arange(len(z))[ix]).astype(z.dtype)
ph = np.exp(2j * np.pi * ft * n[ix]).astype(z.dtype)

x = np.fft.fftshift(pb.fft.fft(z.data * ph, axis=0), axes=(0,))

Expand Down
131 changes: 97 additions & 34 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Tests for core signal functions."""

import math
import pytest
import itertools
import numpy as np
Expand All @@ -12,9 +11,13 @@

def assert_equal_signals(x, y):
assert np.allclose(np.array(x), np.array(y), atol=1e-6)
assert Time.isclose(x.start_time, y.start_time)
assert u.isclose(x.sample_rate, y.sample_rate)

if x.start_time is None:
assert y.start_time is None
else:
assert Time.isclose(x.start_time, y.start_time)


def assert_equal_radiosignals(x, y):
assert_equal_signals(x, y)
Expand All @@ -25,7 +28,7 @@ def impulse(N, t0):
"""Generate noisy impulse at t0, with given S/N."""
n = (np.arange(N) - N // 2) / N
x = np.exp(-2j * np.pi * t0 * n)
return np.fft.ifft(np.fft.ifftshift(x)).astype(np.complex128)
return np.fft.ifft(np.fft.ifftshift(x, axes=(-1,))).astype(np.complex128)


def sinusoid(N, f0):
Expand All @@ -34,6 +37,12 @@ def sinusoid(N, f0):
return np.exp(2j * np.pi * f0 * n).astype(np.complex128)


def noise(shape):
"""Generate complex Gaussian noise."""
f = np.random.default_rng().standard_normal
return (f(shape) + 1j * f(shape)) / np.sqrt(2)


class TestConcatenate:
@pytest.mark.parametrize("use_dask", [True, False])
def test_basic(self, use_dask):
Expand Down Expand Up @@ -304,44 +313,84 @@ def test_errors(self):
class TestTimeShift:
@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("use_complex", [True, False])
def test_int_roll(self, use_dask, use_complex):
if use_dask:
f = da.random.standard_normal
else:
f = np.random.standard_normal
@pytest.mark.parametrize("start_time", [Time.now(), None])
def test_int_scalar(self, use_dask, use_complex, start_time):
kw = dict(sample_rate=1 * u.kHz, start_time=start_time)

for shape in [(4096, 4, 2), (4096, 4), (4096,)]:
if use_complex:
z = pb.Signal(noise(shape), **kw)
else:
z = pb.Signal(noise(shape).real, **kw)

shape = (4096, 4, 2)
if use_dask:
z = z.to_dask_array()

if use_complex:
x = (f(shape) + 1j * f(shape)).astype(np.complex128)
else:
x = f(shape).astype(np.float64)
for n in [-12, -7, -3, 0, 4, 8, 13]:
for s in [n, n * u.ms]:
y = pb.time_shift(z, s)

z = pb.Signal(x, sample_rate=1 * u.kHz, start_time=Time.now())
if z.start_time is None:
assert y.start_time is None
else:
assert Time.isclose(z.start_time, y.start_time)

for n in [1, 10, 55, 211]:
assert_equal_signals(z[:-n], pb.time_shift(z, n))
assert_equal_signals(z[:-n], pb.time_shift(z, n * u.ms))
if n < 0:
assert np.allclose(np.array(y[:n]), np.array(z[-n:]), atol=1E-6)
assert np.allclose(np.array(y[n:]), 0)
elif n > 0:
assert np.allclose(np.array(y[n:]), np.array(z[:-n]), atol=1E-6)
assert np.allclose(np.array(y[:n]), 0)
else:
assert np.allclose(np.array(y), np.array(z))

assert_equal_signals(z[n:], pb.time_shift(z, -n))
assert_equal_signals(z[n:], pb.time_shift(z, -n * u.ms))
@pytest.mark.parametrize("use_dask", [True, False])
def test_advanced(self, use_dask):
N = 4096

def test_subsample_roll(self):
def impulse(N, t0):
"""Generate noisy impulse at t0, with given S/N."""
n = (np.arange(N) - N // 2) / N
x = np.exp(-2j * np.pi * t0 * n)
return np.fft.ifft(np.fft.ifftshift(x)).astype(np.complex128)
for shape in [(4096, 4, 2), (4096, 4), (4096,)]:
shifts = np.concatenate(
[
np.random.uniform(-20, 20, (4,) + shape[1:]),
np.random.uniform(0, 20, (3,) + shape[1:]),
np.random.uniform(-20, 0, (3,) + shape[1:])
],
axis=0
)

N = 1024
for shift in [16.5, 32.25, 50.1, 60.9, 466.666]:
imp1 = impulse(N, shift)
imp2 = impulse(N - math.ceil(shift), 0)
x = pb.Signal(imp1, sample_rate=1 * u.kHz, start_time=Time.now())
y = pb.time_shift(x, -shift)
assert np.allclose(np.array(y), imp2)
z = pb.time_shift(x, -shift * u.ms)
assert np.allclose(np.array(z), imp2)
for shift in shifts:
x = impulse(N, 100 - shift[..., None])
x = np.moveaxis(x, -1, 0)

z = pb.Signal(x, sample_rate=1 * u.kHz)

if use_dask:
z = z.to_dask_array()

y1 = pb.time_shift(z, shift, crop=False)
y2 = pb.time_shift(z, shift, crop=True)

x = np.zeros_like(y1.data)
x[100] = 1.0

a = max(0, int(np.ceil(shift.max())))
b = len(x) + min(0, int(np.floor(shift.min())))

assert np.allclose(np.asarray(y1), x)
assert np.allclose(np.asarray(y2), x[a:b])

def test_shape_errors(self):
x = pb.Signal(noise((4096, 4, 2)), sample_rate=1 * u.kHz)

for shape in [(1,), (1, 2), (4,), (4, 1), (4, 2)]:
shift = np.random.uniform(-20, 20, shape)
_ = pb.time_shift(x, shift)

for shape in [(2,), (5, 2), (4, 5), (1, 4), (2, 1)]:
shift = np.random.uniform(-20, 20, shape)

with pytest.raises(ValueError):
_ = pb.time_shift(x, shift)


class TestFreqShift:
Expand Down Expand Up @@ -444,6 +493,20 @@ def test_zeroing(self, N):
assert np.allclose(a[:s], 0)
assert np.allclose(a[s:], 1)

def test_shape_errors(self):
x = pb.BasebandSignal(noise((4096, 4, 2)), sample_rate=1 * u.kHz,
center_freq=1 * u.MHz)

for shape in [(1,), (1, 2), (4,), (4, 1), (4, 2)]:
shift = np.random.uniform(-20, 20, shape) * u.mHz
_ = pb.freq_shift(x, shift)

for shape in [(2,), (5, 2), (4, 5), (1, 4), (2, 1)]:
shift = np.random.uniform(-20, 20, shape) * u.mHz

with pytest.raises(ValueError):
_ = pb.freq_shift(x, shift)


class TestFastLen:
def test_fast(self):
Expand Down