Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make extract_batch_size() easier to redefine #10576

Closed
ghost opened this issue Nov 17, 2021 · 5 comments · Fixed by #10408
Closed

Make extract_batch_size() easier to redefine #10576

ghost opened this issue Nov 17, 2021 · 5 comments · Fixed by #10408
Labels

Comments

@ghost
Copy link

ghost commented Nov 17, 2021

Proposed refactor

The function extract_batch_size() should be easily redefined by user.

Motivation

So I'm using pytorch-lightning and pytorch-geometric together. In pytorch-geometric, graphs are batched together in a peculiar way (there is no batch dimension, more details in their documentation).

Using this in pytorch-lightning causes the following warning :

UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 615. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.

But I exactly know the batch size of my batches ! It's just the tensors don't have a fixed batch dimension.

So to fix this warning, I monkey-patch the code to retrieve the size of the batch :

import pytorch_lightning as pl

# Monkey-patch `extract_batch_size` to not raise warning from weird tensor sizes
def extract_bs(self, batch):
    try:
        # Modification is here
        if "batch_size" in batch:
            batch_size = batch["batch_size"]
        else:
            batch_size = pl.utilities.data.extract_batch_size(batch)
    except RecursionError:
        batch_size = 1
    self.batch_size = batch_size
    return batch_size

pl.trainer.connectors.logger_connector.result.ResultCollection.extract_batch_size = extract_bs

Pitch

The monkey-patch approach seems dirty and dangerous, but I couldn't find another way to solve this issue (specifying a user-defined way to extract the batch size from the batch)

@ghost ghost added the refactor label Nov 17, 2021
@rohitgr7
Copy link
Contributor

we use batch_size to accumulate the metrics on epoch level when they are logged with self.log(..., on_epoch=True). If you know your batch_size you can pass it in like self.log(..., batch_size=your_batch_size). For now you will get this warning if you do this, but you can ignore the warning for now since it will be fixed soon.

@ghost
Copy link
Author

ghost commented Nov 17, 2021

@rohitgr7 Thanks for your answer !

Yes I tried to add the batch size to self.log(..., batch_size=your_batch_size) calls in the Lightning module, but somehow I still had the warnings ? Is it because I used a custom LightningDataModule ?

I can ignore the warning, but since my tensors sizes change every batch, my console is literally flooded with warnings ^^

@rohitgr7
Copy link
Contributor

here is the open PR for it: #10408
once merged, will resolve this issue.

@mpvenkatesh
Copy link

mpvenkatesh commented Nov 19, 2021

I think the underlying problem can be addressed for lists and dictionaries by directly modifying _extract_batch_size() in utilities/data.py. When a batch is an Iterable or Mapping and not a Tensor or str, each item of the list (or each value in the dictionary) should be presumed to be of length batch_size. Instead, the current implementation is a recursive check but there is no reason why the potentially heterogenous contents of one training example must again be of the same length. More simply: yield from _extract_batch_size(sample) must be replaced with yield len(sample).

@ninginthecloud
Copy link
Contributor

Hi, @rohitgr7 Another error could occur in current extract_batch_size() implementation is that user's batch input is

batch={'a':torch.tensor([1, 2, 3, 4, 5]), 'b': [torch.tensor(1), torch.tensor(2)]}

it raise index error when batch.size(0) is called for torch.tensor(1).
I think users could input very flexible values for batch, this is what we can't control, could we allow user to define their own batch_size for metric aggregation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
4 participants