Skip to content

Commit

Permalink
Merge pull request #751 from picarro-yren/yure/pre-compensation
Browse files Browse the repository at this point in the history
Fix bugs in broadbean element and sequence
  • Loading branch information
jenshnielsen authored Dec 19, 2023
2 parents 1c1116b + fe7f74f commit a6b45e6
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 9 deletions.
16 changes: 8 additions & 8 deletions broadbean/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,14 @@ def getArrays(self, includetime: bool = False) -> dict[int, dict[str, np.ndarray

outdict = {}
for channel, signal in self._data.items():
if 'array' in signal.keys():
outdict[channel] = signal['array']
if includetime and 'time' not in signal['array'].keys():
N = len(signal['array']['wfm'])
dur = N/signal['SR']
outdict[channel]['array']['time'] = np.linspace(0, dur, N)
elif 'blueprint' in signal.keys():
bp = signal['blueprint']
if "array" in signal.keys():
outdict[channel] = signal["array"]
if includetime and "time" not in signal["array"].keys():
N = len(signal["array"]["wfm"])
dur = N / signal["SR"]
outdict[channel]["time"] = np.linspace(0, dur, N)
elif "blueprint" in signal.keys():
bp = signal["blueprint"]
durs = bp.durations
SR = bp.SR
forged_bp = _subelementBuilder(bp, SR, durs)
Expand Down
2 changes: 1 addition & 1 deletion broadbean/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def _prepareForOutputting(self) -> list[dict[int, np.ndarray]]:
element.addFlags(chan, flags)

else:
arrays = element[chan]['array']
arrays = element._data[chan]["array"]
for name, arr in arrays.items():
pre_wait = np.zeros(int(delay/self.SR))
post_wait = np.zeros(int((maxdelay-delay)/self.SR))
Expand Down
5 changes: 5 additions & 0 deletions broadbean/tests/test_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def test_addArray():

elem = Element()
elem.addArray(1, wfm, SR, m1=m1, m2=m2)
output = elem.getArrays(includetime=True)
assert np.all(output[1]["m1"] == m1)
assert np.all(output[1]["wfm"] == wfm)
assert np.all(output[1]["time"] == np.linspace(0, N / SR, N))

elem.addArray('2', wfm, SR, m1=m1)
elem.addArray('readout_channel', wfm, SR, m2=m2)

Expand Down
51 changes: 51 additions & 0 deletions broadbean/tests/test_ripasso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Test suite for the Ripasso module of the broadband package

import numpy as np
import pytest

import broadbean as bb
from broadbean.ripasso import applyInverseRCFilter, applyRCFilter

NUM_POINTS = 2500


@pytest.fixture
def squarewave():
periods = 5
periods = int(periods)
array = np.zeros(NUM_POINTS)

for n in range(periods):
array[
int(n * NUM_POINTS / periods) : int((2 * n + 1) * NUM_POINTS / 2 / periods)
] = 1

return array


def test_rc_filter(squarewave):
# Test RC filter and pre-compensation of filter
# Check that after filtering and pre-compensation, processed signal differs from the original one by a constant value
SR = int(10e3)
for filter_type in ["HP", "LP"]:
signal1_filtered = applyRCFilter(squarewave, SR, filter_type, f_cut=12, order=1)
signal1_filtered2 = applyInverseRCFilter(
signal1_filtered, SR, filter_type, f_cut=12, order=1
)
difference = np.abs(np.diff(squarewave - signal1_filtered2))
assert np.all(np.isclose(difference, 0))


def test_output_seqx_file(squarewave):
SR = int(10e3)
elem1 = bb.Element()
signal1_filtered = applyInverseRCFilter(squarewave, SR, "HP", f_cut=12, order=1)
elem1.addArray(
1, signal1_filtered, SR, m1=np.zeros(NUM_POINTS), m2=np.zeros(NUM_POINTS)
)
seq1 = bb.Sequence()
seq1.addElement(1, elem1)
seq1.setSR(elem1.SR)
seq1.setChannelAmplitude(1, 2.5)
seqx_input = seq1.outputForSEQXFile()
assert np.all(seqx_input[5][0][0][0] == signal1_filtered)

0 comments on commit a6b45e6

Please sign in to comment.