Skip to content
This repository has been archived by the owner on Jun 22, 2024. It is now read-only.

Commit

Permalink
Replace recalculate with a two buffer solution
Browse files Browse the repository at this point in the history
  • Loading branch information
benHeid committed May 16, 2022
1 parent 4972147 commit b65b387
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 69 deletions.
104 changes: 74 additions & 30 deletions pywatts/core/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,32 @@ def __init__(self, input_steps: Optional[Dict[str, "BaseStep"]] = None,
self.id = -1
self.finished = False
self.last = True
self.refitted = True
self._current_end = None
self.buffer: Dict[str, xr.DataArray] = {}
self.current_buffer: Dict[str, xr.DataArray] = {}
self.result_buffer: Dict[str, xr.DataArray] = {}
self.training_time = SummaryObjectList(self.name + " Training Time", category=SummaryCategory.FitTime)
self.transform_time = SummaryObjectList(self.name + " Transform Time", category=SummaryCategory.TransformTime)


def should_recalculate(self):
if self.refitted == True:
self.renew_current_buffer()
return True
else:
for in_step in self.input_steps.values():
if in_step.should_recalculate:
self.renew_current_buffer()
return True
for in_step in self.input_steps.values():
if in_step.should_recalculate:
self.renew_current_buffer()
return True
return False


def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_element: str = None,
return_all=False, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
return_all=False, minimum_data=(0, pd.Timedelta(0)), use_result_buffer=False):
"""
This method is responsible for providing the result of this step.
Therefore,
Expand All @@ -71,10 +90,10 @@ def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_el
return None

# Only execute the module if the step is not finished and the results are not yet calculated
if recalculate or (not self.finished and not (
end is not None and self._current_end is not None and end <= self._current_end)):
if not self.buffer or not self._current_end or end > self._current_end:
self._compute(start, end, minimum_data, recalculate)
if not self.finished and not (
end is not None and self._current_end is not None and end <= self._current_end):
if not self.current_buffer or not self._current_end or end > self._current_end:
self._compute(start, end, minimum_data)
self._current_end = end
if not end:
self.finished = True
Expand All @@ -83,9 +102,11 @@ def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_el

# Only call callbacks if the step is finished
if self.finished:
self.update_buffer(self.result_buffer, self.current_buffer)
self._callbacks()

return self._pack_data(start, end, buffer_element, return_all=return_all, minimum_data=minimum_data)
return self._pack_data(start, end, buffer_element, return_all=return_all, minimum_data=minimum_data,
use_result_buffer=use_result_buffer)

def _compute(self, start, end, minimum_data, recalculate=False) -> Dict[str, xr.DataArray]:
pass
Expand All @@ -99,8 +120,10 @@ def further_elements(self, counter: pd.Timestamp) -> bool:
:return: True if there exist further data
:rtype: bool
"""
if not self.buffer or all(
[counter < b.indexes[_get_time_indexes(self.buffer)[0]][-1] for b in self.buffer.values()]):
if not self.current_buffer or all(
[counter < b.indexes[
_get_time_indexes(self.current_buffer)[0]][-1]
for b in self.current_buffer.values()]):
return True
for input_step in self.input_steps.values():
if not input_step.further_elements(counter):
Expand All @@ -110,11 +133,17 @@ def further_elements(self, counter: pd.Timestamp) -> bool:
return False
return True

def _pack_data(self, start, end, buffer_element=None, return_all=False, minimum_data=(0, pd.Timedelta(0))):
def _pack_data(self, start, end, buffer_element=None, return_all=False, minimum_data=(0, pd.Timedelta(0)),
use_result_buffer=False):
# Provide requested data
time_index = _get_time_indexes(self.buffer)
if use_result_buffer:
buffer = self.result_buffer
self.update_buffer(self.result_buffer, self.current_buffer)
else:
buffer = self.current_buffer
time_index = _get_time_indexes(buffer)
if start:
index = list(self.buffer.values())[0].indexes[time_index[0]]
index = list(buffer.values())[0].indexes[time_index[0]]
if len(index) > 1:
freq = index[1] - index[0]
else:
Expand All @@ -124,22 +153,20 @@ def _pack_data(self, start, end, buffer_element=None, return_all=False, minimum_
end = end.to_numpy() if end is not None else (index[-1] + pd.Timedelta(nanoseconds=1)).to_numpy()
# After sel copy is not needed, since it returns a new array.
if buffer_element is not None:
return self.buffer[buffer_element].sel(
**{time_index[0]: index[(index >= start) & (index < end)]})
return buffer[buffer_element].sel(**{time_index[0]: index[(index >= start) & (index < end)]})
elif return_all:
return {key: b.sel(**{time_index[0]: index[(index >= start) & (index < end)]}) for
key, b in self.buffer.items()}
key, b in buffer.items()}
else:
return list(self.buffer.values())[0].sel(
**{time_index[0]: index[(index >= start) & (index < end)]})
return list(buffer.values())[0].sel(**{time_index[0]: index[(index >= start) & (index < end)]})
else:
self.finished = True
if buffer_element is not None:
return self.buffer[buffer_element].copy()
return buffer[buffer_element].copy()
elif return_all:
return copy.deepcopy(self.buffer)
return copy.deepcopy(buffer)
else:
return list(self.buffer.values())[0].copy()
return list(buffer.values())[0].copy()

def _transform(self, input_step):
pass
Expand All @@ -154,14 +181,10 @@ def _post_transform(self, result):
if not isinstance(result, dict):
result = {self.name: result}

if not self.buffer:
self.buffer = result
if not self.current_buffer:
self.current_buffer = result
else:
# Time dimension is mandatory, consequently there dim has to exist
dim = _get_time_indexes(result)[0]
for key in self.buffer.keys():
last = self.buffer[key][dim].values[-1]
self.buffer[key] = xr.concat([self.buffer[key], result[key][result[key][dim] > last]], dim=dim)
self.update_buffer(self.current_buffer, result)
return result

def get_json(self, fm: FileManager) -> Dict:
Expand Down Expand Up @@ -198,10 +221,10 @@ def load(cls, stored_step: dict, inputs, targets, module, file_manager):
:return: The restored step.
"""

def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), use_result_buffer=False):
return None

def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0))):
return None

def _should_stop(self, start, end) -> bool:
Expand All @@ -223,7 +246,8 @@ def reset(self, keep_buffer=False):
:param keep_buffer: Flag indicating if the buffer should be resetted too.
"""
if not keep_buffer:
self.buffer = {}
self.current_buffer = {}
self.result_buffer = {}
self.finished = False
self.current_run_setting = self.default_run_setting.clone()

