Skip to content

Commit

Permalink
Cfs refactor (#512)
Browse files Browse the repository at this point in the history
* Move CFS functionality to subclass of AveragingPeriod
  • Loading branch information
RemingtonRohel authored Nov 13, 2024
1 parent a44e417 commit e288a7e
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 230 deletions.
6 changes: 3 additions & 3 deletions src/data_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def output_data(
log.info(
"wrote record",
write_time=write_time * 1e3,
time_units="ms",
time_unit="ms",
dataset_name=self.timestamp,
)

Expand Down Expand Up @@ -670,7 +670,7 @@ def main():
expected_type=int,
)
log.debug(
"Received CFS sequence, increasing expected_sqn_num",
"received CFS sequence, increasing expected_sqn_num",
cfs_sqn_num=cfs_sqn_num,
)
cfs_nums.append(cfs_sqn_num)
Expand Down Expand Up @@ -750,7 +750,7 @@ def main():
log.info(
f"parsed sequence {pd.sequence_num}",
parse_time=parse_time * 1e3,
time_units="ms",
time_unit="ms",
slice_ids=[dset.slice_id for dset in pd.output_datasets],
)

Expand Down
7 changes: 2 additions & 5 deletions src/experiment_prototype/experiment_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,15 +624,12 @@ def check_rx_antenna_pattern(cls, rx_antenna_pattern, values):
rx_antenna_pattern(
values["beam_angle"],
values["freq"],
options.main_antenna_count,
options.main_antenna_spacing,
options.main_antenna_locations,
),
rx_antenna_pattern(
values["beam_angle"],
values["freq"],
options.intf_antenna_count,
options.intf_antenna_spacing,
offset=-100,
options.intf_antenna_locations,
),
]
for index in range(0, len(antenna_pattern)):
Expand Down
214 changes: 149 additions & 65 deletions src/experiment_prototype/interface_classes/averaging_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
:author: Marci Detwiller
"""

import datetime

# built-in
import inspect
from pathlib import Path
Expand Down Expand Up @@ -50,7 +52,7 @@ class AveragingPeriod(InterfaceClassBase):
"""
Set up the AveragingPeriods.
An averagingperiod contains sequences and integrates one or multiple pulse sequences together in
An averaging period contains sequences and integrates one or multiple pulse sequences together in
a given time frame or in a given number of averages, if that is the preferred limiter.
**The unique members of the averagingperiod are (not a member of the interfaceclassbase):**
Expand All @@ -61,10 +63,6 @@ class AveragingPeriod(InterfaceClassBase):
slice_to_beamdir
passed in by the scan that this AveragingPeriod instance is contained in. A dictionary of
slice: beamdir(s) for all slices contained in this aveperiod.
cfs_flag
Boolean, True if clrfrqsearch should be performed.
cfs_range
The range of frequency to search if cfs_flag is True. Otherwise empty.
intt
The priority limitation. The time limit (ms) at which time the aveperiod will end. If None,
we will use intn to end the aveperiod (a number of sequences).
Expand Down Expand Up @@ -98,29 +96,6 @@ def __init__(
self.slice_to_beamorder = slice_to_beamorder_dict
self.slice_to_beamdir = slice_to_beamdir_dict

# Metadata for an AveragingPeriod: clear frequency search, integration time, number of averages goal
self.cfs_flag = False
self.cfs_always_run = False
self.cfs_sequence = None
self.cfs_slice_ids = []
self.cfs_scan_order = []
self.cfs_stable_time = 0
self.cfs_pwr_threshold = 0
self.cfs_fft_n = 0
# there may be multiple slices in this averaging period at different frequencies so
# we may have to search multiple ranges.
self.cfs_range = []
for slice_id in self.slice_ids:
if self.slice_dict[slice_id].cfs_flag:
self.cfs_stable_time = self.slice_dict[slice_id].cfs_stable_time
self.cfs_pwr_threshold = self.slice_dict[slice_id].cfs_pwr_threshold
self.cfs_fft_n = self.slice_dict[slice_id].cfs_fft_n
self.cfs_flag = True
self.cfs_slice_ids.append(slice_id)
self.cfs_range.append(self.slice_dict[slice_id].cfs_range)
if self.slice_dict[slice_id].cfs_always_run:
self.cfs_always_run = True

self.intt = self.slice_dict[self.slice_ids[0]].intt
self.intn = self.slice_dict[self.slice_ids[0]].intn
self.txctrfreq = self.slice_dict[self.slice_ids[0]].txctrfreq
Expand Down Expand Up @@ -166,43 +141,16 @@ def __init__(
" interfaced and do not have the same rxctrfreq"
)
raise ExperimentException(errmsg)
for slice_id in self.cfs_slice_ids:
if self.slice_dict[slice_id].cfs_pwr_threshold != self.cfs_pwr_threshold:
errmsg = (
f"Slices {self.slice_ids[0]} and {slice_id} are SEQUENCE or CONCURRENT"
" interfaced and do not have the same cfs_power_threshold"
)
raise ExperimentException(errmsg)
if self.slice_dict[slice_id].cfs_fft_n != self.cfs_fft_n:
errmsg = (
f"Slices {self.slice_ids[0]} and {slice_id} are SEQUENCE or CONCURRENT"
" interfaced and do not have the same cfs_fft_n"
)
raise ExperimentException(errmsg)
if self.slice_dict[slice_id].cfs_stable_time != self.cfs_stable_time:
errmsg = (
f"Slices {self.slice_ids[0]} and {slice_id} are SEQUENCE or CONCURRENT"
" interfaced and do not have the same cfs_stable_time"
)
raise ExperimentException(errmsg)

self.num_beams_in_scan = len(self.slice_dict[self.slice_ids[0]].rx_beam_order)

# NOTE: Do not need beam information inside the AveragingPeriod, this is in Scan.

if self.cfs_flag:
self.build_cfs_sequence()

# Determine how this averaging period is made by separating out the SEQUENCE interfaced.
self.nested_slice_list = self.get_nested_slice_ids()
self.sequences = []

self.cfs_sequences = []
for params in self.prep_for_nested_interface_class():
new_sequence = Sequence(*params)
if new_sequence.cfs_flag:
self.cfs_sequences.append(new_sequence)
self.sequences.append(new_sequence)
self.sequences.append(Sequence(*params))

self.one_pulse_only = False

Expand Down Expand Up @@ -242,7 +190,111 @@ def set_beamdirdict(self, beamiter):

return slice_to_beamdir_dict

def select_cfs_freqs(self, cfs_spectrum):

class CFSAveragingPeriod(AveragingPeriod):
"""
A variation of AveragingPeriod that conducts a clear frequency search to determine the frequency to use
for some or all of the slices within the averaging period.
"""

def __init__(
self,
ave_keys,
ave_slice_dict,
ave_interface,
transmit_metadata,
slice_to_beamorder_dict,
slice_to_beamdir_dict,
):
super().__init__(
ave_keys,
ave_slice_dict,
ave_interface,
transmit_metadata,
slice_to_beamorder_dict,
slice_to_beamdir_dict,
)

# Metadata for an AveragingPeriod: clear frequency search, integration time, number of averages goal
self.cfs_always_run = False
self.cfs_sequence = None
self.cfs_slice_ids = []
self.cfs_scan_order = []
self.cfs_stable_time = 0
self.cfs_pwr_threshold = 0
self.cfs_fft_n = 0

# {slice_id : np.ndarray of shape [num_freqs]}
self.cfs_freq = dict()

# {slice_id : [np.ndarray of shape [num_freqs]] * num_beams}
self.cfs_mags = dict()

# {slice_id : list of [lower_freq, upper_freq]}
self.cfs_range = dict()

# {slice_id : [np.ndarray of shape [num_freqs]] * num_beams}
self.cfs_masks = dict()

# [datetime] * num_beams
self.last_cfs_set_time = list()

# {slice_id : [float] * num_beams}
self.beam_frequency = dict()

# {slice_id : [bool] * num_beams}
self.set_new_freq = dict()

# there may be multiple slices in this averaging period at different frequencies so
# we may have to search multiple ranges.
for slice_id in self.slice_ids:
if self.slice_dict[slice_id].cfs_flag:
self.cfs_stable_time = self.slice_dict[slice_id].cfs_stable_time
self.cfs_pwr_threshold = self.slice_dict[slice_id].cfs_pwr_threshold
self.cfs_fft_n = self.slice_dict[slice_id].cfs_fft_n
self.cfs_flag = True
self.cfs_slice_ids.append(slice_id)
self.cfs_freq[slice_id] = None
self.cfs_mags[slice_id] = [None] * self.num_beams_in_scan
self.cfs_masks[slice_id] = [None] * self.num_beams_in_scan
self.beam_frequency[slice_id] = [None] * self.num_beams_in_scan
self.cfs_range[slice_id] = self.slice_dict[slice_id].cfs_range
self.set_new_freq[slice_id] = [True] * self.num_beams_in_scan
if self.slice_dict[slice_id].cfs_always_run:
self.cfs_always_run = True

for slice_id in self.cfs_slice_ids:
if self.slice_dict[slice_id].cfs_pwr_threshold != self.cfs_pwr_threshold:
errmsg = (
f"Slices {self.slice_ids[0]} and {slice_id} are SEQUENCE or CONCURRENT"
" interfaced and do not have the same cfs_power_threshold"
)
raise ExperimentException(errmsg)
if self.slice_dict[slice_id].cfs_fft_n != self.cfs_fft_n:
errmsg = (
f"Slices {self.slice_ids[0]} and {slice_id} are SEQUENCE or CONCURRENT"
" interfaced and do not have the same cfs_fft_n"
)
raise ExperimentException(errmsg)
if self.slice_dict[slice_id].cfs_stable_time != self.cfs_stable_time:
errmsg = (
f"Slices {self.slice_ids[0]} and {slice_id} are SEQUENCE or CONCURRENT"
" interfaced and do not have the same cfs_stable_time"
)
raise ExperimentException(errmsg)

# Set to a time in the past that is guaranteed to trigger a clear frequency search on the
# first averaging period run
self.last_cfs_set_time = [
datetime.datetime.utcnow()
- datetime.timedelta(seconds=self.cfs_stable_time)
] * len(self.slice_dict[self.slice_ids[0]].rx_beam_order)

self.build_cfs_sequence()

self.cfs_sequences = [sqn for sqn in self.sequences if sqn.cfs_flag]

def select_cfs_freqs(self, cfs_packet):
"""
Accepts the analysis results of the clear frequency search and uses the passed frequencies and powers
to determine what frequencies to set each clear frequency search slice to.
Expand All @@ -263,11 +315,11 @@ def select_cfs_freqs(self, cfs_spectrum):
* Builds each CFS sequence
* Return the frequency masks
:param cfs_spectrum: Analyzed CFS sequence data
:type cfs_spectrum: ProcessedSequenceMessage
:param cfs_packet: Analyzed CFS sequence data
:type cfs_packet: ProcessedSequenceMessage
"""
cfs_freq_hz = np.array(cfs_spectrum.cfs_freq) # at baseband
cfs_data = [dset.cfs_data for dset in cfs_spectrum.output_datasets]
cfs_freq_hz = np.array(cfs_packet.cfs_freq) # at baseband
cfs_data = [dset.cfs_data for dset in cfs_packet.output_datasets]
# Sort measured frequencies based on measured power at each freq
slice_masks = dict()
slice_used_freqs = dict()
Expand Down Expand Up @@ -356,6 +408,7 @@ def select_cfs_freqs(self, cfs_spectrum):
slice_masks[slice_id] = mask
ind = np.argmin(cfs_data[i][mask])
cfs_set_freq[slice_id] = int(np.round(shifted_cfs_khz[ind]))
self.beam_frequency[slice_id][self.beam_iter] = cfs_set_freq[slice_id]

for sqn in self.cfs_sequences:
if slice_id in sqn.slice_ids:
Expand All @@ -370,14 +423,12 @@ def select_cfs_freqs(self, cfs_spectrum):
)
# Set cfs slice frequency and add frequency to used_freqs for all other concurrent slices

