Skip to content
This repository has been archived by the owner on Jun 22, 2024. It is now read-only.

Commit

Permalink
Add improvements for online_learning
Browse files Browse the repository at this point in the history
  • Loading branch information
benHeid committed Feb 28, 2022
1 parent 93a611f commit 2e41481
Show file tree
Hide file tree
Showing 18 changed files with 196 additions and 64 deletions.
5 changes: 5 additions & 0 deletions pywatts/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, name: str):
self.has_predict_proba = False
# TODO each module needs a function for returning how much past values it needs for executing the transformation.
# SEE Issue 147

@abstractmethod
def get_params(self) -> Dict[str, object]:
"""
Expand Down Expand Up @@ -139,6 +140,10 @@ def refit(self, **kwargs):
"""
return self.fit(**kwargs)

def get_min_data(self):
# TODO hacky solution
return pd.Timedelta("0h")

def __call__(self,
use_inverse_transform: bool = False,
use_prob_transform: bool = False,
Expand Down
37 changes: 32 additions & 5 deletions pywatts/core/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def __init__(self, input_steps: Optional[Dict[str, "BaseStep"]] = None,
self.buffer: Dict[str, xr.DataArray] = {}
self.training_time = SummaryObjectList(self.name + " Training Time", category=SummaryCategory.FitTime)
self.transform_time = SummaryObjectList(self.name + " Transform Time", category=SummaryCategory.TransformTime)
self.refit_time = SummaryObjectList(self.name + " Refit Time", category=SummaryCategory.RefitTime)
self.additional_summary = SummaryObjectList(self.name + " Additional Information",
category=SummaryCategory.AdditionalModuleInformation)

def get_summaries(self):
return [self.transform_time, self.transform_time, self.refit_time]

def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_element: str = None,
return_all=False):
Expand Down Expand Up @@ -91,8 +97,10 @@ def get_result(self, start: pd.Timestamp, end: Optional[pd.Timestamp], buffer_el
# Check if the cached results fits to the request, if yes return it.
if self.cached_result["cached"] is not None and self.cached_result["start"] == start and self.cached_result[
"end"] == end:
return copy.deepcopy(self.cached_result["cached"]) if return_all else copy.deepcopy(self.cached_result["cached"][
buffer_element]) if buffer_element is not None else copy.deepcopy(list(self.cached_result["cached"].values())[
return copy.deepcopy(self.cached_result["cached"]) if return_all else copy.deepcopy(
self.cached_result["cached"][
buffer_element]) if buffer_element is not None else copy.deepcopy(
list(self.cached_result["cached"].values())[
0])
return self._pack_data(start, end, buffer_element, return_all=return_all)

Expand Down Expand Up @@ -121,6 +129,7 @@ def further_elements(self, counter: pd.Timestamp) -> bool:

def _pack_data(self, start, end, buffer_element=None, return_all=False):
# Provide requested data
# TODO Refactor
time_index = _get_time_indexes(self.buffer)
if end and start and end > start:
index = list(self.buffer.values())[0].indexes[time_index[0]]
Expand All @@ -135,6 +144,19 @@ def _pack_data(self, start, end, buffer_element=None, return_all=False):
else:
return list(self.buffer.values())[0].sel(
**{time_index[0]: index[(index >= start) & (index < end.to_numpy())]})
elif start and end is None:
index = list(self.buffer.values())[0].indexes[time_index[0]]
start = max(index[0], start.to_numpy())
# After sel copy is not needed, since it returns a new array.
if buffer_element is not None:
return self.buffer[buffer_element].sel(
**{time_index[0]: index[(index >= start)]})
elif return_all:
return {key: b.sel(**{time_index[0]: index[(index >= start)]}) for
key, b in self.buffer.items()}
else:
return list(self.buffer.values())[0].sel(
**{time_index[0]: index[(index >= start)]})
else:
self.finished = True
if buffer_element is not None:
Expand All @@ -144,6 +166,9 @@ def _pack_data(self, start, end, buffer_element=None, return_all=False):
else:
return list(self.buffer.values())[0].copy()

def create_summary(self):
pass

def _transform(self, input_step):
pass

Expand All @@ -163,7 +188,8 @@ def _post_transform(self, result):
# Time dimension is mandatory, consequently there dim has to exist
dim = _get_time_indexes(result)[0]
for key in self.buffer.keys():
self.buffer[key] = xr.concat([self.buffer[key], result[key]], dim=dim)
last = self.buffer[key][dim].values[-1]
self.buffer[key] = xr.concat([self.buffer[key], result[key][result[key][dim] > last]], dim=dim)
return result

def get_json(self, fm: FileManager) -> Dict:
Expand Down Expand Up @@ -225,11 +251,12 @@ def _should_stop(self, start, end) -> bool:
def _input_stopped(input_data):
return (input_data is not None and len(input_data) > 0 and any(map(lambda x: x is None, input_data.values())))

def reset(self):
def reset(self, keep_buffer=False):
"""
Resets all information of the step concerning a specific run.
"""
self.buffer = {}
if not keep_buffer:
self.buffer = {}
self.finished = False
self.current_run_setting = self.default_run_setting.clone()

Expand Down
2 changes: 1 addition & 1 deletion pywatts/core/either_or_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, input_steps):

def _compute(self, start, end):
input_data = self._get_input(start, end)
return self._transform(input_data)
return self._transform(input_data, start, end)

def _get_input(self, start, batch):
inputs = []
Expand Down
84 changes: 60 additions & 24 deletions pywatts/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class Pipeline(BaseTransformer):
:type batch: Optional[pd.Timedelta]
"""

def __init__(self, path: Optional[str] = ".", batch=pd.Timedelta(hours=0), name="Pipeline"):
def __init__(self, path: Optional[str] = ".", batch: Optional[pd.Timedelta] = pd.Timedelta(hours=0),
name="Pipeline"):
super().__init__(name)
self.batch = batch
self.counter = None
Expand All @@ -71,24 +72,37 @@ def transform(self, **x: xr.DataArray) -> xr.DataArray:
:return:The transformed data
:rtype: xr.DataArray
"""
if self.current_run_setting.online_start is not None:
index_name = _get_time_indexes(x)[0]
time_index = list(x.values())[0][_get_time_indexes(x)[0]]
if time_index[0].values < self.current_run_setting.online_start:
return self._transform(x, None)
else:
return self._comp(x, self.current_run_setting.summary_formatter, self.batch)[0]
else:
return self._comp(x, self.current_run_setting.summary_formatter, self.batch)[0]

def _transform(self, x, batch=None):
for key, (start_step, _) in self.start_steps.items():
start_step.buffer = {key: x[key].copy()}
if not start_step.buffer:
start_step.buffer = {key: x[key].copy()}
else:
dim = _get_time_indexes(x)[0]
last = start_step.buffer[key][dim].values[-1]
start_step.buffer[key] = xr.concat([start_step.buffer[key], x[key][x[key][dim] > last]], dim=dim)
start_step.finished = True

time_index = _get_time_indexes(x)
self.counter = list(x.values())[0].indexes[time_index[0]][0] # The start date of the input time series.

last_steps = list(filter(lambda x: x.last, self.id_to_step.values()))

if not self.batch:
if not batch:
return self._collect_results(last_steps)
return self._collect_batches(last_steps, time_index)
return self._collect_batches(last_steps, time_index, batch)

def _collect_batches(self, last_steps, time_index):
def _collect_batches(self, last_steps, time_index, batch):
result = {}
while all(map(lambda step: step.further_elements(self.counter), last_steps)):
print(self.counter)
res = self._collect_results(last_steps)
res = self._collect_results(last_steps, batch=batch)
if res is not None:
for key in res:
result[key] = xr.concat([result[key], res[key]], dim=time_index[0]) if key in result else res[key]
Expand All @@ -105,9 +119,9 @@ def _collect_batches(self, last_steps, time_index):
self.counter += self.batch
return result

def _collect_results(self, inputs):
def _collect_results(self, inputs, batch=None):
# Note the return value is None if none of the inputs provide a result for this step...
end = None if not self.batch else self.counter + self.batch
end = None if not batch else self.counter + batch
result = dict()
for i, step in enumerate(inputs):
if not isinstance(step, SummaryStep):
Expand Down Expand Up @@ -154,7 +168,7 @@ def draw(self):
# TODO built the graph which should be drawn by starting with the last steps...

def test(self, data: Union[pd.DataFrame, xr.Dataset], summary: bool = False,
summary_formatter: SummaryFormatter = SummaryMarkdown()):
summary_formatter: SummaryFormatter = SummaryMarkdown(), online_start=None):
"""
Executes all modules in the pipeline in the correct order. This method call only transform on every module
if the ComputationMode is Default. I.e. if no computationMode is specified during the addition of the module to
Expand All @@ -169,7 +183,7 @@ def test(self, data: Union[pd.DataFrame, xr.Dataset], summary: bool = False,
:return: The result of all end points of the pipeline
:rtype: Dict[xr.DataArray]
"""
return self._run(data, ComputationMode.Transform, summary, summary_formatter)
return self._run(data, ComputationMode.Transform, summary, summary_formatter, online_start)

def train(self, data: Union[pd.DataFrame, xr.Dataset], summary: bool = False,
summary_formatter: SummaryFormatter = SummaryMarkdown()):
Expand All @@ -191,19 +205,33 @@ def train(self, data: Union[pd.DataFrame, xr.Dataset], summary: bool = False,
return self._run(data, ComputationMode.FitTransform, summary, summary_formatter)

def _run(self, data: Union[pd.DataFrame, xr.Dataset], mode: ComputationMode, summary: bool,
summary_formatter: SummaryFormatter):
summary_formatter: SummaryFormatter, online_start=None):

self.current_run_setting = RunSetting(computation_mode=mode,
summary_formatter=summary_formatter,
online_start=online_start)
for step in self.id_to_step.values():
step.reset()
step.set_run_setting(RunSetting(computation_mode=mode, summary_formatter=summary_formatter))
step.set_run_setting(run_setting=self.current_run_setting)

if isinstance(data, pd.DataFrame):
data = data.to_xarray()

if isinstance(data, xr.Dataset):
result = self.transform(**{key: data[key] for key in data.data_vars})
sum = self._create_summary(summary_formatter)
return (result, sum) if summary else result
if self.current_run_setting.online_start is not None:
index_name = _get_time_indexes(data)[0]
self._transform({key: data[key].sel(
**{index_name: data[key][index_name] < online_start}) for key in data.data_vars}, False)
for step in self.id_to_step.values():
step.reset(keep_buffer=True)
step.set_run_setting(RunSetting(computation_mode=mode,
summary_formatter=summary_formatter,
online_start=online_start))
return self._comp({key: data[key].sel(
**{index_name: data[key][index_name] >= online_start}) for key in data.data_vars}
, summary_formatter, self.batch, start=online_start)
else:
return self._comp({key: data[key] for key in data.data_vars}, summary_formatter, self.batch)
elif isinstance(data, dict):
for key in data:
if not isinstance(data[key], xr.DataArray):
Expand All @@ -212,16 +240,24 @@ def _run(self, data: Union[pd.DataFrame, xr.Dataset], mode: ComputationMode, sum
"Make sure to pass Dict[str, xr.DataArray].",
self.name
)
result = self.transform(**data)
sum = self._create_summary(summary_formatter)
return (result, sum) if summary else result
if online_start is not None:
self._transform(data[:online_start], self.batch)
return self._compute(data[online_start:], summary_formatter, self.batch)

else:
return self._compute(data, summary_formatter, self.batch)

raise WrongParameterException(
"Unkown data type to pass to pipeline steps.",
"Make sure to use pandas DataFrames, xarray Datasets, or Dict[str, xr.DataArray].",
self.name
)

def _comp(self, data, summary_formatter, batch, start=None, end=None):
result = self._transform(data, batch)
sum = self._create_summary(summary_formatter, start, end)
return (result, sum) if sum else result

def add(self, *,
module: Union[BaseStep],
input_ids: List[int] = None,
Expand Down Expand Up @@ -415,12 +451,12 @@ def __getitem__(self, item: str):
start_step.id = self.add(module=start_step, input_ids=[], target_ids=[])
return self.start_steps[item][-1]

def _create_summary(self, summary_formatter):
def _create_summary(self, summary_formatter, start=None, end=None):
summaries = []
for step in self.id_to_step.values():
if isinstance(step, SummaryStep):
summaries.append(step.get_summary())
summaries.extend([step.transform_time, step.training_time])
summaries.append(step.get_summary(start, end))
summaries.extend(step.get_summaries())
return summary_formatter.create_summary(summaries, self.file_manager)

def refit(self, start, end):
Expand Down
24 changes: 17 additions & 7 deletions pywatts/core/pipeline_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from pywatts.core.exceptions import NotFittedException
from pywatts.core.pipeline import Pipeline
from pywatts.core.step import Step

from pywatts.core.base_summary import BaseSummary
logger = logging.getLogger(__name__)


class PipelineStep(Step):
"""
This step is necessary for subpipelining. Since it contains functionality for adding a pipeline as a
Expand Down Expand Up @@ -45,26 +46,35 @@ def set_run_setting(self, run_setting: RunSetting):
for step in self.module.id_to_step.values():
step.set_run_setting(run_setting)

self.module.current_run_setting = self.current_run_setting

def _post_transform(self, result):
self.module._create_summary(self.current_run_setting.summary_formatter)
return super()._post_transform(result)

def reset(self):
def reset(self, keep_buffer=False):
"""
Resets all information of the step concerning a specific run. Furthermore, it resets also all steps
of the subpipeline.
"""
super().reset()
super().reset(keep_buffer=keep_buffer)
for step in self.module.id_to_step.values():
step.reset()
step.reset(keep_buffer=keep_buffer)

def _transform(self, input_step):
if isinstance(self.module, BaseEstimator) and not self.module.is_fitted:
message = f"Try to call transform in {self.name} on not fitted module {self.module.name}"
logger.error(message)
raise NotFittedException(message, self.name, self.module.name)
result = self.module.transform(**input_step)
if self.refit_summary is None:
self.refit_summary = f"# {self.name} Refit Summary:\n"
self.refit_summary += (self.module.create_summary())
# self.additional_summary = self.module.create_summary()
return self._post_transform(result)

def get_summaries(self):
summaries = []
for m in self.module.id_to_step.values():
if isinstance(m, BaseSummary):
summaries.append(m.get_summary())
summaries.extend(m.get_summaries())

return summaries
8 changes: 6 additions & 2 deletions pywatts/core/run_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class RunSetting:
:type summary_formatter: SummaryFormatter
"""

def __init__(self, computation_mode: ComputationMode, summary_formatter: SummaryFormatter = SummaryMarkdown()):
def __init__(self, computation_mode: ComputationMode, summary_formatter: SummaryFormatter = SummaryMarkdown(),
online_start=False):
self.computation_mode = computation_mode
self.summary_formatter = summary_formatter
self.online_start = online_start

def update(self, run_setting: 'RunSetting') -> 'RunSetting':
"""
Expand All @@ -31,6 +33,7 @@ def update(self, run_setting: 'RunSetting') -> 'RunSetting':
if setting.computation_mode == ComputationMode.Default:
setting.computation_mode = run_setting.computation_mode
setting.summary_formatter = run_setting.summary_formatter
setting.online_start = run_setting.online_start
return setting

def clone(self) -> 'RunSetting':
Expand All @@ -41,7 +44,8 @@ def clone(self) -> 'RunSetting':
"""
return RunSetting(
computation_mode=self.computation_mode,
summary_formatter=self.summary_formatter
summary_formatter=self.summary_formatter,
online_start = self.online_start
)

def save(self) -> Dict:
Expand Down
Loading

0 comments on commit 2e41481

Please sign in to comment.