Skip to content

Commit

Permalink
remove ListDataset from tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Sep 27, 2023
1 parent 41fdd72 commit aa8e3a8
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions test/model/npts/test_npts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
import pandas as pd
import pytest

from gluonts.dataset.common import DataEntry, Dataset, ListDataset
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.exceptions import GluonTSDataError
from gluonts.model.npts import KernelType, NPTSPredictor
from gluonts.model.npts._weighted_sampler import WeightedSampler


def get_test_data(history_length: int, freq: str) -> pd.Series:
def get_test_data(
history_length: int, freq: str, dtype=np.float32
) -> pd.Series:
index = pd.date_range("1/1/2011", periods=history_length, freq=freq)
return pd.Series(np.arange(len(index)), index=index)
return pd.Series(np.arange(len(index), dtype=dtype), index=index)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -100,10 +102,12 @@ def test_climatological_forecaster(
kernel_type=KernelType.uniform,
)

dataset = ListDataset(
[{"start": str(train_ts.index[0]), "target": train_ts.values}],
freq=freq,
)
dataset = [
{
"start": pd.Period(train_ts.index[0], freq=freq),
"target": train_ts.values,
}
]

# validate that the predictor works with targets with NaNs
_test_nans_in_target(predictor, dataset)
Expand Down Expand Up @@ -263,10 +267,12 @@ def test_npts_forecaster(
use_seasonal_model=use_seasonal_model,
)

dataset = ListDataset(
[{"start": str(train_ts.index[0]), "target": train_ts.values}],
freq=freq,
)
dataset = [
{
"start": pd.Period(train_ts.index[0], freq=freq),
"target": train_ts.values,
}
]

# validate that the predictor works with targets with NaNs
_test_nans_in_target(predictor, dataset)
Expand Down Expand Up @@ -427,16 +433,13 @@ def test_npts_custom_features(
use_default_time_features=False, # disable default time features
)

dataset = ListDataset(
[
{
"start": str(train_ts.index[0]),
"target": train_ts.values,
"feat_dynamic_real": feat_dynamic_real,
}
],
freq=freq,
)
dataset = [
{
"start": pd.Period(train_ts.index[0], freq=freq),
"target": train_ts.values,
"feat_dynamic_real": np.array(feat_dynamic_real),
}
]

# validate that the predictor works with targets with NaNs
_test_nans_in_target(predictor, dataset)
Expand Down

0 comments on commit aa8e3a8

Please sign in to comment.