Skip to content

Commit

Permalink
Estimator: add in memory dateloader in default pytorch dataloaders (h…
Browse files Browse the repository at this point in the history
…orovod#2991)

Use as default dataloader if inmmem_cache_all is True for lighting
estimator.

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc authored Jul 12, 2021
1 parent 37d7d97 commit 3d9f894
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 17 deletions.
55 changes: 51 additions & 4 deletions horovod/spark/data_loaders/pytorch_data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from petastorm.pytorch import BatchedDataLoader
from petastorm.pytorch import BatchedDataLoader, InMemBatchedDataLoader
from horovod.data import BaseDataLoader, AsyncDataLoaderMixin


Expand All @@ -32,7 +32,8 @@ def __init__(self, reader, batch_size, shuffling_queue_capacity, name="",
f"limit_step_per_epoch={limit_step_per_epoch}")

def __len__(self):
return self.limit_step_per_epoch if self.limit_step_per_epoch != -1 else len(self.reader)
# We cannot infer length from reader.
return self.limit_step_per_epoch if self.limit_step_per_epoch != -1 else 0

def _iterate(self):
# Reset the reader if needed.
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, *args, **kwargs):
self.reader,
batch_size=self.batch_size,
shuffling_queue_capacity=self.shuffling_queue_capacity)
self.iterater = iter(self.data_loader)
self.iterator = iter(self.data_loader)

def _iterate(self):
num_steps = 0
Expand All @@ -95,14 +96,60 @@ def _iterate(self):
break
num_steps += 1

yield next(self.iterater)
yield next(self.iterator)


class PytorchInfiniteAsyncDataLoader(AsyncDataLoaderMixin, PytorchInfiniteDataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class PytorchInmemDataLoader(BaseDataLoader):
def __init__(self, reader, batch_size, num_epochs, name="",
shuffle=False, limit_step_per_epoch=-1, verbose=False):
self.batch_size = batch_size
self.limit_step_per_epoch = limit_step_per_epoch
self.name = name
self.verbose = verbose

if limit_step_per_epoch == -1:
raise ValueError('limit_step_per_epoch cannot be -1 for inmem dataloader')

print(f"[{self.name}]: Initializing petastorm inmem_dataloader with batch_size={batch_size}"
f"num_epochs={num_epochs}, "
f"shuffle={shuffle}"
f"limit_step_per_epoch={limit_step_per_epoch}")

self.dataloader = InMemBatchedDataLoader(reader, batch_size=batch_size, num_epochs=num_epochs,
rows_capacity=batch_size*limit_step_per_epoch, shuffle=shuffle)
self.iterator = iter(self.dataloader)

def __len__(self):
# We cannot infer length from reader.
return self.limit_step_per_epoch

def _iterate(self):
num_steps = 0
self._print_verbose(f"[{self.name}]: Start to generate batch data. limit_step_per_epoch={self.limit_step_per_epoch}")

while True:
if num_steps == self.limit_step_per_epoch:
self._print_verbose(f"[{self.name}]: Reach limit_step_per_epoch. Stop at step {num_steps}.")
break
num_steps += 1

yield next(self.iterator)

def _print_verbose(self, *args, **kwargs):
if self.verbose:
print(*args, **kwargs)


class PytorchAsyncInmemDataLoader(AsyncDataLoaderMixin, PytorchInmemDataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class PetastormBatchedDataLoader(BatchedDataLoader):
def __init__(self, name="", limit_step_per_epoch=-1, verbose=False, *args, **kwargs):
print(f"[{name}]Petastorm BatchedDataLoader will ignore limit_step_per_epoch and verbose.")
Expand Down
41 changes: 28 additions & 13 deletions horovod/spark/lightning/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,15 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
if sample_weight_col:
schema_fields.append(sample_weight_col)

data_loader_cls = _create_dataloader(feature_columns, input_shapes, metadata, data_loader_cls)
data_loader_cls = _create_dataloader(feature_columns, input_shapes, metadata, inmemory_cache_all, data_loader_cls)

# Storage
store = estimator.getStore()
remote_store = store.to_remote(run_id, dataset_idx)

set_data_loader = _set_data_loader_fn(transformation, schema_fields, batch_size,
data_loader_cls, loader_num_epochs, store, verbose)
data_loader_cls, loader_num_epochs, store,
epochs, inmemory_cache_all, verbose)

def train(serialized_model):
import horovod.torch as hvd
Expand Down Expand Up @@ -218,7 +219,8 @@ def on_sanity_check_end(self, trainer, model):
return [ResetCallback()]


def _set_data_loader_fn(transformation, schema_fields, batch_size, data_loader_cls, num_epochs, store, verbose=False):
def _set_data_loader_fn(transformation, schema_fields, batch_size, data_loader_cls,
loader_num_epochs, store, epochs, inmemory_cache_all=False, verbose=False):
storage_options = store.storage_options

@contextlib.contextmanager
Expand Down Expand Up @@ -256,8 +258,10 @@ def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, read
# Setting num_epochs=None will cause an infinite iterator
# and enables ranks to perform training and validation with
# unequal number of samples
# `loader_num_epochs` is None by default.
# This doesn't apply to inmem dataloader, which loads whole reader into memory.
with reader_factory(data_path,
num_epochs=num_epochs,
num_epochs=1 if inmemory_cache_all else loader_num_epochs,
cur_shard=hvd.rank(),
shard_count=hvd.size(),
reader_pool_type=reader_pool_type,
Expand All @@ -268,11 +272,17 @@ def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, read
storage_options=storage_options,
**reader_factory_kwargs) as reader:
def dataloader_fn():
return data_loader_cls(reader=reader, batch_size=batch_size,
shuffling_queue_capacity=shuffling_queue_capacity,
name=name,
limit_step_per_epoch=limit_step_per_epoch,
verbose=verbose)
kwargs = dict(reader=reader, batch_size=batch_size,
name=name,
limit_step_per_epoch=limit_step_per_epoch,
verbose=verbose)
if inmemory_cache_all:
# Use inmem dataloader
kwargs['shuffle'] = shuffling_queue_capacity > 0
kwargs['num_epochs'] = epochs
else:
kwargs['shuffling_queue_capacity'] = shuffling_queue_capacity
return data_loader_cls(**kwargs)
try:
setattr(model, dataloader_attr, dataloader_fn)
yield
Expand Down Expand Up @@ -326,11 +336,16 @@ def calculate_shuffle_buffer_size():
return calculate_shuffle_buffer_size


def _create_dataloader(feature_columns, input_shapes, metadata, data_loader_cls=None):
def _create_dataloader(feature_columns, input_shapes, metadata, inmemory_cache_all, data_loader_cls=None):
if data_loader_cls is None:
# set PytorchInfiniteAsyncDataLoader as default
from horovod.spark.data_loaders.pytorch_data_loaders import PytorchInfiniteAsyncDataLoader
data_loader_cls = PytorchInfiniteAsyncDataLoader
if inmemory_cache_all:
# set PytorchInmemDataLoader as default
from horovod.spark.data_loaders.pytorch_data_loaders import PytorchInmemDataLoader
data_loader_cls = PytorchInmemDataLoader
else:
# set PytorchInfiniteAsyncDataLoader as default
from horovod.spark.data_loaders.pytorch_data_loaders import PytorchInfiniteAsyncDataLoader
data_loader_cls = PytorchInfiniteAsyncDataLoader

print(f"Using dataloader: {data_loader_cls}")

Expand Down
29 changes: 29 additions & 0 deletions test/integration/test_spark_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,35 @@ def test_train_with_pytorch_infinite_async_data_loader(self):
assert len(pred) == 1
assert pred.dtype == torch.float32

"""
Test train model with inmemory_cache_all (using PytorchInmemDataLoader)
"""
def test_train_with_inmemory_cache_all(self):
with spark_session('test_fit_model') as spark:
df = create_noisy_xor_data(spark)
model = create_xor_model()

with local_store() as store:
torch_estimator = hvd_spark.TorchEstimator(
num_proc=1, # Normally inmem dataloader is for single worker training with small data
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=0.2,
batch_size=4,
epochs=2,
verbose=2,
inmemory_cache_all=True)

torch_model = torch_estimator.fit(df)

# TODO: Find a way to pass log metrics from remote, and assert base on the logger.
trained_model = torch_model.getModel()
pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
assert len(pred) == 1
assert pred.dtype == torch.float32

"""
Test pytorch lightning trainer with origin petastrom reader.
Expand Down

0 comments on commit 3d9f894

Please sign in to comment.