Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935))


- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))


Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import operator
from abc import ABC
from collections import OrderedDict
Expand Down Expand Up @@ -60,6 +61,11 @@ def _is_namedtuple(obj: object) -> bool:
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")


def _is_dataclass_instance(obj):
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)


def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
Expand Down Expand Up @@ -110,6 +116,14 @@ def apply_to_collection(
out.append(v)
return elem_type(*out) if is_namedtuple else elem_type(out)

if _is_dataclass_instance(data):
out = dict()
for field in data.__dataclass_fields__:
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
if include_none or v is not None:
out[field] = v
return elem_type(**out)

# data is neither of dtype, nor a collection
return data

Expand Down
45 changes: 45 additions & 0 deletions tests/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import numbers
from collections import namedtuple, OrderedDict
from typing import List

import numpy as np
import pytest
Expand All @@ -24,6 +26,17 @@
def test_recursive_application_to_collection():
ntc = namedtuple('Foo', ['bar'])

@dataclasses.dataclass
class Feature:
input_ids: torch.Tensor
segment_ids: np.ndarray

@dataclasses.dataclass
class ModelExample:
example_ids: List[str]
feature: Feature
label: torch.Tensor

to_reduce = {
'a': torch.tensor([1.]), # Tensor
'b': [torch.tensor([2.])], # list
Expand All @@ -32,6 +45,12 @@ def test_recursive_application_to_collection():
'e': np.array([10.]), # numpy array
'f': 'this_is_a_dummy_str', # string
'g': 12., # number
'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass
'i': ModelExample(
example_ids=['i-1', 'i-2', 'i-3'],
feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])),
label=torch.tensor([7., 8., 9.])
) # nested dataclass
}

expected_result = {
Expand All @@ -42,6 +61,12 @@ def test_recursive_application_to_collection():
'e': np.array([20.]),
'f': 'this_is_a_dummy_str',
'g': 24.,
'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
'i': ModelExample(
example_ids=['i-1', 'i-2', 'i-3'],
feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
label=torch.tensor([14., 16., 18.])
)
}

reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
Expand Down Expand Up @@ -78,6 +103,26 @@ def test_recursive_application_to_collection():
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number'
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'

assert dataclasses.is_dataclass(reduced['h']) and not isinstance(reduced['h'], type), \
'Reduction of a dataclass should result in a dataclass'
assert torch.allclose(reduced['h'].input_ids, expected_result['h'].input_ids), \
'Reduction of a dataclass did not yield the desired result'
assert np.allclose(reduced['h'].segment_ids, expected_result['h'].segment_ids), \
'Reduction of a dataclass did not yield the desired result'

assert dataclasses.is_dataclass(reduced['i']) and not isinstance(reduced['i'], type), \
'Reduction of a dataclass should result in a dataclass'
assert dataclasses.is_dataclass(reduced['i'].feature) and not isinstance(reduced['i'].feature, type), \
'Reduction of a nested dataclass should result in a nested dataclass'
assert reduced['i'].example_ids == expected_result['i'].example_ids, \
'Reduction of a nested dataclass did not yield the desired result'
assert torch.allclose(reduced['i'].label, expected_result['i'].label), \
'Reduction of a nested dataclass did not yield the desired result'
assert torch.allclose(reduced['i'].feature.input_ids, expected_result['i'].feature.input_ids), \
'Reduction of a nested dataclass did not yield the desired result'
assert np.allclose(reduced['i'].feature.segment_ids, expected_result['i'].feature.segment_ids), \
'Reduction of a nested dataclass did not yield the desired result'

# mapping support
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
assert reduced == {'a': '1', 'b': '2'}
Expand Down