diff --git a/.zenodo.json b/.zenodo.json index 9b1f32f..bf61736 100644 --- a/.zenodo.json +++ b/.zenodo.json @@ -10,5 +10,8 @@ "name": "Rebecca Lin", "orcid": "0000-0003-4530-4254" } - ] + ], + + "license": "GPL-3.0-or-later", + "title": "pulsarbat: PULSAR Baseband Analysis Tools" } diff --git a/README.rst b/README.rst index b12e175..1b7ca08 100644 --- a/README.rst +++ b/README.rst @@ -55,7 +55,19 @@ Citing ``pulsarbat`` has a DOI via Zenodo: https://doi.org/10.5281/zenodo.6934355 -This DOI represents all versions, and will always resolve to the latest one. To cite a specific version, follow the link and find the version you want to cite on Zenodo. +This DOI link represents all versions, and will always resolve to the latest one. +Use the following Bibtex entry to cite this work: + +.. code-block:: + + @software{pulsarbat, + author = {Nikhil Mahajan and Rebecca Lin}, + title = {pulsarbat: PULSAR Baseband Analysis Tools}, + year = {2023}, + publisher = {Zenodo}, + doi = {10.5281/zenodo.6934355}, + url = {https://doi.org/10.5281/zenodo.6934355} + } License diff --git a/docs/changelog.rst b/docs/changelog.rst index 8e4a70d..fbd0ddb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,4 +7,4 @@ Release Notes - Added a frequency-shift function ``pulsarbat.freq_shift`` (:pr:`52`) - Added a snippet function ``pulsarbat.snippet`` (:pr:`54`) - +- ``pulsarbat.time_shift`` can now support multiple shifts (:pr:`58`) diff --git a/pulsarbat/transforms/transforms.py b/pulsarbat/transforms/transforms.py index dcfdb13..71826ec 100644 --- a/pulsarbat/transforms/transforms.py +++ b/pulsarbat/transforms/transforms.py @@ -149,7 +149,7 @@ def concatenate(signals, /, axis=0): def snippet(z, /, t, n): - """Extracts a snippet of a signal. + """Extracts a snippet of a signal in time. If ``t`` corresponds to non-integer number of samples from the start of ``z``, time-shifting via FFT (by applying a phase gradient @@ -197,79 +197,98 @@ def snippet(z, /, t, n): if (i := int(t)) < t: shift = i - t - if isinstance(z.data, da.Array): - f = da.fft.fftfreq(len(z), 1, chunks=(-1,)) - else: - f = np.fft.fftfreq(len(z), 1) - - ph = np.exp(-2j * np.pi * shift * f).astype(np.complex64) - - ix = (slice(None),) + (None,) * (z.ndim - 1) - shifted = pb.fft.ifft(pb.fft.fft(z.data, axis=0) * ph[ix], axis=0) - shifted = shifted if np.iscomplexobj(z.data) else shifted.real - shifted = shifted.astype(z.dtype) - if z.start_time is None: new_start = None else: new_start = z.start_time - shift * z.dt + shifted = pb.time_shift(z, shift, crop=True).data z = type(z).like(z, shifted, start_time=new_start) return z[i : i + n] -def time_shift(z, /, t): - """Shift signal by given number of samples or time. +def time_shift(z, /, shift, crop=False): + """Shift signal data by given number of samples or time. - This function shifts signals in time via FFT by multiplying by - a phase gradient in frequency domain. + This function shifts the signal data in time via FFT by multiplying by + a phase gradient in frequency domain. This usually only makes sense if + ``z`` is a :py:class:`.BasebandSignal`. For non-baseband signals, + the output might not be meaningful. Parameters ---------- z : Signal Input signal. - t : int, float or Quantity + shift : int, float, array-like or Quantity Shift amount. If a number (int or float), the signal is shifted by that number of samples. An astropy Quantity with units of time can also be passed, in which case the signal will be - shifted by `t * z.sample_rate` samples. + shifted by `dt * z.sample_rate` samples. If an array, must have + shape such that axes with length more than 1 match ``z.sample_shape``. + crop : bool, optional + Whether the returned signal is cropped to eliminate out-of-bounds + data. Default is False. Returns ------- out : Signal - Shifted signal. + Shifted signal. If the ``crop`` parameter is ``False``, will have + the same shape and ``start_time`` as input signal. If ``crop`` is + ``True``, ``start_time`` will change by ``max(0, shift.max())``. + + Notes + ----- + Since an FFT is used, it is efficient to provide a signal with a + fast FFT length via :py:func:`pulsarbat.fast_len`. """ - if t == 0: + if isinstance(shift, u.Quantity): + shift = (shift * z.sample_rate).to_value(u.one) + + shift = np.array(shift) + + if shift.ndim >= z.ndim: + raise ValueError( + f"shift has too many dimensions. Expected <= {z.ndim - 1} dimensions, " + f"got {shift.ndim} dimensions!" + ) + + # If shifts are zero, do nothing + if np.allclose(shift, 0): return z - if isinstance(t, u.Quantity): - n = np.float64((t * z.sample_rate).to_value(u.one)) - else: - n = np.float64(t) + if shift.ndim > 0: + ix = (slice(None),) * shift.ndim + (None,) * (z.ndim - shift.ndim - 1) + shift = shift[ix] + f_ix = tuple(slice(None) if j == 0 else None for j in range(z.ndim)) if isinstance(z.data, da.Array): - f = da.fft.fftfreq(len(z), 1) + f = da.fft.fftfreq(len(z), 1, chunks=(-1,))[f_ix] else: - f = np.fft.fftfreq(len(z), 1) + f = np.fft.fftfreq(len(z), 1)[f_ix] - ix = tuple(slice(None) if i == 0 else None for i in range(z.ndim)) - ph = np.exp(-2j * np.pi * n * f).astype(np.complex64)[ix] + ph = np.exp(-2j * np.pi * shift * f).astype(np.complex64) shifted = pb.fft.ifft(pb.fft.fft(z.data, axis=0) * ph, axis=0) + shifted = shifted if np.iscomplexobj(z.data) else shifted.real - if np.iscomplexobj(z.data): - shifted = shifted.astype(z.dtype) - else: - shifted = shifted.real.astype(z.dtype) + start, stop = 0, 0 + it = np.nditer(shift, flags=["multi_index"]) + for a in it: + if a < 0: + a = int(np.floor(a)) + ix = (np.s_[a:],) + it.multi_index + stop = min(stop, a) + else: + a = int(np.ceil(a)) + ix = (np.s_[:a],) + it.multi_index + start = max(start, a) - x = type(z).like(z, shifted, start_time=z.start_time - n * z.dt) + shifted[ix] = 0 - if n >= 0: - i = np.int64(np.ceil(n)) - x = x[i:] - else: - i = np.int64(np.floor(n)) - x = x[:i] + x = type(z).like(z, shifted) + + if crop: + x = x[start:len(x) + stop] return x @@ -290,8 +309,8 @@ def freq_shift(z, /, shift): z : BasebandSignal Input signal. shift : Quantity - Shift amount in units of frequency. Should be a scalar or - have shape ``z.sample_shape[:n]`` for ``0 <= n``. + Shift amount in units of frequency. Should be a scalar or have + shape that such that axes with length more than 1 match ``z.sample_shape``. Returns ------- @@ -318,8 +337,13 @@ def freq_shift(z, /, shift): ix = (slice(None),) * shift.ndim + (None,) * (z.ndim - shift.ndim - 1) ft = (shift[ix] * z.dt).to_value(u.one) + if isinstance(z.data, da.Array): + n = da.arange(len(z), chunks=(-1,)) + else: + n = np.arange(len(z)) + ix = tuple(slice(None) if j == 0 else None for j in range(z.ndim)) - ph = np.exp(2j * np.pi * ft * np.arange(len(z))[ix]).astype(z.dtype) + ph = np.exp(2j * np.pi * ft * n[ix]).astype(z.dtype) x = np.fft.fftshift(pb.fft.fft(z.data * ph, axis=0), axes=(0,)) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6747a47..13b076f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,6 +1,5 @@ """Tests for core signal functions.""" -import math import pytest import itertools import numpy as np @@ -12,9 +11,13 @@ def assert_equal_signals(x, y): assert np.allclose(np.array(x), np.array(y), atol=1e-6) - assert Time.isclose(x.start_time, y.start_time) assert u.isclose(x.sample_rate, y.sample_rate) + if x.start_time is None: + assert y.start_time is None + else: + assert Time.isclose(x.start_time, y.start_time) + def assert_equal_radiosignals(x, y): assert_equal_signals(x, y) @@ -25,7 +28,7 @@ def impulse(N, t0): """Generate noisy impulse at t0, with given S/N.""" n = (np.arange(N) - N // 2) / N x = np.exp(-2j * np.pi * t0 * n) - return np.fft.ifft(np.fft.ifftshift(x)).astype(np.complex128) + return np.fft.ifft(np.fft.ifftshift(x, axes=(-1,))).astype(np.complex128) def sinusoid(N, f0): @@ -34,6 +37,12 @@ def sinusoid(N, f0): return np.exp(2j * np.pi * f0 * n).astype(np.complex128) +def noise(shape): + """Generate complex Gaussian noise.""" + f = np.random.default_rng().standard_normal + return (f(shape) + 1j * f(shape)) / np.sqrt(2) + + class TestConcatenate: @pytest.mark.parametrize("use_dask", [True, False]) def test_basic(self, use_dask): @@ -304,44 +313,84 @@ def test_errors(self): class TestTimeShift: @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("use_complex", [True, False]) - def test_int_roll(self, use_dask, use_complex): - if use_dask: - f = da.random.standard_normal - else: - f = np.random.standard_normal + @pytest.mark.parametrize("start_time", [Time.now(), None]) + def test_int_scalar(self, use_dask, use_complex, start_time): + kw = dict(sample_rate=1 * u.kHz, start_time=start_time) + + for shape in [(4096, 4, 2), (4096, 4), (4096,)]: + if use_complex: + z = pb.Signal(noise(shape), **kw) + else: + z = pb.Signal(noise(shape).real, **kw) - shape = (4096, 4, 2) + if use_dask: + z = z.to_dask_array() - if use_complex: - x = (f(shape) + 1j * f(shape)).astype(np.complex128) - else: - x = f(shape).astype(np.float64) + for n in [-12, -7, -3, 0, 4, 8, 13]: + for s in [n, n * u.ms]: + y = pb.time_shift(z, s) - z = pb.Signal(x, sample_rate=1 * u.kHz, start_time=Time.now()) + if z.start_time is None: + assert y.start_time is None + else: + assert Time.isclose(z.start_time, y.start_time) - for n in [1, 10, 55, 211]: - assert_equal_signals(z[:-n], pb.time_shift(z, n)) - assert_equal_signals(z[:-n], pb.time_shift(z, n * u.ms)) + if n < 0: + assert np.allclose(np.array(y[:n]), np.array(z[-n:]), atol=1E-6) + assert np.allclose(np.array(y[n:]), 0) + elif n > 0: + assert np.allclose(np.array(y[n:]), np.array(z[:-n]), atol=1E-6) + assert np.allclose(np.array(y[:n]), 0) + else: + assert np.allclose(np.array(y), np.array(z)) - assert_equal_signals(z[n:], pb.time_shift(z, -n)) - assert_equal_signals(z[n:], pb.time_shift(z, -n * u.ms)) + @pytest.mark.parametrize("use_dask", [True, False]) + def test_advanced(self, use_dask): + N = 4096 - def test_subsample_roll(self): - def impulse(N, t0): - """Generate noisy impulse at t0, with given S/N.""" - n = (np.arange(N) - N // 2) / N - x = np.exp(-2j * np.pi * t0 * n) - return np.fft.ifft(np.fft.ifftshift(x)).astype(np.complex128) + for shape in [(4096, 4, 2), (4096, 4), (4096,)]: + shifts = np.concatenate( + [ + np.random.uniform(-20, 20, (4,) + shape[1:]), + np.random.uniform(0, 20, (3,) + shape[1:]), + np.random.uniform(-20, 0, (3,) + shape[1:]) + ], + axis=0 + ) - N = 1024 - for shift in [16.5, 32.25, 50.1, 60.9, 466.666]: - imp1 = impulse(N, shift) - imp2 = impulse(N - math.ceil(shift), 0) - x = pb.Signal(imp1, sample_rate=1 * u.kHz, start_time=Time.now()) - y = pb.time_shift(x, -shift) - assert np.allclose(np.array(y), imp2) - z = pb.time_shift(x, -shift * u.ms) - assert np.allclose(np.array(z), imp2) + for shift in shifts: + x = impulse(N, 100 - shift[..., None]) + x = np.moveaxis(x, -1, 0) + + z = pb.Signal(x, sample_rate=1 * u.kHz) + + if use_dask: + z = z.to_dask_array() + + y1 = pb.time_shift(z, shift, crop=False) + y2 = pb.time_shift(z, shift, crop=True) + + x = np.zeros_like(y1.data) + x[100] = 1.0 + + a = max(0, int(np.ceil(shift.max()))) + b = len(x) + min(0, int(np.floor(shift.min()))) + + assert np.allclose(np.asarray(y1), x) + assert np.allclose(np.asarray(y2), x[a:b]) + + def test_shape_errors(self): + x = pb.Signal(noise((4096, 4, 2)), sample_rate=1 * u.kHz) + + for shape in [(1,), (1, 2), (4,), (4, 1), (4, 2)]: + shift = np.random.uniform(-20, 20, shape) + _ = pb.time_shift(x, shift) + + for shape in [(2,), (5, 2), (4, 5), (1, 4), (2, 1)]: + shift = np.random.uniform(-20, 20, shape) + + with pytest.raises(ValueError): + _ = pb.time_shift(x, shift) class TestFreqShift: @@ -444,6 +493,20 @@ def test_zeroing(self, N): assert np.allclose(a[:s], 0) assert np.allclose(a[s:], 1) + def test_shape_errors(self): + x = pb.BasebandSignal(noise((4096, 4, 2)), sample_rate=1 * u.kHz, + center_freq=1 * u.MHz) + + for shape in [(1,), (1, 2), (4,), (4, 1), (4, 2)]: + shift = np.random.uniform(-20, 20, shape) * u.mHz + _ = pb.freq_shift(x, shift) + + for shape in [(2,), (5, 2), (4, 5), (1, 4), (2, 1)]: + shift = np.random.uniform(-20, 20, shape) * u.mHz + + with pytest.raises(ValueError): + _ = pb.freq_shift(x, shift) + class TestFastLen: def test_fast(self):