Skip to content

Commit

Permalink
Add ETT datasets (#3149)
Browse files Browse the repository at this point in the history
*Description of changes:* Add electricity transformer datasets from
https://github.com/zhouhaoyi/ETDataset


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
lostella authored Mar 22, 2024
1 parent 462d2ed commit 9f33c1b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
69 changes: 69 additions & 0 deletions src/gluonts/dataset/repository/_ett_small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pathlib import Path

import pandas as pd
from gluonts.dataset import DatasetWriter
from gluonts.dataset.common import MetaData, TrainDatasets


# Currently data from only two regions are made public.
NUM_REGIONS = 2


def generate_ett_small_dataset(
dataset_path: Path,
dataset_writer: DatasetWriter,
base_file_name: str,
freq: str,
prediction_length: int,
):
dfs = []
for i in range(NUM_REGIONS):
df = pd.read_csv(
f"https://raw.githubusercontent.com/zhouhaoyi/ETDataset"
f"/main/ETT-small/{base_file_name}{i+1}.csv"
)
df["date"] = df["date"].astype("datetime64[ms]")
dfs.append(df)

test = []
for df in dfs:
start = pd.Period(df["date"][0], freq=freq)
for col in df.columns:
if col in ["date"]:
continue
test.append(
{
"start": start,
"target": df[col].values,
}
)

train = []
for df in dfs:
start = pd.Period(df["date"][0], freq=freq)
for col in df.columns:
if col in ["date"]:
continue
train.append(
{
"start": start,
"target": df[col].values[:-prediction_length],
}
)

metadata = MetaData(freq=freq, prediction_length=prediction_length)
dataset = TrainDatasets(metadata=metadata, train=train, test=test)
dataset.save(str(dataset_path), writer=dataset_writer, overwrite=True)
13 changes: 13 additions & 0 deletions src/gluonts/dataset/repository/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ._artificial import generate_artificial_dataset
from ._airpassengers import generate_airpassengers_dataset
from ._ercot import generate_ercot_dataset
from ._ett_small import generate_ett_small_dataset
from ._gp_copula_2019 import generate_gp_copula_dataset
from ._lstnet import generate_lstnet_dataset
from ._m3 import generate_m3_dataset
Expand Down Expand Up @@ -243,6 +244,18 @@ def get_download_path() -> Path:
dataset_name="vehicle_trips_without_missing",
),
"ercot": partial(generate_ercot_dataset),
"ett_small_15min": partial(
generate_ett_small_dataset,
base_file_name="ETTm",
freq="15min",
prediction_length=24,
),
"ett_small_1h": partial(
generate_ett_small_dataset,
base_file_name="ETTh",
freq="1h",
prediction_length=24,
),
}


Expand Down

0 comments on commit 9f33c1b

Please sign in to comment.