Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TokenClassification] Label realignment for subword aggregation #11680

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
from .text_classification import TextClassificationPipeline
from .text_generation import TextGenerationPipeline
from .token_classification import NerPipeline, TokenClassificationArgumentHandler, TokenClassificationPipeline
from .token_classification import (
AggregationStrategy,
NerPipeline,
TokenClassificationArgumentHandler,
TokenClassificationPipeline,
)
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline


Expand Down
198 changes: 151 additions & 47 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np

from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.bert.tokenization_bert import BasicTokenizer
from ..tokenization_utils import PreTrainedTokenizer
Expand Down Expand Up @@ -48,13 +49,43 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
return inputs, offset_mapping


class AggregationStrategy(ExplicitEnum):
"""All the valid aggregation strategies for TokenClassificationPipeline"""

NONE = "none"
SIMPLE = "simple"
FIRST = "first"
AVERAGE = "average"
MAX = "max"


@add_end_docstrings(
PIPELINE_INIT_ARGS,
r"""
ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
A list of labels to ignore.
grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to group the tokens corresponding to the same entity together in the predictions or not.
DEPRECATED, use :obj:`aggregation_strategy` instead. Whether or not to group the tokens corresponding to
the same entity together in the predictions or not.
aggregation_strategy (:obj:`str`, `optional`, defaults to :obj:`"none"`): The strategy to fuse (or not) tokens based on the model prediction.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice docstring


- "none" : Will simply not do any aggregation and simply return raw results from the model
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
different entities. On word based languages, we might end up splitting words undesirably : Imagine
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
that support that meaning, which is basically tokens separated by a space). These mitigations will
only work on real words, "New york" might still be tagged with two different entities.
- "first" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. Words will simply use the tag of the first token of the word when
there is ambiguity.
- "average" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. Word entity will simply be the token with the maximum score.
""",
)
class TokenClassificationPipeline(Pipeline):
Expand Down Expand Up @@ -84,8 +115,9 @@ def __init__(
binary_output: bool = False,
ignore_labels=["O"],
task: str = "",
grouped_entities: bool = False,
ignore_subwords: bool = False,
grouped_entities: Optional[bool] = None,
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
):
super().__init__(
model=model,
Expand All @@ -106,15 +138,40 @@ def __init__(
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self._args_parser = args_parser
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
self.ignore_subwords = ignore_subwords

if self.ignore_subwords and not self.tokenizer.is_fast:
if aggregation_strategy is None:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None or ignore_subwords is not None:

if grouped_entities and ignore_subwords:
aggregation_strategy = AggregationStrategy.FIRST
elif grouped_entities and not ignore_subwords:
aggregation_strategy = AggregationStrategy.SIMPLE
else:
aggregation_strategy = AggregationStrategy.NONE

if grouped_entities is not None:
warnings.warn(
f'`grouped_entities` is deprecated, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
f'`ignore_subwords` is deprecated, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice warnings. We can also add is deprecated and will be removed in version v5.0.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if isinstance(aggregation_strategy, str):
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]

if (
aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
and not self.tokenizer.is_fast
):
raise ValueError(
"Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option"
"to `False` or use a fast tokenizer."
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
'to `"simple"` or use a fast tokenizer.'
)

self.aggregation_strategy = aggregation_strategy

def __call__(self, inputs: Union[str, List[str]], **kwargs):
"""
Classify each token of the text(s) given as inputs.
Expand All @@ -125,13 +182,13 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):

Return:
A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
the corresponding input, or each entity if this pipeline was instantiated with
the corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy
:obj:`grouped_entities=True`) with the following keys:

- **word** (:obj:`str`) -- The token/word classified.
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when
`grouped_entities` is set to True.
`aggregation_strategy` is not :obj:`"none"`.
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
corresponding token in the sentence.
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence.
Expand Down Expand Up @@ -212,22 +269,83 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
"start": start_ind,
"end": end_ind,
}

if self.grouped_entities and self.ignore_subwords:
# These fields will be consumed by self.aggregate and not appear
# in the final result
if self.aggregation_strategy != AggregationStrategy.NONE:
entity["is_subword"] = is_subword
if self.aggregation_strategy == AggregationStrategy.AVERAGE:
# AVERAGE needs to keep intermediate scores.
entity["scores"] = score[idx]

entities += [entity]

if self.grouped_entities:
answers += [self.group_entities(entities)]
# Append ungrouped entities
else:
answers += [entities]
# Might be no-op for NONE strategy
grouped_entities = self.aggregate(entities, self.aggregation_strategy)
answers += [grouped_entities]

if len(answers) == 1:
return answers[0]
return answers

def aggregate(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
if aggregation_strategy == AggregationStrategy.NONE:
return entities
if aggregation_strategy != AggregationStrategy.SIMPLE:
entities = self.aggregate_words(entities, aggregation_strategy)
return self.group_entities(entities)

def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict:
word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
if aggregation_strategy == AggregationStrategy.FIRST:
entity = entities[0]["entity"]
score = entities[0]["score"]
elif aggregation_strategy == AggregationStrategy.MAX:
max_entity = max(entities, key=lambda entity: entity["score"])
score = max_entity["score"]
entity = max_entity["entity"]
elif aggregation_strategy == AggregationStrategy.AVERAGE:
scores = np.stack([entity["scores"] for entity in entities])
average_scores = np.nanmean(scores, axis=0)
entity_idx = average_scores.argmax()
entity = self.model.config.id2label[entity_idx]
Copy link
Contributor

@cceyda cceyda May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm average_scores.argmax() would just return the index where max score occurs. which !=entity_idx on self.model.config.id2label
am I missing something? shouldn't it be:
Edit: yes I was missing something

Copy link
Contributor

@cceyda cceyda May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also not sure what exactly is happening in the average strategy based on the code😵 In the code scores looks like a list [0.1, 0.5 , 0.1, 0.4] which makes averaging then argmaxing make no sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't scores a 2D array of dimensions (N_subwords, N_softmax)? We are averaging along subwords and then taking the overall argmax label. This makes sense if we assume each subword carries some amount of uncorrelated signal about the entity label.

Copy link
Contributor

@cceyda cceyda May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha just realized these are pre argmax 2d scores, it all makes sense now!
Nevermind, morning brain 🤦

score = average_scores[entity_idx]
else:
raise ValueError("Invalid aggregation_strategy")
new_entity = {
"entity": entity,
"score": score,
"word": word,
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return new_entity

def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
"""
Override tokens from a given word that disagree to force agreement on word boundaries.

Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
company| B-ENT I-ENT
"""
assert aggregation_strategy not in {
AggregationStrategy.NONE,
AggregationStrategy.SIMPLE,
}, "NONE and SIMPLE strategies are invalid"

word_entities = []
word_group = None
for entity in entities:
if word_group is None:
word_group = [entity]
elif entity["is_subword"]:
word_group.append(entity)
else:
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
word_group = [entity]
# Last item
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
return word_entities

def group_sub_entities(self, entities: List[dict]) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
Expand Down Expand Up @@ -260,45 +378,31 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
entity_groups = []
entity_group_disagg = []

if entities:
last_idx = entities[-1]["index"]

for entity in entities:

is_last_idx = entity["index"] == last_idx
is_subword = self.ignore_subwords and entity["is_subword"]
if not entity_group_disagg:
entity_group_disagg += [entity]
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
entity_group_disagg.append(entity)
continue

# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
# If the current entity is similar and adjacent to the previous entity,
# append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" suffixes
# Shouldn't merge if both entities are B-type
if (
(
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
and entity["entity"].split("-")[0] != "B"
)
and entity["index"] == entity_group_disagg[-1]["index"] + 1
) or is_subword:
bi, tag = entity["entity"].split("-")
last_bi, last_tag = entity_group_disagg[-1]["entity"].split("-")
# Index might not be available if we aggregate words first.
index_agree = entity["index"] == entity_group_disagg[-1]["index"] + 1 if "index" in entity else True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just sayin True when index is not available feels dangerous 🧐
non consecutive entities having matching B-I-tags it could happen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might have to keep track of first subword index and last subword index of the entity during aggregation.
maybe index could be a tuple 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are perfectly correct, and doing the filtering later should actually remove that check anyway (as there can't be any more non-consecutive entities anymore.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's right~ turns out it's a - kill two 🐛 with one 🥌- situation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It led to a bit more consequential change, but should be good now.

We are still filtering special_tokens early on, but ignore_labels later, so all contributions should be there.

if (tag == last_tag and bi != "B") and index_agree:
# Modify subword type to be previous_type
if is_subword:
entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1]
entity["score"] = np.nan # set ignored scores to nan and use np.nanmean

entity_group_disagg += [entity]
entity_group_disagg.append(entity)
# Group the entities at the last entity
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
else:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
# If the current entity is different from the previous entity
# aggregate the disaggregated entity group
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_group_disagg = [entity]
# If it's the last entity, add it to the entity groups
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
if entity_group_disagg:
# it's the last entity, add it to the entity groups
entity_groups.append(self.group_sub_entities(entity_group_disagg))

return entity_groups

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,13 +1207,15 @@ def nested_simplify(obj, decimals=3):
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
within tests.
"""
import numpy as np

from transformers.tokenization_utils import BatchEncoding

if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, (dict, BatchEncoding)):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int)):
elif isinstance(obj, (str, int, np.int64)):
return obj
elif is_torch_available() and isinstance(obj, torch.Tensor):
return nested_simplify(obj.tolist())
Expand Down
Loading