Skip to content

Commit 59d0c65

Browse files
dalek-whopre-commit-ci[bot]carmoccaBorda
authored
Add dataclass support to apply_to_collection (#7935)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent cdd01f3 commit 59d0c65

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935))
13+
14+
1215
- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))
1316

1417

pytorch_lightning/utilities/apply_func.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import dataclasses
1415
import operator
1516
from abc import ABC
1617
from collections import OrderedDict
@@ -60,6 +61,11 @@ def _is_namedtuple(obj: object) -> bool:
6061
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
6162

6263

64+
def _is_dataclass_instance(obj):
65+
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
66+
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
67+
68+
6369
def apply_to_collection(
6470
data: Any,
6571
dtype: Union[type, tuple],
@@ -110,6 +116,14 @@ def apply_to_collection(
110116
out.append(v)
111117
return elem_type(*out) if is_namedtuple else elem_type(out)
112118

119+
if _is_dataclass_instance(data):
120+
out = dict()
121+
for field in data.__dataclass_fields__:
122+
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
123+
if include_none or v is not None:
124+
out[field] = v
125+
return elem_type(**out)
126+
113127
# data is neither of dtype, nor a collection
114128
return data
115129

tests/utilities/test_apply_func.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import dataclasses
1415
import numbers
1516
from collections import namedtuple, OrderedDict
17+
from typing import List
1618

1719
import numpy as np
1820
import pytest
@@ -24,6 +26,17 @@
2426
def test_recursive_application_to_collection():
2527
ntc = namedtuple('Foo', ['bar'])
2628

29+
@dataclasses.dataclass
30+
class Feature:
31+
input_ids: torch.Tensor
32+
segment_ids: np.ndarray
33+
34+
@dataclasses.dataclass
35+
class ModelExample:
36+
example_ids: List[str]
37+
feature: Feature
38+
label: torch.Tensor
39+
2740
to_reduce = {
2841
'a': torch.tensor([1.]), # Tensor
2942
'b': [torch.tensor([2.])], # list
@@ -32,6 +45,12 @@ def test_recursive_application_to_collection():
3245
'e': np.array([10.]), # numpy array
3346
'f': 'this_is_a_dummy_str', # string
3447
'g': 12., # number
48+
'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass
49+
'i': ModelExample(
50+
example_ids=['i-1', 'i-2', 'i-3'],
51+
feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])),
52+
label=torch.tensor([7., 8., 9.])
53+
) # nested dataclass
3554
}
3655

3756
expected_result = {
@@ -42,6 +61,12 @@ def test_recursive_application_to_collection():
4261
'e': np.array([20.]),
4362
'f': 'this_is_a_dummy_str',
4463
'g': 24.,
64+
'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
65+
'i': ModelExample(
66+
example_ids=['i-1', 'i-2', 'i-3'],
67+
feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
68+
label=torch.tensor([14., 16., 18.])
69+
)
4570
}
4671

4772
reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
@@ -78,6 +103,26 @@ def test_recursive_application_to_collection():
78103
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number'
79104
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'
80105

106+
assert dataclasses.is_dataclass(reduced['h']) and not isinstance(reduced['h'], type), \
107+
'Reduction of a dataclass should result in a dataclass'
108+
assert torch.allclose(reduced['h'].input_ids, expected_result['h'].input_ids), \
109+
'Reduction of a dataclass did not yield the desired result'
110+
assert np.allclose(reduced['h'].segment_ids, expected_result['h'].segment_ids), \
111+
'Reduction of a dataclass did not yield the desired result'
112+
113+
assert dataclasses.is_dataclass(reduced['i']) and not isinstance(reduced['i'], type), \
114+
'Reduction of a dataclass should result in a dataclass'
115+
assert dataclasses.is_dataclass(reduced['i'].feature) and not isinstance(reduced['i'].feature, type), \
116+
'Reduction of a nested dataclass should result in a nested dataclass'
117+
assert reduced['i'].example_ids == expected_result['i'].example_ids, \
118+
'Reduction of a nested dataclass did not yield the desired result'
119+
assert torch.allclose(reduced['i'].label, expected_result['i'].label), \
120+
'Reduction of a nested dataclass did not yield the desired result'
121+
assert torch.allclose(reduced['i'].feature.input_ids, expected_result['i'].feature.input_ids), \
122+
'Reduction of a nested dataclass did not yield the desired result'
123+
assert np.allclose(reduced['i'].feature.segment_ids, expected_result['i'].feature.segment_ids), \
124+
'Reduction of a nested dataclass did not yield the desired result'
125+
81126
# mapping support
82127
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
83128
assert reduced == {'a': '1', 'b': '2'}

0 commit comments

Comments
 (0)