Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dataclass support to _extract_batch_size #12573

Merged
merged 6 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `strategy` argument being case insensitive ([#12528](https://github.com/PyTorchLightning/pytorch-lightning/pull/12528))


- Added dataclass support to `_extract_batch_size` ([#12573](https://github.com/PyTorchLightning/pytorch-lightning/pull/12573))


- Changed checkpoints save path in the case of one logger and user-provided weights_save_path from `weights_save_path/name/version/checkpoints` to `weights_save_path/checkpoints` ([#12372](https://github.com/PyTorchLightning/pytorch-lightning/pull/12372))


Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
import os
from contextlib import contextmanager
from dataclasses import fields
from functools import partial
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
Expand All @@ -25,6 +26,7 @@
import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import _is_dataclass_instance
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -49,6 +51,9 @@ def _extract_batch_size(batch: BType) -> Generator[int, None, None]:

for sample in batch:
yield from _extract_batch_size(sample)
elif _is_dataclass_instance(batch):
for field in fields(batch):
yield from _extract_batch_size(getattr(batch, field.name))
else:
yield None

Expand Down
14 changes: 14 additions & 0 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass

import pytest
import torch
from torch import Tensor
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning import Trainer
Expand Down Expand Up @@ -36,6 +39,11 @@ def _check_error_raised(data):
with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"):
extract_batch_size(batch)

@dataclass
class CustomDataclass:
a: Tensor
b: Tensor

# Warning not raised
batch = torch.zeros(11, 10, 9, 8)
_check_warning_not_raised(batch, 11)
Expand All @@ -46,13 +54,19 @@ def _check_error_raised(data):
batch = [torch.zeros(11, 10)]
_check_warning_not_raised(batch, 11)

batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(11, 10))
_check_warning_not_raised(batch, 11)

batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
_check_warning_not_raised(batch, 11)

# Warning raised
batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
_check_warning_raised(batch, 1)

batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(1))
_check_warning_raised(batch, 11)

batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
_check_warning_raised(batch, 11)

Expand Down