Skip to content
This repository has been archived by the owner on Jun 22, 2024. It is now read-only.

Commit

Permalink
Recalculate input if a module is refitted.
Browse files Browse the repository at this point in the history
  • Loading branch information
benHeid committed Mar 28, 2022
1 parent 9ad44f0 commit 9554f2a
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 35 deletions.
13 changes: 7 additions & 6 deletions pywatts/core/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, input_steps: Optional[Dict[str, "BaseStep"]] = None,
self.transform_time = SummaryObjectList(self.name + " Transform Time", category=SummaryCategory.TransformTime)

def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_element: str = None,
return_all=False, minimum_data=(0, pd.Timedelta(0))):
return_all=False, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
"""
This method is responsible for providing the result of this step.
Therefore,
Expand All @@ -71,9 +71,10 @@ def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_el
return None

# Only execute the module if the step is not finished and the results are not yet calculated
if not self.finished and not (end is not None and self._current_end is not None and end <= self._current_end):
if recalculate or (not self.finished and not (
end is not None and self._current_end is not None and end <= self._current_end)):
if not self.buffer or not self._current_end or end > self._current_end:
self._compute(start, end, minimum_data)
self._compute(start, end, minimum_data, recalculate)
self._current_end = end
if not end:
self.finished = True
Expand All @@ -86,7 +87,7 @@ def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_el

return self._pack_data(start, end, buffer_element, return_all=return_all, minimum_data=minimum_data)

def _compute(self, start, end, minimum_data) -> Dict[str, xr.DataArray]:
def _compute(self, start, end, minimum_data, recalculate=False) -> Dict[str, xr.DataArray]:
pass

def further_elements(self, counter: pd.Timestamp) -> bool:
Expand Down Expand Up @@ -197,10 +198,10 @@ def load(cls, stored_step: dict, inputs, targets, module, file_manager):
:return: The restored step.
"""

def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0))):
def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
return None

def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0))):
def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
return None

def _should_stop(self, start, end) -> bool:
Expand Down
9 changes: 4 additions & 5 deletions pywatts/core/either_or_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,21 @@ def __init__(self, input_steps):
super().__init__(input_steps)
self.name = "EitherOr"

def _compute(self, start, end, minimum_data):
input_data = self._get_input(start, end, minimum_data)
def _compute(self, start, end, minimum_data, recalculate=False):
input_data = self._get_input(start, end, minimum_data, recalculate=recalculate)
return self._transform(input_data)

def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0))):
def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
inputs = []
for step in self.input_steps.values():
inp = step.get_result(start, batch, minimum_data=minimum_data)
inp = step.get_result(start, batch, minimum_data=minimum_data, recalculate=recalculate)
inputs.append(inp)
return inputs

def _transform(self, input_step):
# Chooses the first input_step which calculation is not stopped.
for in_step in input_step:
if in_step is not None:
# This buffer is never changed in this step. Consequently, no copy is necessary..
return self._post_transform(in_step)

@classmethod
Expand Down
1 change: 0 additions & 1 deletion pywatts/core/inverse_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,4 @@ def _transform(self, input_step):
if not self.module.has_inverse_transform:
raise KindOfTransformDoesNotExistException(f"The module {self.module.name} has no inverse transform",
KindOfTransform.INVERSE_TRANSFORM)

return self._post_transform(self.module.inverse_transform(**input_step))
1 change: 0 additions & 1 deletion pywatts/core/probabilistic_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,4 @@ def _transform(self, input_step):
if not self.module.has_predict_proba:
raise KindOfTransformDoesNotExistException(f"The module {self.module.name} has no probablisitic transform",
KindOfTransform.PROBABILISTIC_TRANSFORM)

return self._post_transform(self.module.predict_proba(input_step))
2 changes: 1 addition & 1 deletion pywatts/core/result_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, input_steps, buffer_element: str):
self.buffer_element = buffer_element

def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_element: str = None,
return_all=False, minimum_data=(0, pd.Timedelta(0))):
return_all=False, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
"""
Returns the specified result of the previous step.
"""
Expand Down
34 changes: 13 additions & 21 deletions pywatts/core/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,14 @@ def load(cls, stored_step: Dict, inputs, targets, module, file_manager):

return step

def _compute(self, start, end, minimum_data):
input_data = self._get_input(start, end, minimum_data)
target = self._get_target(start, end, minimum_data)
def _compute(self, start, end, minimum_data, recalculate=False):
input_data = self._get_input(start, end, minimum_data, recalculate=recalculate)
target = self._get_target(start, end, minimum_data, recalculate=recalculate)
if self.current_run_setting.computation_mode in [ComputationMode.Default, ComputationMode.FitTransform,
ComputationMode.Train]:
# Fetch input_data and target data
if self.batch_size:
input_batch = self._get_input(end - self.batch_size, end, minimum_data)
target_batch = self._get_target(end - self.batch_size, end, minimum_data)
start_time = time.time()
self._fit(input_batch, target_batch)
self.training_time.set_kv("", time.time() - start_time)
else:
start_time = time.time()
self._fit(input_data, target)
self.training_time.set_kv("", time.time() - start_time)
start_time = time.time()
self._fit(input_data, target)
self.training_time.set_kv("", time.time() - start_time)
elif self.module is BaseEstimator:
logger.info("%s not fitted in Step %s", self.module.name, self.name)

Expand All @@ -174,7 +166,7 @@ def _compute(self, start, end, minimum_data):
result_dict[key] = res.sel(**{_get_time_indexes(res)[0]: index[(index >= start)]})
return result_dict

def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0))):
def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
min_data_module = self.module.get_min_data()
if isinstance(min_data_module, (int, np.integer)):
minimum_data = minimum_data[0] + min_data_module, minimum_data[1]
Expand All @@ -185,14 +177,14 @@ def _get_target(self, start, batch, minimum_data=(0, pd.Timedelta(0))):
for key, target in self.targets.items()
}

def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0))):
def _get_input(self, start, batch, minimum_data=(0, pd.Timedelta(0)), recalculate=False):
min_data_module = self.module.get_min_data()
if isinstance(min_data_module, (int, np.integer)):
minimum_data = minimum_data[0] + min_data_module, minimum_data[1]
else:
minimum_data = minimum_data[0], minimum_data[1] + min_data_module
return {
key: input_step.get_result(start, batch, minimum_data=minimum_data) for
key: input_step.get_result(start, batch, minimum_data=minimum_data, recalculate=False) for
key, input_step in self.input_steps.items()
}

Expand Down Expand Up @@ -237,15 +229,15 @@ def refit(self, start: pd.Timestamp, end: pd.Timestamp):
self._refit(end)
break
elif isinstance(refit_condition, Callable):
input_data = self._get_input(start, end)
target = self._get_target(start, end)
input_data = self._get_input(start, end, recalculate=False)
target = self._get_target(start, end, recalculate=False)
if refit_condition(input_data, target):
self._refit(end)
break

def _refit(self, end):
refit_input = self._get_input(end - self.retrain_batch, end)
refit_target = self._get_target(end - self.retrain_batch, end)
refit_input = self._get_input(end - self.retrain_batch, end, recalculate=True)
refit_target = self._get_target(end - self.retrain_batch, end, recalculate=True)
self.module.refit(**refit_input, **refit_target)

def get_result_step(self, item: str):
Expand Down

0 comments on commit 9554f2a

Please sign in to comment.