diff --git a/pywatts/core/base_step.py b/pywatts/core/base_step.py index f14f1a01..213d25d4 100644 --- a/pywatts/core/base_step.py +++ b/pywatts/core/base_step.py @@ -42,13 +42,32 @@ def __init__(self, input_steps: Optional[Dict[str, "BaseStep"]] = None, self.id = -1 self.finished = False self.last = True + self.refitted = True self._current_end = None - self.buffer: Dict[str, xr.DataArray] = {} + self.current_buffer: Dict[str, xr.DataArray] = {} + self.result_buffer: Dict[str, xr.DataArray] = {} self.training_time = SummaryObjectList(self.name + " Training Time", category=SummaryCategory.FitTime) self.transform_time = SummaryObjectList(self.name + " Transform Time", category=SummaryCategory.TransformTime) + + def should_recalculate(self): + if self.refitted == True: + self.renew_current_buffer() + return True + else: + for in_step in self.input_steps.values(): + if in_step.should_recalculate: + self.renew_current_buffer() + return True + for in_step in self.input_steps.values(): + if in_step.should_recalculate: + self.renew_current_buffer() + return True + return False + + def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_element: str = None, - return_all=False, minimum_data=(0, pd.Timedelta(0)), recalculate=False): + return_all=False, minimum_data=(0, pd.Timedelta(0)), use_result_buffer=False): """ This method is responsible for providing the result of this step. Therefore, @@ -71,10 +90,10 @@ def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_el return None # Only execute the module if the step is not finished and the results are not yet calculated - if recalculate or (not self.finished and not ( - end is not None and self._current_end is not None and end <= self._current_end)): - if not self.buffer or not self._current_end or end > self._current_end: - self._compute(start, end, minimum_data, recalculate) + if not self.finished and not ( + end is not None and self._current_end is not None and end <= self._current_end): + if not self.current_buffer or not self._current_end or end > self._current_end: + self._compute(start, end, minimum_data) self._current_end = end if not end: self.finished = True @@ -83,9 +102,11 @@ def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_el # Only call callbacks if the step is finished if self.finished: + self.update_buffer(self.result_buffer, self.current_buffer) self._callbacks() - return self._pack_data(start, end, buffer_element, return_all=return_all, minimum_data=minimum_data) + return self._pack_data(start, end, buffer_element, return_all=return_all, minimum_data=minimum_data, + use_result_buffer=use_result_buffer) def _compute(self, start, end, minimum_data, recalculate=False) -> Dict[str, xr.DataArray]: pass @@ -99,8 +120,10 @@ def further_elements(self, counter: pd.Timestamp) -> bool: :return: True if there exist further data :rtype: bool """ - if not self.buffer or all( - [counter < b.indexes[_get_time_indexes(self.buffer)[0]][-1] for b in self.buffer.values()]): + if not self.current_buffer or all( + [counter < b.indexes[ + _get_time_indexes(self.current_buffer)[0]][-1] + for b in self.current_buffer.values()]): return True for input_step in self.input_steps.values(): if not input_step.further_elements(counter): @@ -110,11 +133,17 @@ def further_elements(self, counter: pd.Timestamp) -> bool: return False return True - def _pack_data(self, start, end, buffer_element=None, return_all=False, minimum_data=(0, pd.Timedelta(0))): + def _pack_data(self, start, end, buffer_element=None, return_all=False, minimum_data=(0, pd.Timedelta(0)), + use_result_buffer=False): # Provide requested data - time_index = _get_time_indexes(self.buffer) + if use_result_buffer: + buffer = self.result_buffer + self.update_buffer(self.result_buffer, self.current_buffer) + else: + buffer = self.current_buffer + time_index = _get_time_indexes(buffer) if start: - index = list(self.buffer.values())[0].indexes[time_index[0]] + index = list(buffer.values())[0].indexes[time_index[0]] if len(index) > 1: freq = index[1] - index[0] else: @@ -124,22 +153,20 @@ def _pack_data(self, start, end, buffer_element=None, return_all=False, minimum_ end = end.to_numpy() if end is not None else (index[-1] + pd.Timedelta(nanoseconds=1)).to_numpy() # After sel copy is not needed, since it returns a new array. if buffer_element is not None: - return self.buffer[buffer_element].sel( - **{time_index[0]: index[(index >= start) & (index < end)]}) + return buffer[buffer_element].sel(**{time_index[0]: index[(index >= start) & (index < end)]}) elif return_all: return {key: b.sel(**{time_index[0]: index[(index >= start) & (index < end)]}) for - key, b in self.buffer.items()} + key, b in buffer.items()} else: - return list(self.buffer.values())[0].sel( - **{time_index[0]: index[(index >= start) & (index < end)]}) + return list(buffer.values())[0].sel(**{time_index[0]: index[(index >= start) & (index < end)]}) else: self.finished = True if buffer_element is not None: - return self.buffer[buffer_element].copy() + return buffer[buffer_element].copy() elif return_all: - return copy.deepcopy(self.buffer) + return copy.deepcopy(buffer) else: - return list(self.buffer.values())[0].copy() + return list(buffer.values())[0].copy() def _transform(self, input_step): pass @@ -154,14 +181,10 @@ def _post_transform(self, result): if not isinstance(result, dict): result = {self.name: result} - if not self.buffer: - self.buffer = result + if not self.current_buffer: + self.current_buffer = result else: - # Time dimension is mandatory, consequently there dim has to exist - dim = _get_time_indexes(result)[0] - for key in self.buffer.keys(): - last = self.buffer[key][dim].values[-1] - self.buffer[key] = xr.concat([self.buffer[key], result[key][result[key][dim] > last]], dim=dim) + self.update_buffer(self.current_buffer, result) return result def get_json(self, fm: FileManager) -> Dict: @@ -198,10 +221,10 @@ def load(cls, stored_step: dict, inputs, targets, module, file_manager): :return: The restored step. """ - def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False): + def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), use_result_buffer=False): return None - def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False): + def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0))): return None def _should_stop(self, start, end) -> bool: @@ -223,7 +246,8 @@ def reset(self, keep_buffer=False): :param keep_buffer: Flag indicating if the buffer should be resetted too. """ if not keep_buffer: - self.buffer = {} + self.current_buffer = {} + self.result_buffer = {} self.finished = False self.current_run_setting = self.default_run_setting.clone() @@ -237,3 +261,23 @@ def set_run_setting(self, run_setting: RunSetting): :type computation_mode: ComputationMode """ self.current_run_setting = self.default_run_setting.update(run_setting) + + def renew_current_buffer(self): + """ + Fill data from current buffer in to the result buffer reset the current buffer. + """ + self.update_buffer(self.result_buffer, self.current_buffer) + self.current_buffer = {} + + def update_buffer(self,buffer, new_data): + """ + TODO + """ + if not buffer: + for key in new_data: + buffer[key] = new_data[key] + time_index = _get_time_indexes(buffer)[0] + index = list(buffer.values())[0].indexes[time_index] + last = index[-1] + for key in buffer.keys(): + buffer[key] = xr.concat([buffer[key],new_data[key][new_data[key][time_index] > last]], dim=time_index) diff --git a/pywatts/core/pipeline.py b/pywatts/core/pipeline.py index defefefa..169ee46e 100644 --- a/pywatts/core/pipeline.py +++ b/pywatts/core/pipeline.py @@ -84,22 +84,23 @@ def transform(self, **x: xr.DataArray) -> xr.DataArray: else: return self._comp(x, False, self.current_run_setting.summary_formatter, self.batch) - def _transform(self, x, batch=None): + def _transform(self, x, batch=None, use_result_buffer=False): for step in self.id_to_step.values(): step.finished = False for key, (start_step, _) in self.start_steps.items(): - if not start_step.buffer: - start_step.buffer = {key: x[key].copy()} + if not start_step.current_buffer: + start_step.current_buffer = {key: x[key].copy()} else: - dim = _get_time_indexes(start_step.buffer[key])[0] - last = start_step.buffer[key][dim].values[-1] - start_step.buffer[key] = xr.concat([start_step.buffer[key], x[key][x[key][dim] > last]], dim=dim) + dim = _get_time_indexes(start_step.current_buffer[key])[0] + last = start_step.current_buffer[key][dim].values[-1] + start_step.current_buffer[key] = xr.concat([start_step.current_buffer[key], x[key][x[key][dim] > last]], + dim=dim) start_step.finished = True time_index = _get_time_indexes(x) self.counter = list(x.values())[0].indexes[time_index[0]][0] # The start date of the input time series. last_steps = list(filter(lambda x: x.last, self.id_to_step.values())) if not batch: - return self._collect_results(last_steps) + return self._collect_results(last_steps, use_result_buffer=use_result_buffer) return self._collect_batches(last_steps) def _collect_batches(self, last_steps): @@ -122,13 +123,13 @@ def _collect_batches(self, last_steps): self.counter += self.batch return result - def _collect_results(self, inputs, use_batch=False): + def _collect_results(self, inputs, use_batch=False, use_result_buffer=False): # Note the return value is None if none of the inputs provide a result for this step... end = None if not use_batch else self.counter + self.batch result = dict() for i, step in enumerate(inputs): if not isinstance(step, SummaryStep): - res = step.get_result(self.counter, end, return_all=True) + res = step.get_result(self.counter, end, return_all=True, use_result_buffer=use_result_buffer) for key, value in res.items(): result = self._add_to_result(i, key, value, result) return result @@ -246,14 +247,17 @@ def _run(self, data: Union[pd.DataFrame, xr.Dataset], mode: ComputationMode, sum for step in self.id_to_step.values(): step.reset(keep_buffer=True) step.set_run_setting(self.current_run_setting.clone()) - return self._comp({key: data[key].sel(**{index_name: data[key][index_name] >= self.current_run_setting.online_start}) for key in data}, - self.current_run_setting.return_summary, summary_formatter, self.batch, start=self.current_run_setting.online_start) + online_data = {key: data[key].sel( + **{index_name: data[key][index_name] >= self.current_run_setting.online_start}) for key in data} + return self._comp(online_data, self.current_run_setting.return_summary, summary_formatter, self.batch, + start=self.current_run_setting.online_start, use_result_buffer=True) else: - return self._comp(data, self.current_run_setting.return_summary, summary_formatter, self.batch) + return self._comp(data, self.current_run_setting.return_summary, summary_formatter, self.batch, + use_result_buffer=True) - def _comp(self, data, return_summary, summary_formatter, batch, start=None): - result = self._transform(data, batch) + def _comp(self, data, return_summary, summary_formatter, batch, start=None, use_result_buffer=False): + result = self._transform(data, batch, use_result_buffer=use_result_buffer) summary = self._create_summary(summary_formatter, start) return (result, summary) if return_summary else result diff --git a/pywatts/core/result_step.py b/pywatts/core/result_step.py index d400ab40..89aa94a6 100644 --- a/pywatts/core/result_step.py +++ b/pywatts/core/result_step.py @@ -16,15 +16,19 @@ def __init__(self, input_steps, buffer_element: str): self.buffer_element = buffer_element def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_element: str = None, - return_all=False, minimum_data=(0, pd.Timedelta(0)), recalculate=False): + return_all=False, minimum_data=(0, pd.Timedelta(0)), use_result_buffer=False): """ Returns the specified result of the previous step. """ - + # There exist only one input step in a result buffer + input_step = list(self.input_steps.values())[0] if not return_all: - return list(self.input_steps.values())[0].get_result(start, end, self.buffer_element, minimum_data=minimum_data) + return input_step.get_result(start, end, self.buffer_element, minimum_data=minimum_data, + use_result_buffer=use_result_buffer) else: - return {self.buffer_element: list(self.input_steps.values())[0].get_result(start, end, self.buffer_element, minimum_data=minimum_data)} + return {self.buffer_element: input_step.get_result(start, end, self.buffer_element, + minimum_data=minimum_data, + use_result_buffer=use_result_buffer)} def get_json(self, fm: FileManager) -> Dict: """ diff --git a/pywatts/core/start_step.py b/pywatts/core/start_step.py index b815444c..7443b153 100644 --- a/pywatts/core/start_step.py +++ b/pywatts/core/start_step.py @@ -42,9 +42,9 @@ def further_elements(self, counter): :return: True if there exist further data :rtype: bool """ - indexes = _get_time_indexes(self.buffer) + indexes = _get_time_indexes(self.current_buffer) if len(indexes) == 0 or not all( - [counter < b.indexes[_get_time_indexes(self.buffer)[0]][-1] for b in self.buffer.values()]): + [counter < b.indexes[_get_time_indexes(self.current_buffer)[0]][-1] for b in self.current_buffer.values()]): return False else: return True @@ -56,3 +56,9 @@ def get_json(self, fm: FileManager) -> Dict: json = super().get_json(fm) json["index"] = self.index return json + + def should_recalculate(self): + return False + + def renew_current_buffer(self): + return \ No newline at end of file diff --git a/pywatts/core/step.py b/pywatts/core/step.py index 540f506c..4f679b08 100644 --- a/pywatts/core/step.py +++ b/pywatts/core/step.py @@ -84,18 +84,18 @@ def _fit(self, inputs: Dict[str, BaseStep], target_step): def _callbacks(self): # plots and writs the data if the step is finished. for callback in self.callbacks: - dim = _get_time_indexes(self.buffer)[0] + dim = _get_time_indexes(self.result_buffer)[0] if self.current_run_setting.online_start is not None: - to_plot = {k: self.buffer[k][self.buffer[k][dim] >= self.current_run_setting.online_start] for k in - self.buffer.keys()} + to_plot = {k: self.result_buffer[k][self.result_buffer[k][dim] >= self.current_run_setting.online_start] for k in + self.result_buffer.keys()} else: - to_plot = self.buffer + to_plot = self.result_buffer if isinstance(callback, BaseCallback): callback.set_filemanager(self.file_manager) - if isinstance(self.buffer, xr.DataArray) or isinstance(self.buffer, xr.Dataset): + if isinstance(self.result_buffer, xr.DataArray) or isinstance(self.result_buffer, xr.Dataset): # DEPRECATED: direct DataArray or Dataset passing is depricated - callback({"deprecated": self.buffer}) + callback({"deprecated": self.result_buffer}) else: callback(to_plot) @@ -146,8 +146,8 @@ def load(cls, stored_step: Dict, inputs, targets, module, file_manager): return step def _compute(self, start, end, minimum_data, recalculate=False): - input_data = self._get_input(start, end, minimum_data, recalculate=recalculate) - target = self._get_target(start, end, minimum_data, recalculate=recalculate) + input_data = self._get_input(start, end, minimum_data) + target = self._get_target(start, end, minimum_data) if self.current_run_setting.computation_mode in [ComputationMode.Default, ComputationMode.FitTransform, ComputationMode.Train]: start_time = time.time() @@ -166,7 +166,7 @@ def _compute(self, start, end, minimum_data, recalculate=False): result_dict[key] = res.sel(**{_get_time_indexes(res)[0]: index[(index >= start)]}) return result_dict - def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False): + def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0))): min_data_module = self.module.get_min_data() if isinstance(min_data_module, (int, np.integer)): minimum_data = minimum_data[0] + min_data_module, minimum_data[1] @@ -177,14 +177,14 @@ def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalcula for key, target in self.targets.items() } - def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False): + def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), use_result_buffer=False): min_data_module = self.module.get_min_data() if isinstance(min_data_module, (int, np.integer)): minimum_data = minimum_data[0] + min_data_module, minimum_data[1] else: minimum_data = minimum_data[0], minimum_data[1] + min_data_module return { - key: input_step.get_result(start, batch, minimum_data=minimum_data, recalculate=False) for + key: input_step.get_result(start, batch, minimum_data=minimum_data, use_result_buffer=use_result_buffer) for key, input_step in self.input_steps.items() } @@ -227,17 +227,18 @@ def refit(self, start: pd.Timestamp, end: pd.Timestamp): refit_condition.kwargs.items()} if refit_condition.evaluate(**condition_input): self._refit(end) + self.refitted = True break elif isinstance(refit_condition, Callable): - input_data = self._get_input(start, end, recalculate=False) - target = self._get_target(start, end, recalculate=False) + input_data = self._get_input(start, end) + target = self._get_target(start, end) if refit_condition(input_data, target): self._refit(end) break def _refit(self, end): - refit_input = self._get_input(end - self.retrain_batch, end, recalculate=True) - refit_target = self._get_target(end - self.retrain_batch, end, recalculate=True) + refit_input = self._get_input(end - self.retrain_batch, end) + refit_target = self._get_target(end - self.retrain_batch, end) self.module.refit(**refit_input, **refit_target) def get_result_step(self, item: str): diff --git a/tests/unit/core/test_step.py b/tests/unit/core/test_step.py index d9de8f1c..b8dac9ba 100644 --- a/tests/unit/core/test_step.py +++ b/tests/unit/core/test_step.py @@ -134,9 +134,10 @@ def test_transform_batch_with_existing_buffer(self, xr_mock, *args): coords={'time': time2}), list(xr_mock.concat.call_args_list[0])[0][0][1]) assert {'dim': 'time'} == list(xr_mock.concat.call_args_list[0])[1] - def test_get_result(self): - # Tests if the get_result method calls correctly the previous step and the module + def test_get_result_recalculate(self): + pass + def test_get_result(self): input_step = MagicMock() input_step_result_mock = MagicMock() input_step.get_result.return_value = input_step_result_mock @@ -151,9 +152,11 @@ def test_get_result(self): # Two calls, once in should_stop and once in _transform input_step.get_result.assert_has_calls( [call(pd.Timestamp("2000-01-01"), pd.Timestamp('2020-12-12 '), - minimum_data=(0, self.module_mock.get_min_data().__radd__())), + minimum_data=(0, self.module_mock.get_min_data().__radd__()), + use_result_buffer=False), call(pd.Timestamp('2000-01-01 '), pd.Timestamp('2020-12-12 '), - minimum_data=(0, self.module_mock.get_min_data().__radd__()))]) + minimum_data=(0, self.module_mock.get_min_data().__radd__()))], + use_result_buffer=False) self.module_mock.transform.assert_called_once_with(x=input_step_result_mock) @@ -306,3 +309,8 @@ def test_multiple_refit_conditions(self, isinstance_mock, get_input_mock, get_ta step.refit(pd.Timestamp("2000.01.01"), pd.Timestamp("2020.01.01")) self.module_mock.refit.assert_called_once_with(x=2, target=1) + def test_has_to_renew_buffer(self): + self.fail() + + def test_renew_buffer(self): + self.fail() \ No newline at end of file