Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/PSP-fix' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
emiddell committed Sep 7, 2024
2 parents bd68146 + c8069fa commit e802f72
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions src/cedalion/sigproc/quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def psp(
# the first sample in the window. Setting the stride size to the same value as the
# window length will result in non-overlapping windows.
windows = amp.rolling(time=nsamples).construct("window", stride=nsamples)

windows = windows.fillna(1e-6)
fs = amp.cd.sampling_rate

psp = np.zeros([len(windows["channel"]), len(windows["time"])])

# Vectorized signal extraction and correlation
sig = windows.transpose("channel", "time", "wavelength", "window").values
psp = np.zeros((sig.shape[0], sig.shape[1]))
lags = np.arange(-nsamples + 1, nsamples)

for w in range(sig.shape[1]): # loop over windows
sig_temp = sig[:,w,:,:]
Expand All @@ -118,26 +119,30 @@ def psp(
)

# FIXME assumes 2 wavelengths
norm_factor = [
np.sqrt(np.sum(sig_temp[ch, 0, :] ** 2) * np.sum(sig_temp[ch, 1, :] ** 2))
for ch in range(sig.shape[0])
]

corr /= np.tile(norm_factor, (corr.shape[1],1)).T

for ch in range(sig.shape[0]):
window = signal.windows.hamming(len(corr[ch,:]))
f, pxx = signal.periodogram(
corr[ch, :],
window=window,
nfft=len(corr[ch, :]),
fs=fs,
scaling="spectrum",
)
corr = corr /(nsamples - np.abs(lags))

nperseg = corr.shape[1]
window = np.hamming(nperseg)
window_seg = corr * window

psp[ch, w] = np.max(pxx)
fft_out = np.fft.rfft(window_seg, axis=1)
psd = (np.abs(fft_out) ** 2) / (fs * np.sum(window ** 2))
#freqs = np.fft.rfftfreq(nperseg, 1/fs)

# for ch in range(sig.shape[0]):
# window = signal.windows.hamming(len(corr[ch,:]))
# f, pxx = signal.welch(
# corr[ch, :],
# window=window,
# nfft=len(corr[ch, :]),
# fs=fs,
# scaling="density",
# )

psp[:, w] = np.max(psd, 1)

# keep dims channel and time

psp_xr = windows.isel(wavelength=0, window=0).drop_vars("wavelength").copy(data=psp)

# Apply threshold mask
Expand Down Expand Up @@ -200,7 +205,7 @@ def sci(amplitudes: NDTimeSeries, window_length: Quantity, sci_thresh: float):
def snr(amplitudes: cdt.NDTimeSeries, snr_thresh: float = 2.0):
"""Calculates signal-to-noise ratio for each channel and other dimension.
SNR is here the ratio of the average signal over time divided by its standard deviation.
SNR is the ratio of the average signal over time divided by its standard deviation.
Args:
amplitudes (:class:`NDTimeSeries`, (time, *)): the input time series
Expand Down

0 comments on commit e802f72

Please sign in to comment.