Skip to content

Commit

Permalink
Fix leaderboard with static features (open-mmlab#2398)
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur authored Nov 15, 2022
1 parent 6b5f851 commit 7815dd7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
2 changes: 1 addition & 1 deletion timeseries/src/autogluon/timeseries/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def score(
return self.load_trainer().score(data=data, model=model, metric=metric)

def leaderboard(self, data: Optional[TimeSeriesDataFrame] = None) -> pd.DataFrame:
if self.static_feature_pipeline.is_fit():
if data is not None and self.static_feature_pipeline.is_fit():
fix_message = (
"Please make sure that data has static_features with columns and dtypes exactly matching "
"train_data.static_features. "
Expand Down
34 changes: 34 additions & 0 deletions timeseries/tests/unittests/test_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,37 @@ def test_when_ignore_index_is_true_and_known_covariates_available_then_learner_c
known_covariates = get_data_frame_with_variable_lengths(ITEM_ID_TO_LENGTH, known_covariates_names=["X", "Y"])
preds = learner.predict(train_data, known_covariates=known_covariates)
assert preds.item_ids.equals(train_data.item_ids)


@pytest.mark.parametrize("pred_data_present", [True, False])
@pytest.mark.parametrize("static_features_present", [True, False])
@pytest.mark.parametrize("known_covariates_present", [True, False])
def test_when_train_data_has_static_or_dynamic_feat_then_leaderboard_works(
temp_model_path, pred_data_present, static_features_present, known_covariates_present
):
if static_features_present:
static_features = get_static_features(["B", "A"], feature_names=["f1", "f2"])
else:
static_features = None

if known_covariates_present:
known_covariates_names = ["X", "Y"]
else:
known_covariates_names = None

train_data = get_data_frame_with_variable_lengths(
{"B": 20, "A": 15}, static_features=static_features, known_covariates_names=known_covariates_names
)

if pred_data_present:
pred_data = get_data_frame_with_variable_lengths(
{"B": 20, "A": 15}, static_features=static_features, known_covariates_names=known_covariates_names
)
else:
pred_data = None

learner = TimeSeriesLearner(path_context=temp_model_path)
learner.fit(train_data=train_data, hyperparameters=HYPERPARAMETERS_DUMMY)
leaderboard = learner.leaderboard(data=pred_data)
assert len(leaderboard) > 0
assert ("score_test" in leaderboard.columns) == pred_data_present

0 comments on commit 7815dd7

Please sign in to comment.