From 28f9e2e94c2b50101c2a2a24cc0cf8a7693f2267 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Sun, 6 Oct 2024 21:11:19 -0700 Subject: [PATCH] [ENHANCEMENT] Better latency calculation via averaging (#485) * Take impulse response replicates' mean to denoise for latency calculation * Reduce safety factor to 1 * Fix test --- nam/train/core.py | 70 +++++++++++++------------- tests/test_nam/test_train/test_core.py | 6 +-- 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/nam/train/core.py b/nam/train/core.py index 73b773e..b719a3b 100644 --- a/nam/train/core.py +++ b/nam/train/core.py @@ -333,7 +333,7 @@ class _DataInfo(BaseModel): _DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003 _DELAY_CALIBRATION_REL_THRESHOLD = 0.001 -_DELAY_CALIBRATION_SAFETY_FACTOR = 4 +_DELAY_CALIBRATION_SAFETY_FACTOR = 1 # Might be able to make this zero... def _warn_lookaheads(indices: Sequence[int]) -> str: @@ -391,7 +391,7 @@ def report_any_latency_warnings( lookahead = 1_000 lookback = 10_000 - # Calibrate the trigger: + # Calibrate the level for the trigger: y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips] background_level = np.max( np.abs( @@ -407,51 +407,53 @@ def report_any_latency_warnings( (1.0 + rel_threshold) * background_level, ) - delays = [] + y_scans = [] for blip_index, i_abs in enumerate(data_info.blip_locations[0], 1): # Relative to start of the data i_rel = i_abs - data_info.first_blips_start start_looking = i_rel - lookahead stop_looking = i_rel + lookback - y_scan = y[start_looking:stop_looking] - triggered = np.where(np.abs(y_scan) > trigger_threshold)[0] - if len(triggered) == 0: - msg = ( - f"No response activated the trigger in response to blip " - f"{blip_index}. Is something wrong with the reamp?" - ) - print(msg) - print("SHARE THIS PLOT IF YOU ASK FOR HELP") - plt.figure() - plt.plot(np.arange(-lookahead, lookback), y_scan, label="Signal") - plt.axvline(x=0, color="C1", linestyle="--", label="Trigger") - plt.axhline( - y=-trigger_threshold, color="k", linestyle="--", label="Threshold" - ) - plt.axhline(y=trigger_threshold, color="k", linestyle="--") - plt.xlim((-lookahead, lookback)) - plt.xlabel("Samples") - plt.ylabel("Response") - plt.legend() - plt.show() - raise RuntimeError(msg) - else: - j = triggered[0] - delays.append(j + start_looking - i_rel) + y_scans.append(y[start_looking:stop_looking]) + y_scan_average = np.mean(np.stack(y_scans), axis=0) + triggered = np.where(np.abs(y_scan_average) > trigger_threshold)[0] + if len(triggered) == 0: + msg = ( + "No response activated the trigger in response to input spikes. " + "Is something wrong with the reamp?" + ) + print(msg) + print("SHARE THIS PLOT IF YOU ASK FOR HELP") + plt.figure() + plt.plot(np.arange(-lookahead, lookback), y_scan_average, color="C0", label="Signal average") + for y_scan in y_scans: + plt.plot(np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2) + plt.axvline(x=0, color="C1", linestyle="--", label="Trigger") + plt.axhline( + y=-trigger_threshold, color="k", linestyle="--", label="Threshold" + ) + plt.axhline(y=trigger_threshold, color="k", linestyle="--") + plt.xlim((-lookahead, lookback)) + plt.xlabel("Samples") + plt.ylabel("Response") + plt.legend() + plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP") + plt.show() + raise RuntimeError(msg) + else: + j = triggered[0] + delay = j + start_looking - i_rel - print("Delays:") - for i_rel, d in enumerate(delays, 1): - print(f" Blip {i_rel:2d}: {d}") - warnings = report_any_latency_warnings(delays) + print(f"Delay based on average is {delay}") + warnings = report_any_latency_warnings([delay]) - delay_post_safety_factor = int(np.min(delays)) - safety_factor + delay_post_safety_factor = delay - safety_factor print( f"After aplying safety factor of {safety_factor}, the final delay is " f"{delay_post_safety_factor}" ) return metadata.LatencyCalibration( algorithm_version=1, - delays=delays, + delays=[delay], safety_factor=safety_factor, recommended=delay_post_safety_factor, warnings=warnings, diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py index 21ac709..9544312 100644 --- a/tests/test_nam/test_train/test_core.py +++ b/tests/test_nam/test_train/test_core.py @@ -161,9 +161,9 @@ def __exit__(self, *args): with Capturing() as output: self._calibrate_delay(y) # `[0]` -- Only look in the first set of blip locations - expected_warning = core._warn_lookaheads( - list(range(1, len(self._data_info.blip_locations[0]) + 1)) - ) + # With #485, we average them all together so there's only one index. + # TODO clean this up. + expected_warning = core._warn_lookaheads([1]) # "Blip 1" assert any(o == expected_warning for o in output), output