Skip to content

Commit

Permalink
ktr code clean up (#638)
Browse files Browse the repository at this point in the history
* ktr code clean up

* seasonality label clean up

* seasonality label clean up
  • Loading branch information
wangzhishi authored Dec 5, 2021
1 parent ba46035 commit 3c10089
Showing 1 changed file with 19 additions and 42 deletions.
61 changes: 19 additions & 42 deletions orbit/template/ktr.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,44 +296,11 @@ def _set_model_param_names(self):

def _set_default_args(self):
"""Set default attributes for None"""
# default checks for seasonality and seasonality_fs_order will be conducted
# in ktrlite model and we will extract them from ktrlite model directly later
if self.coef_prior_list is not None:
self._coef_prior_list = deepcopy(self.coef_prior_list)

# set default seasonality and related attributes
if self.seasonality is None:
self._seasonality = list()
self._seasonality_fs_order = list()
elif not isinstance(self._seasonality, list) and isinstance(self._seasonality, (int, float)):
self._seasonality = [self.seasonality]

# set some defaults for seasonality_fs_order
if self._seasonality and self._seasonality_fs_order is None:
self._seasonality_fs_order = [2] * len(self._seasonality)
elif not isinstance(self._seasonality_fs_order, list) and isinstance(self._seasonality_fs_order, (int, float)):
self._seasonality_fs_order = [self.seasonality_fs_order]

if len(self._seasonality_fs_order) != len(self._seasonality):
raise IllegalArgument('length of seasonality and fs_order not matching')

seasonality_labels = list()
for idx, order in enumerate(self._seasonality_fs_order):
if 2 * order > self._seasonality[idx] - 1:
raise IllegalArgument('reduce seasonality_fs_order to avoid over-fitting')
seasonality_labels.append('seasonality_{}'.format(self._seasonality[idx]))
self._seasonality_labels = seasonality_labels

# TODO: this is done by KTRLite; we may not need this for now
# if not isinstance(self.seasonal_initial_knot_scale, list) and \
# isinstance(self.seasonal_initial_knot_scale * 1.0, float):
# self._seasonal_initial_knot_scale = [self.seasonal_initial_knot_scale] * len(self._seasonality)
# else:
# self._seasonal_initial_knot_scale = self.seasonal_initial_knot_scale
#
# if not isinstance(self.seasonal_knot_scale, list) and isinstance(self.seasonal_knot_scale * 1.0, float):
# self._seasonal_knot_scale = [self.seasonal_knot_scale] * len(self._seasonality)
# else:
# self._seasonal_knot_scale = self.seasonal_knot_scale

# if no regressors, end here #
if self.regressor_col is None:
# regardless of what args are set for these, if regressor_col is None
Expand Down Expand Up @@ -687,7 +654,8 @@ def _generate_insample_tp(self, training_meta, date_array):
# coefs = np.squeeze(np.matmul(coef_knot, kernel_coef.transpose(1, 0)), axis=0).transpose(1, 0)
# return coefs

def _generate_seas(self, df, training_meta, coef_knot_dates, coef_knots, seasonality, seasonality_fs_order):
def _generate_seas(self, df, training_meta, coef_knot_dates, coef_knots,
seasonality, seasonality_fs_order, seasonality_labels):
"""To calculate the seasonality term based on the _seasonal_knots_input.
Parameters
----------
Expand Down Expand Up @@ -735,7 +703,7 @@ def _generate_seas(self, df, training_meta, coef_knot_dates, coef_knots, seasona
n=df.shape[0],
periods=seasonality,
orders=seasonality_fs_order,
labels=self._seasonality_labels,
labels=seasonality_labels,
shift=start,
)

Expand All @@ -746,7 +714,7 @@ def _generate_seas(self, df, training_meta, coef_knot_dates, coef_knots, seasona
# init of regression matrix depends on length of response vector
total_seas_regression = np.zeros((1, df.shape[0]), dtype=np.double)

for k in self._seasonality_labels:
for k in seasonality_labels:
seas_regresor_matrix = seas_regressors[k]
coef_knot = coef_knots[k]
# time-step x coefficients
Expand Down Expand Up @@ -775,7 +743,7 @@ def _set_levs_and_seas(self, df, training_meta):
level_knot_distance=self.level_knot_distance,
seasonality=self.seasonality,
seasonality_fs_order=self.seasonality_fs_order,
seasonal_initial_knot_scale=self.seasonal_knot_scale,
seasonal_initial_knot_scale=self.seasonal_initial_knot_scale,
seasonal_knot_scale=self.seasonal_knot_scale,
seasonality_segments=self.seasonality_segments,
degree_of_freedom=self.degree_of_freedom,
Expand All @@ -787,6 +755,13 @@ def _set_levs_and_seas(self, df, training_meta):
# self._ktrlite_model = ktrlite
ktrlite_pt_posteriors = ktrlite.get_point_posteriors()
ktrlite_obs_scale = ktrlite_pt_posteriors['map']['obs_scale']

# load _seasonality and _seasonality_fs_order
self._seasonality = ktrlite._model._seasonality
self._seasonality_fs_order = ktrlite._model._seasonality_fs_order
for seas in self._seasonality:
self._seasonality_labels.append('seasonality_{}'.format(seas))

# if input None for upper bound of residuals scale, use data-driven input
if self.residuals_scale_upper is None:
# make it 5 times to have some buffer in case we over-fit in KTRLite
Expand Down Expand Up @@ -836,7 +811,8 @@ def _set_levs_and_seas(self, df, training_meta):
self._seasonality_coef_knot_dates,
self._seasonality_coef_knots,
self._seasonality,
self._seasonality_fs_order)
self._seasonality_fs_order,
self._seasonality_labels)
# remove batch size as an input for models
self._seas_term = np.squeeze(self._seas_term, 0)

Expand All @@ -862,12 +838,12 @@ def _filter_coef_prior(self, df):

def set_dynamic_attributes(self, df, training_meta):
"""Overriding: func: `~orbit.models.BaseETS._set_dynamic_attributes"""
self._set_valid_response_attributes(training_meta)
self._set_regressor_matrix(df, training_meta)
self._set_coefficients_kernel_matrix(df, training_meta)
self._set_knots_scale_matrix(df, training_meta)
self._set_levs_and_seas(df, training_meta)
self._filter_coef_prior(df)
self._set_valid_response_attributes(training_meta)

@staticmethod
def _concat_regression_coefs(pr_beta=None, rr_beta=None):
Expand Down Expand Up @@ -971,7 +947,8 @@ def predict(self, posterior_estimates, df, training_meta, prediction_meta,
self._seasonality_coef_knot_dates,
self._seasonality_coef_knots,
self._seasonality,
self._seasonality_fs_order)
self._seasonality_fs_order,
self._seasonality_labels)

# # seas is 1-d array, add the batch size back
# seas = np.expand_dims(seas, 0)
Expand Down

0 comments on commit 3c10089

Please sign in to comment.