Skip to content

Commit

Permalink
Fix issue with no-init dataclass fields in move_to_device (#9963)
Browse files Browse the repository at this point in the history
Co-authored-by: ronif <ronif@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people authored Oct 17, 2021
1 parent e5dfdf3 commit 7b4df7b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed use of `LightningCLI` in computer_vision_fine_tuning.py example ([#9934](https://github.com/PyTorchLightning/pytorch-lightning/pull/9934))


- Fixed issue with non-init dataclass fields in `apply_to_collection` ([#9963](https://github.com/PyTorchLightning/pytorch-lightning/issues/9963))


## [1.4.9] - 2021-09-30

- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))
Expand Down
25 changes: 13 additions & 12 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,19 @@ def apply_to_collection(

if _is_dataclass_instance(data):
out_dict = {}
for field in data.__dataclass_fields__:
v = apply_to_collection(
getattr(data, field),
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if include_none or v is not None:
out_dict[field] = v
for field in dataclasses.fields(data):
if field.init:
v = apply_to_collection(
getattr(data, field.name),
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if include_none or v is not None:
out_dict[field.name] = v
return elem_type(**out_dict)

# data is neither of dtype, nor a collection
Expand Down
4 changes: 4 additions & 0 deletions tests/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class ModelExample:
example_ids: List[str]
feature: Feature
label: torch.Tensor
some_constant: int = dataclasses.field(init=False)

def __post_init__(self):
self.some_constant = 7

to_reduce = {
"a": torch.tensor([1.0]), # Tensor
Expand Down

0 comments on commit 7b4df7b

Please sign in to comment.