From b563cb2aea29d45a0ca5da022a1abb4c7dfb5cba Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 27 Sep 2023 14:52:57 +0200 Subject: [PATCH 1/3] upgrade LabeledSpanLengthCollector to SpanLengthCollector: make labels optional and inferrable, add tokenization capability --- src/pytorch_ie/metrics/statistics.py | 120 ++++++++++++++++++++++++--- tests/core/test_statistic.py | 43 +++++++++- 2 files changed, 147 insertions(+), 16 deletions(-) diff --git a/src/pytorch_ie/metrics/statistics.py b/src/pytorch_ie/metrics/statistics.py index 05d474c0..77280a89 100644 --- a/src/pytorch_ie/metrics/statistics.py +++ b/src/pytorch_ie/metrics/statistics.py @@ -4,8 +4,13 @@ from transformers import AutoTokenizer, PreTrainedTokenizer +from pytorch_ie import tokenize_document +from pytorch_ie.annotations import Span from pytorch_ie.core import Document, DocumentStatistic -from pytorch_ie.documents import TextBasedDocument +from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument +from pytorch_ie.utils.hydra import resolve_optional_document_type + +logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) @@ -72,24 +77,113 @@ def _collect(self, doc: Document) -> List[int]: return lengths -class LabeledSpanLengthCollector(DocumentStatistic): - """Collects the length of spans in a field per label, e.g. to collect the length of entities per type. +class SpanLengthCollector(DocumentStatistic): + """Collects the lengths of Span annotations. If labels are provided, the lengths collected per + label. - The field should be a list of elements with a label, a start and end attribute. + If a tokenizer is provided, the span length is calculated in means of tokens, otherwise in + means of characters. """ - DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len"] + DEFAULT_AGGREGATION_FUNCTIONS = ["len", "mean", "std", "min", "max"] - def __init__(self, field: str, **kwargs): + def __init__( + self, + layer: str, + tokenize: bool = False, + tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, + tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None, + labels: Optional[Union[List[str], str]] = None, + label_attribute: str = "label", + tokenize_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): super().__init__(**kwargs) - self.field = field + self.layer = layer + if isinstance(labels, str) and labels != "INFERRED": + raise ValueError("labels must be a list of strings or 'INFERRED'") + if labels == "INFERRED": + logger.warning( + f"Inferring labels with {self.__class__.__name__} from data produces wrong results " + f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " + f"are not included in the calculation. We remove these aggregation functions from " + f"this collector, but be aware that the results may be wrong for your own aggregation " + f"functions that rely on zero values." + ) + self.aggregation_functions = { + name: func + for name, func in self.aggregation_functions.items() + if name not in ["mean", "std", "min"] + } + self.labels = labels + self.label_field = label_attribute + self.tokenize = tokenize + if self.tokenize: + if tokenizer is None: + raise ValueError( + "tokenizer must be provided to calculate the span length in means of tokens" + ) + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + self.tokenizer = tokenizer + resolved_tokenized_document_type = resolve_optional_document_type( + tokenized_document_type + ) + if resolved_tokenized_document_type is None: + raise ValueError( + "tokenized_document_type must be provided to calculate the span length in means of tokens" + ) + if not ( + isinstance(resolved_tokenized_document_type, type) + and issubclass(resolved_tokenized_document_type, TokenBasedDocument) + ): + raise TypeError( + f"tokenized_document_type must be a subclass of TokenBasedDocument, but it is: " + f"{resolved_tokenized_document_type}" + ) + self.tokenized_document_type = resolved_tokenized_document_type + self.tokenize_kwargs = tokenize_kwargs or {} + + def _collect(self, doc: Document) -> Union[List[int], Dict[str, List[int]]]: + docs: Union[List[Document], List[TokenBasedDocument]] + if self.tokenize: + if not isinstance(doc, TextBasedDocument): + raise ValueError( + "doc must be a TextBasedDocument to calculate the span length in means of tokens" + ) + if not isinstance(doc, TextBasedDocument): + raise ValueError( + "doc must be a TextBasedDocument to calculate the span length in means of tokens" + ) + docs = tokenize_document( + doc, + tokenizer=self.tokenizer, + result_document_type=self.tokenized_document_type, + **self.tokenize_kwargs, + ) + else: + docs = [doc] - def _collect(self, doc: Document) -> Dict[str, List[int]]: - field_obj = getattr(doc, self.field) - counts = defaultdict(list) - for elem in field_obj: - counts[elem.label].append(elem.end - elem.start) - return dict(counts) + values: Dict[str, List[int]] + if isinstance(self.labels, str): + values = defaultdict(list) + else: + values = {label: [] for label in self.labels or ["ALL"]} + for doc in docs: + layer_obj = getattr(doc, self.layer) + for span in layer_obj: + if not isinstance(span, Span): + raise TypeError( + f"span length calculation is not yet supported for {type(span)}" + ) + length = span.end - span.start + if self.labels is None: + label = "ALL" + else: + label = getattr(span, self.label_field) + values[label].append(length) + + return values if self.labels is not None else values["ALL"] class DummyCollector(DocumentStatistic): diff --git a/tests/core/test_statistic.py b/tests/core/test_statistic.py index 172e2dca..7ca2aacd 100644 --- a/tests/core/test_statistic.py +++ b/tests/core/test_statistic.py @@ -5,12 +5,12 @@ from pytorch_ie import DatasetDict from pytorch_ie.annotations import LabeledSpan from pytorch_ie.core import AnnotationList, annotation_field -from pytorch_ie.documents import TextBasedDocument +from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument from pytorch_ie.metrics.statistics import ( DummyCollector, FieldLengthCollector, LabelCountCollector, - LabeledSpanLengthCollector, + SpanLengthCollector, SubFieldLengthCollector, TokenCountCollector, ) @@ -113,7 +113,21 @@ def test_statistics(dataset): "validation": {"max": 187, "mean": 89.66666666666667, "min": 17, "std": 71.5603863103665}, } - statistic = LabeledSpanLengthCollector(field="entities") + statistic = SpanLengthCollector(layer="entities") + values = statistic(dataset) + assert values == { + "train": {"len": 5, "mean": 7.6, "std": 4.223742416388575, "min": 2, "max": 15}, + "validation": { + "len": 6, + "mean": 10.833333333333334, + "std": 2.9674156357941426, + "min": 6, + "max": 14, + }, + "test": {"len": 5, "mean": 9.4, "std": 5.748043145279966, "min": 5, "max": 20}, + } + + statistic = SpanLengthCollector(layer="entities", labels="INFERRED") values = statistic(dataset) assert values == { "train": { @@ -140,6 +154,29 @@ def test_statistics(dataset): }, } + @dataclasses.dataclass + class TokenDocumentWithLabeledEntities(TokenBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + + statistic = SpanLengthCollector( + layer="entities", + tokenize=True, + tokenizer="bert-base-uncased", + tokenized_document_type=TokenDocumentWithLabeledEntities, + ) + values = statistic(dataset) + assert values == { + "test": {"len": 5, "max": 4, "mean": 2.4, "min": 1, "std": 1.2000000000000002}, + "train": {"len": 5, "max": 2, "mean": 1.2, "min": 1, "std": 0.4}, + "validation": { + "len": 6, + "max": 2, + "mean": 1.3333333333333333, + "min": 1, + "std": 0.4714045207910317, + }, + } + # this is not super useful, we just collect teh lengths of the labels, but it is enough to test the code statistic = SubFieldLengthCollector(field="entities", subfield="label") values = statistic(dataset) From ad4797ec6d78d81da41b85f889985c26beaf0e87 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 27 Sep 2023 14:56:38 +0200 Subject: [PATCH 2/3] remove duplicated logger (from rebase) --- src/pytorch_ie/metrics/statistics.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pytorch_ie/metrics/statistics.py b/src/pytorch_ie/metrics/statistics.py index 77280a89..293f798b 100644 --- a/src/pytorch_ie/metrics/statistics.py +++ b/src/pytorch_ie/metrics/statistics.py @@ -12,8 +12,6 @@ logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) - class TokenCountCollector(DocumentStatistic): """Collects the token count of a field when tokenizing its content with a Huggingface tokenizer. From 6cc0619d166c6314b3687ee627c39e422c8f4461 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 27 Sep 2023 15:15:17 +0200 Subject: [PATCH 3/3] fix tests and mark test for SpanLengthCollector with tokenize=True as slow --- tests/core/test_statistic.py | 67 +++++++++++++++++------------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/tests/core/test_statistic.py b/tests/core/test_statistic.py index 7ca2aacd..cde7bfe0 100644 --- a/tests/core/test_statistic.py +++ b/tests/core/test_statistic.py @@ -131,49 +131,23 @@ def test_statistics(dataset): values = statistic(dataset) assert values == { "train": { - "ORG": {"mean": 2.0, "std": 0.0, "min": 2, "max": 2, "len": 1}, - "MISC": {"mean": 6.5, "std": 0.5, "min": 6, "max": 7, "len": 2}, - "PER": {"mean": 15.0, "std": 0.0, "min": 15, "max": 15, "len": 1}, - "LOC": {"mean": 8.0, "std": 0.0, "min": 8, "max": 8, "len": 1}, + "ORG": {"max": 2, "len": 1}, + "MISC": {"max": 7, "len": 2}, + "PER": {"max": 15, "len": 1}, + "LOC": {"max": 8, "len": 1}, }, "test": { "LOC": { - "mean": 10.333333333333334, - "std": 6.847546194724712, - "min": 5, "max": 20, "len": 3, }, - "PER": {"mean": 8.0, "std": 3.0, "min": 5, "max": 11, "len": 2}, + "PER": {"max": 11, "len": 2}, }, "validation": { - "ORG": {"mean": 12.0, "std": 2.8284271247461903, "min": 8, "max": 14, "len": 3}, - "LOC": {"mean": 6.0, "std": 0.0, "min": 6, "max": 6, "len": 1}, - "MISC": {"mean": 11.0, "std": 0.0, "min": 11, "max": 11, "len": 1}, - "PER": {"mean": 12.0, "std": 0.0, "min": 12, "max": 12, "len": 1}, - }, - } - - @dataclasses.dataclass - class TokenDocumentWithLabeledEntities(TokenBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") - - statistic = SpanLengthCollector( - layer="entities", - tokenize=True, - tokenizer="bert-base-uncased", - tokenized_document_type=TokenDocumentWithLabeledEntities, - ) - values = statistic(dataset) - assert values == { - "test": {"len": 5, "max": 4, "mean": 2.4, "min": 1, "std": 1.2000000000000002}, - "train": {"len": 5, "max": 2, "mean": 1.2, "min": 1, "std": 0.4}, - "validation": { - "len": 6, - "max": 2, - "mean": 1.3333333333333333, - "min": 1, - "std": 0.4714045207910317, + "ORG": {"max": 14, "len": 3}, + "LOC": {"max": 6, "len": 1}, + "MISC": {"max": 11, "len": 1}, + "PER": {"max": 12, "len": 1}, }, } @@ -200,3 +174,26 @@ def test_statistics_with_tokenize(dataset): "train": {"max": 9, "mean": 5.666666666666667, "min": 2, "std": 2.8674417556808756}, "validation": {"max": 38, "mean": 18.333333333333332, "min": 6, "std": 14.055445761538678}, } + + @dataclasses.dataclass + class TokenDocumentWithLabeledEntities(TokenBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + + statistic = SpanLengthCollector( + layer="entities", + tokenize=True, + tokenizer="bert-base-uncased", + tokenized_document_type=TokenDocumentWithLabeledEntities, + ) + values = statistic(dataset) + assert values == { + "test": {"len": 5, "max": 4, "mean": 2.4, "min": 1, "std": 1.2000000000000002}, + "train": {"len": 5, "max": 2, "mean": 1.2, "min": 1, "std": 0.4}, + "validation": { + "len": 6, + "max": 2, + "mean": 1.3333333333333333, + "min": 1, + "std": 0.4714045207910317, + }, + }