self.update_cfs_freqs(cfs_set_freq)

return slice_masks, cfs_set_freq

def update_cfs_freqs(self, cfs_set_freq):
def update_cfs_freqs(self):
for i, slice_id in enumerate(self.cfs_slice_ids):
slice_obj = self.slice_dict[slice_id]
slice_obj.freq = cfs_set_freq[slice_id]
slice_obj.freq = self.beam_frequency[slice_id][self.beam_iter]
log.verbose(
"selecting cfs slice freq",
slice_id=slice_obj.slice_id,
Expand All @@ -386,6 +437,39 @@ def update_cfs_freqs(self, cfs_set_freq):
for sequence in self.cfs_sequences:
sequence.build_sequence_pulses()

def check_update_freq(self, cfs_spectrum, cfs_slices, threshold, beam_iter):
"""
Checks if any scanned frequencies have power levels that
exceed the current power of each cfs slice based on a threshold
:params cfs_packet: Results of the CFS analysis
:type cfs_spectrum: OutputDataset dataclass from message_formats.py
:params cfs_slices: Slice ids of each cfs slice to be checked
:type cfs_slices: list
:params threshold: Power threshold (dB) used in check
:type threshold: float
:params beam_iter: current beam index
:type beam_iter: int
"""
cfs_data = [dset.cfs_data for dset in cfs_spectrum.output_datasets]
for i, slice_id in enumerate(cfs_slices):
# Shift the current frequency down to baseband and then use the
# result to determine the index in the measured frequency
# spectrum that the current frequency is from
shifted_frequency = self.beam_frequency[slice_id][beam_iter] - int(
sum(self.cfs_range[slice_id]) / 2
)
idx = (
np.abs(np.asarray(self.cfs_freq) - shifted_frequency * 1000)
).argmin()

# calculate the ratio of the current freq power over all other freqs
pwr_ratio = cfs_data[i][idx] - np.asarray(
cfs_data[i][self.cfs_masks[slice_id][beam_iter]]
)
if any(pwr_ratio > threshold):
self.set_new_freq[slice_id][beam_iter] = True

def build_cfs_sequence(self):
"""
Builds an empty rx only pulse sequence to collect clear frequency search data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def get_nested_slice_ids(self):
"AveragingPeriod": [
"CONCURRENT"
], # Combine everything CONCURRENT interfaced
"CFSAveragingPeriod": ["CONCURRENT"], # Same as AveragingPeriod
"Sequence": [], # All slices in a Sequence are already CONCURRENT and should be combined already
}

Expand Down
10 changes: 8 additions & 2 deletions src/experiment_prototype/interface_classes/scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import structlog

# local
from experiment_prototype.interface_classes.averaging_periods import AveragingPeriod
from experiment_prototype.interface_classes.averaging_periods import (
AveragingPeriod,
CFSAveragingPeriod,
)
from experiment_prototype.interface_classes.interface_class_base import (
InterfaceClassBase,
)
Expand Down Expand Up @@ -76,7 +79,10 @@ def __init__(self, scan_keys, scan_slice_dict, scan_interface, transmit_metadata
self.nested_slice_list = self.get_nested_slice_ids()

for params in self.prep_for_nested_interface_class():
self.aveperiods.append(AveragingPeriod(*params))
if any([s.cfs_flag for s in params[1].values()]):
self.aveperiods.append(CFSAveragingPeriod(*params))
else:
self.aveperiods.append(AveragingPeriod(*params))

# determine how many beams in scan:
num_unique_aveperiods = 0
Expand Down
Loading

0 comments on commit e288a7e

Please sign in to comment.