Skip to content

Commit

Permalink
add validation option to TrainDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Pedro Eduardo Mercado Lopez committed Nov 30, 2023
1 parent 7dbe5e6 commit 8cfc8a9
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ class SourceContext(NamedTuple):
class TrainDatasets(NamedTuple):
"""
A dataset containing two subsets, one to be used for training purposes, and
the other for testing purposes, as well as metadata.
the other for validation and testing purposes, as well as metadata.
"""

metadata: MetaData
train: Dataset
validation: Optional[Dataset] = None
test: Optional[Dataset] = None

def save(
Expand Down Expand Up @@ -114,6 +115,11 @@ def save(
test.mkdir(parents=True)
writer.write_to_folder(self.test, test)

if self.validation is not None:
validation = path / "validation"
validation.mkdir(parents=True)
writer.write_to_folder(self.validation, validation)


def infer_file_type(path):
suffix = "".join(path.suffixes)
Expand Down Expand Up @@ -427,6 +433,7 @@ def __call__(self, data: DataEntry) -> DataEntry:
def load_datasets(
metadata: Path,
train: Path,
validation: Optional[Path],
test: Optional[Path],
one_dim_target: bool = True,
cache: bool = False,
Expand All @@ -442,6 +449,8 @@ def load_datasets(
Path to the training dataset files.
test
Path to the test dataset files.
validation
Path to the validation dataset files.
one_dim_target
Whether to load FileDatasets as univariate target time series.
cache
Expand All @@ -467,4 +476,17 @@ def load_datasets(
else None
)

return TrainDatasets(metadata=meta, train=train_ds, test=test_ds)
validation_ds = (
FileDataset(
path=validation,
freq=meta.freq,
one_dim_target=one_dim_target,
cache=cache,
)
if validation
else None
)

return TrainDatasets(
metadata=meta, train=train_ds, validation=validation_ds, test=test_ds
)

0 comments on commit 8cfc8a9

Please sign in to comment.