Skip to content
This repository has been archived by the owner on Jun 22, 2024. It is now read-only.

Commit

Permalink
* rename train_if to refit_condition
Browse files Browse the repository at this point in the history
* Add Changelog entry
  • Loading branch information
benHeid committed Feb 28, 2022
1 parent 632a258 commit 93a611f
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 51 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
### Added
* Metric can be calculated on cutouts ([#149](https://github.com/KIT-IAI/pyWATTS/issues/149))
* Add a MASE Summary ([#148](https://github.com/KIT-IAI/pyWATTS/issues/148))
* Improve Online Learning Functionality ([#117](https://github.com/KIT-IAI/pyWATTS/issues/117))
* Replace train_if callable by a ConditionObject.
* Rename train_if to refit_condition.
* Separate transform from refit. I.e., the complete pipeline is transformed before any step is refitted.
* Add PeriodicCondition


### Changed
* Retraining is triggered after all steps are transformed ([#117](https://github.com/KIT-IAI/pyWATTS/issues/117))
Expand Down
1 change: 0 additions & 1 deletion pywatts/conditions/periodic_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def evaluate(self):
self.counter = self.counter % self.num_steps

if self.counter == 0:
print("Periodic Relearn") # TODO logging?
return True
else:
return False
10 changes: 5 additions & 5 deletions pywatts/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, name: str):
self.has_inverse_transform = False
self.has_predict_proba = False
# TODO each module needs a function for returning how much past values it needs for executing the transformation.

# SEE Issue 147
@abstractmethod
def get_params(self) -> Dict[str, object]:
"""
Expand Down Expand Up @@ -146,7 +146,7 @@ def __call__(self,
condition: Optional[Callable] = None,
computation_mode: ComputationMode = ComputationMode.Default,
batch_size: Optional[pd.Timedelta] = None,
train_if: Optional[Union[ConditionObject]] = None,
refit_condition: Optional[Union[ConditionObject]] = None,
lag: Optional[int] = pd.Timedelta(hours=0),
retrain_batch: Optional[int] = pd.Timedelta(hours=24),
**kwargs: Union[StepInformation, Tuple[StepInformation, ...]]
Expand All @@ -171,8 +171,8 @@ def __call__(self,
:type use_prob_transform: bool
:param callbacks: Callbacks to use after results are processed.
:type callbacks: List[BaseCallback, Callable[[Dict[str, xr.DataArray]]]]
:param train_if: A callable, which contains a condition that indicates if the module should be trained or not
:type train_if: Optional[Callable]
:param refit_condition: A callable, which contains a condition that indicates if the module should be trained or not
:type refit_condition: Optional[Callable]
:param batch_size: Determines how much data from the past should be used for training
:type batch_size: pd.Timedelta
:param computation_mode: Determines the computation mode of the step. Could be ComputationMode.Train,
Expand All @@ -195,7 +195,7 @@ def __call__(self,
condition=condition,
callbacks=callbacks,
computation_mode=computation_mode, batch_size=batch_size,
train_if=train_if,
refit_condition=refit_condition,
retrain_batch=retrain_batch,
lag=lag
)
Expand Down
4 changes: 2 additions & 2 deletions pywatts/core/base_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def __call__(self, **kwargs) -> SummaryInformation:
:rtype: SummaryInformation
"""

non_supported_kwargs = ["use_inverse_transform", "train_if", "callbacks", "condition", "computation_mode",
"batch_size"]
non_supported_kwargs = ["use_inverse_transform", "refit_condition", "callbacks", "condition",
"computation_mode", "batch_size"]

for kwa in non_supported_kwargs:
if kwa in kwargs:
Expand Down
3 changes: 2 additions & 1 deletion pywatts/core/condition_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class ConditionObject(ABC):
This module contains a function which returns either True or False. The input of this function is the output of one
or more modules.
A condition object can be passed to the train_if function of steps
:param name: The name of the condition
:type name: str
"""

def __init__(self, name):
Expand Down
47 changes: 24 additions & 23 deletions pywatts/core/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class Step(BaseStep):
:type callbacks: List[Union[BaseCallback, Callable[[Dict[str, xr.DataArray]], None]]]
:param condition: A callable which checks if the step should be executed with the current data.
:type condition: Callable[xr.DataArray, xr.DataArray, bool]
:param train_if: A callable which checks if the train_if step should be executed or not.
:type train_if: Callable[xr.DataArray, xr.DataArray, bool] #TODO fix type?
:param refit_condition: A ConditionObject which checks if the step should be refitted or not.
:type refit_condition: ConditionObject
:param lag: Needed for online learning. Determines what data can be used for retraining.
E.g., when 24 hour forecasts are performed, a lag of 24 hours is needed, else the retraining would
use future values as target values.
Expand All @@ -57,7 +57,7 @@ def __init__(self, module: Base, input_steps: Dict[str, BaseStep], file_manager,
callbacks: List[Union[BaseCallback, Callable[[Dict[str, xr.DataArray]], None]]] = [],
condition=None,
batch_size: Optional[None] = None,
train_if=None,
refit_condition=None,
retrain_batch=pd.Timedelta(hours=24),
lag=pd.Timedelta(hours=24)):
super().__init__(input_steps, targets, condition=condition,
Expand All @@ -67,12 +67,13 @@ def __init__(self, module: Base, input_steps: Dict[str, BaseStep], file_manager,
self.retrain_batch = retrain_batch
self.callbacks = callbacks
self.batch_size = batch_size
if self.computation_mode not in [ComputationMode.Refit] and train_if is not None:
message = "You added a train if condition without setting the computation_mode to refit. So train_if will be ignored."
if self.computation_mode not in [ComputationMode.Refit] and refit_condition is not None:
message = "You added a refit_condition without setting the computation_mode to refit." \
" The conditoin will be ignored."
warnings.warn(message)
logger.warn(message)
self.lag = lag
self.train_if = train_if
self.refit_condition = refit_condition
self.result_steps: Dict[str, ResultStep] = {}
self.retrain_batch = retrain_batch

Expand Down Expand Up @@ -115,11 +116,11 @@ def load(cls, stored_step: Dict, inputs, targets, module, file_manager):
condition = cloudpickle.load(pickle_file)
else:
condition = None
if stored_step["train_if"]:
with open(stored_step["train_if"], 'rb') as pickle_file:
train_if = cloudpickle.load(pickle_file)
if stored_step["refit_condition"]:
with open(stored_step["refit_condition"], 'rb') as pickle_file:
refit_condition = cloudpickle.load(pickle_file)
else:
train_if = None
refit_condition = None
callbacks = []
for callback_path in stored_step["callbacks"]:
with open(callback_path, 'rb') as pickle_file:
Expand All @@ -128,7 +129,7 @@ def load(cls, stored_step: Dict, inputs, targets, module, file_manager):
callbacks.append(callback)

step = cls(module, inputs, targets=targets, file_manager=file_manager, condition=condition,
train_if=train_if, callbacks=callbacks, batch_size=stored_step["batch_size"])
refit_condition=refit_condition, callbacks=callbacks, batch_size=stored_step["batch_size"])
step.default_run_setting = RunSetting.load(stored_step["default_run_setting"])
step.current_run_setting = step.default_run_setting.clone()
step.id = stored_step["id"]
Expand All @@ -139,10 +140,10 @@ def load(cls, stored_step: Dict, inputs, targets, module, file_manager):

def refit(self, start: pd.Timestamp, end: pd.Timestamp):
if self.computation_mode in [ComputationMode.Refit] and isinstance(self.module, BaseEstimator):
if self.train_if:
condition_input = {key: value.step.get_result(start, end) for key, value in self.train_if.kwargs.items()}
if self.train_if.evaluate(**condition_input):
# TODO should the same data be used for refitting and for calling the train_if condition?
if self.refit_condition:
condition_input = {key: value.step.get_result(start, end) for key, value in self.refit_condition.kwargs.items()}
if self.refit_condition.evaluate(**condition_input):
# TODO should the same data be used for refitting and for calling the refit_condition condition?
# NOPE -> Make it more flexible... Perhaps something for issue 147
refit_input = self._get_input(end - self.retrain_batch, end)
refit_target = self._get_target(end - self.retrain_batch, end)
Expand Down Expand Up @@ -193,24 +194,24 @@ def _get_input(self, start, batch):
def get_json(self, fm: FileManager):
json = super().get_json(fm)
condition_path = None
train_if_path = None
refit_condition_path = None
callbacks_paths = []
if self.condition:
condition_path = fm.get_path(f"{self.name}_condition.pickle")
with open(condition_path, 'wb') as outfile:
cloudpickle.dump(self.condition, outfile)
if self.train_if:
train_if_path = fm.get_path(f"{self.name}_train_if.pickle")
with open(train_if_path, 'wb') as outfile:
cloudpickle.dump(self.train_if, outfile)
if self.refit_condition:
refit_condition_path = fm.get_path(f"{self.name}_refit_condition.pickle")
with open(refit_condition_path, 'wb') as outfile:
cloudpickle.dump(self.refit_condition, outfile)
for callback in self.callbacks:
callback_path = fm.get_path(f"{self.name}_callback.pickle")
with open(callback_path, 'wb') as outfile:
cloudpickle.dump(callback, outfile)
callbacks_paths.append(callback_path)
json.update({"callbacks": callbacks_paths,
"condition": condition_path,
"train_if": train_if_path,
"refit_condition": refit_condition_path,
"batch_size": self.batch_size})
return json

Expand All @@ -222,10 +223,10 @@ def refit(self, start: pd.Timestamp, end: pd.Timestamp):
"""
if self.current_run_setting.computation_mode in [ComputationMode.Refit] and isinstance(self.module,
BaseEstimator):
if self.train_if:
if self.refit_condition:
input_data = self._get_input(start, end)
target = self._get_target(start, end)
if self.train_if(input_data, target):
if self.refit_condition(input_data, target):
refit_input = self._get_input(end - self.retrain_batch, end)
refit_target = self._get_target(end - self.retrain_batch, end)
self.module.refit(**refit_input, **refit_target)
Expand Down
12 changes: 6 additions & 6 deletions pywatts/core/step_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def create_step(self,
condition,
batch_size,
computation_mode,
train_if,
refit_condition,
retrain_batch,
lag):
"""
Expand All @@ -46,7 +46,7 @@ def create_step(self,
:param condition: A function returning True or False which indicates if the step should be performed
:param batch_size: The size of the past time range which should be used for relearning the module
:param computation_mode: The computation mode of the step
:param train_if: A method for determining if the step should be fitted at a specific timestamp.
:param refit_condition: A method for determining if the step should be fitted at a specific timestamp.
:param retrain_batch: Determines how much past data should be used for relearning.
:param lag: Needed for online learning. Determines what data can be used for retraining.
E.g., when 24 hour forecasts are performed, a lag of 24 hours is needed, else the retraining would
Expand Down Expand Up @@ -74,19 +74,19 @@ def create_step(self,
if isinstance(module, Pipeline):
step = PipelineStep(module, input_steps, pipeline.file_manager, targets=target_steps,
callbacks=callbacks, computation_mode=computation_mode, condition=condition,
batch_size=batch_size, train_if=train_if, retrain_batch=retrain_batch, lag=lag)
batch_size=batch_size, refit_condition=refit_condition, retrain_batch=retrain_batch, lag=lag)
elif use_inverse_transform:
step = InverseStep(module, input_steps, pipeline.file_manager, targets=target_steps,
callbacks=callbacks, computation_mode=computation_mode, condition=condition,
retrain_batch=retrain_batch, lag=lag)
refit_condition=refit_condition, retrain_batch=retrain_batch, lag=lag)
elif use_predict_proba:
step = ProbablisticStep(module, input_steps, pipeline.file_manager, targets=target_steps,
callbacks=callbacks, computation_mode=computation_mode, condition=condition,
retrain_batch=retrain_batch, lag=lag)
refit_condition=refit_condition, retrain_batch=retrain_batch, lag=lag)
else:
step = Step(module, input_steps, pipeline.file_manager, targets=target_steps,
callbacks=callbacks, computation_mode=computation_mode, condition=condition,
batch_size=batch_size, train_if=train_if, retrain_batch=retrain_batch, lag=lag)
batch_size=batch_size, refit_condition=refit_condition, retrain_batch=retrain_batch, lag=lag)

step_id = pipeline.add(module=step,
input_ids=[step.id for step in input_steps.values()],
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/core/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
'module_id': 0,
'name': 'StandardScaler',
'target_ids': {},
'train_if': None},
'refit_condition': None},
{'batch_size': None,
'callbacks': [],
'class': 'Step',
Expand All @@ -67,7 +67,7 @@
'module_id': 1,
'name': 'LinearRegression',
'target_ids': {},
'train_if': None}],
'refit_condition': None}],
'version': 1}


Expand Down
22 changes: 11 additions & 11 deletions tests/unit/core/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_store_load_of_step_with_condition(self, cloudpickle_mock, open_mock):
"input_ids": {},
"id": -1,
'default_run_setting': {'computation_mode': 4},
"train_if": None,
"refit_condition": None,
"module": "pywatts.core.step",
"class": "Step",
"name": "test",
Expand All @@ -65,19 +65,19 @@ def test_store_load_of_step_with_condition(self, cloudpickle_mock, open_mock):

@patch("builtins.open")
@patch("pywatts.core.step.cloudpickle")
def test_store_load_of_step_with_train_if(self, cloudpickle_mock, open_mock):
train_if_mock = MagicMock()
step = Step(self.module_mock, self.step_mock, None, train_if=train_if_mock)
def test_store_load_of_step_with_refit_condition(self, cloudpickle_mock, open_mock):
refit_condition_mock = MagicMock()
step = Step(self.module_mock, self.step_mock, None, refit_condition=refit_condition_mock)
fm_mock = MagicMock()
fm_mock.get_path.return_value = os.path.join("folder", "test_train_if.pickle")
fm_mock.get_path.return_value = os.path.join("folder", "refit_condition.pickle")
json = step.get_json(fm_mock)
reloaded_step = Step.load(json, [self.step_mock], targets=None, module=self.module_mock,
file_manager=MagicMock())

# One call in load and one in save
open_mock.assert_has_calls(
[call(os.path.join("folder", "test_train_if.pickle"), "wb"),
call(os.path.join("folder", "test_train_if.pickle"), "rb")],
[call(os.path.join("folder", "refit_condition.pickle"), "wb"),
call(os.path.join("folder", "refit_condition.pickle"), "rb")],
any_order=True)
self.assertEqual(json, {
"target_ids": {},
Expand All @@ -86,7 +86,7 @@ def test_store_load_of_step_with_train_if(self, cloudpickle_mock, open_mock):
"id": -1,
'batch_size': None,
'default_run_setting': {'computation_mode': 4},
"train_if": os.path.join("folder", "test_train_if.pickle"),
"refit_condition": os.path.join("folder", "refit_condition.pickle"),
"module": "pywatts.core.step",
"class": "Step",
"name": "test",
Expand All @@ -98,7 +98,7 @@ def test_store_load_of_step_with_train_if(self, cloudpickle_mock, open_mock):
self.assertEqual(reloaded_step.module, self.module_mock)
self.assertEqual(reloaded_step.input_steps, [self.step_mock])
cloudpickle_mock.load.assert_called_once_with(open_mock().__enter__.return_value)
cloudpickle_mock.dump.assert_called_once_with(train_if_mock, open_mock().__enter__.return_value)
cloudpickle_mock.dump.assert_called_once_with(refit_condition_mock, open_mock().__enter__.return_value)

@patch("pywatts.core.base_step._get_time_indexes", return_value=["time"])
@patch("pywatts.core.base_step.xr")
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_load(self):
"module": "pywatts.core.step",
"class": "Step",
"condition": None,
"train_if": None,
"refit_condition": None,
'callbacks': [],
"name": "test",
"last": False,
Expand All @@ -241,7 +241,7 @@ def test_get_json(self):
'module': 'pywatts.core.step',
'name': 'test',
'target_ids': {},
'train_if': None}, json)
'refit_condition': None}, json)

def test_set_run_setting(self):
step = Step(MagicMock(), MagicMock(), MagicMock())
Expand Down

0 comments on commit 93a611f

Please sign in to comment.