From 8cb6aac1393073ec5a7e4faacc493e379ea94c8c Mon Sep 17 00:00:00 2001 From: lauracarlton Date: Mon, 2 Sep 2024 11:59:48 -0400 Subject: [PATCH 1/2] fix to PSP metric to account for unbiased cross-correlation --- src/cedalion/sigproc/quality.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/cedalion/sigproc/quality.py b/src/cedalion/sigproc/quality.py index 3c3e58f..5639565 100644 --- a/src/cedalion/sigproc/quality.py +++ b/src/cedalion/sigproc/quality.py @@ -106,6 +106,7 @@ def psp( # Vectorized signal extraction and correlation sig = windows.transpose("channel", "time", "wavelength", "window").values psp = np.zeros((sig.shape[0], sig.shape[1])) + lags = np.arange(-nsamples + 1, nsamples) for w in range(sig.shape[1]): # loop over windows sig_temp = sig[:,w,:,:] @@ -118,21 +119,22 @@ def psp( ) # FIXME assumes 2 wavelengths - norm_factor = [ - np.sqrt(np.sum(sig_temp[ch, 0, :] ** 2) * np.sum(sig_temp[ch, 1, :] ** 2)) - for ch in range(sig.shape[0]) - ] + corr = corr /(nsamples - np.abs(lags)) + # norm_factor = [ + # np.sqrt(np.sum(sig_temp[ch, 0, :] ** 2) * np.sum(sig_temp[ch, 1, :] ** 2)) + # for ch in range(sig.shape[0]) + # ] - corr /= np.tile(norm_factor, (corr.shape[1],1)).T + # corr /= np.tile(norm_factor, (corr.shape[1],1)).T for ch in range(sig.shape[0]): window = signal.windows.hamming(len(corr[ch,:])) - f, pxx = signal.periodogram( + f, pxx = signal.welch( corr[ch, :], window=window, nfft=len(corr[ch, :]), fs=fs, - scaling="spectrum", + scaling="density", ) psp[ch, w] = np.max(pxx) From c8069fa679a22af8965d32fc504639f4d2e4a2a9 Mon Sep 17 00:00:00 2001 From: lauracarlton Date: Tue, 3 Sep 2024 13:55:19 -0400 Subject: [PATCH 2/2] try implementation using np.fft.rfft --- src/cedalion/sigproc/quality.py | 41 ++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/cedalion/sigproc/quality.py b/src/cedalion/sigproc/quality.py index 5639565..9d958c4 100644 --- a/src/cedalion/sigproc/quality.py +++ b/src/cedalion/sigproc/quality.py @@ -98,11 +98,11 @@ def psp( # the first sample in the window. Setting the stride size to the same value as the # window length will result in non-overlapping windows. windows = amp.rolling(time=nsamples).construct("window", stride=nsamples) - + windows = windows.fillna(1e-6) fs = amp.cd.sampling_rate psp = np.zeros([len(windows["channel"]), len(windows["time"])]) - + # Vectorized signal extraction and correlation sig = windows.transpose("channel", "time", "wavelength", "window").values psp = np.zeros((sig.shape[0], sig.shape[1])) @@ -120,26 +120,29 @@ def psp( # FIXME assumes 2 wavelengths corr = corr /(nsamples - np.abs(lags)) - # norm_factor = [ - # np.sqrt(np.sum(sig_temp[ch, 0, :] ** 2) * np.sum(sig_temp[ch, 1, :] ** 2)) - # for ch in range(sig.shape[0]) - # ] - - # corr /= np.tile(norm_factor, (corr.shape[1],1)).T - - for ch in range(sig.shape[0]): - window = signal.windows.hamming(len(corr[ch,:])) - f, pxx = signal.welch( - corr[ch, :], - window=window, - nfft=len(corr[ch, :]), - fs=fs, - scaling="density", - ) - psp[ch, w] = np.max(pxx) + nperseg = corr.shape[1] + window = np.hamming(nperseg) + window_seg = corr * window + + fft_out = np.fft.rfft(window_seg, axis=1) + psd = (np.abs(fft_out) ** 2) / (fs * np.sum(window ** 2)) + freqs = np.fft.rfftfreq(nperseg, 1/fs) + + # for ch in range(sig.shape[0]): + # window = signal.windows.hamming(len(corr[ch,:])) + # f, pxx = signal.welch( + # corr[ch, :], + # window=window, + # nfft=len(corr[ch, :]), + # fs=fs, + # scaling="density", + # ) + + psp[:, w] = np.max(psd, 1) # keep dims channel and time + psp_xr = windows.isel(wavelength=0, window=0).drop_vars("wavelength").copy(data=psp) # Apply threshold mask