Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer.predict(return_predictions=False) does not track batch_indices for BasePredictionWriter, which are necessary for ddp #13580

Closed
georgestein opened this issue Jul 8, 2022 · 0 comments · Fixed by #13629
Labels
bug Something isn't working callback: prediction writer
Milestone

Comments

@georgestein
Copy link
Contributor

georgestein commented Jul 8, 2022

🐛 Bug

According to the documentation, pytorch_lightning.callbacks.BasePredictionWriter is the way to store predictions:
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.BasePredictionWriter.html

Using base BasePredictionWriter to construct a callback for Trainer.Predict returns both predictions and batch_indices when Trainer.predict(return_predictions=True), but does not return the batch_indices when Trainer.predict(return_predictions=False).

This causes a problem when strategy=='ddp'. In DDP the prediction writer needs to store both the predictions and the batch indices for each batch in order to reconstruct which predictions belong to which original data indices.

e.g. num_devices = 2, batch_size=4, dataset_size=8:
device 0 gets batch_indices [0, 2, 4, 6]
device 1 gets batch_indices [1, 3, 5, 7]

Without the batch_indices we cannot write out batches in parallel with DDP. Thus DDP can only be used for small datasets/prediction dimensions where

 if self.should_store_predictions:
          self.predictions.append(move_data_to_device(predictions, torch.device("cpu")))

Does not cause a crash due to using up all the available memory. But for large datasets I need DDP (I have 75 million images and am writing a 2048 dimensional representation for each = 572 GB)

Fix:

This is caused by self.should_store_predictions PYTORCH_LIGHTNING.LOOPS.EPOCH.PREDICTION_EPOCH_LOOP

def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
     """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
     :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
     # the batch_sampler is not be defined in case of CombinedDataLoaders
     batch_sampler = getattr(
         self.trainer.predict_dataloaders[dataloader_idx],  # type: ignore[has-type]
         "batch_sampler",
         None,
     )
     if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
         return batch_sampler.seen_batch_indices

     warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
     return []

I would modify

if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
           return batch_sampler.seen_batch_indices

to instead be

 if isinstance(batch_sampler, IndexBatchSamplerWrapper):
            return batch_sampler.seen_batch_indices

I.e. we want the batch indices whether or not we are appending the predictions to a main list. I don't see the harm in always returning them

cc @Borda @rohitgr7

@georgestein georgestein added the needs triage Waiting to be triaged by maintainers label Jul 8, 2022
georgestein added a commit to georgestein/lightning that referenced this issue Jul 12, 2022
@rohitgr7 rohitgr7 added feature Is an improvement or enhancement callback: prediction writer and removed needs triage Waiting to be triaged by maintainers labels Jul 13, 2022
@rohitgr7 rohitgr7 added this to the pl:1.7 milestone Jul 13, 2022
@carmocca carmocca modified the milestones: pl:1.7, pl:1.6.x Jul 15, 2022
@carmocca carmocca added bug Something isn't working and removed feature Is an improvement or enhancement labels Jul 15, 2022
Borda pushed a commit that referenced this issue Jul 18, 2022
…_indices` (#13629)

* Pull request for fixing issue #13580
* chlog and test
* disable track for epoch

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: prediction writer
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants