From 3d9f894b26f5bb16a323219acc5953c2301b1a86 Mon Sep 17 00:00:00 2001 From: chongxiaoc <74630762+chongxiaoc@users.noreply.github.com> Date: Mon, 12 Jul 2021 14:53:10 -0700 Subject: [PATCH] Estimator: add in memory dateloader in default pytorch dataloaders (#2991) Use as default dataloader if inmmem_cache_all is True for lighting estimator. Signed-off-by: Chongxiao Cao --- .../data_loaders/pytorch_data_loaders.py | 55 +++++++++++++++++-- horovod/spark/lightning/remote.py | 41 +++++++++----- test/integration/test_spark_lightning.py | 29 ++++++++++ 3 files changed, 108 insertions(+), 17 deletions(-) diff --git a/horovod/spark/data_loaders/pytorch_data_loaders.py b/horovod/spark/data_loaders/pytorch_data_loaders.py index aee1ffae3a..ca11a22d4c 100644 --- a/horovod/spark/data_loaders/pytorch_data_loaders.py +++ b/horovod/spark/data_loaders/pytorch_data_loaders.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from petastorm.pytorch import BatchedDataLoader +from petastorm.pytorch import BatchedDataLoader, InMemBatchedDataLoader from horovod.data import BaseDataLoader, AsyncDataLoaderMixin @@ -32,7 +32,8 @@ def __init__(self, reader, batch_size, shuffling_queue_capacity, name="", f"limit_step_per_epoch={limit_step_per_epoch}") def __len__(self): - return self.limit_step_per_epoch if self.limit_step_per_epoch != -1 else len(self.reader) + # We cannot infer length from reader. + return self.limit_step_per_epoch if self.limit_step_per_epoch != -1 else 0 def _iterate(self): # Reset the reader if needed. @@ -83,7 +84,7 @@ def __init__(self, *args, **kwargs): self.reader, batch_size=self.batch_size, shuffling_queue_capacity=self.shuffling_queue_capacity) - self.iterater = iter(self.data_loader) + self.iterator = iter(self.data_loader) def _iterate(self): num_steps = 0 @@ -95,7 +96,7 @@ def _iterate(self): break num_steps += 1 - yield next(self.iterater) + yield next(self.iterator) class PytorchInfiniteAsyncDataLoader(AsyncDataLoaderMixin, PytorchInfiniteDataLoader): @@ -103,6 +104,52 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +class PytorchInmemDataLoader(BaseDataLoader): + def __init__(self, reader, batch_size, num_epochs, name="", + shuffle=False, limit_step_per_epoch=-1, verbose=False): + self.batch_size = batch_size + self.limit_step_per_epoch = limit_step_per_epoch + self.name = name + self.verbose = verbose + + if limit_step_per_epoch == -1: + raise ValueError('limit_step_per_epoch cannot be -1 for inmem dataloader') + + print(f"[{self.name}]: Initializing petastorm inmem_dataloader with batch_size={batch_size}" + f"num_epochs={num_epochs}, " + f"shuffle={shuffle}" + f"limit_step_per_epoch={limit_step_per_epoch}") + + self.dataloader = InMemBatchedDataLoader(reader, batch_size=batch_size, num_epochs=num_epochs, + rows_capacity=batch_size*limit_step_per_epoch, shuffle=shuffle) + self.iterator = iter(self.dataloader) + + def __len__(self): + # We cannot infer length from reader. + return self.limit_step_per_epoch + + def _iterate(self): + num_steps = 0 + self._print_verbose(f"[{self.name}]: Start to generate batch data. limit_step_per_epoch={self.limit_step_per_epoch}") + + while True: + if num_steps == self.limit_step_per_epoch: + self._print_verbose(f"[{self.name}]: Reach limit_step_per_epoch. Stop at step {num_steps}.") + break + num_steps += 1 + + yield next(self.iterator) + + def _print_verbose(self, *args, **kwargs): + if self.verbose: + print(*args, **kwargs) + + +class PytorchAsyncInmemDataLoader(AsyncDataLoaderMixin, PytorchInmemDataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class PetastormBatchedDataLoader(BatchedDataLoader): def __init__(self, name="", limit_step_per_epoch=-1, verbose=False, *args, **kwargs): print(f"[{name}]Petastorm BatchedDataLoader will ignore limit_step_per_epoch and verbose.") diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index a0d2ddcd90..91483ff390 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -76,14 +76,15 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro if sample_weight_col: schema_fields.append(sample_weight_col) - data_loader_cls = _create_dataloader(feature_columns, input_shapes, metadata, data_loader_cls) + data_loader_cls = _create_dataloader(feature_columns, input_shapes, metadata, inmemory_cache_all, data_loader_cls) # Storage store = estimator.getStore() remote_store = store.to_remote(run_id, dataset_idx) set_data_loader = _set_data_loader_fn(transformation, schema_fields, batch_size, - data_loader_cls, loader_num_epochs, store, verbose) + data_loader_cls, loader_num_epochs, store, + epochs, inmemory_cache_all, verbose) def train(serialized_model): import horovod.torch as hvd @@ -218,7 +219,8 @@ def on_sanity_check_end(self, trainer, model): return [ResetCallback()] -def _set_data_loader_fn(transformation, schema_fields, batch_size, data_loader_cls, num_epochs, store, verbose=False): +def _set_data_loader_fn(transformation, schema_fields, batch_size, data_loader_cls, + loader_num_epochs, store, epochs, inmemory_cache_all=False, verbose=False): storage_options = store.storage_options @contextlib.contextmanager @@ -256,8 +258,10 @@ def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, read # Setting num_epochs=None will cause an infinite iterator # and enables ranks to perform training and validation with # unequal number of samples + # `loader_num_epochs` is None by default. + # This doesn't apply to inmem dataloader, which loads whole reader into memory. with reader_factory(data_path, - num_epochs=num_epochs, + num_epochs=1 if inmemory_cache_all else loader_num_epochs, cur_shard=hvd.rank(), shard_count=hvd.size(), reader_pool_type=reader_pool_type, @@ -268,11 +272,17 @@ def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, read storage_options=storage_options, **reader_factory_kwargs) as reader: def dataloader_fn(): - return data_loader_cls(reader=reader, batch_size=batch_size, - shuffling_queue_capacity=shuffling_queue_capacity, - name=name, - limit_step_per_epoch=limit_step_per_epoch, - verbose=verbose) + kwargs = dict(reader=reader, batch_size=batch_size, + name=name, + limit_step_per_epoch=limit_step_per_epoch, + verbose=verbose) + if inmemory_cache_all: + # Use inmem dataloader + kwargs['shuffle'] = shuffling_queue_capacity > 0 + kwargs['num_epochs'] = epochs + else: + kwargs['shuffling_queue_capacity'] = shuffling_queue_capacity + return data_loader_cls(**kwargs) try: setattr(model, dataloader_attr, dataloader_fn) yield @@ -326,11 +336,16 @@ def calculate_shuffle_buffer_size(): return calculate_shuffle_buffer_size -def _create_dataloader(feature_columns, input_shapes, metadata, data_loader_cls=None): +def _create_dataloader(feature_columns, input_shapes, metadata, inmemory_cache_all, data_loader_cls=None): if data_loader_cls is None: - # set PytorchInfiniteAsyncDataLoader as default - from horovod.spark.data_loaders.pytorch_data_loaders import PytorchInfiniteAsyncDataLoader - data_loader_cls = PytorchInfiniteAsyncDataLoader + if inmemory_cache_all: + # set PytorchInmemDataLoader as default + from horovod.spark.data_loaders.pytorch_data_loaders import PytorchInmemDataLoader + data_loader_cls = PytorchInmemDataLoader + else: + # set PytorchInfiniteAsyncDataLoader as default + from horovod.spark.data_loaders.pytorch_data_loaders import PytorchInfiniteAsyncDataLoader + data_loader_cls = PytorchInfiniteAsyncDataLoader print(f"Using dataloader: {data_loader_cls}") diff --git a/test/integration/test_spark_lightning.py b/test/integration/test_spark_lightning.py index 1f092668ea..20e4824199 100644 --- a/test/integration/test_spark_lightning.py +++ b/test/integration/test_spark_lightning.py @@ -856,6 +856,35 @@ def test_train_with_pytorch_infinite_async_data_loader(self): assert len(pred) == 1 assert pred.dtype == torch.float32 + """ + Test train model with inmemory_cache_all (using PytorchInmemDataLoader) + """ + def test_train_with_inmemory_cache_all(self): + with spark_session('test_fit_model') as spark: + df = create_noisy_xor_data(spark) + model = create_xor_model() + + with local_store() as store: + torch_estimator = hvd_spark.TorchEstimator( + num_proc=1, # Normally inmem dataloader is for single worker training with small data + store=store, + model=model, + input_shapes=[[-1, 2]], + feature_cols=['features'], + label_cols=['y'], + validation=0.2, + batch_size=4, + epochs=2, + verbose=2, + inmemory_cache_all=True) + + torch_model = torch_estimator.fit(df) + + # TODO: Find a way to pass log metrics from remote, and assert base on the logger. + trained_model = torch_model.getModel() + pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) + assert len(pred) == 1 + assert pred.dtype == torch.float32 """ Test pytorch lightning trainer with origin petastrom reader.