From 3c41e6432707ef8701b0e21fe6d64b444a9de063 Mon Sep 17 00:00:00 2001 From: vilim Date: Sun, 6 Mar 2022 19:00:00 +0100 Subject: [PATCH] progress on estimators --- stytra/collectors/accumulators.py | 44 +++++++++++----------- stytra/experiments/tracking_experiments.py | 8 +++- stytra/stimulation/estimator_process.py | 16 +++++++- stytra/stimulation/estimators.py | 4 ++ 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/stytra/collectors/accumulators.py b/stytra/collectors/accumulators.py index 90254326..a25e3329 100644 --- a/stytra/collectors/accumulators.py +++ b/stytra/collectors/accumulators.py @@ -14,19 +14,35 @@ class Accumulator(QObject): - def __init__(self, experiment, name="", max_history_if_not_running=1000): + def __init__(self, name="", max_trimmed_len=1000, trim = False): super().__init__() self.name = name - #self.exp = experiment self.stored_data = [] self.times = [] - self.max_history_if_not_running = max_history_if_not_running + self.max_trimmed_len = max_trimmed_len + self._trim = trim # + + @property + def trim(self) -> bool: + return self._trim + + def trim_data(self): + if self.trim and len(self.times) > self.max_trimmed_len * 1.5: + self.times[: -self.max_trimmed_len] = [] + self.stored_data[: -self.max_trimmed_len] = [] + + @property + def t0(self) -> float: + raise NotImplementedError + + def is_empty(self) -> bool: + return len(self.stored_data) == 0 class DataFrameAccumulator(Accumulator): """Abstract class for accumulating streams of data. - It is use to save or plot in real time data from stimulus logs or + It is used to save or plot in real time data from stimulus logs or behavior tracking. Data is stored in a list in the stored_data attribute. @@ -134,14 +150,6 @@ def reset(self, monitored_headers=None): self._header_dict = None - def trim_data(self): - if ( - not self.exp.protocol_runner.running - and len(self.times) > self.max_history_if_not_running * 1.5 - ): - self.times[: -self.max_history_if_not_running] = [] - self.stored_data[: -self.max_history_if_not_running] = [] - def get_fps(self): """ """ try: @@ -229,9 +237,6 @@ def save(self, path, format="csv"): saved_filename = save_df(df, path, format) return basename(saved_filename) - def is_empty(self): - return len(self.stored_data) == 0 - class QueueDataAccumulator(DataFrameAccumulator): """General class for retrieving data from a Queue. @@ -248,9 +253,9 @@ class QueueDataAccumulator(DataFrameAccumulator): data_queue : NamedTupleQueue queue from witch to retrieve data. output_queue:Optional[NamedTupleQueue] - an optinal queue to forward the data to + an optional queue to forward the data to header_list : list of str - headers for the data to stored. + headers for the data to be stored. """ @@ -307,11 +312,6 @@ def __init__(self, *args, goal_framerate=None, **kwargs): super().__init__(*args, **kwargs) self.goal_framerate = goal_framerate - def trim_data(self): - if len(self.times) > self.max_history_if_not_running * 1.5: - self.times[: -self.max_history_if_not_running] = [] - self.stored_data[: -self.max_history_if_not_running] = [] - def reset(self): self.times = [] self.stored_data = [] diff --git a/stytra/experiments/tracking_experiments.py b/stytra/experiments/tracking_experiments.py index bf9649ef..16ed83ae 100644 --- a/stytra/experiments/tracking_experiments.py +++ b/stytra/experiments/tracking_experiments.py @@ -246,9 +246,15 @@ def __init__(self, *args, tracking, recording=None, second_output_queue=None, ** if est is not None: self.estimator_process = EstimatorProcess(est_type, self.tracking_output_queue, self.finished_sig) self.estimator_log = EstimatorLog(experiment=self) - self.estimator = est(self.acc_tracking, experiment=self, **tracking.get("estimator_params", {})) + self.estimator = est(self.acc_tracking, experiment=self) + first_est_params = tracking.get("estimator_params", None) + if first_est_params is not None: + self.estimator_process.estimator_parameter_queue.put(first_est_params) + self.estimator_log.sig_acc_init.connect(self.refresh_plots) tracking_output_queue = self.estimator_process.tracking_output_queue + self.protocol_runner.attach_estimator_queue(self.est) + self.estimator_process.start() else: self.estimator = None tracking_output_queue = self.tracking_output_queue diff --git a/stytra/stimulation/estimator_process.py b/stytra/stimulation/estimator_process.py index 994915d5..fe95dbba 100644 --- a/stytra/stimulation/estimator_process.py +++ b/stytra/stimulation/estimator_process.py @@ -1,4 +1,5 @@ -from multiprocessing import Event, Process +from multiprocessing import Event, Process, Queue +from queue import Empty from typing import Type from stytra.collectors import QueueDataAccumulator @@ -11,19 +12,32 @@ def __init__( self, estimator_cls: Type[Estimator], tracking_queue: NamedTupleQueue, + estimator_parameter_queue: Queue, finished_signal: Event, ): super().__init__() self.tracking_queue = tracking_queue self.tracking_output_queue = NamedTupleQueue() + self.estimator_parameter_queue = estimator_parameter_queue self.estimator_queue = NamedTupleQueue() self.tracking_accumulator = QueueDataAccumulator(self.tracking_queue, self.tracking_output_queue) self.finished_signal = finished_signal self.estimator_cls = estimator_cls + + def update_estimator_params(self, estimator): + while True: + try: + param_dict = self.estimator_parameter_queue.get(timeout=0.0001) + estimator.update_params(param_dict) + except Empty: + break + + def run(self): estimator = self.estimator_cls(self.tracking_accumulator, self.estimator_queue) while not self.finished_signal.is_set(): + self.update_estimator_params(estimator) self.tracking_accumulator.update_list() estimator.update() diff --git a/stytra/stimulation/estimators.py b/stytra/stimulation/estimators.py index ed0c6f6b..47d726b3 100644 --- a/stytra/stimulation/estimators.py +++ b/stytra/stimulation/estimators.py @@ -22,6 +22,10 @@ def __init__(self, acc_tracking: QueueDataAccumulator, output_queue: NamedTupleQ self.cam_to_proj = cam_to_proj self._output_type = None + def update_params(self, **params): + for key, value in params.items(): + setattr(self, key, value) + def update(self): raise NotImplementedError