-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TokenClassification] Label realignment for subword aggregation #11680
Changes from 2 commits
1ef8336
572f972
81a640b
8803857
e7cda0d
7fad918
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import warnings | ||
from typing import TYPE_CHECKING, List, Optional, Union | ||
|
||
import numpy as np | ||
|
||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available | ||
from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available | ||
from ..modelcard import ModelCard | ||
from ..models.bert.tokenization_bert import BasicTokenizer | ||
from ..tokenization_utils import PreTrainedTokenizer | ||
|
@@ -48,13 +49,43 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
return inputs, offset_mapping | ||
|
||
|
||
class AggregationStrategy(ExplicitEnum): | ||
"""All the valid aggregation strategies for TokenClassificationPipeline""" | ||
|
||
NONE = "none" | ||
SIMPLE = "simple" | ||
FIRST = "first" | ||
AVERAGE = "average" | ||
MAX = "max" | ||
|
||
|
||
@add_end_docstrings( | ||
PIPELINE_INIT_ARGS, | ||
r""" | ||
ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`): | ||
A list of labels to ignore. | ||
grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
Whether or not to group the tokens corresponding to the same entity together in the predictions or not. | ||
DEPRECATED, use :obj:`aggregation_strategy` instead. Whether or not to group the tokens corresponding to | ||
the same entity together in the predictions or not. | ||
aggregation_strategy (:obj:`str`, `optional`, defaults to :obj:`"none"`): The strategy to fuse (or not) tokens based on the model prediction. | ||
|
||
- "none" : Will simply not do any aggregation and simply return raw results from the model | ||
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C, | ||
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D", | ||
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as | ||
different entities. On word based languages, we might end up splitting words undesirably : Imagine | ||
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity": | ||
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages | ||
that support that meaning, which is basically tokens separated by a space). These mitigations will | ||
only work on real words, "New york" might still be tagged with two different entities. | ||
- "first" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, | ||
cannot end up with different tags. Words will simply use the tag of the first token of the word when | ||
there is ambiguity. | ||
- "average" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, | ||
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum | ||
label is applied. | ||
- "max" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, | ||
cannot end up with different tags. Word entity will simply be the token with the maximum score. | ||
""", | ||
) | ||
class TokenClassificationPipeline(Pipeline): | ||
|
@@ -84,8 +115,9 @@ def __init__( | |
binary_output: bool = False, | ||
ignore_labels=["O"], | ||
task: str = "", | ||
grouped_entities: bool = False, | ||
ignore_subwords: bool = False, | ||
grouped_entities: Optional[bool] = None, | ||
ignore_subwords: Optional[bool] = None, | ||
aggregation_strategy: Optional[AggregationStrategy] = None, | ||
): | ||
super().__init__( | ||
model=model, | ||
|
@@ -106,15 +138,40 @@ def __init__( | |
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) | ||
self._args_parser = args_parser | ||
self.ignore_labels = ignore_labels | ||
self.grouped_entities = grouped_entities | ||
self.ignore_subwords = ignore_subwords | ||
|
||
if self.ignore_subwords and not self.tokenizer.is_fast: | ||
if aggregation_strategy is None: | ||
aggregation_strategy = AggregationStrategy.NONE | ||
if grouped_entities is not None or ignore_subwords is not None: | ||
|
||
if grouped_entities and ignore_subwords: | ||
aggregation_strategy = AggregationStrategy.FIRST | ||
elif grouped_entities and not ignore_subwords: | ||
aggregation_strategy = AggregationStrategy.SIMPLE | ||
else: | ||
aggregation_strategy = AggregationStrategy.NONE | ||
|
||
if grouped_entities is not None: | ||
warnings.warn( | ||
f'`grouped_entities` is deprecated, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' | ||
) | ||
if ignore_subwords is not None: | ||
warnings.warn( | ||
f'`ignore_subwords` is deprecated, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice warnings. We can also add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
if isinstance(aggregation_strategy, str): | ||
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()] | ||
|
||
if ( | ||
aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE} | ||
and not self.tokenizer.is_fast | ||
): | ||
raise ValueError( | ||
"Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option" | ||
"to `False` or use a fast tokenizer." | ||
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" | ||
'to `"simple"` or use a fast tokenizer.' | ||
) | ||
|
||
self.aggregation_strategy = aggregation_strategy | ||
|
||
def __call__(self, inputs: Union[str, List[str]], **kwargs): | ||
""" | ||
Classify each token of the text(s) given as inputs. | ||
|
@@ -125,13 +182,13 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
|
||
Return: | ||
A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in | ||
the corresponding input, or each entity if this pipeline was instantiated with | ||
the corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy | ||
:obj:`grouped_entities=True`) with the following keys: | ||
|
||
- **word** (:obj:`str`) -- The token/word classified. | ||
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`. | ||
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when | ||
`grouped_entities` is set to True. | ||
`aggregation_strategy` is not :obj:`"none"`. | ||
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the | ||
corresponding token in the sentence. | ||
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence. | ||
|
@@ -212,22 +269,83 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
"start": start_ind, | ||
"end": end_ind, | ||
} | ||
|
||
if self.grouped_entities and self.ignore_subwords: | ||
# These fields will be consumed by self.aggregate and not appear | ||
# in the final result | ||
if self.aggregation_strategy != AggregationStrategy.NONE: | ||
entity["is_subword"] = is_subword | ||
if self.aggregation_strategy == AggregationStrategy.AVERAGE: | ||
# AVERAGE needs to keep intermediate scores. | ||
entity["scores"] = score[idx] | ||
|
||
entities += [entity] | ||
|
||
if self.grouped_entities: | ||
answers += [self.group_entities(entities)] | ||
# Append ungrouped entities | ||
else: | ||
answers += [entities] | ||
# Might be no-op for NONE strategy | ||
grouped_entities = self.aggregate(entities, self.aggregation_strategy) | ||
answers += [grouped_entities] | ||
|
||
if len(answers) == 1: | ||
return answers[0] | ||
return answers | ||
|
||
def aggregate(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: | ||
if aggregation_strategy == AggregationStrategy.NONE: | ||
return entities | ||
if aggregation_strategy != AggregationStrategy.SIMPLE: | ||
entities = self.aggregate_words(entities, aggregation_strategy) | ||
return self.group_entities(entities) | ||
|
||
def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict: | ||
word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities]) | ||
if aggregation_strategy == AggregationStrategy.FIRST: | ||
entity = entities[0]["entity"] | ||
score = entities[0]["score"] | ||
elif aggregation_strategy == AggregationStrategy.MAX: | ||
max_entity = max(entities, key=lambda entity: entity["score"]) | ||
score = max_entity["score"] | ||
entity = max_entity["entity"] | ||
elif aggregation_strategy == AggregationStrategy.AVERAGE: | ||
scores = np.stack([entity["scores"] for entity in entities]) | ||
average_scores = np.nanmean(scores, axis=0) | ||
entity_idx = average_scores.argmax() | ||
entity = self.model.config.id2label[entity_idx] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. umm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aha just realized these are pre argmax 2d scores, it all makes sense now! |
||
score = average_scores[entity_idx] | ||
else: | ||
raise ValueError("Invalid aggregation_strategy") | ||
new_entity = { | ||
"entity": entity, | ||
"score": score, | ||
"word": word, | ||
"start": entities[0]["start"], | ||
"end": entities[-1]["end"], | ||
} | ||
return new_entity | ||
|
||
def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: | ||
""" | ||
Override tokens from a given word that disagree to force agreement on word boundaries. | ||
|
||
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft| | ||
company| B-ENT I-ENT | ||
""" | ||
assert aggregation_strategy not in { | ||
AggregationStrategy.NONE, | ||
AggregationStrategy.SIMPLE, | ||
}, "NONE and SIMPLE strategies are invalid" | ||
|
||
word_entities = [] | ||
word_group = None | ||
for entity in entities: | ||
if word_group is None: | ||
word_group = [entity] | ||
elif entity["is_subword"]: | ||
word_group.append(entity) | ||
else: | ||
word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) | ||
word_group = [entity] | ||
# Last item | ||
word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) | ||
return word_entities | ||
|
||
def group_sub_entities(self, entities: List[dict]) -> dict: | ||
""" | ||
Group together the adjacent tokens with the same entity predicted. | ||
|
@@ -260,45 +378,31 @@ def group_entities(self, entities: List[dict]) -> List[dict]: | |
entity_groups = [] | ||
entity_group_disagg = [] | ||
|
||
if entities: | ||
last_idx = entities[-1]["index"] | ||
|
||
for entity in entities: | ||
|
||
is_last_idx = entity["index"] == last_idx | ||
is_subword = self.ignore_subwords and entity["is_subword"] | ||
if not entity_group_disagg: | ||
entity_group_disagg += [entity] | ||
if is_last_idx: | ||
entity_groups += [self.group_sub_entities(entity_group_disagg)] | ||
entity_group_disagg.append(entity) | ||
continue | ||
|
||
# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group | ||
# If the current entity is similar and adjacent to the previous entity, | ||
# append it to the disaggregated entity group | ||
# The split is meant to account for the "B" and "I" suffixes | ||
# Shouldn't merge if both entities are B-type | ||
if ( | ||
( | ||
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1] | ||
and entity["entity"].split("-")[0] != "B" | ||
) | ||
and entity["index"] == entity_group_disagg[-1]["index"] + 1 | ||
) or is_subword: | ||
bi, tag = entity["entity"].split("-") | ||
last_bi, last_tag = entity_group_disagg[-1]["entity"].split("-") | ||
# Index might not be available if we aggregate words first. | ||
index_agree = entity["index"] == entity_group_disagg[-1]["index"] + 1 if "index" in entity else True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just sayin There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we might have to keep track of first subword index and last subword index of the entity during aggregation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are perfectly correct, and doing the filtering later should actually remove that check anyway (as there can't be any more non-consecutive entities anymore. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh that's right~ turns out it's a - kill two 🐛 with one 🥌- situation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It led to a bit more consequential change, but should be good now. We are still filtering special_tokens early on, but |
||
if (tag == last_tag and bi != "B") and index_agree: | ||
# Modify subword type to be previous_type | ||
if is_subword: | ||
entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1] | ||
entity["score"] = np.nan # set ignored scores to nan and use np.nanmean | ||
|
||
entity_group_disagg += [entity] | ||
entity_group_disagg.append(entity) | ||
# Group the entities at the last entity | ||
if is_last_idx: | ||
entity_groups += [self.group_sub_entities(entity_group_disagg)] | ||
# If the current entity is different from the previous entity, aggregate the disaggregated entity group | ||
else: | ||
entity_groups += [self.group_sub_entities(entity_group_disagg)] | ||
# If the current entity is different from the previous entity | ||
# aggregate the disaggregated entity group | ||
entity_groups.append(self.group_sub_entities(entity_group_disagg)) | ||
entity_group_disagg = [entity] | ||
# If it's the last entity, add it to the entity groups | ||
if is_last_idx: | ||
entity_groups += [self.group_sub_entities(entity_group_disagg)] | ||
if entity_group_disagg: | ||
# it's the last entity, add it to the entity groups | ||
entity_groups.append(self.group_sub_entities(entity_group_disagg)) | ||
|
||
return entity_groups | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice docstring