You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using base BasePredictionWriter to construct a callback for Trainer.Predict returns both predictions and batch_indices when Trainer.predict(return_predictions=True), but does not return the batch_indices when Trainer.predict(return_predictions=False).
This causes a problem when strategy=='ddp'. In DDP the prediction writer needs to store both the predictions and the batch indices for each batch in order to reconstruct which predictions belong to which original data indices.
Without the batch_indices we cannot write out batches in parallel with DDP. Thus DDP can only be used for small datasets/prediction dimensions where
if self.should_store_predictions:
self.predictions.append(move_data_to_device(predictions, torch.device("cpu")))
Does not cause a crash due to using up all the available memory. But for large datasets I need DDP (I have 75 million images and am writing a 2048 dimensional representation for each = 572 GB)
Fix:
This is caused by self.should_store_predictions PYTORCH_LIGHTNING.LOOPS.EPOCH.PREDICTION_EPOCH_LOOP
def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
:class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
# the batch_sampler is not be defined in case of CombinedDataLoaders
batch_sampler = getattr(
self.trainer.predict_dataloaders[dataloader_idx], # type: ignore[has-type]
"batch_sampler",
None,
)
if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
return batch_sampler.seen_batch_indices
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
return []
I would modify
if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
return batch_sampler.seen_batch_indices
to instead be
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
return batch_sampler.seen_batch_indices
I.e. we want the batch indices whether or not we are appending the predictions to a main list. I don't see the harm in always returning them
🐛 Bug
According to the documentation, pytorch_lightning.callbacks.BasePredictionWriter is the way to store predictions:
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.BasePredictionWriter.html
Using base BasePredictionWriter to construct a callback for Trainer.Predict returns both predictions and batch_indices when Trainer.predict(return_predictions=True), but does not return the batch_indices when Trainer.predict(return_predictions=False).
This causes a problem when strategy=='ddp'. In DDP the prediction writer needs to store both the predictions and the batch indices for each batch in order to reconstruct which predictions belong to which original data indices.
e.g. num_devices = 2, batch_size=4, dataset_size=8:
device 0 gets batch_indices [0, 2, 4, 6]
device 1 gets batch_indices [1, 3, 5, 7]
Without the batch_indices we cannot write out batches in parallel with DDP. Thus DDP can only be used for small datasets/prediction dimensions where
Does not cause a crash due to using up all the available memory. But for large datasets I need DDP (I have 75 million images and am writing a 2048 dimensional representation for each = 572 GB)
Fix:
This is caused by self.should_store_predictions PYTORCH_LIGHTNING.LOOPS.EPOCH.PREDICTION_EPOCH_LOOP
I would modify
to instead be
I.e. we want the batch indices whether or not we are appending the predictions to a main list. I don't see the harm in always returning them
cc @Borda @rohitgr7
The text was updated successfully, but these errors were encountered: