Skip to content

Make extract_batch_size() easier to redefine #10576

Closed
@ghost

Description

Proposed refactor

The function extract_batch_size() should be easily redefined by user.

Motivation

So I'm using pytorch-lightning and pytorch-geometric together. In pytorch-geometric, graphs are batched together in a peculiar way (there is no batch dimension, more details in their documentation).

Using this in pytorch-lightning causes the following warning :

UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 615. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.

But I exactly know the batch size of my batches ! It's just the tensors don't have a fixed batch dimension.

So to fix this warning, I monkey-patch the code to retrieve the size of the batch :

import pytorch_lightning as pl

# Monkey-patch `extract_batch_size` to not raise warning from weird tensor sizes
def extract_bs(self, batch):
    try:
        # Modification is here
        if "batch_size" in batch:
            batch_size = batch["batch_size"]
        else:
            batch_size = pl.utilities.data.extract_batch_size(batch)
    except RecursionError:
        batch_size = 1
    self.batch_size = batch_size
    return batch_size

pl.trainer.connectors.logger_connector.result.ResultCollection.extract_batch_size = extract_bs

Pitch

The monkey-patch approach seems dirty and dangerous, but I couldn't find another way to solve this issue (specifying a user-defined way to extract the batch size from the batch)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions