Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MP DataLoader Improvements #742

Merged
merged 25 commits into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8f40b94
Dataset iterations are now cache aligned and switched to default Pool…
Apr 3, 2020
c7aad21
Added caching support for FileDataset.
Apr 6, 2020
73483b3
Refactoring.
Apr 6, 2020
b2563c9
Added missing documentation in train loader.
Apr 6, 2020
09ecc26
Added more return types.
Apr 6, 2020
6e71dc6
mend
Apr 6, 2020
f98d8a8
Fixed bug regarding num_batches_for_sampling.
Apr 7, 2020
50bffcd
Reverting back to modulo based segmentation for code readability.
Apr 7, 2020
4c38ef9
Massively simplified worker_fn due to simplified logic of num_batches…
Apr 7, 2020
d75202e
mend
Apr 7, 2020
e50813b
User warning in case of mp but not caching.
Apr 7, 2020
9ba3b3a
Minor refactoring.
Apr 16, 2020
87178c3
Merge branch 'master' into mp_data_loader_updates_V2
AaronSpieler Apr 16, 2020
bcf353c
Merge branch 'master' into mp_data_loader_updates_V2
lostella Apr 16, 2020
f717eeb
Smaller reformatting.
Apr 16, 2020
7221ecc
Merge branch 'mp_data_loader_updates_V2' of https://github.com/AaronS…
Apr 16, 2020
f5839a8
Updated doc.
Apr 16, 2020
957113d
Simplified segmenting, test fix.
Apr 16, 2020
23d09db
Yield from improvement.
Apr 16, 2020
33ef3e4
Dataset Coverage Test Explicit.
Apr 16, 2020
314cdbe
removed print
AaronSpieler Apr 16, 2020
5e21cfc
Removed unused import.
Apr 16, 2020
2e44be8
Merge branch 'master' into mp_data_loader_updates_V2
lostella Apr 17, 2020
e6582e7
Disabling windows mp evaluation, lowering required JSonLine throughput.
Apr 17, 2020
24988ac
Merge branch 'mp_data_loader_updates_V2' of https://github.com/AaronS…
Apr 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

# Dictionary used for data flowing through the transformations.
DataEntry = Dict[str, Any]
DataBatch = Dict[str, Any]

# TODO: change this maybe to typing_extensions.Protocol
# A Dataset is an iterable of DataEntry.
Expand Down Expand Up @@ -185,26 +186,39 @@ class FileDataset(Dataset):
Must be a valid Pandas frequency.
one_dim_target
Whether to accept only univariate target time series.
cache
AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved
Indicates whether the dataset should be cached or not.
"""

def __init__(
self, path: Path, freq: str, one_dim_target: bool = True,
self,
path: Path,
freq: str,
one_dim_target: bool = True,
cache: bool = False,
) -> None:
self.cache = cache
AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved
self.path = path
self.process = ProcessDataEntry(freq, one_dim_target=one_dim_target)
self._len = None

if not self.files():
raise OSError(f"no valid file found in {path}")

# necessary, in order to preserve the cached datasets, in case caching was enabled
self._json_line_files = [
jsonl.JsonLinesFile(path=path, cache=cache)
for path in self.files()
]

def __iter__(self) -> Iterator[DataEntry]:
for path in self.files():
for line in jsonl.JsonLinesFile(path=path):
for json_line_file in self._json_line_files:
for line in json_line_file:
data = self.process(line.content)
data["source"] = SourceContext(
source=line.span.path, row=line.span.line
)
yield data
self._burnt_in = True

def __len__(self):
if self._len is None:
Expand Down Expand Up @@ -254,17 +268,23 @@ def __init__(
one_dim_target: bool = True,
) -> None:
self.process = ProcessDataEntry(freq, one_dim_target)
self.list_data = list(data_iter)
# TODO: implement caching here
self.list_data = list(data_iter) # dataset always cached

def __iter__(self) -> Iterator[DataEntry]:
source_name = "list_data"
# Basic idea is to split the dataset into roughly equally sized segments
# with lower and upper bound, where each worker is assigned one segment
segment_size = int(len(self) / util.MPWorkerInfo.num_workers)

for row_number, data in enumerate(self.list_data):
# The dataset is equally distributed among the workers
if not (
row_number % util.MPWorkerInfo.num_workers
== util.MPWorkerInfo.worker_id
):
lower_bound = util.MPWorkerInfo.worker_id * segment_size
upper_bound = (
(util.MPWorkerInfo.worker_id + 1) * segment_size
if util.MPWorkerInfo.worker_id + 1
!= util.MPWorkerInfo.num_workers
else len(self)
)
if not lower_bound <= row_number < upper_bound:
continue

data = self.process(data)
Expand Down
49 changes: 32 additions & 17 deletions src/gluonts/dataset/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# Third-party imports
import ujson as json
import numpy as np

# First-party imports
from gluonts.core.exception import GluonTSDataError
Expand Down Expand Up @@ -55,28 +56,42 @@ class JsonLinesFile:
JSON Lines file.
"""

def __init__(self, path) -> None:
def __init__(self, path: Path, cache: bool = False) -> None:
self.path = path
self.cache = cache
self._len = None
# TODO: implement caching here
self._data_cache: list = []

def __iter__(self):
with open(self.path) as jsonl_file:
for line_number, raw in enumerate(jsonl_file):
# The dataset is equally distributed among the workers
if not (
line_number % MPWorkerInfo.num_workers
== MPWorkerInfo.worker_id
):
continue

span = Span(path=self.path, line=line_number)
try:
yield Line(json.loads(raw), span=span)
except ValueError:
raise GluonTSDataError(
f"Could not read json line {line_number}, {raw}"
# Basic idea is to split the dataset into roughly equally sized segments
# with lower and upper bound, where each worker is assigned one segment
segment_size = int(len(self) / MPWorkerInfo.num_workers)

if not self.cache or (self.cache and not self._data_cache):
with open(self.path) as jsonl_file:
for line_number, raw in enumerate(jsonl_file):
lower_bound = MPWorkerInfo.worker_id * segment_size
upper_bound = (
(MPWorkerInfo.worker_id + 1) * segment_size
if MPWorkerInfo.worker_id + 1
!= MPWorkerInfo.num_workers
else len(self)
)
if not lower_bound <= line_number < upper_bound:
continue

span = Span(path=self.path, line=line_number)
try:
parsed_line = Line(json.loads(raw), span=span)
if self.cache:
self._data_cache.append(parsed_line)
yield parsed_line
except ValueError:
raise GluonTSDataError(
f"Could not read json line {line_number}, {raw}"
)
else:
yield from self._data_cache

def __len__(self):
if self._len is None:
Expand Down
27 changes: 18 additions & 9 deletions src/gluonts/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@

# First-party imports
from gluonts.core.component import DType
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.dataset.common import DataEntry, Dataset, DataBatch
from gluonts.dataset.parallelized_loader import ParallelDataLoader
from gluonts.transform import Transformation

DataBatch = Dict[str, Any]


class DataLoader(Iterable[DataEntry]):
"""
Expand Down Expand Up @@ -67,13 +65,14 @@ def __init__(
dataset: Dataset,
*,
transform: Transformation,
cyclic: bool,
lostella marked this conversation as resolved.
Show resolved Hide resolved
is_train: bool,
batch_size: int,
ctx: mx.Context,
dtype: DType = np.float32,
cyclic: bool = False,
num_workers: Optional[int] = None,
num_prefetch: Optional[int] = None,
num_batches_for_shuffling: Optional[int] = None,
**kwargs
) -> None:
self.batch_size = batch_size
Expand All @@ -82,17 +81,21 @@ def __init__(
self.is_train = is_train
self.transform = transform
self.cyclic = cyclic
self.num_workers = num_workers
self.num_prefetch = num_prefetch
self.num_batches_for_shuffling = num_batches_for_shuffling

self.parallel_data_loader = ParallelDataLoader(
dataset=dataset,
transformation=self.transform,
cyclic=self.cyclic,
is_train=self.is_train,
batch_size=self.batch_size,
ctx=ctx,
ctx=self.ctx,
dtype=self.dtype,
num_workers=num_workers,
num_prefetch=num_prefetch,
num_workers=self.num_workers,
num_prefetch=self.num_prefetch,
num_batches_for_shuffling=self.num_batches_for_shuffling,
**kwargs,
)

Expand Down Expand Up @@ -132,7 +135,12 @@ class TrainDataLoader(DataLoader):
multiple worker processes, try reduce `num_workers` in this case.
By default it defaults to `num_workers * 2`.
dtype
Floating point type to use.
Floating point type to use. Default is np.float32.
shuffle_for_training
Whether to shuffle the samples.
num_batches_for_shuffling
The effective number of batches among which samples are shuffled. If num_batches_for_shuffling = 8 and
batch_size = 8 then the next batch will be randomly sampled from about 64 samples.
"""

def __init__(
Expand All @@ -146,7 +154,7 @@ def __init__(
num_prefetch: Optional[int] = None,
dtype: DType = np.float32,
shuffle_for_training: bool = True,
num_batches_for_shuffling: int = 10, # TODO: this does not work currently
num_batches_for_shuffling: int = 8,
**kwargs
) -> None:
assert dataset, "empty dataset"
Expand All @@ -162,6 +170,7 @@ def __init__(
cyclic=True,
num_workers=num_workers,
num_prefetch=num_prefetch,
num_batches_for_shuffling=num_batches_for_shuffling,
**kwargs,
)

Expand Down
Loading