Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ktr code clean up #638

Merged
merged 3 commits into from
Dec 5, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 19 additions & 42 deletions orbit/template/ktr.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,44 +296,11 @@ def _set_model_param_names(self):

def _set_default_args(self):
"""Set default attributes for None"""
# default checks for seasonality and seasonality_fs_order will be conducted
# in ktrlite model and we will extract them from ktrlite model directly later
if self.coef_prior_list is not None:
self._coef_prior_list = deepcopy(self.coef_prior_list)

# set default seasonality and related attributes
if self.seasonality is None:
self._seasonality = list()
self._seasonality_fs_order = list()
elif not isinstance(self._seasonality, list) and isinstance(self._seasonality, (int, float)):
self._seasonality = [self.seasonality]

# set some defaults for seasonality_fs_order
if self._seasonality and self._seasonality_fs_order is None:
self._seasonality_fs_order = [2] * len(self._seasonality)
elif not isinstance(self._seasonality_fs_order, list) and isinstance(self._seasonality_fs_order, (int, float)):
self._seasonality_fs_order = [self.seasonality_fs_order]

if len(self._seasonality_fs_order) != len(self._seasonality):
raise IllegalArgument('length of seasonality and fs_order not matching')

seasonality_labels = list()
for idx, order in enumerate(self._seasonality_fs_order):
if 2 * order > self._seasonality[idx] - 1:
raise IllegalArgument('reduce seasonality_fs_order to avoid over-fitting')
seasonality_labels.append('seasonality_{}'.format(self._seasonality[idx]))
self._seasonality_labels = seasonality_labels

# TODO: this is done by KTRLite; we may not need this for now
# if not isinstance(self.seasonal_initial_knot_scale, list) and \
# isinstance(self.seasonal_initial_knot_scale * 1.0, float):
# self._seasonal_initial_knot_scale = [self.seasonal_initial_knot_scale] * len(self._seasonality)
# else:
# self._seasonal_initial_knot_scale = self.seasonal_initial_knot_scale
#
# if not isinstance(self.seasonal_knot_scale, list) and isinstance(self.seasonal_knot_scale * 1.0, float):
# self._seasonal_knot_scale = [self.seasonal_knot_scale] * len(self._seasonality)
# else:
# self._seasonal_knot_scale = self.seasonal_knot_scale

# if no regressors, end here #
if self.regressor_col is None:
# regardless of what args are set for these, if regressor_col is None
Expand Down Expand Up @@ -687,7 +654,8 @@ def _generate_insample_tp(self, training_meta, date_array):
# coefs = np.squeeze(np.matmul(coef_knot, kernel_coef.transpose(1, 0)), axis=0).transpose(1, 0)
# return coefs

def _generate_seas(self, df, training_meta, coef_knot_dates, coef_knots, seasonality, seasonality_fs_order):
def _generate_seas(self, df, training_meta, coef_knot_dates, coef_knots,
seasonality, seasonality_fs_order, seasonality_labels):
"""To calculate the seasonality term based on the _seasonal_knots_input.
Parameters
----------
Expand Down Expand Up @@ -735,7 +703,7 @@ def _generate_seas(self, df, training_meta, coef_knot_dates, coef_knots, seasona
n=df.shape[0],
periods=seasonality,
orders=seasonality_fs_order,
labels=self._seasonality_labels,
labels=seasonality_labels,
shift=start,
)

Expand All @@ -746,7 +714,7 @@ def _generate_seas(self, df, training_meta, coef_knot_dates, coef_knots, seasona
# init of regression matrix depends on length of response vector
total_seas_regression = np.zeros((1, df.shape[0]), dtype=np.double)

for k in self._seasonality_labels:
for k in seasonality_labels:
seas_regresor_matrix = seas_regressors[k]
coef_knot = coef_knots[k]
# time-step x coefficients
Expand Down Expand Up @@ -775,7 +743,7 @@ def _set_levs_and_seas(self, df, training_meta):
level_knot_distance=self.level_knot_distance,
seasonality=self.seasonality,
seasonality_fs_order=self.seasonality_fs_order,
seasonal_initial_knot_scale=self.seasonal_knot_scale,
seasonal_initial_knot_scale=self.seasonal_initial_knot_scale,
seasonal_knot_scale=self.seasonal_knot_scale,
seasonality_segments=self.seasonality_segments,
degree_of_freedom=self.degree_of_freedom,
Expand All @@ -787,6 +755,13 @@ def _set_levs_and_seas(self, df, training_meta):
# self._ktrlite_model = ktrlite
ktrlite_pt_posteriors = ktrlite.get_point_posteriors()
ktrlite_obs_scale = ktrlite_pt_posteriors['map']['obs_scale']

# load _seasonality and _seasonality_fs_order
self._seasonality = ktrlite._model._seasonality
self._seasonality_fs_order = ktrlite._model._seasonality_fs_order
for seas in self._seasonality:
self._seasonality_labels.append('seasonality_{}'.format(seas))

# if input None for upper bound of residuals scale, use data-driven input
if self.residuals_scale_upper is None:
# make it 5 times to have some buffer in case we over-fit in KTRLite
Expand Down Expand Up @@ -836,7 +811,8 @@ def _set_levs_and_seas(self, df, training_meta):
self._seasonality_coef_knot_dates,
self._seasonality_coef_knots,
self._seasonality,
self._seasonality_fs_order)
self._seasonality_fs_order,
self._seasonality_labels)
# remove batch size as an input for models
self._seas_term = np.squeeze(self._seas_term, 0)

Expand All @@ -862,12 +838,12 @@ def _filter_coef_prior(self, df):

def set_dynamic_attributes(self, df, training_meta):
"""Overriding: func: `~orbit.models.BaseETS._set_dynamic_attributes"""
self._set_valid_response_attributes(training_meta)
self._set_regressor_matrix(df, training_meta)
self._set_coefficients_kernel_matrix(df, training_meta)
self._set_knots_scale_matrix(df, training_meta)
self._set_levs_and_seas(df, training_meta)
self._filter_coef_prior(df)
self._set_valid_response_attributes(training_meta)

@staticmethod
def _concat_regression_coefs(pr_beta=None, rr_beta=None):
Expand Down Expand Up @@ -971,7 +947,8 @@ def predict(self, posterior_estimates, df, training_meta, prediction_meta,
self._seasonality_coef_knot_dates,
self._seasonality_coef_knots,
self._seasonality,
self._seasonality_fs_order)
self._seasonality_fs_order,
self._seasonality_labels)

# # seas is 1-d array, add the batch size back
# seas = np.expand_dims(seas, 0)
Expand Down