Expand All @@ -237,3 +261,23 @@ def set_run_setting(self, run_setting: RunSetting):
:type computation_mode: ComputationMode
"""
self.current_run_setting = self.default_run_setting.update(run_setting)

def renew_current_buffer(self):
"""
Fill data from current buffer in to the result buffer reset the current buffer.
"""
self.update_buffer(self.result_buffer, self.current_buffer)
self.current_buffer = {}

def update_buffer(self,buffer, new_data):
"""
TODO
"""
if not buffer:
for key in new_data:
buffer[key] = new_data[key]
time_index = _get_time_indexes(buffer)[0]
index = list(buffer.values())[0].indexes[time_index]
last = index[-1]
for key in buffer.keys():
buffer[key] = xr.concat([buffer[key],new_data[key][new_data[key][time_index] > last]], dim=time_index)
32 changes: 18 additions & 14 deletions pywatts/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,23 @@ def transform(self, **x: xr.DataArray) -> xr.DataArray:
else:
return self._comp(x, False, self.current_run_setting.summary_formatter, self.batch)

def _transform(self, x, batch=None):
def _transform(self, x, batch=None, use_result_buffer=False):
for step in self.id_to_step.values():
step.finished = False
for key, (start_step, _) in self.start_steps.items():
if not start_step.buffer:
start_step.buffer = {key: x[key].copy()}
if not start_step.current_buffer:
start_step.current_buffer = {key: x[key].copy()}
else:
dim = _get_time_indexes(start_step.buffer[key])[0]
last = start_step.buffer[key][dim].values[-1]
start_step.buffer[key] = xr.concat([start_step.buffer[key], x[key][x[key][dim] > last]], dim=dim)
dim = _get_time_indexes(start_step.current_buffer[key])[0]
last = start_step.current_buffer[key][dim].values[-1]
start_step.current_buffer[key] = xr.concat([start_step.current_buffer[key], x[key][x[key][dim] > last]],
dim=dim)
start_step.finished = True
time_index = _get_time_indexes(x)
self.counter = list(x.values())[0].indexes[time_index[0]][0] # The start date of the input time series.
last_steps = list(filter(lambda x: x.last, self.id_to_step.values()))
if not batch:
return self._collect_results(last_steps)
return self._collect_results(last_steps, use_result_buffer=use_result_buffer)
return self._collect_batches(last_steps)

def _collect_batches(self, last_steps):
Expand All @@ -122,13 +123,13 @@ def _collect_batches(self, last_steps):
self.counter += self.batch
return result

def _collect_results(self, inputs, use_batch=False):
def _collect_results(self, inputs, use_batch=False, use_result_buffer=False):
# Note the return value is None if none of the inputs provide a result for this step...
end = None if not use_batch else self.counter + self.batch
result = dict()
for i, step in enumerate(inputs):
if not isinstance(step, SummaryStep):
res = step.get_result(self.counter, end, return_all=True)
res = step.get_result(self.counter, end, return_all=True, use_result_buffer=use_result_buffer)
for key, value in res.items():
result = self._add_to_result(i, key, value, result)
return result
Expand Down Expand Up @@ -246,14 +247,17 @@ def _run(self, data: Union[pd.DataFrame, xr.Dataset], mode: ComputationMode, sum
for step in self.id_to_step.values():
step.reset(keep_buffer=True)
step.set_run_setting(self.current_run_setting.clone())
return self._comp({key: data[key].sel(**{index_name: data[key][index_name] >= self.current_run_setting.online_start}) for key in data},
self.current_run_setting.return_summary, summary_formatter, self.batch, start=self.current_run_setting.online_start)
online_data = {key: data[key].sel(
**{index_name: data[key][index_name] >= self.current_run_setting.online_start}) for key in data}
return self._comp(online_data, self.current_run_setting.return_summary, summary_formatter, self.batch,
start=self.current_run_setting.online_start, use_result_buffer=True)
else:
return self._comp(data, self.current_run_setting.return_summary, summary_formatter, self.batch)
return self._comp(data, self.current_run_setting.return_summary, summary_formatter, self.batch,
use_result_buffer=True)


def _comp(self, data, return_summary, summary_formatter, batch, start=None):
result = self._transform(data, batch)
def _comp(self, data, return_summary, summary_formatter, batch, start=None, use_result_buffer=False):
result = self._transform(data, batch, use_result_buffer=use_result_buffer)
summary = self._create_summary(summary_formatter, start)
return (result, summary) if return_summary else result

Expand Down
12 changes: 8 additions & 4 deletions pywatts/core/result_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@ def __init__(self, input_steps, buffer_element: str):
self.buffer_element = buffer_element

def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_element: str = None,
return_all=False, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
return_all=False, minimum_data=(0, pd.Timedelta(0)), use_result_buffer=False):
"""
Returns the specified result of the previous step.
"""

# There exist only one input step in a result buffer
input_step = list(self.input_steps.values())[0]
if not return_all:
return list(self.input_steps.values())[0].get_result(start, end, self.buffer_element, minimum_data=minimum_data)
return input_step.get_result(start, end, self.buffer_element, minimum_data=minimum_data,
use_result_buffer=use_result_buffer)
else:
return {self.buffer_element: list(self.input_steps.values())[0].get_result(start, end, self.buffer_element, minimum_data=minimum_data)}
return {self.buffer_element: input_step.get_result(start, end, self.buffer_element,
minimum_data=minimum_data,
use_result_buffer=use_result_buffer)}

def get_json(self, fm: FileManager) -> Dict:
"""
Expand Down
10 changes: 8 additions & 2 deletions pywatts/core/start_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def further_elements(self, counter):
:return: True if there exist further data
:rtype: bool
"""
indexes = _get_time_indexes(self.buffer)
indexes = _get_time_indexes(self.current_buffer)
if len(indexes) == 0 or not all(
[counter < b.indexes[_get_time_indexes(self.buffer)[0]][-1] for b in self.buffer.values()]):
[counter < b.indexes[_get_time_indexes(self.current_buffer)[0]][-1] for b in self.current_buffer.values()]):
return False
else:
return True
Expand All @@ -56,3 +56,9 @@ def get_json(self, fm: FileManager) -> Dict:
json = super().get_json(fm)
json["index"] = self.index
return json

def should_recalculate(self):
return False

def renew_current_buffer(self):
return
Loading

0 comments on commit b65b387

Please sign in to comment.