Closed

Description
Proposed refactor
The function extract_batch_size()
should be easily redefined by user.
Motivation
So I'm using pytorch-lightning
and pytorch-geometric
together. In pytorch-geometric
, graphs are batched together in a peculiar way (there is no batch dimension, more details in their documentation).
Using this in pytorch-lightning
causes the following warning :
UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 615. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
But I exactly know the batch size of my batches ! It's just the tensors don't have a fixed batch dimension.
So to fix this warning, I monkey-patch the code to retrieve the size of the batch :
import pytorch_lightning as pl
# Monkey-patch `extract_batch_size` to not raise warning from weird tensor sizes
def extract_bs(self, batch):
try:
# Modification is here
if "batch_size" in batch:
batch_size = batch["batch_size"]
else:
batch_size = pl.utilities.data.extract_batch_size(batch)
except RecursionError:
batch_size = 1
self.batch_size = batch_size
return batch_size
pl.trainer.connectors.logger_connector.result.ResultCollection.extract_batch_size = extract_bs
Pitch
The monkey-patch approach seems dirty and dangerous, but I couldn't find another way to solve this issue (specifying a user-defined way to extract the batch size from the batch)