Skip to content

Commit

Permalink
fix Document.annotation_fields() (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder authored Dec 8, 2023
1 parent 234f059 commit ac03161
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 6 deletions.
28 changes: 24 additions & 4 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def _get_reference_fields_and_container_types(


def _get_annotation_fields(fields: List[dataclasses.Field]) -> Set[dataclasses.Field]:
return {field for field in fields if typing.get_origin(field.type) is AnnotationLayer}
# this was broken, so we raise an exception for now
# return {f for f in fields if typing.get_origin(f.type) is AnnotationLayer}
raise Exception(
"_get_annotation_fields() is broken, please use Document.annotation_fields() instead"
)


def annotation_field(
Expand Down Expand Up @@ -482,8 +486,24 @@ def fields(cls):
]

@classmethod
def annotation_fields(cls):
return _get_annotation_fields(list(dataclasses.fields(cls)))
def field_types(cls) -> Dict[str, typing.Type]:
result = {}
for f in cls.fields():
# If we got just the string representation of the type, we resolve the whole class.
# But this may be slow, so we only do it if necessary.
if not isinstance(f.type, type):
return typing.get_type_hints(cls)
result[f.name] = f.type
return result

@classmethod
def annotation_fields(cls) -> Set[dataclasses.Field]:
ann_field_types = cls.field_types()
return {
f
for f in cls.fields()
if typing.get_origin(ann_field_types[f.name]) is AnnotationLayer
}

def __getitem__(self, key: str) -> AnnotationLayer:
if key not in self._annotation_fields:
Expand Down Expand Up @@ -583,7 +603,7 @@ def asdict(self):
@classmethod
def fromdict(cls, dct):
fields = dataclasses.fields(cls)
annotation_fields = _get_annotation_fields(fields)
annotation_fields = cls.annotation_fields()

cls_kwargs = {}
for field in fields:
Expand Down
11 changes: 11 additions & 0 deletions tests/core/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,14 @@ class MyDocument(Document):
assert (
str(document.attributes[1]) == "Attribute(annotation=Span(start=6, end=11), label=label)"
)


def test_document_annotation_fields():
@dataclasses.dataclass
class MyDocument(Document):
text: str
words: AnnotationLayer[Span] = annotation_field(target="text")

annotation_fields = MyDocument.annotation_fields()
annotation_field_names = {field.name for field in annotation_fields}
assert annotation_field_names == {"words"}
36 changes: 34 additions & 2 deletions tests/test_document.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import dataclasses
import re
from typing import Optional
from typing import Any, Dict, List, Optional, Set

import pytest

from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.core.document import Annotation, Document, _enumerate_dependencies
from pytorch_ie.documents import TextDocument, TokenBasedDocument
from pytorch_ie.documents import (
TextDocument,
TextDocumentWithSpansBinaryRelationsAndLabeledPartitions,
TokenBasedDocument,
)


def test_text_document():
Expand Down Expand Up @@ -716,3 +720,31 @@ def test_document_extend_from_other_remove(text_document):
assert len(doc_new.relations) == 0
assert len(doc_new.labels) == 1
assert len(doc_new.relation_attributes) == 0


def test_document_field_types():
@dataclasses.dataclass
class MyDocument(Document):
text: str
words: AnnotationLayer[Span] = annotation_field(target="text")

annotation_fields = MyDocument.annotation_fields()
annotation_field_names = {field.name for field in annotation_fields}
assert annotation_field_names == {"words"}

field_types = MyDocument.field_types()
assert field_types == {"text": str, "words": AnnotationLayer[Span]}

# this requires to resolve the field types with typing.get_type_hints() because field.type is
# a string at this point (externally defined document class)
field_types = TextDocumentWithSpansBinaryRelationsAndLabeledPartitions.field_types()
assert field_types == {
"_annotation_fields": Set[str],
"_annotation_graph": Dict[str, List[str]],
"binary_relations": AnnotationLayer[BinaryRelation],
"id": Optional[str],
"labeled_partitions": AnnotationLayer[LabeledSpan],
"metadata": Dict[str, Any],
"spans": AnnotationLayer[Span],
"text": str,
}

0 comments on commit ac03161

Please sign in to comment.