-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make extract_batch_size()
easier to redefine
#10576
Comments
we use batch_size to accumulate the metrics on epoch level when they are logged with |
@rohitgr7 Thanks for your answer ! Yes I tried to add the batch size to I can ignore the warning, but since my tensors sizes change every batch, my console is literally flooded with warnings ^^ |
here is the open PR for it: #10408 |
I think the underlying problem can be addressed for lists and dictionaries by directly modifying _extract_batch_size() in utilities/data.py. When a batch is an Iterable or Mapping and not a Tensor or str, each item of the list (or each value in the dictionary) should be presumed to be of length batch_size. Instead, the current implementation is a recursive check but there is no reason why the potentially heterogenous contents of one training example must again be of the same length. More simply: |
Hi, @rohitgr7 Another error could occur in current
it raise index error when batch.size(0) is called for torch.tensor(1). |
Proposed refactor
The function
extract_batch_size()
should be easily redefined by user.Motivation
So I'm using
pytorch-lightning
andpytorch-geometric
together. Inpytorch-geometric
, graphs are batched together in a peculiar way (there is no batch dimension, more details in their documentation).Using this in
pytorch-lightning
causes the following warning :But I exactly know the batch size of my batches ! It's just the tensors don't have a fixed batch dimension.
So to fix this warning, I monkey-patch the code to retrieve the size of the batch :
Pitch
The monkey-patch approach seems dirty and dangerous, but I couldn't find another way to solve this issue (specifying a user-defined way to extract the batch size from the batch)
The text was updated successfully, but these errors were encountered: