Skip to content

Commit

Permalink
support custom data loaders in TorchEstimator (horovod#3787)
Browse files Browse the repository at this point in the history
Signed-off-by: Lee Yang <leewyang@gmail.com>
  • Loading branch information
leewyang authored Feb 1, 2023
1 parent 25e0f64 commit 88ecd06
Show file tree
Hide file tree
Showing 8 changed files with 499 additions and 238 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased] - YYYY-MM-DD

### Added
- Spark Estimator: Added support for custom data loaders in TorchEstimator. ([#3787](https://github.com/horovod/horovod/pull/3787))
- Spark Estimator: Added NVTabular data loader for TorchEstimator. ([#3787](https://github.com/horovod/horovod/pull/3787))
- Added `HOROVOD_SPARK_USE_LOCAL_RANK_GPU_INDEX` environment variable to ignore GPU device indices assigned by Spark and always use local rank GPU device in Spark estimators. ([#3737](https://github.com/horovod/horovod/pull/3737))

- Reducescatter: Added support for prescale_factor and postscale_factor and moved averaging into Horovod backend. ([#3815](https://github.com/horovod/horovod/pull/3815))

### Changed
Expand Down
6 changes: 3 additions & 3 deletions docs/spark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ and local filesystems.

`Petastorm <https://github.com/uber/petastorm/blob/master/petastorm/pytorch.py#L259>`__ based data loader is used by default,
but user can define a custom data loader by overriding the `BaseDataLoader` interface. An async data loader mixin can also
be added on top of the data loader. Additionally, the KerasEstimator supports a DataModule argument, similar
to the Lightning DataModule, which abstracts the data loading and allows for alternative implementations. For example,
the NVTabularDataModule integrates the `KerasSequenceLoader <https://github.com/NVIDIA-Merlin/NVTabular/blob/main/nvtabular/loader/tensorflow.py>`__
be added on top of the data loader. Additionally, KerasEstimator and TorchEstimator both support an optional DataModule
argument, similar to the Lightning DataModule, which abstracts the data loading and allows for alternative implementations.
For example, the NVTabularDataModule integrates the `KerasSequenceLoader <https://github.com/NVIDIA-Merlin/NVTabular/blob/main/nvtabular/loader/tensorflow.py>`__
from NVTabular to enable GPU-accelerated data loading.

There is an `example Dockerfile <https://github.com/horovod/horovod/blob/master/docker/horovod-nvtabular/Dockerfile>`__
Expand Down
13 changes: 7 additions & 6 deletions horovod/spark/keras/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def __enter__(self):

def __exit__(self, type, value, traceback):
if self.has_val and self.val_reader:
self.val_reader.stop()
self.val_reader.__exit__(type, value, traceback)
if self.train_reader:
self.train_reader.stop()
self.train_reader.__exit__(type, value, traceback)
super().__exit__(type, value, traceback)

@contextlib.contextmanager
Expand All @@ -110,6 +110,7 @@ def __init__(self, label_cols=[], categorical_cols=[], continuous_cols=[], **kwa
self.label_cols = label_cols
self.categorical_cols = categorical_cols
self.continuous_cols = continuous_cols
self.kwargs = kwargs

@staticmethod
def seed_fn():
Expand Down Expand Up @@ -141,8 +142,8 @@ def to_dense(X, labels):

def train_data(self):
import horovod.tensorflow.keras as hvd
from nvtabular.loader.tensorflow import KerasSequenceLoader
return KerasSequenceLoader(self.train_dir,
from nvtabular.loader.tensorflow import KerasSequenceLoader, Dataset
return KerasSequenceLoader(Dataset(self.train_dir, engine="parquet", calculate_divisions=True, **self.kwargs),
batch_size=self.train_batch_size,
label_names=self.label_cols,
cat_names=self.categorical_cols,
Expand All @@ -157,8 +158,8 @@ def train_data(self):

def val_data(self):
import horovod.tensorflow.keras as hvd
from nvtabular.loader.tensorflow import KerasSequenceLoader
return KerasSequenceLoader(self.val_dir,
from nvtabular.loader.tensorflow import KerasSequenceLoader, Dataset
return KerasSequenceLoader(Dataset(self.train_dir, engine="parquet", calculate_divisions=True, **self.kwargs),
batch_size=self.val_batch_size,
label_names=self.label_cols,
cat_names=self.categorical_cols,
Expand Down
229 changes: 229 additions & 0 deletions horovod/spark/torch/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (C) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import glob
import horovod.torch as hvd
import os

from horovod.spark.common import constants
from horovod.spark.common.datamodule import DataModule


class PetastormDataModule(DataModule):
"""Default Petastorm-based DataModule for KerasEstimator."""

def __init__(self,
reader_pool_type="thread",
train_reader_worker_count=10,
val_reader_worker_count=10,
random_seed=0,
**kwargs):
from petastorm import TransformSpec, make_reader, make_batch_reader

super().__init__(**kwargs)
self.reader_pool_type = reader_pool_type
self.train_reader_worker_count = train_reader_worker_count
self.val_reader_worker_count = val_reader_worker_count
self.random_seed = random_seed

# In general, make_batch_reader is faster than make_reader for reading the dataset.
# However, we found out that make_reader performs data transformations much faster than
# make_batch_reader with parallel worker processes. Therefore, the default reader
# we choose is make_batch_reader unless there are data transformations.
self.transform_spec = TransformSpec(self.transform_fn) if self.transform_fn else None
self.reader_factory_kwargs = dict()
if self.transform_spec:
self.reader_factory = make_reader
self.reader_factory_kwargs['pyarrow_serialize'] = True
else:
self.reader_factory = make_batch_reader

def __enter__(self):
super().__enter__()
# Petastorm: read data from the store with the correct shard for this rank
# setting num_epochs=None will cause an infinite iterator
# and enables ranks to perform training and validation with
# unequal number of samples
self.train_reader = self.reader_factory(
self.train_dir,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type=self.reader_pool_type,
workers_count=self.train_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=constants.PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
transform_spec=self.transform_spec,
storage_options=self.storage_options,
shuffle_rows=self.shuffle,
shuffle_row_groups=self.shuffle,
seed=self.random_seed,
**self.reader_factory_kwargs
)

self.val_reader = self.reader_factory(
self.val_dir,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type=self.reader_pool_type,
workers_count=self.val_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=constants.PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
transform_spec=self.transform_spec,
storage_options=self.storage_options,
shuffle_rows=False,
shuffle_row_groups=False,
seed=self.random_seed,
**self.reader_factory_kwargs
) if self.has_val else self.empty_batch_reader()

return self

def __exit__(self, type, value, traceback):
if self.has_val and self.val_reader:
self.val_reader.__exit__(type, value, traceback)
if self.train_reader:
self.train_reader.__exit__(type, value, traceback)
super().__exit__(type, value, traceback)

@contextlib.contextmanager
def empty_batch_reader(self):
yield None

def train_data(self):
from petastorm.pytorch import BatchedDataLoader, InMemBatchedDataLoader

if self.inmemory_cache_all:
train_loader = InMemBatchedDataLoader(self.train_reader,
batch_size=self.train_batch_size,
num_epochs=self.num_train_epochs,
rows_capacity=self.steps_per_epoch_train*self.train_batch_size,
shuffle=self.shuffle)
else:
train_loader = BatchedDataLoader(self.train_reader,
batch_size=self.train_batch_size,
# No need to shuffle again in dataloader level
shuffling_queue_capacity=0)
return train_loader

def val_data(self):
from petastorm.pytorch import BatchedDataLoader, InMemBatchedDataLoader

if self.inmemory_cache_all:
val_loader = InMemBatchedDataLoader(self.val_reader,
batch_size=self.val_batch_size,
num_epochs=self.num_train_epochs,
rows_capacity=self.steps_per_epoch_val*self.val_batch_size,
shuffle=False)
else:
val_loader = BatchedDataLoader(self.val_reader,
batch_size=self.val_batch_size,
shuffling_queue_capacity=0)
return val_loader


class MapIterable():
"""Wraps an iterable with a user-defined map function for N epochs."""

def __init__(self, data, epochs=None, map_fn=lambda x: x):
self.data = data
self.epochs = epochs
self.map_fn = map_fn

def __iter__(self):
if self.epochs:
for _ in range(self.epochs):
for x in self.data:
yield self.map_fn(x)
else:
while True:
for x in self.data:
yield self.map_fn(x)


class NVTabularDataModule(DataModule):
"""NVTabular-based DataModule for TorchEstimator for GPU-accelerated data loading of tabular datasets.
Note: requires `label_cols`, `categorical_cols`, and `continuous_cols` to be explicitly provided."""

def __init__(self, label_cols=[], categorical_cols=[], continuous_cols=[], **kwargs):
super().__init__(**kwargs)
self.label_cols = label_cols
self.categorical_cols = categorical_cols
self.continuous_cols = continuous_cols
self.kwargs = kwargs

@staticmethod
def seed_fn():
"""
Generate consistent dataloader shuffle seeds across workers
Reseeds each worker's dataloader each epoch to get fresh a shuffle
that's consistent across workers.
"""
import numpy as np
import torch
hvd.init()
seed = np.random.randint(0, torch.iinfo(torch.int32).max)
seed_tensor = torch.tensor(seed)
root_seed = hvd.broadcast(seed_tensor, name="shuffle_seed", root_rank=0)
return root_seed

def _transform(self, features_and_label):
"""Transform NVTabular value-and-offsets arrays into torch arrays."""
import torch
features, label = features_and_label
for k, v in features.items():
if isinstance(v, tuple): # values and offsets
indices = v[1].flatten().tolist()
features[k] = torch.vstack(v[0].tensor_split(indices[1:]))
return features, label

def train_data(self):
import nvtabular as nvt
from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader

train_dataset = TorchAsyncItr(
nvt.Dataset(self.train_dir, engine='parquet', calculate_divisions=True, **self.kwargs),
batch_size=self.train_batch_size,
cats=self.categorical_cols,
conts=self.continuous_cols,
labels=self.label_cols,
shuffle=self.shuffle,
parts_per_chunk=1,
global_size=hvd.size(),
global_rank=hvd.rank(),
seed_fn=self.seed_fn)

train_dataloader = DLDataLoader(train_dataset, batch_size=None, collate_fn=lambda x: x, pin_memory=False, num_workers=0)
return MapIterable(train_dataloader, epochs=self.num_train_epochs, map_fn=self._transform)

def val_data(self):
import nvtabular as nvt
from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader

val_dataset = TorchAsyncItr(
nvt.Dataset(self.val_dir, engine='parquet', calculate_divisions=True, **self.kwargs),
batch_size=self.val_batch_size,
cats=self.categorical_cols,
conts=self.continuous_cols,
labels=self.label_cols,
shuffle=False,
parts_per_chunk=1,
global_size=hvd.size(),
global_rank=hvd.rank()) if self.has_val else None

val_dataloader = DLDataLoader(val_dataset, batch_size=None, collate_fn=lambda x: x, pin_memory=False, num_workers=0)
return MapIterable(val_dataloader, epochs=self.num_train_epochs, map_fn=self._transform)
24 changes: 20 additions & 4 deletions horovod/spark/torch/estimator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
# Modifications copyright (C) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +34,7 @@
from horovod.spark.common.serialization import \
HorovodParamsWriter, HorovodParamsReader
from horovod.spark.torch import remote
from horovod.spark.torch.datamodule import PetastormDataModule
from horovod.spark.torch.util import deserialize_fn, serialize_fn, \
save_into_bio

Expand Down Expand Up @@ -95,6 +97,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
Args:
num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`.
data_module: (Optional) DataModule class used for training and validation, if not set, defaults to the PetastormDataModule.
model: PyTorch model to train.
backend: Optional Backend object for running distributed training function. Defaults to SparkBackend with
`num_proc` worker processes. Cannot be specified if `num_proc` is also provided.
Expand All @@ -108,6 +111,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
gradient_compression: Gradient compression used by `hvd.DistributedOptimizer`.
feature_cols: Column names used as feature inputs to the model. Must be a list with each feature
mapping to a sequential argument in the model's forward() function.
continuous_cols: Column names of all columns with continuous features.
categorical_cols: Column names of all columns with categorical features.
input_shapes: List of shapes for each input tensor to the model.
validation: Optional validation column name (string) where every row in the column is either 1/True or 0/False,
or validation split (float) giving percent of data to be randomly selected for validation.
Expand Down Expand Up @@ -156,6 +161,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
reducing and applying them. Defaults to 1.
"""

data_module = Param(Params._dummy(), 'data_module', 'data module class to use when reading data')
input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
loss_constructors = Param(Params._dummy(), 'loss_constructors',
'functions that construct the loss')
Expand All @@ -165,6 +171,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
@keyword_only
def __init__(self,
num_proc=None,
data_module=None,
model=None,
backend=None,
store=None,
Expand All @@ -176,6 +183,8 @@ def __init__(self,
sample_weight_col=None,
gradient_compression=None,
feature_cols=None,
continuous_cols=None,
categorical_cols=None,
input_shapes=None,
validation=None,
label_cols=None,
Expand Down Expand Up @@ -203,7 +212,8 @@ def __init__(self,
backward_passes_per_step=1):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
self._setDefault(data_module=PetastormDataModule,
loss_constructors=None,
input_shapes=None,
train_minibatch_fn=None,
transformation_fn=None)
Expand All @@ -212,12 +222,18 @@ def __init__(self,

if EstimatorParams.loss.name in kwargs and TorchEstimator.loss_constructors.name in kwargs:
raise ValueError("only one of loss_constructors and loss parameters can be specified.")

if backward_passes_per_step <= 0:
raise ValueError("backward_passes_per_step must be > 0")

self.setParams(**kwargs)

def setDataModule(self, value):
return self._set(data_module=value)

def getDataModule(self):
return self.getOrDefault(self.data_module)

def setTrainMinibatchFn(self, value):
return self._set(train_minibatch_fn=value)

Expand Down Expand Up @@ -294,8 +310,8 @@ def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row

def _load_checkpoint(self, run_id):
store = self.getStore()
last_ckpt_path = os.path.join(store.get_checkpoint_path(run_id),store.get_checkpoint_filename())
last_ckpt_path = os.path.join(store.get_checkpoint_path(run_id), store.get_checkpoint_filename())

if not store.fs.exists(last_ckpt_path):
return None

Expand Down
Loading

0 comments on commit 88ecd06

Please sign in to comment.