Skip to content

Commit

Permalink
[ENHANCEMENT] Better latency calculation via averaging (#485)
Browse files Browse the repository at this point in the history
* Take impulse response replicates' mean to denoise for latency calculation

* Reduce safety factor to 1

* Fix test
  • Loading branch information
sdatkinson authored Oct 7, 2024
1 parent f0e27f4 commit 28f9e2e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 37 deletions.
70 changes: 36 additions & 34 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class _DataInfo(BaseModel):

_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003
_DELAY_CALIBRATION_REL_THRESHOLD = 0.001
_DELAY_CALIBRATION_SAFETY_FACTOR = 4
_DELAY_CALIBRATION_SAFETY_FACTOR = 1 # Might be able to make this zero...


def _warn_lookaheads(indices: Sequence[int]) -> str:
Expand Down Expand Up @@ -391,7 +391,7 @@ def report_any_latency_warnings(

lookahead = 1_000
lookback = 10_000
# Calibrate the trigger:
# Calibrate the level for the trigger:
y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips]
background_level = np.max(
np.abs(
Expand All @@ -407,51 +407,53 @@ def report_any_latency_warnings(
(1.0 + rel_threshold) * background_level,
)

delays = []
y_scans = []
for blip_index, i_abs in enumerate(data_info.blip_locations[0], 1):
# Relative to start of the data
i_rel = i_abs - data_info.first_blips_start
start_looking = i_rel - lookahead
stop_looking = i_rel + lookback
y_scan = y[start_looking:stop_looking]
triggered = np.where(np.abs(y_scan) > trigger_threshold)[0]
if len(triggered) == 0:
msg = (
f"No response activated the trigger in response to blip "
f"{blip_index}. Is something wrong with the reamp?"
)
print(msg)
print("SHARE THIS PLOT IF YOU ASK FOR HELP")
plt.figure()
plt.plot(np.arange(-lookahead, lookback), y_scan, label="Signal")
plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
plt.axhline(
y=-trigger_threshold, color="k", linestyle="--", label="Threshold"
)
plt.axhline(y=trigger_threshold, color="k", linestyle="--")
plt.xlim((-lookahead, lookback))
plt.xlabel("Samples")
plt.ylabel("Response")
plt.legend()
plt.show()
raise RuntimeError(msg)
else:
j = triggered[0]
delays.append(j + start_looking - i_rel)
y_scans.append(y[start_looking:stop_looking])
y_scan_average = np.mean(np.stack(y_scans), axis=0)
triggered = np.where(np.abs(y_scan_average) > trigger_threshold)[0]
if len(triggered) == 0:
msg = (
"No response activated the trigger in response to input spikes. "
"Is something wrong with the reamp?"
)
print(msg)
print("SHARE THIS PLOT IF YOU ASK FOR HELP")
plt.figure()
plt.plot(np.arange(-lookahead, lookback), y_scan_average, color="C0", label="Signal average")
for y_scan in y_scans:
plt.plot(np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2)
plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
plt.axhline(
y=-trigger_threshold, color="k", linestyle="--", label="Threshold"
)
plt.axhline(y=trigger_threshold, color="k", linestyle="--")
plt.xlim((-lookahead, lookback))
plt.xlabel("Samples")
plt.ylabel("Response")
plt.legend()
plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP")
plt.show()
raise RuntimeError(msg)
else:
j = triggered[0]
delay = j + start_looking - i_rel

print("Delays:")
for i_rel, d in enumerate(delays, 1):
print(f" Blip {i_rel:2d}: {d}")
warnings = report_any_latency_warnings(delays)
print(f"Delay based on average is {delay}")
warnings = report_any_latency_warnings([delay])

delay_post_safety_factor = int(np.min(delays)) - safety_factor
delay_post_safety_factor = delay - safety_factor
print(
f"After aplying safety factor of {safety_factor}, the final delay is "
f"{delay_post_safety_factor}"
)
return metadata.LatencyCalibration(
algorithm_version=1,
delays=delays,
delays=[delay],
safety_factor=safety_factor,
recommended=delay_post_safety_factor,
warnings=warnings,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_nam/test_train/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def __exit__(self, *args):
with Capturing() as output:
self._calibrate_delay(y)
# `[0]` -- Only look in the first set of blip locations
expected_warning = core._warn_lookaheads(
list(range(1, len(self._data_info.blip_locations[0]) + 1))
)
# With #485, we average them all together so there's only one index.
# TODO clean this up.
expected_warning = core._warn_lookaheads([1]) # "Blip 1"
assert any(o == expected_warning for o in output), output


Expand Down

1 comment on commit 28f9e2e

@38github
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like all these tweaks you have done lately!

Please sign in to comment.