Skip to content

Commit

Permalink
Estimator/Lightning: use lightning datamodule (horovod#3084)
Browse files Browse the repository at this point in the history
Using context manager is likely to stop reader earlier than async
dataloader worker thread. This would rise a runtime error in petastorm.

Introduing lightning datamodule to close petastorm reader explicitly.

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc authored Aug 17, 2021
1 parent 1bb42f1 commit 1a672ed
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 305 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
### Fixed

- Fix Horovod develop/editable install mode and incremental builds. ([#3074](https://github.com/horovod/horovod/pull/3074))
- Estimator/Lightning: use lightning datamodule ([#3084](https://github.com/horovod/horovod/pull/3084))

## [v0.22.1] - 2021-06-10

Expand Down
2 changes: 1 addition & 1 deletion docs/spark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Horovod Spark Estimators additionally require at least one of these combinations

* ``tensorflow-gpu >= 1.12.0`` or ``tensorflow >= 1.12.0`` (for ``KerasEstimator``)
* ``torch >= 1.0.0`` and ``tensorboard >= 1.14.0`` (for ``TorchEstimator``)
* ``torch >= 1.4.0`` and ``pytorch_lightning >= 1.2.9`` (for ``LightningEstimator``)
* ``torch >= 1.4.0`` and ``pytorch_lightning >= 1.4.1`` (for ``LightningEstimator``)


Horovod Spark Estimators
Expand Down
2 changes: 1 addition & 1 deletion horovod/spark/data_loaders/pytorch_data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _print_verbose(self, *args, **kwargs):
print(*args, **kwargs)


class PytorchAsyncInmemDataLoader(AsyncDataLoaderMixin, PytorchInmemDataLoader):
class PytorchInmemAsyncDataLoader(AsyncDataLoaderMixin, PytorchInmemDataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
113 changes: 113 additions & 0 deletions horovod/spark/lightning/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import pytorch_lightning as pl

from horovod.spark.common import constants
from horovod.spark.data_loaders.pytorch_data_loaders import (
PytorchInfiniteAsyncDataLoader,
PytorchInmemAsyncDataLoader)
from petastorm import TransformSpec, make_reader, make_batch_reader

PETASTORM_HDFS_DRIVER = constants.PETASTORM_HDFS_DRIVER

class PetastormDataModule(pl.LightningDataModule):
"""Default DataModule for Lightning Estimator"""
def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_val: bool=True,
train_batch_size: int=32, val_batch_size: int=32, shuffle_size: int=1000,
num_reader_epochs=None, reader_pool_type: str="process", reader_worker_count: int=2,
transform_spec=None, inmemory_cache_all=False,
cur_shard: int=0, shard_count: int=1, schema_fields=None, storage_options=None,
steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True, **kwargs):
super().__init__()
self.train_dir = train_dir
self.val_dir = val_dir
self.num_train_epochs = num_train_epochs
self.has_val = has_val
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.shuffle_size = shuffle_size
self.num_reader_epochs = num_reader_epochs
self.reader_pool_type = reader_pool_type
self.reader_worker_count = reader_worker_count
self.transform_spec = transform_spec
self.inmemory_cache_all = inmemory_cache_all
self.cur_shard = cur_shard
self.shard_count = shard_count
self.schema_fields = schema_fields
self.storage_options = storage_options
self.steps_per_epoch_train = steps_per_epoch_train
self.steps_per_epoch_val = steps_per_epoch_val
self.verbose = verbose

def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
transform_spec = TransformSpec(self.transform_spec) if self.transform_spec else None
# In general, make_batch_reader is faster than make_reader for reading the dataset.
# However, we found out that make_reader performs data transformations much faster than
# make_batch_reader with parallel worker processes. Therefore, the default reader
# we choose is make_batch_reader unless there are data transformations.
if transform_spec:
reader_factory = make_reader
else:
reader_factory = make_batch_reader

self.train_reader = reader_factory(self.train_dir, num_epochs=self.num_reader_epochs,
cur_shard=self.cur_shard, shard_count=self.shard_count,
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
storage_options=self.storage_options)
if self.has_val:
self.val_reader = reader_factory(self.val_dir, num_epochs=self.num_reader_epochs,
cur_shard=self.cur_shard, shard_count=self.shard_count,
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
storage_options=self.storage_options)

def teardown(self, stage=None):
if stage == "fit" or stage is None:
if self.verbose:
print("Tear down petastorm readers")
if not self.inmemory_cache_all:
# Reader was loaded once and stopped for inmemory datalaoder.
self.train_reader.stop()
self.train_reader.join()
if self.has_val:
self.val_reader.stop()
self.val_reader.join()

def train_dataloader(self):
if self.verbose:
print("Setup train dataloader")
kwargs = dict(reader=self.train_reader, batch_size=self.train_batch_size,
name="train dataloader",
limit_step_per_epoch=self.steps_per_epoch_train,
verbose=self.verbose)
if self.inmemory_cache_all:
# Use inmem dataloader
dataloader_class = PytorchInmemAsyncDataLoader
kwargs['shuffle'] = self.shuffle_size > 0
kwargs['num_epochs'] = self.num_train_epochs
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = self.shuffle_size

return dataloader_class(**kwargs)

def val_dataloader(self):
if not self.has_val:
return None
if self.verbose:
print("setup val dataloader")
kwargs = dict(reader=self.val_reader, batch_size=self.val_batch_size,
name="val dataloader",
limit_step_per_epoch=self.steps_per_epoch_val,
verbose=self.verbose)
if self.inmemory_cache_all:
# Use inmem dataloader
dataloader_class = PytorchInmemAsyncDataLoader
kwargs['shuffle'] = False
kwargs['num_epochs'] = self.num_train_epochs
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = 0

return dataloader_class(**kwargs)
22 changes: 11 additions & 11 deletions horovod/spark/lightning/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import torch
import torch.utils.data

MIN_PL_VERSION = "1.2.9"
MIN_PL_VERSION = "1.4.1"


def _torch_param_serialize(param_name, param_val):
Expand Down Expand Up @@ -106,8 +106,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
Defaults to SparkBackend with `num_proc` worker processes. Cannot be specified
if `num_proc` is also provided.
batch_size: Number of rows from the DataFrame per batch.
data_loader_class: (Optional) Class of the custom data loader, if not set, lightning
trainer will use PythonAsyncDataLoader as default.
data_module: (Optional) Lightning datamodule used for training and validadation, if not set,
lightning trainer will use PetastormDataModule as default.
epochs: Number of epochs to train.
feature_cols: Column names used as feature inputs to the model. Must be a list with
each feature mapping to a sequential argument in the model's forward()
Expand Down Expand Up @@ -194,8 +194,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
log_every_n_steps = Param(Params._dummy(), 'log_every_n_steps', 'control the frequency of logging',
typeConverter=TypeConverters.toInt)

data_loader_class = Param(Params._dummy(), 'data_loader_class',
'Class of the custom data loader, if not set, lightning trainer will use PythonAsyncDataLoader as default.')
data_module = Param(Params._dummy(), 'data_module',
'(Optional) Lightning datamodule used for training and validadation, if not set, lightning trainer will use PetastormDataModule as default..')

loader_num_epochs = Param(Params._dummy(), 'loader_num_epochs',
'An epoch is a single pass over all rows in the dataset. Default to None, which means reader will be in infinite loop mode, and generate unlimite data as needed. ')
Expand Down Expand Up @@ -240,7 +240,7 @@ def __init__(self,
num_gpus=None,
logger=None,
log_every_n_steps=50,
data_loader_class=None,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False):

Expand All @@ -253,7 +253,7 @@ def __init__(self,
num_gpus=None,
logger=None,
log_every_n_steps=50,
data_loader_class=None,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False)

Expand Down Expand Up @@ -310,11 +310,11 @@ def setLogEveryNSteps(self, value):
def getLogEveryNSteps(self):
return self.getOrDefault(self.log_every_n_steps)

def setDataLoaderClass(self, value):
return self._set(data_loader_class=value)
def setDataModule(self, value):
return self._set(data_module=value)

def getDataLoaderClass(self):
return self.getOrDefault(self.data_loader_class)
def getDataModule(self):
return self.getOrDefault(self.data_module)

def setLoaderNumEpochs(self, value):
return self._set(loader_num_epochs=value)
Expand Down
Loading

0 comments on commit 1a672ed

Please sign in to comment.