Skip to content

Commit

Permalink
Add missing Maskformer dataclass decorator, add dataclass check in Mo…
Browse files Browse the repository at this point in the history
…delOutput for subclasses (huggingface#25638)

* Add @DataClass to MaskFormerPixelDecoderOutput

* Add dataclass check if subclass of ModelOutout

* Use unittest assertRaises rather than pytest per contribution doc

* Update src/transformers/utils/generic.py per suggested change

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and blbadger committed Nov 8, 2023
1 parent 19863c0 commit ce547c9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class MaskFormerPixelLevelModuleOutput(ModelOutput):
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class MaskFormerPixelDecoderOutput(ModelOutput):
"""
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
Expand Down
21 changes: 20 additions & 1 deletion src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections import OrderedDict, UserDict
from collections.abc import MutableMapping
from contextlib import ExitStack, contextmanager
from dataclasses import fields
from dataclasses import fields, is_dataclass
from enum import Enum
from typing import Any, ContextManager, List, Tuple

Expand Down Expand Up @@ -314,7 +314,26 @@ def __init_subclass__(cls) -> None:
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Subclasses of ModelOutput must use the @dataclass decorator
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
# Just need to check that the current class is not ModelOutput
is_modeloutput_subclass = self.__class__ != ModelOutput

if is_modeloutput_subclass and not is_dataclass(self):
raise TypeError(
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
)

def __post_init__(self):
"""Check the ModelOutput dataclass.
Only occurs if @dataclass decorator has been used.
"""
class_fields = fields(self)

# Safety and consistency checks
Expand Down
20 changes: 20 additions & 0 deletions tests/utils/test_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,23 @@ def test_torch_pytree(self):

unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)


class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""

a: float
b: Optional[float] = None
c: Optional[float] = None


class ModelOutputSubclassTester(unittest.TestCase):
def test_direct_model_output(self):
# Check that direct usage of ModelOutput instantiates without errors
ModelOutput({"a": 1.1})

def test_subclass_no_dataclass(self):
# Check that a subclass of ModelOutput without @dataclass is invalid
# A valid subclass is inherently tested other unit tests above.
with self.assertRaises(TypeError):
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)

0 comments on commit ce547c9

Please sign in to comment.