Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass multiprocessing parameters through the shell #952

Merged
merged 9 commits into from
Jul 27, 2020
45 changes: 27 additions & 18 deletions src/gluonts/dataset/parallelized_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@


# Standard library imports
import collections
import functools
import io
import itertools
import logging
import multiprocessing
import multiprocessing.queues
import pathlib
import pickle
import random
import sys
Expand All @@ -30,7 +28,7 @@
from multiprocessing.pool import Pool
from multiprocessing.reduction import ForkingPickler
from queue import Queue
from typing import Any, Callable, Iterable, Iterator, List, Optional, Union
from typing import Any, Callable, Iterator, List, Optional, Union

import mxnet as mx

Expand Down Expand Up @@ -86,7 +84,7 @@ def reduce_ndarray(data):


def _is_stackable(
arrays: List[Union[np.ndarray, mx.nd.NDArray, Any]], axis: int = 0,
arrays: List[Union[np.ndarray, mx.nd.NDArray, Any]], axis: int = 0
) -> bool:
"""
Check if elements are scalars, have too few dimensions, or their
Expand All @@ -99,7 +97,7 @@ def _is_stackable(


def _pad_arrays(
data: List[Union[np.ndarray, mx.nd.NDArray]], axis: int = 0,
data: List[Union[np.ndarray, mx.nd.NDArray]], axis: int = 0
) -> List[Union[np.ndarray, mx.nd.NDArray]]:
assert isinstance(data[0], (np.ndarray, mx.nd.NDArray))
is_mx = isinstance(data[0], mx.nd.NDArray)
Expand Down Expand Up @@ -239,9 +237,7 @@ def _sequential_sample_generator(
cyclic: bool,
) -> Iterator[DataEntry]:
while True:
yield from transformation(
data_it=dataset, is_train=is_train,
)
yield from transformation(data_it=dataset, is_train=is_train)
# Dont cycle if not training time
if not cyclic:
return
Expand Down Expand Up @@ -365,7 +361,7 @@ class ShuffleIter(Iterator[DataEntry]):

def __init__(
self, base_iterator: Iterator[DataEntry], shuffle_buffer_length: int
):
) -> None:
self.shuffle_buffer: list = []
self.shuffle_buffer_length = shuffle_buffer_length
self.base_iterator = base_iterator
Expand Down Expand Up @@ -418,7 +414,7 @@ def __init__(
dataset_len: int,
timeout: int,
shuffle_buffer_length: Optional[int],
):
) -> None:
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._data_buffer: dict = (
Expand Down Expand Up @@ -471,6 +467,7 @@ def _push_next(self) -> None:
def __next__(self) -> DataBatch:
# Try to get a batch, sometimes its possible that an iterator was
# exhausted and thus we don't get a new batch
logger = logging.getLogger(__name__)
success = False
while not success:
try:
Expand Down Expand Up @@ -508,15 +505,17 @@ def __next__(self) -> DataBatch:
# or return with the right context straight away
return _as_in_context(batch, self._ctx)
except multiprocessing.context.TimeoutError:
print(
logger.error(
f"Worker timed out after {self._timeout} seconds. This might be caused by "
"\n - Slow transform. Please increase timeout to allow slower data loading in each worker. "
"\n - Insufficient shared_memory if `timeout` is large enough. "
"\n Please consider to reduce `num_workers` or increase shared_memory in system."
)
raise
except Exception:
print("An unexpected error occurred in the WorkerIterator.")
except Exception as e:
logger.error(
f"An unexpected error occurred in the WorkerIterator: {e}."
)
self._worker_pool.terminate()
raise
return {}
Expand Down Expand Up @@ -580,10 +579,11 @@ def __init__(
num_prefetch: Optional[int] = None,
num_workers: Optional[int] = None,
shuffle_buffer_length: Optional[int] = None,
):
) -> None:
self.logger = logging.getLogger(__name__)
# Some windows error with the ForkingPickler prevents usage currently:
if sys.platform == "win32":
logging.warning(
self.logger.warning(
"You have set `num_workers` to a non zero value, "
"however, currently multiprocessing is not supported on windows and therefore"
"`num_workers will be set to 0."
Expand All @@ -593,7 +593,7 @@ def __init__(
if num_workers is not None and num_workers > 0:
if isinstance(dataset, FileDataset):
if not dataset.cache:
logging.warning(
self.logger.warning(
"You have set `num_workers` to a non zero value, "
"however, you have not enabled caching for your FileDataset. "
"To improve training performance you can enable caching for the FileDataset. "
Expand Down Expand Up @@ -637,11 +637,17 @@ def __init__(
num_workers if num_workers is not None else default_num_workers,
self.dataset_len,
) # cannot have more than dataset entries
self.logger.info(
f"gluonts[multiprocessing]: num_workers={self.num_workers}"
)
self.num_prefetch = (
num_prefetch if num_prefetch is not None else 2 * self.num_workers
)
self.logger.info(
f"gluonts[multiprocessing]: num_prefetch={self.num_prefetch}"
)
if self.num_prefetch < self.num_workers:
logging.warning(
self.logger.warning(
"You have set `num_prefetch` to less than `num_workers`, which is counter productive."
"If you want to reduce load, reduce `num_workers`."
)
Expand All @@ -652,6 +658,9 @@ def __init__(
# In order to recycle unused but pre-calculated batches from last epoch for training:
self.multi_worker_cache: Optional[Iterator[DataBatch]] = None
self.shuffle_buffer_length: Optional[int] = shuffle_buffer_length
self.logger.info(
f"gluonts[multiprocessing]: shuffle_buffer_length={self.shuffle_buffer_length}"
)

if self.num_workers > 0:
# generate unique ids for processes
Expand All @@ -678,7 +687,7 @@ def __iter__(self) -> Iterator[DataBatch]:
self.cycle_num += 1
if self.num_workers == 0:
generator = _sequential_sample_generator(
self.dataset, self.transformation, self.is_train, self.cyclic,
self.dataset, self.transformation, self.is_train, self.cyclic
)
if self.shuffle_buffer_length is not None:
generator = ShuffleIter(
Expand Down
11 changes: 8 additions & 3 deletions src/gluonts/shell/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# Standard library imports
from distutils.util import strtobool
import json
import logging
import os
from pathlib import Path
from pydantic import BaseModel
from typing import Dict, Optional

# First party imports
from gluonts.dataset.common import FileDataset, ListDataset, MetaData
from gluonts.dataset.common import Dataset, FileDataset, ListDataset, MetaData
from gluonts.model.forecast import Config as ForecastConfig
from gluonts.support.util import map_dct_values

Expand Down Expand Up @@ -132,9 +133,11 @@ def _get_current_host(resourceconfig: Path) -> str:

def _load_datasets(
hyperparameters: dict, channels: Dict[str, Path]
) -> Dict[str, FileDataset]:
) -> Dict[str, Dataset]:
logger = logging.getLogger(__name__)
freq = hyperparameters["freq"]
listify_dataset = strtobool(hyperparameters.get("listify_dataset", "no"))
logger.info(f"gluonts[cached]: listify_dataset = {listify_dataset}")
dataset_dict = {}
for name in DATASET_NAMES:
if name in channels:
Expand All @@ -144,5 +147,7 @@ def _load_datasets(
if listify_dataset
else file_dataset
)

logger.info(
f"gluonts[cached]: Type of {name} dataset is {type(dataset_dict[name])}"
)
return dataset_dict
31 changes: 29 additions & 2 deletions src/gluonts/shell/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@

# Standard library imports
import logging
import multiprocessing
from typing import Any, Optional, Type, Union

# Third-party imports
import numpy as np

# First-party imports
import gluonts
from gluonts.core import fqname_for
Expand Down Expand Up @@ -82,7 +86,10 @@ def run_train_and_test(
predictor = forecaster
else:
predictor = run_train(
forecaster, env.datasets["train"], env.datasets.get("validation")
forecaster=forecaster,
train_dataset=env.datasets["train"],
validation_dataset=env.datasets.get("validation"),
hyperparameters=env.hyperparameters,
)

predictor.serialize(env.path.model)
Expand All @@ -94,10 +101,30 @@ def run_train_and_test(
def run_train(
forecaster: Estimator,
train_dataset: Dataset,
hyperparameters: dict,
validation_dataset: Optional[Dataset],
) -> Predictor:
num_workers = (
int(hyperparameters["num_workers"])
if "num_workers" in hyperparameters.keys()
else None
)
shuffle_buffer_length = (
int(hyperparameters["shuffle_buffer_length"])
if "shuffle_buffer_length" in hyperparameters.keys()
else None
)
num_prefetch = (
int(hyperparameters["num_prefetch"])
if "num_prefetch" in hyperparameters.keys()
else None
)
return forecaster.train(
training_data=train_dataset, validation_data=validation_dataset
training_data=train_dataset,
validation_data=validation_dataset,
num_workers=num_workers,
num_prefetch=num_prefetch,
shuffle_buffer_length=shuffle_buffer_length,
)


Expand Down
4 changes: 0 additions & 4 deletions src/gluonts/testutil/dummy_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
from random import randint
from typing import List, Tuple

# Third-party imports
import numpy as np
import pytest

# First-party imports
from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName
Expand Down
2 changes: 0 additions & 2 deletions test/model/seq2seq/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,8 @@ def test_backwards_compatibility():
"use_past_feat_dynamic_real": True,
"enable_encoder_dynamic_feature": True,
"enable_decoder_dynamic_feature": True,
"num_workers": 0,
"scaling": True,
"scaling_decoder_dynamic_feature": True,
"num_batches_shuffle": 8,
}

dataset_train, dataset_test = make_dummy_datasets_with_features(
Expand Down
48 changes: 28 additions & 20 deletions test/shell/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

# Standard library imports
import json
from typing import ContextManager, Optional
from typing import ContextManager
import sys
from distutils.util import strtobool

# Third-party imports
import numpy as np
Expand All @@ -24,6 +25,7 @@
from gluonts.core.component import equals
from gluonts.dataset.common import FileDataset, ListDataset
from gluonts.model.trivial.mean import MeanPredictor
from gluonts.model.seq2seq import MQCNNEstimator
from gluonts.shell.sagemaker import ServeEnv, TrainEnv
from gluonts.shell.train import run_train_and_test

Expand Down Expand Up @@ -51,6 +53,10 @@ def train_env(listify_dataset) -> ContextManager[TrainEnv]:
"prediction_length": prediction_length,
"num_samples": num_samples,
"listify_dataset": listify_dataset,
"num_workers": 3,
"num_prefetch": 4,
"shuffle_buffer_length": 256,
"epochs": 3,
}
with testutil.temporary_train_env(hyperparameters, "constant") as env:
yield env
Expand Down Expand Up @@ -98,34 +104,36 @@ def batch_transform(monkeypatch, train_env):
return inference_config


@pytest.mark.parametrize("listify_dataset", [True, False])
@pytest.mark.parametrize("listify_dataset", ["yes", "no"])
def test_listify_dataset(train_env: TrainEnv, listify_dataset):
for dataset_name in train_env.datasets.keys():
assert (
isinstance(train_env.datasets[dataset_name], ListDataset)
if listify_dataset
if strtobool(listify_dataset)
else isinstance(train_env.datasets[dataset_name], FileDataset)
)


@pytest.mark.parametrize("listify_dataset", [True, False])
def test_train_shell(train_env: TrainEnv, caplog) -> None:
run_train_and_test(env=train_env, forecaster_type=MeanPredictor)
@pytest.mark.parametrize("listify_dataset", ["yes", "no"])
@pytest.mark.parametrize("forecaster_type", [MeanPredictor, MQCNNEstimator])
def test_train_shell(train_env: TrainEnv, caplog, forecaster_type) -> None:
run_train_and_test(env=train_env, forecaster_type=forecaster_type)

for _, _, line in caplog.record_tuples:
if "#test_score (local, QuantileLoss" in line:
assert line.endswith("0.0")
if "local, wQuantileLoss" in line:
assert line.endswith("0.0")
if "local, Coverage" in line:
assert line.endswith("0.0")
if "MASE" in line or "MSIS" in line:
assert line.endswith("0.0")
if "abs_target_sum" in line:
assert line.endswith("270.0")
if forecaster_type == MeanPredictor:
for _, _, line in caplog.record_tuples:
if "#test_score (local, QuantileLoss" in line:
assert line.endswith("0.0")
if "local, wQuantileLoss" in line:
assert line.endswith("0.0")
if "local, Coverage" in line:
assert line.endswith("0.0")
if "MASE" in line or "MSIS" in line:
assert line.endswith("0.0")
if "abs_target_sum" in line:
assert line.endswith("270.0")


@pytest.mark.parametrize("listify_dataset", [True, False])
@pytest.mark.parametrize("listify_dataset", ["yes", "no"])
def test_server_shell(
train_env: TrainEnv, static_server: "testutil.ServerFacade", caplog
) -> None:
Expand Down Expand Up @@ -167,7 +175,7 @@ def test_server_shell(
assert equals(exp_samples, act_samples)


@pytest.mark.parametrize("listify_dataset", [True, False])
@pytest.mark.parametrize("listify_dataset", ["yes", "no"])
def test_dynamic_shell(
train_env: TrainEnv, dynamic_server: "testutil.ServerFacade", caplog
) -> None:
Expand Down Expand Up @@ -210,7 +218,7 @@ def test_dynamic_shell(
assert equals(exp_samples, act_samples)


@pytest.mark.parametrize("listify_dataset", [True, False])
@pytest.mark.parametrize("listify_dataset", ["yes", "no"])
def test_dynamic_batch_shell(
batch_transform,
train_env: TrainEnv,
Expand Down