Skip to content

Commit

Permalink
Fix gluonts.mx.trainer.Trainer in case of empty data loader (#2228)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Aug 26, 2022
1 parent 300e1fe commit a67738b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/gluonts/mx/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,12 @@ def loop( # todo call run epoch
)

batch_iter = itertools.islice(batch_iter, num_batches_to_use)

it = tqdm(batch_iter, total=num_batches_to_use)
any_batches = False

for batch_no, batch in enumerate(it, start=1):
any_batches = True

# `batch` here is expected to be a dictionary whose fields
# should correspond 1-to-1 with the network inputs
# see below how `batch.values()` is fed into the network
Expand Down Expand Up @@ -406,6 +409,13 @@ def loop( # todo call run epoch
)
it.close()

if not any_batches:
raise GluonTSDataError(
"No training data batch could be constructed; "
"this usually indicates that the training dataset "
"is empty, or consists of too short series."
)

# mark epoch end time and log time cost of current epoch
toc = time.time()
logger.info(
Expand Down
30 changes: 30 additions & 0 deletions test/mx/test_no_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pytest

from gluonts.exceptions import GluonTSDataError
from gluonts.mx.model.deepar import DeepAREstimator
from gluonts.mx.trainer import Trainer


@pytest.mark.parametrize("dataset", [[]])
def test_deepar_no_batches(dataset):
estimator = DeepAREstimator(
prediction_length=10,
freq="H",
trainer=Trainer(epochs=1, num_batches_per_epoch=1),
)

with pytest.raises(GluonTSDataError):
estimator.train(dataset)

0 comments on commit a67738b

Please sign in to comment.