From 22d4254c6f971228048a499e7d7dc0f6208c05e8 Mon Sep 17 00:00:00 2001 From: Jasper Zschiegner Date: Fri, 24 Nov 2023 15:01:28 +0100 Subject: [PATCH 1/2] Torch: Remove double caching of dataset. --- src/gluonts/torch/model/estimator.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 7cca653a15..372bdb8b9d 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -156,18 +156,16 @@ def train_model( transformation = self.create_transformation() with env._let(max_idle_transforms=max(len(training_data), 100)): - transformed_training_data: Union[ - Cached, TransformedDataset - ] = transformation.apply(training_data, is_train=True) + transformed_training_data: Dataset = transformation.apply( + training_data, is_train=True + ) if cache_data: transformed_training_data = Cached(transformed_training_data) training_network = self.create_lightning_module() training_data_loader = self.create_training_data_loader( - Cached(transformed_training_data) - if cache_data - else transformed_training_data, + transformed_training_data, training_network, shuffle_buffer_length=shuffle_buffer_length, ) @@ -176,9 +174,9 @@ def train_model( if validation_data is not None: with env._let(max_idle_transforms=max(len(validation_data), 100)): - transformed_validation_data: Union[ - Cached, TransformedDataset - ] = transformation.apply(validation_data, is_train=True) + transformed_validation_data: Dataset = transformation.apply( + validation_data, is_train=True + ) if cache_data: transformed_validation_data = Cached( transformed_validation_data From 859c8ed5b4c13ff6dfa79eed33e7fa443c25a8cb Mon Sep 17 00:00:00 2001 From: Jasper Zschiegner Date: Fri, 24 Nov 2023 17:51:04 +0100 Subject: [PATCH 2/2] Fixup. --- src/gluonts/torch/model/estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 372bdb8b9d..b8a1147d44 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import NamedTuple, Optional, Iterable, Dict, Any, Union +from typing import NamedTuple, Optional, Iterable, Dict, Any import logging import numpy as np @@ -24,7 +24,7 @@ from gluonts.itertools import Cached from gluonts.model import Estimator, Predictor from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.transform import Transformation, TransformedDataset +from gluonts.transform import Transformation logger = logging.getLogger(__name__)