From 4f55e3623ed34f95042f37b2388c823ac5f7c618 Mon Sep 17 00:00:00 2001 From: dalek-who <839991169@qq.com> Date: Fri, 11 Jun 2021 11:00:47 +0800 Subject: [PATCH 01/11] add dataclass support on apply_to_collection --- pytorch_lightning/utilities/apply_func.py | 10 +++++ tests/utilities/test_apply_func.py | 47 +++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 2f46ff05699df..b60370bfcf730 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -18,6 +18,8 @@ from copy import copy from functools import partial from typing import Any, Callable, Optional, Union +import dataclasses + import numpy as np import torch @@ -110,6 +112,14 @@ def apply_to_collection( out.append(v) return elem_type(*out) if is_namedtuple else elem_type(out) + if dataclasses.is_dataclass(data) and not isinstance(data, type): + out = [] + for field in data.__dataclass_fields__: + v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if include_none or v is not None: + out.append((field, v)) + return elem_type(**OrderedDict(out)) + # data is neither of dtype, nor a collection return data diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 2457cf998c2cd..31cd3c8e6a5de 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union, List, Dict, Tuple, Optional import numbers from collections import namedtuple, OrderedDict +import dataclasses import numpy as np import pytest @@ -24,6 +26,17 @@ def test_recursive_application_to_collection(): ntc = namedtuple('Foo', ['bar']) + @dataclasses.dataclass + class Feature: + input_ids: torch.Tensor + segment_ids: np.ndarray + + @dataclasses.dataclass + class ModelExample: + example_ids: List[str] + feature: Feature + label: torch.Tensor + to_reduce = { 'a': torch.tensor([1.]), # Tensor 'b': [torch.tensor([2.])], # list @@ -32,6 +45,13 @@ def test_recursive_application_to_collection(): 'e': np.array([10.]), # numpy array 'f': 'this_is_a_dummy_str', # string 'g': 12., # number + 'h': Feature( + input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass + "i": ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), + label=torch.tensor([7., 8., 9.]) + ) # nested dataclass } expected_result = { @@ -42,6 +62,13 @@ def test_recursive_application_to_collection(): 'e': np.array([20.]), 'f': 'this_is_a_dummy_str', 'g': 24., + 'h': Feature( + input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), + "i": ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), + label=torch.tensor([14., 16., 18.]) + ) } reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2) @@ -78,6 +105,26 @@ def test_recursive_application_to_collection(): assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number' assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result' + assert dataclasses.is_dataclass(reduced['h']) and not isinstance(reduced['h'], type), \ + 'Reduction of a dataclass should result in a dataclass' + assert torch.allclose(reduced['h'].input_ids, expected_result['h'].input_ids), \ + 'Reduction of a dataclass did not yield the desired result' + assert np.allclose(reduced['h'].segment_ids, expected_result['h'].segment_ids), \ + 'Reduction of a dataclass did not yield the desired result' + + assert dataclasses.is_dataclass(reduced['i']) and not isinstance(reduced['i'], type), \ + 'Reduction of a dataclass should result in a dataclass' + assert dataclasses.is_dataclass(reduced['i'].feature) and not isinstance(reduced['i'].feature, type), \ + 'Reduction of a nested dataclass should result in a nested dataclass' + assert reduced['i'].example_ids == expected_result['i'].example_ids, \ + 'Reduction of a nested dataclass did not yield the desired result' + assert torch.allclose(reduced['i'].label, expected_result['i'].label), \ + 'Reduction of a nested dataclass did not yield the desired result' + assert torch.allclose(reduced['i'].feature.input_ids, expected_result['i'].feature.input_ids), \ + 'Reduction of a dataclass did not yield the desired result' + assert np.allclose(reduced['i'].feature.segment_ids, expected_result['i'].feature.segment_ids), \ + 'Reduction of a dataclass did not yield the desired result' + # mapping support reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x)) assert reduced == {'a': '1', 'b': '2'} From bc5bfd0207d8daeeeb29567f67f58debba60e7d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Jun 2021 03:47:00 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/apply_func.py | 3 +-- tests/utilities/test_apply_func.py | 10 ++++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index b60370bfcf730..e3fcfbe737055 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import operator from abc import ABC from collections import OrderedDict @@ -18,8 +19,6 @@ from copy import copy from functools import partial from typing import Any, Callable, Optional, Union -import dataclasses - import numpy as np import torch diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 31cd3c8e6a5de..ec857d7f407cf 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, List, Dict, Tuple, Optional +import dataclasses import numbers from collections import namedtuple, OrderedDict -import dataclasses +from typing import Dict, List, Optional, Tuple, Union import numpy as np import pytest @@ -45,8 +45,7 @@ class ModelExample: 'e': np.array([10.]), # numpy array 'f': 'this_is_a_dummy_str', # string 'g': 12., # number - 'h': Feature( - input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass + 'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass "i": ModelExample( example_ids=["i-1", "i-2", "i-3"], feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), @@ -62,8 +61,7 @@ class ModelExample: 'e': np.array([20.]), 'f': 'this_is_a_dummy_str', 'g': 24., - 'h': Feature( - input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), + 'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), "i": ModelExample( example_ids=["i-1", "i-2", "i-3"], feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), From b53932b883c265df57da1a675337785b9e0d0168 Mon Sep 17 00:00:00 2001 From: dalek-who <839991169@qq.com> Date: Fri, 11 Jun 2021 14:50:05 +0800 Subject: [PATCH 03/11] remove unused type-hints --- tests/utilities/test_apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 31cd3c8e6a5de..edd4cc82930e6 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, List, Dict, Tuple, Optional +from typing import List import numbers from collections import namedtuple, OrderedDict import dataclasses From 57bc2abf56e40fbeda1606fce80192578918e08e Mon Sep 17 00:00:00 2001 From: dalek-who <839991169@qq.com> Date: Fri, 11 Jun 2021 14:54:04 +0800 Subject: [PATCH 04/11] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54f9ad91db2b1..bb53a8121788b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add `dataclass` support for `pytorch_lightning.utilities.apply_func` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935)) + + - Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617)) From 7760542583b33e65aedecbeb711411fa7e34f891 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Jun 2021 06:59:14 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_apply_func.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index edd4cc82930e6..0f1894a9d530d 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +import dataclasses import numbers from collections import namedtuple, OrderedDict -import dataclasses +from typing import List import numpy as np import pytest @@ -45,8 +45,7 @@ class ModelExample: 'e': np.array([10.]), # numpy array 'f': 'this_is_a_dummy_str', # string 'g': 12., # number - 'h': Feature( - input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass + 'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass "i": ModelExample( example_ids=["i-1", "i-2", "i-3"], feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), @@ -62,8 +61,7 @@ class ModelExample: 'e': np.array([20.]), 'f': 'this_is_a_dummy_str', 'g': 24., - 'h': Feature( - input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), + 'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), "i": ModelExample( example_ids=["i-1", "i-2", "i-3"], feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), From 6ba479539cab0ede47fed1927adba546311fd395 Mon Sep 17 00:00:00 2001 From: dalek-who <839991169@qq.com> Date: Fri, 11 Jun 2021 20:19:57 +0800 Subject: [PATCH 06/11] change OrderedDict to dict --- pytorch_lightning/utilities/apply_func.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index e3fcfbe737055..adb7eb3988b38 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -112,12 +112,12 @@ def apply_to_collection( return elem_type(*out) if is_namedtuple else elem_type(out) if dataclasses.is_dataclass(data) and not isinstance(data, type): - out = [] + out = dict() for field in data.__dataclass_fields__: v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) if include_none or v is not None: - out.append((field, v)) - return elem_type(**OrderedDict(out)) + out[field] = v + return elem_type(**out) # data is neither of dtype, nor a collection return data From 5e3e28935b9b1861eb043b2aec62c08d6ad9ce4a Mon Sep 17 00:00:00 2001 From: dalek-who <839991169@qq.com> Date: Fri, 11 Jun 2021 20:20:34 +0800 Subject: [PATCH 07/11] unify quote style --- tests/utilities/test_apply_func.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index edd4cc82930e6..88a5f05b07435 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -47,8 +47,8 @@ class ModelExample: 'g': 12., # number 'h': Feature( input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass - "i": ModelExample( - example_ids=["i-1", "i-2", "i-3"], + 'i': ModelExample( + example_ids=['i-1', 'i-2', 'i-3'], feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), label=torch.tensor([7., 8., 9.]) ) # nested dataclass @@ -64,8 +64,8 @@ class ModelExample: 'g': 24., 'h': Feature( input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), - "i": ModelExample( - example_ids=["i-1", "i-2", "i-3"], + 'i': ModelExample( + example_ids=['i-1', 'i-2', 'i-3'], feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), label=torch.tensor([14., 16., 18.]) ) @@ -121,9 +121,9 @@ class ModelExample: assert torch.allclose(reduced['i'].label, expected_result['i'].label), \ 'Reduction of a nested dataclass did not yield the desired result' assert torch.allclose(reduced['i'].feature.input_ids, expected_result['i'].feature.input_ids), \ - 'Reduction of a dataclass did not yield the desired result' + 'Reduction of a nested dataclass did not yield the desired result' assert np.allclose(reduced['i'].feature.segment_ids, expected_result['i'].feature.segment_ids), \ - 'Reduction of a dataclass did not yield the desired result' + 'Reduction of a nested dataclass did not yield the desired result' # mapping support reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x)) From 6507e9e289806fb4c8b117ec9cd9a73f19ce161c Mon Sep 17 00:00:00 2001 From: dalek-who <839991169@qq.com> Date: Fri, 11 Jun 2021 20:21:13 +0800 Subject: [PATCH 08/11] update changelog function --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb53a8121788b..33ea9dd752730 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Add `dataclass` support for `pytorch_lightning.utilities.apply_func` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935)) +- Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935)) - Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617)) From 2213d0a11b17bc469f661850e331f009d2d16ddd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Jun 2021 12:28:31 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_apply_func.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 5359faabe51d4..8959a3283d639 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -45,8 +45,7 @@ class ModelExample: 'e': np.array([10.]), # numpy array 'f': 'this_is_a_dummy_str', # string 'g': 12., # number - 'h': Feature( - input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass + 'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass 'i': ModelExample( example_ids=['i-1', 'i-2', 'i-3'], feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), @@ -62,8 +61,7 @@ class ModelExample: 'e': np.array([20.]), 'f': 'this_is_a_dummy_str', 'g': 24., - 'h': Feature( - input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), + 'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), 'i': ModelExample( example_ids=['i-1', 'i-2', 'i-3'], feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), From 85d76601d020061296bc9a85b5fe9ad1427219df Mon Sep 17 00:00:00 2001 From: dalek-who <839991169@qq.com> Date: Sat, 12 Jun 2021 11:13:54 +0800 Subject: [PATCH 10/11] extract _is_dataclass_instance function --- pytorch_lightning/utilities/apply_func.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index adb7eb3988b38..59e959d1d0f5c 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -61,6 +61,11 @@ def _is_namedtuple(obj: object) -> bool: return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") +def _is_dataclass_instance(obj): + # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions + return dataclasses.is_dataclass(obj) and not isinstance(obj, type) + + def apply_to_collection( data: Any, dtype: Union[type, tuple], @@ -103,6 +108,7 @@ def apply_to_collection( is_namedtuple = _is_namedtuple(data) is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + is_dataclass = _is_dataclass_instance(data) if is_namedtuple or is_sequence: out = [] for d in data: @@ -111,7 +117,7 @@ def apply_to_collection( out.append(v) return elem_type(*out) if is_namedtuple else elem_type(out) - if dataclasses.is_dataclass(data) and not isinstance(data, type): + if is_dataclass: out = dict() for field in data.__dataclass_fields__: v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) From 632ebdb523f80bd870ac967f3467bd2e34046cc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 12 Jun 2021 13:17:59 +0200 Subject: [PATCH 11/11] Apply suggestions from code review Co-authored-by: Jirka Borovec --- pytorch_lightning/utilities/apply_func.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 59e959d1d0f5c..42a694ebad9ba 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -108,7 +108,6 @@ def apply_to_collection( is_namedtuple = _is_namedtuple(data) is_sequence = isinstance(data, Sequence) and not isinstance(data, str) - is_dataclass = _is_dataclass_instance(data) if is_namedtuple or is_sequence: out = [] for d in data: @@ -117,7 +116,7 @@ def apply_to_collection( out.append(v) return elem_type(*out) if is_namedtuple else elem_type(out) - if is_dataclass: + if _is_dataclass_instance(data): out = dict() for field in data.__dataclass_fields__: v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)