-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Ner pipeline grouped_entities fixes #5970
Changes from 21 commits
85d7554
31176c0
590ed80
56860f7
22d21cb
47a5e21
77f93e1
87c327e
456451a
99f7aad
188fc0b
b8d4b99
bd1c9bb
ba6dacb
9221ca6
2585ea2
47797d1
92115ee
0cf0e73
8e77d26
4b3d8eb
3bc55e4
70a4dc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1358,6 +1358,29 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): | |
return results | ||
|
||
|
||
class TokenClassificationArgumentHandler(ArgumentHandler): | ||
""" | ||
Handles arguments for token classification. | ||
""" | ||
|
||
def __call__(self, *args, **kwargs): | ||
|
||
if args is not None and len(args) > 0: | ||
if isinstance(args, str): | ||
inputs = [args] | ||
else: | ||
inputs = args | ||
batch_size = len(inputs) | ||
|
||
offset_mapping = kwargs.get("offset_mapping", None) | ||
if offset_mapping: | ||
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple): | ||
offset_mapping = [offset_mapping] | ||
if len(offset_mapping) != batch_size: | ||
raise ("offset_mapping should have the same batch size as the input") | ||
return inputs, offset_mapping | ||
|
||
|
||
@add_end_docstrings( | ||
PIPELINE_INIT_ARGS, | ||
r""" | ||
|
@@ -1395,13 +1418,14 @@ def __init__( | |
ignore_labels=["O"], | ||
task: str = "", | ||
grouped_entities: bool = False, | ||
ignore_subwords: bool = True, | ||
): | ||
super().__init__( | ||
model=model, | ||
tokenizer=tokenizer, | ||
modelcard=modelcard, | ||
framework=framework, | ||
args_parser=args_parser, | ||
args_parser=TokenClassificationArgumentHandler(), | ||
device=device, | ||
binary_output=binary_output, | ||
task=task, | ||
|
@@ -1416,6 +1440,7 @@ def __init__( | |
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) | ||
self.ignore_labels = ignore_labels | ||
self.grouped_entities = grouped_entities | ||
self.ignore_subwords = ignore_subwords | ||
|
||
def __call__(self, *args, **kwargs): | ||
""" | ||
|
@@ -1436,9 +1461,11 @@ def __call__(self, *args, **kwargs): | |
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the | ||
corresponding token in the sentence. | ||
""" | ||
inputs = self._args_parser(*args, **kwargs) | ||
|
||
inputs, offset_mappings = self._args_parser(*args, **kwargs) | ||
answers = [] | ||
for sentence in inputs: | ||
|
||
for i, sentence in enumerate(inputs): | ||
|
||
# Manage correct placement of the tensors | ||
with self.device_placement(): | ||
|
@@ -1448,7 +1475,18 @@ def __call__(self, *args, **kwargs): | |
return_attention_mask=False, | ||
return_tensors=self.framework, | ||
truncation=True, | ||
return_special_tokens_mask=True, | ||
return_offsets_mapping=self.tokenizer.is_fast, | ||
) | ||
if self.tokenizer.is_fast: | ||
offset_mapping = tokens["offset_mapping"].cpu().numpy()[0] | ||
del tokens["offset_mapping"] | ||
elif offset_mappings: | ||
offset_mapping = offset_mappings[i] | ||
else: | ||
raise Exception("To decode [UNK] tokens use a fast tokenizer or provide offset_mapping parameter") | ||
special_tokens_mask = tokens["special_tokens_mask"].cpu().numpy()[0] | ||
del tokens["special_tokens_mask"] | ||
|
||
# Forward | ||
if self.framework == "tf": | ||
|
@@ -1465,24 +1503,35 @@ def __call__(self, *args, **kwargs): | |
|
||
entities = [] | ||
# Filter to labels not in `self.ignore_labels` | ||
# Filter special_tokens | ||
filtered_labels_idx = [ | ||
(idx, label_idx) | ||
for idx, label_idx in enumerate(labels_idx) | ||
if self.model.config.id2label[label_idx] not in self.ignore_labels | ||
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx] | ||
] | ||
|
||
for idx, label_idx in filtered_labels_idx: | ||
start_ind, end_ind = offset_mapping[idx] | ||
word_ref = sentence[start_ind:end_ind] | ||
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0] | ||
is_subword = len(word_ref) != len(word) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
|
||
if int(input_ids[idx]) == self.tokenizer.unk_token_id: | ||
word = word_ref | ||
is_subword = False | ||
|
||
entity = { | ||
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])), | ||
"word": word, | ||
"score": score[idx][label_idx].item(), | ||
"entity": self.model.config.id2label[label_idx], | ||
"index": idx, | ||
} | ||
|
||
if self.grouped_entities and self.ignore_subwords: | ||
entity["is_subword"] = is_subword | ||
|
||
entities += [entity] | ||
|
||
# Append grouped entities | ||
if self.grouped_entities: | ||
answers += [self.group_entities(entities)] | ||
# Append ungrouped entities | ||
|
@@ -1501,8 +1550,8 @@ def group_sub_entities(self, entities: List[dict]) -> dict: | |
entities (:obj:`dict`): The entities predicted by the pipeline. | ||
""" | ||
# Get the first entity in the entity group | ||
entity = entities[0]["entity"] | ||
scores = np.mean([entity["score"] for entity in entities]) | ||
entity = entities[0]["entity"].split("-")[-1] | ||
scores = np.nanmean([entity["score"] for entity in entities]) | ||
tokens = [entity["word"] for entity in entities] | ||
|
||
entity_group = { | ||
|
@@ -1527,7 +1576,9 @@ def group_entities(self, entities: List[dict]) -> List[dict]: | |
last_idx = entities[-1]["index"] | ||
|
||
for entity in entities: | ||
|
||
is_last_idx = entity["index"] == last_idx | ||
is_subword = self.ignore_subwords and entity["is_subword"] | ||
if not entity_group_disagg: | ||
entity_group_disagg += [entity] | ||
if is_last_idx: | ||
|
@@ -1536,10 +1587,19 @@ def group_entities(self, entities: List[dict]) -> List[dict]: | |
|
||
# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group | ||
# The split is meant to account for the "B" and "I" suffixes | ||
# Shouldn't merge if both entities are B-type | ||
cceyda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if ( | ||
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1] | ||
( | ||
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1] | ||
and entity["entity"].split("-")[0] != "B" | ||
) | ||
and entity["index"] == entity_group_disagg[-1]["index"] + 1 | ||
): | ||
) or is_subword: | ||
# Modify subword type to be previous_type | ||
if is_subword: | ||
entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1] | ||
entity["score"] = np.nan # set ignored scores to nan and use np.nanmean | ||
|
||
entity_group_disagg += [entity] | ||
# Group the entities at the last entity | ||
if is_last_idx: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
import unittest | ||
|
||
from transformers import pipeline | ||
from transformers import AutoTokenizer, pipeline | ||
from transformers.pipelines import Pipeline | ||
from transformers.testing_utils import require_tf | ||
from transformers.testing_utils import require_tf, require_torch | ||
|
||
from .test_pipelines_common import CustomInputPipelineCommonMixin | ||
|
||
|
@@ -19,38 +19,54 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): | |
|
||
def _test_pipeline(self, nlp: Pipeline): | ||
output_keys = {"entity", "word", "score"} | ||
if nlp.grouped_entities: | ||
output_keys = {"entity_group", "word", "score"} | ||
|
||
ungrouped_ner_inputs = [ | ||
[ | ||
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "word": "Cons"}, | ||
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "word": "##uelo"}, | ||
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "word": "Ara"}, | ||
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "word": "##új"}, | ||
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "word": "##o"}, | ||
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "word": "No"}, | ||
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "word": "##guera"}, | ||
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "word": "Andrés"}, | ||
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "word": "Pas"}, | ||
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "word": "##tran"}, | ||
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "word": "##a"}, | ||
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "word": "Far"}, | ||
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "word": "##c"}, | ||
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"}, | ||
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"}, | ||
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"}, | ||
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"}, | ||
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"}, | ||
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"}, | ||
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"}, | ||
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"}, | ||
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"}, | ||
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"}, | ||
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"}, | ||
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"}, | ||
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"}, | ||
], | ||
[ | ||
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "word": "En"}, | ||
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "word": "##zo"}, | ||
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"}, | ||
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"}, | ||
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"}, | ||
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"}, | ||
], | ||
] | ||
|
||
expected_grouped_ner_results = [ | ||
[ | ||
{"entity_group": "B-PER", "score": 0.9710702640669686, "word": "Consuelo Araújo Noguera"}, | ||
{"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"}, | ||
{"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"}, | ||
{"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"}, | ||
{"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"}, | ||
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"}, | ||
], | ||
[ | ||
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"}, | ||
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"}, | ||
], | ||
] | ||
|
||
expected_grouped_ner_results_w_subword = [ | ||
[ | ||
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"}, | ||
{"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"}, | ||
{"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"}, | ||
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"}, | ||
], | ||
[ | ||
{"entity_group": "I-PER", "score": 0.9962901175022125, "word": "Enzo"}, | ||
{"entity_group": "I-ORG", "score": 0.9986497163772583, "word": "UN"}, | ||
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"}, | ||
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"}, | ||
], | ||
] | ||
|
||
|
@@ -77,12 +93,80 @@ def _test_pipeline(self, nlp: Pipeline): | |
for key in output_keys: | ||
self.assertIn(key, result) | ||
|
||
for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results): | ||
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) | ||
if nlp.grouped_entities: | ||
if nlp.ignore_subwords: | ||
for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results): | ||
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) | ||
else: | ||
for ungrouped_input, grouped_result in zip( | ||
ungrouped_ner_inputs, expected_grouped_ner_results_w_subword | ||
): | ||
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) | ||
|
||
@require_tf | ||
def test_tf_only(self): | ||
model_name = "Narsil/small" # This model only has a TensorFlow version | ||
# We test that if we don't specificy framework='tf', it gets detected automatically | ||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer) | ||
self._test_pipeline(nlp) | ||
|
||
# offset=tokenizer(VALID_INPUTS[0],return_offsets_mapping=True)['offset_mapping'] | ||
# pipeline_running_kwargs = {"offset_mapping"} # Additional kwargs to run the pipeline with | ||
|
||
@require_tf | ||
def test_tf_defaults(self): | ||
for model_name in self.small_models: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="tf") | ||
self._test_pipeline(nlp) | ||
|
||
@require_tf | ||
def test_tf_small(self): | ||
for model_name in self.small_models: | ||
print(model_name) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
nlp = pipeline( | ||
task="ner", | ||
model=model_name, | ||
tokenizer=tokenizer, | ||
framework="tf", | ||
grouped_entities=True, | ||
ignore_subwords=True, | ||
) | ||
self._test_pipeline(nlp) | ||
|
||
for model_name in self.small_models: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
nlp = pipeline( | ||
task="ner", | ||
model=model_name, | ||
tokenizer=tokenizer, | ||
framework="tf", | ||
grouped_entities=True, | ||
ignore_subwords=False, | ||
) | ||
self._test_pipeline(nlp) | ||
|
||
@require_torch | ||
def test_pt_defaults(self): | ||
for model_name in self.small_models: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer) | ||
self._test_pipeline(nlp) | ||
|
||
@require_torch | ||
def test_torch_small(self): | ||
for model_name in self.small_models: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
nlp = pipeline( | ||
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) | ||
self._test_pipeline(nlp) | ||
|
||
for model_name in self.small_models: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
nlp = pipeline( | ||
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False | ||
) | ||
self._test_pipeline(nlp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added this to check
offset_mapping
if provided. (does a simple batch_size check)