From 9ba351f8def79570983f1aa2f7c33c87c2dc43a3 Mon Sep 17 00:00:00 2001 From: Nikhil Mahajan Date: Sat, 25 Feb 2023 23:35:40 -0500 Subject: [PATCH 1/3] Fixed time-shift function and edited pb.snippet to use it. --- pulsarbat/transforms/transforms.py | 70 +++++++++++++++--------------- tests/test_transforms.py | 53 ++++++++++++++-------- 2 files changed, 69 insertions(+), 54 deletions(-) diff --git a/pulsarbat/transforms/transforms.py b/pulsarbat/transforms/transforms.py index dcfdb13..7e91c68 100644 --- a/pulsarbat/transforms/transforms.py +++ b/pulsarbat/transforms/transforms.py @@ -149,7 +149,7 @@ def concatenate(signals, /, axis=0): def snippet(z, /, t, n): - """Extracts a snippet of a signal. + """Extracts a snippet of a signal in time. If ``t`` corresponds to non-integer number of samples from the start of ``z``, time-shifting via FFT (by applying a phase gradient @@ -196,25 +196,7 @@ def snippet(z, /, t, n): if (i := int(t)) < t: shift = i - t - - if isinstance(z.data, da.Array): - f = da.fft.fftfreq(len(z), 1, chunks=(-1,)) - else: - f = np.fft.fftfreq(len(z), 1) - - ph = np.exp(-2j * np.pi * shift * f).astype(np.complex64) - - ix = (slice(None),) + (None,) * (z.ndim - 1) - shifted = pb.fft.ifft(pb.fft.fft(z.data, axis=0) * ph[ix], axis=0) - shifted = shifted if np.iscomplexobj(z.data) else shifted.real - shifted = shifted.astype(z.dtype) - - if z.start_time is None: - new_start = None - else: - new_start = z.start_time - shift * z.dt - - z = type(z).like(z, shifted, start_time=new_start) + z = time_shift(z, shift) return z[i : i + n] @@ -223,7 +205,9 @@ def time_shift(z, /, t): """Shift signal by given number of samples or time. This function shifts signals in time via FFT by multiplying by - a phase gradient in frequency domain. + a phase gradient in frequency domain. This usually only makes sense if + ``z`` is a :py:class:`.BasebandSignal`. For non-baseband signals, + the output might not be meaningful. Parameters ---------- @@ -233,42 +217,56 @@ def time_shift(z, /, t): Shift amount. If a number (int or float), the signal is shifted by that number of samples. An astropy Quantity with units of time can also be passed, in which case the signal will be - shifted by `t * z.sample_rate` samples. + shifted by `dt * z.sample_rate` samples. Returns ------- out : Signal - Shifted signal. + Shifted signal, with out-of-bounds data cropped out. The output signal + will usually have a length smaller than the input signal as a result. + See notes for cropping behavior. + + Notes + ----- + Since an FFT is used, it is efficient to provide a signal with a + fast FFT length via :py:func:`pulsarbat.fast_len`. + + The primary use-case for this function is when shift is between + -1 and 0, which is useful for sub-sample shifting of large baseband + signals. A large positive shift (for example, 10) simply returns a + cropped signal (``x[10:]``). A non-integer positive shift, such as 5.6, + has the same effect as a of -0.4, but with more unnecessary cropping of + the signal. """ if t == 0: return z if isinstance(t, u.Quantity): - n = np.float64((t * z.sample_rate).to_value(u.one)) - else: - n = np.float64(t) + t = (t * z.sample_rate).to_value(u.one) if isinstance(z.data, da.Array): - f = da.fft.fftfreq(len(z), 1) + f = da.fft.fftfreq(len(z), 1, chunks=(-1,)) else: f = np.fft.fftfreq(len(z), 1) - ix = tuple(slice(None) if i == 0 else None for i in range(z.ndim)) - ph = np.exp(-2j * np.pi * n * f).astype(np.complex64)[ix] + ix = (slice(None),) + (None,) * (z.ndim - 1) + ph = np.exp(-2j * np.pi * t * f).astype(np.complex64)[ix] + shifted = pb.fft.ifft(pb.fft.fft(z.data, axis=0) * ph, axis=0) + shifted = shifted if np.iscomplexobj(z.data) else shifted.real - if np.iscomplexobj(z.data): - shifted = shifted.astype(z.dtype) + if z.start_time is None: + new_start = None else: - shifted = shifted.real.astype(z.dtype) + new_start = z.start_time - t * z.dt - x = type(z).like(z, shifted, start_time=z.start_time - n * z.dt) + x = type(z).like(z, shifted, start_time=new_start) - if n >= 0: - i = np.int64(np.ceil(n)) + if t >= 0: + i = np.int64(np.ceil(t)) x = x[i:] else: - i = np.int64(np.floor(n)) + i = np.int64(np.floor(t)) x = x[:i] return x diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6747a47..15528f1 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -12,9 +12,13 @@ def assert_equal_signals(x, y): assert np.allclose(np.array(x), np.array(y), atol=1e-6) - assert Time.isclose(x.start_time, y.start_time) assert u.isclose(x.sample_rate, y.sample_rate) + if x.start_time is None: + assert y.start_time is None + else: + assert Time.isclose(x.start_time, y.start_time) + def assert_equal_radiosignals(x, y): assert_equal_signals(x, y) @@ -304,7 +308,8 @@ def test_errors(self): class TestTimeShift: @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("use_complex", [True, False]) - def test_int_roll(self, use_dask, use_complex): + @pytest.mark.parametrize("start_time", [Time.now(), None]) + def test_int_roll(self, use_dask, use_complex, start_time): if use_dask: f = da.random.standard_normal else: @@ -317,31 +322,43 @@ def test_int_roll(self, use_dask, use_complex): else: x = f(shape).astype(np.float64) - z = pb.Signal(x, sample_rate=1 * u.kHz, start_time=Time.now()) + z = pb.Signal(x, sample_rate=1 * u.kHz, start_time=start_time) - for n in [1, 10, 55, 211]: + for n in [1, 10, 55]: assert_equal_signals(z[:-n], pb.time_shift(z, n)) assert_equal_signals(z[:-n], pb.time_shift(z, n * u.ms)) assert_equal_signals(z[n:], pb.time_shift(z, -n)) assert_equal_signals(z[n:], pb.time_shift(z, -n * u.ms)) - def test_subsample_roll(self): - def impulse(N, t0): - """Generate noisy impulse at t0, with given S/N.""" - n = (np.arange(N) - N // 2) / N - x = np.exp(-2j * np.pi * t0 * n) - return np.fft.ifft(np.fft.ifftshift(x)).astype(np.complex128) + assert_equal_signals(z, pb.time_shift(z, 0)) + def test_subsample_roll(self): N = 1024 - for shift in [16.5, 32.25, 50.1, 60.9, 466.666]: - imp1 = impulse(N, shift) - imp2 = impulse(N - math.ceil(shift), 0) - x = pb.Signal(imp1, sample_rate=1 * u.kHz, start_time=Time.now()) - y = pb.time_shift(x, -shift) - assert np.allclose(np.array(y), imp2) - z = pb.time_shift(x, -shift * u.ms) - assert np.allclose(np.array(z), imp2) + + t0 = Time.now() + for shift in [-20.4, -10.666, -4.99, 10.33, 2.9, 5.5]: + if shift < 0: + x = pb.Signal( + impulse(N, -shift), + sample_rate=1 * u.Hz, + start_time=t0 + shift * u.s, + ) + else: + x = pb.Signal( + impulse(N, np.ceil(shift) - shift), + sample_rate=1 * u.Hz, + start_time=t0 - (np.ceil(shift) - shift) * u.s, + ) + + y = pb.time_shift(x, shift) + + z = np.zeros_like(y.data) + z[0] = 1 + + assert np.allclose(np.asarray(y.data), z) + assert x.sample_rate == y.sample_rate + assert Time.isclose(t0, y.start_time) class TestFreqShift: From cbf722c64cfdda30dbb21c1423024fddd6462d22 Mon Sep 17 00:00:00 2001 From: Nikhil Mahajan Date: Mon, 27 Feb 2023 02:08:57 -0500 Subject: [PATCH 2/3] Added multi-shift functionality in pb.time_shift --- .zenodo.json | 5 +- README.rst | 14 +++- pulsarbat/transforms/transforms.py | 102 ++++++++++++++--------- tests/test_transforms.py | 128 ++++++++++++++++++++--------- 4 files changed, 168 insertions(+), 81 deletions(-) diff --git a/.zenodo.json b/.zenodo.json index 9b1f32f..bf61736 100644 --- a/.zenodo.json +++ b/.zenodo.json @@ -10,5 +10,8 @@ "name": "Rebecca Lin", "orcid": "0000-0003-4530-4254" } - ] + ], + + "license": "GPL-3.0-or-later", + "title": "pulsarbat: PULSAR Baseband Analysis Tools" } diff --git a/README.rst b/README.rst index b12e175..1b7ca08 100644 --- a/README.rst +++ b/README.rst @@ -55,7 +55,19 @@ Citing ``pulsarbat`` has a DOI via Zenodo: https://doi.org/10.5281/zenodo.6934355 -This DOI represents all versions, and will always resolve to the latest one. To cite a specific version, follow the link and find the version you want to cite on Zenodo. +This DOI link represents all versions, and will always resolve to the latest one. +Use the following Bibtex entry to cite this work: + +.. code-block:: + + @software{pulsarbat, + author = {Nikhil Mahajan and Rebecca Lin}, + title = {pulsarbat: PULSAR Baseband Analysis Tools}, + year = {2023}, + publisher = {Zenodo}, + doi = {10.5281/zenodo.6934355}, + url = {https://doi.org/10.5281/zenodo.6934355} + } License diff --git a/pulsarbat/transforms/transforms.py b/pulsarbat/transforms/transforms.py index 7e91c68..71826ec 100644 --- a/pulsarbat/transforms/transforms.py +++ b/pulsarbat/transforms/transforms.py @@ -196,15 +196,22 @@ def snippet(z, /, t, n): if (i := int(t)) < t: shift = i - t - z = time_shift(z, shift) + + if z.start_time is None: + new_start = None + else: + new_start = z.start_time - shift * z.dt + + shifted = pb.time_shift(z, shift, crop=True).data + z = type(z).like(z, shifted, start_time=new_start) return z[i : i + n] -def time_shift(z, /, t): - """Shift signal by given number of samples or time. +def time_shift(z, /, shift, crop=False): + """Shift signal data by given number of samples or time. - This function shifts signals in time via FFT by multiplying by + This function shifts the signal data in time via FFT by multiplying by a phase gradient in frequency domain. This usually only makes sense if ``z`` is a :py:class:`.BasebandSignal`. For non-baseband signals, the output might not be meaningful. @@ -213,61 +220,75 @@ def time_shift(z, /, t): ---------- z : Signal Input signal. - t : int, float or Quantity + shift : int, float, array-like or Quantity Shift amount. If a number (int or float), the signal is shifted by that number of samples. An astropy Quantity with units of time can also be passed, in which case the signal will be - shifted by `dt * z.sample_rate` samples. + shifted by `dt * z.sample_rate` samples. If an array, must have + shape such that axes with length more than 1 match ``z.sample_shape``. + crop : bool, optional + Whether the returned signal is cropped to eliminate out-of-bounds + data. Default is False. Returns ------- out : Signal - Shifted signal, with out-of-bounds data cropped out. The output signal - will usually have a length smaller than the input signal as a result. - See notes for cropping behavior. + Shifted signal. If the ``crop`` parameter is ``False``, will have + the same shape and ``start_time`` as input signal. If ``crop`` is + ``True``, ``start_time`` will change by ``max(0, shift.max())``. Notes ----- Since an FFT is used, it is efficient to provide a signal with a fast FFT length via :py:func:`pulsarbat.fast_len`. - - The primary use-case for this function is when shift is between - -1 and 0, which is useful for sub-sample shifting of large baseband - signals. A large positive shift (for example, 10) simply returns a - cropped signal (``x[10:]``). A non-integer positive shift, such as 5.6, - has the same effect as a of -0.4, but with more unnecessary cropping of - the signal. """ - if t == 0: + if isinstance(shift, u.Quantity): + shift = (shift * z.sample_rate).to_value(u.one) + + shift = np.array(shift) + + if shift.ndim >= z.ndim: + raise ValueError( + f"shift has too many dimensions. Expected <= {z.ndim - 1} dimensions, " + f"got {shift.ndim} dimensions!" + ) + + # If shifts are zero, do nothing + if np.allclose(shift, 0): return z - if isinstance(t, u.Quantity): - t = (t * z.sample_rate).to_value(u.one) + if shift.ndim > 0: + ix = (slice(None),) * shift.ndim + (None,) * (z.ndim - shift.ndim - 1) + shift = shift[ix] + f_ix = tuple(slice(None) if j == 0 else None for j in range(z.ndim)) if isinstance(z.data, da.Array): - f = da.fft.fftfreq(len(z), 1, chunks=(-1,)) + f = da.fft.fftfreq(len(z), 1, chunks=(-1,))[f_ix] else: - f = np.fft.fftfreq(len(z), 1) - - ix = (slice(None),) + (None,) * (z.ndim - 1) - ph = np.exp(-2j * np.pi * t * f).astype(np.complex64)[ix] + f = np.fft.fftfreq(len(z), 1)[f_ix] + ph = np.exp(-2j * np.pi * shift * f).astype(np.complex64) shifted = pb.fft.ifft(pb.fft.fft(z.data, axis=0) * ph, axis=0) shifted = shifted if np.iscomplexobj(z.data) else shifted.real - if z.start_time is None: - new_start = None - else: - new_start = z.start_time - t * z.dt + start, stop = 0, 0 + it = np.nditer(shift, flags=["multi_index"]) + for a in it: + if a < 0: + a = int(np.floor(a)) + ix = (np.s_[a:],) + it.multi_index + stop = min(stop, a) + else: + a = int(np.ceil(a)) + ix = (np.s_[:a],) + it.multi_index + start = max(start, a) - x = type(z).like(z, shifted, start_time=new_start) + shifted[ix] = 0 - if t >= 0: - i = np.int64(np.ceil(t)) - x = x[i:] - else: - i = np.int64(np.floor(t)) - x = x[:i] + x = type(z).like(z, shifted) + + if crop: + x = x[start:len(x) + stop] return x @@ -288,8 +309,8 @@ def freq_shift(z, /, shift): z : BasebandSignal Input signal. shift : Quantity - Shift amount in units of frequency. Should be a scalar or - have shape ``z.sample_shape[:n]`` for ``0 <= n``. + Shift amount in units of frequency. Should be a scalar or have + shape that such that axes with length more than 1 match ``z.sample_shape``. Returns ------- @@ -316,8 +337,13 @@ def freq_shift(z, /, shift): ix = (slice(None),) * shift.ndim + (None,) * (z.ndim - shift.ndim - 1) ft = (shift[ix] * z.dt).to_value(u.one) + if isinstance(z.data, da.Array): + n = da.arange(len(z), chunks=(-1,)) + else: + n = np.arange(len(z)) + ix = tuple(slice(None) if j == 0 else None for j in range(z.ndim)) - ph = np.exp(2j * np.pi * ft * np.arange(len(z))[ix]).astype(z.dtype) + ph = np.exp(2j * np.pi * ft * n[ix]).astype(z.dtype) x = np.fft.fftshift(pb.fft.fft(z.data * ph, axis=0), axes=(0,)) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 15528f1..13b076f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,6 +1,5 @@ """Tests for core signal functions.""" -import math import pytest import itertools import numpy as np @@ -29,7 +28,7 @@ def impulse(N, t0): """Generate noisy impulse at t0, with given S/N.""" n = (np.arange(N) - N // 2) / N x = np.exp(-2j * np.pi * t0 * n) - return np.fft.ifft(np.fft.ifftshift(x)).astype(np.complex128) + return np.fft.ifft(np.fft.ifftshift(x, axes=(-1,))).astype(np.complex128) def sinusoid(N, f0): @@ -38,6 +37,12 @@ def sinusoid(N, f0): return np.exp(2j * np.pi * f0 * n).astype(np.complex128) +def noise(shape): + """Generate complex Gaussian noise.""" + f = np.random.default_rng().standard_normal + return (f(shape) + 1j * f(shape)) / np.sqrt(2) + + class TestConcatenate: @pytest.mark.parametrize("use_dask", [True, False]) def test_basic(self, use_dask): @@ -309,56 +314,83 @@ class TestTimeShift: @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("use_complex", [True, False]) @pytest.mark.parametrize("start_time", [Time.now(), None]) - def test_int_roll(self, use_dask, use_complex, start_time): - if use_dask: - f = da.random.standard_normal - else: - f = np.random.standard_normal + def test_int_scalar(self, use_dask, use_complex, start_time): + kw = dict(sample_rate=1 * u.kHz, start_time=start_time) - shape = (4096, 4, 2) + for shape in [(4096, 4, 2), (4096, 4), (4096,)]: + if use_complex: + z = pb.Signal(noise(shape), **kw) + else: + z = pb.Signal(noise(shape).real, **kw) - if use_complex: - x = (f(shape) + 1j * f(shape)).astype(np.complex128) - else: - x = f(shape).astype(np.float64) + if use_dask: + z = z.to_dask_array() - z = pb.Signal(x, sample_rate=1 * u.kHz, start_time=start_time) + for n in [-12, -7, -3, 0, 4, 8, 13]: + for s in [n, n * u.ms]: + y = pb.time_shift(z, s) - for n in [1, 10, 55]: - assert_equal_signals(z[:-n], pb.time_shift(z, n)) - assert_equal_signals(z[:-n], pb.time_shift(z, n * u.ms)) + if z.start_time is None: + assert y.start_time is None + else: + assert Time.isclose(z.start_time, y.start_time) - assert_equal_signals(z[n:], pb.time_shift(z, -n)) - assert_equal_signals(z[n:], pb.time_shift(z, -n * u.ms)) + if n < 0: + assert np.allclose(np.array(y[:n]), np.array(z[-n:]), atol=1E-6) + assert np.allclose(np.array(y[n:]), 0) + elif n > 0: + assert np.allclose(np.array(y[n:]), np.array(z[:-n]), atol=1E-6) + assert np.allclose(np.array(y[:n]), 0) + else: + assert np.allclose(np.array(y), np.array(z)) - assert_equal_signals(z, pb.time_shift(z, 0)) + @pytest.mark.parametrize("use_dask", [True, False]) + def test_advanced(self, use_dask): + N = 4096 - def test_subsample_roll(self): - N = 1024 + for shape in [(4096, 4, 2), (4096, 4), (4096,)]: + shifts = np.concatenate( + [ + np.random.uniform(-20, 20, (4,) + shape[1:]), + np.random.uniform(0, 20, (3,) + shape[1:]), + np.random.uniform(-20, 0, (3,) + shape[1:]) + ], + axis=0 + ) - t0 = Time.now() - for shift in [-20.4, -10.666, -4.99, 10.33, 2.9, 5.5]: - if shift < 0: - x = pb.Signal( - impulse(N, -shift), - sample_rate=1 * u.Hz, - start_time=t0 + shift * u.s, - ) - else: - x = pb.Signal( - impulse(N, np.ceil(shift) - shift), - sample_rate=1 * u.Hz, - start_time=t0 - (np.ceil(shift) - shift) * u.s, - ) + for shift in shifts: + x = impulse(N, 100 - shift[..., None]) + x = np.moveaxis(x, -1, 0) - y = pb.time_shift(x, shift) + z = pb.Signal(x, sample_rate=1 * u.kHz) - z = np.zeros_like(y.data) - z[0] = 1 + if use_dask: + z = z.to_dask_array() - assert np.allclose(np.asarray(y.data), z) - assert x.sample_rate == y.sample_rate - assert Time.isclose(t0, y.start_time) + y1 = pb.time_shift(z, shift, crop=False) + y2 = pb.time_shift(z, shift, crop=True) + + x = np.zeros_like(y1.data) + x[100] = 1.0 + + a = max(0, int(np.ceil(shift.max()))) + b = len(x) + min(0, int(np.floor(shift.min()))) + + assert np.allclose(np.asarray(y1), x) + assert np.allclose(np.asarray(y2), x[a:b]) + + def test_shape_errors(self): + x = pb.Signal(noise((4096, 4, 2)), sample_rate=1 * u.kHz) + + for shape in [(1,), (1, 2), (4,), (4, 1), (4, 2)]: + shift = np.random.uniform(-20, 20, shape) + _ = pb.time_shift(x, shift) + + for shape in [(2,), (5, 2), (4, 5), (1, 4), (2, 1)]: + shift = np.random.uniform(-20, 20, shape) + + with pytest.raises(ValueError): + _ = pb.time_shift(x, shift) class TestFreqShift: @@ -461,6 +493,20 @@ def test_zeroing(self, N): assert np.allclose(a[:s], 0) assert np.allclose(a[s:], 1) + def test_shape_errors(self): + x = pb.BasebandSignal(noise((4096, 4, 2)), sample_rate=1 * u.kHz, + center_freq=1 * u.MHz) + + for shape in [(1,), (1, 2), (4,), (4, 1), (4, 2)]: + shift = np.random.uniform(-20, 20, shape) * u.mHz + _ = pb.freq_shift(x, shift) + + for shape in [(2,), (5, 2), (4, 5), (1, 4), (2, 1)]: + shift = np.random.uniform(-20, 20, shape) * u.mHz + + with pytest.raises(ValueError): + _ = pb.freq_shift(x, shift) + class TestFastLen: def test_fast(self): From 26e1e1f4d56a32593b099ba7b7df67562586dfcf Mon Sep 17 00:00:00 2001 From: Nikhil Mahajan Date: Mon, 27 Feb 2023 02:12:53 -0500 Subject: [PATCH 3/3] Updated changelog. --- docs/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 8e4a70d..fbd0ddb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,4 +7,4 @@ Release Notes - Added a frequency-shift function ``pulsarbat.freq_shift`` (:pr:`52`) - Added a snippet function ``pulsarbat.snippet`` (:pr:`54`) - +- ``pulsarbat.time_shift`` can now support multiple shifts (:pr:`58`)