Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Inconsistent NER Grouping (Pipeline) #4987

Merged
merged 22 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7b20245
Add B I handling to grouping
enzoampil Jun 14, 2020
562bd7c
Add fix to include separate entity as last token
enzoampil Jun 14, 2020
9f0936c
move last_idx definition outside loop
enzoampil Jun 14, 2020
d3c4838
Use first entity in entity group as reference for entity type
enzoampil Jun 14, 2020
9a182ea
Add test cases
enzoampil Jun 17, 2020
7de9685
Take out extra class accidentally added
enzoampil Jun 17, 2020
4a7a483
Return tf ner grouped test to original
enzoampil Jun 17, 2020
010b784
Take out redundant last entity
enzoampil Jul 4, 2020
e1b2d38
Get last_idx safely
enzoampil Jul 4, 2020
0775ef5
Fix first entity comment
enzoampil Jul 4, 2020
1b097fb
Create separate functions for group_sub_entities and group_entities (…
enzoampil Jul 4, 2020
1eb4989
Take out unnecessary last_idx
enzoampil Jul 4, 2020
f3cc9a4
Remove additional forward pass test
enzoampil Jul 4, 2020
b500617
Move token classification basic tests to separate class
enzoampil Jul 4, 2020
ff91c62
Move token classification basic tests back to monocolumninputtestcase
enzoampil Jul 4, 2020
9533bf7
Move base ner tests to nerpipelinetests
enzoampil Jul 4, 2020
e719f81
Take out unused kwargs
enzoampil Jul 4, 2020
f8d0a76
Add back mandatory_keys argument
enzoampil Jul 4, 2020
f71b178
Add unitary tests for group_entities in _test_ner_pipeline
enzoampil Jul 4, 2020
4a98747
Fix last entity handling
enzoampil Jul 4, 2020
8f29ef9
Fix grouping fucntion used
enzoampil Jul 4, 2020
05f50d9
Add typing to group_sub_entities and group_entities
enzoampil Jul 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 46 additions & 33 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,8 +1003,6 @@ def __call__(self, *args, **kwargs):
labels_idx = score.argmax(axis=-1)

entities = []
entity_groups = []
entity_group_disagg = []
# Filter to labels not in `self.ignore_labels`
filtered_labels_idx = [
(idx, label_idx)
Expand All @@ -1020,50 +1018,26 @@ def __call__(self, *args, **kwargs):
"entity": self.model.config.id2label[label_idx],
"index": idx,
}
last_idx, _ = filtered_labels_idx[-1]
if self.grouped_entities:
if not entity_group_disagg:
entity_group_disagg += [entity]
if idx == last_idx:
entity_groups += [self.group_entities(entity_group_disagg)]
continue

# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
if (
entity["entity"] == entity_group_disagg[-1]["entity"]
and entity["index"] == entity_group_disagg[-1]["index"] + 1
):
entity_group_disagg += [entity]
# Group the entities at the last entity
if idx == last_idx:
entity_groups += [self.group_entities(entity_group_disagg)]
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
else:
entity_groups += [self.group_entities(entity_group_disagg)]
entity_group_disagg = [entity]

entities += [entity]

# Ensure if an entity is the latest one in the sequence it gets appended to the output
if len(entity_group_disagg) > 0:
entity_groups.append(self.group_entities(entity_group_disagg))

# Append
# Append grouped entities
if self.grouped_entities:
answers += [entity_groups]
answers += [self.group_entities(entities)]
# Append ungrouped entities
else:
answers += [entities]

if len(answers) == 1:
return answers[0]
return answers

def group_entities(self, entities):
def group_sub_entities(self, entities: List[dict]) -> dict:
"""
Returns grouped entities
Returns grouped sub entities
"""
# Get the last entity in the entity group
entity = entities[-1]["entity"]
# Get the first entity in the entity group
entity = entities[0]["entity"]
scores = np.mean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]

Expand All @@ -1074,6 +1048,45 @@ def group_entities(self, entities):
}
return entity_group

def group_entities(self, entities: List[dict]) -> List[dict]:
"""
Returns grouped entities
"""

entity_groups = []
entity_group_disagg = []

if entities:
last_idx = entities[-1]["index"]

for entity in entities:
is_last_idx = entity["index"] == last_idx
if not entity_group_disagg:
entity_group_disagg += [entity]
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
continue

# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" suffixes
if (
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
and entity["index"] == entity_group_disagg[-1]["index"] + 1
):
entity_group_disagg += [entity]
# Group the entities at the last entity
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
else:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
entity_group_disagg = [entity]
# If it's the last entity, add it to the entity groups
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]

return entity_groups


NerPipeline = TokenClassificationPipeline

Expand Down
123 changes: 95 additions & 28 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
{"sequence": "<s>The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
],
]

SUMMARIZATION_KWARGS = dict(num_beams=2, min_length=2, max_length=5)


Expand Down Expand Up @@ -156,34 +157,6 @@ def _test_mono_column_pipeline(

self.assertRaises(Exception, nlp, invalid_inputs)

@require_torch
def test_torch_ner(self):
mandatory_keys = {"entity", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys)

@require_torch
def test_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True)
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys)

@require_tf
def test_tf_ner(self):
mandatory_keys = {"entity", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf")
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys)

@require_tf
def test_tf_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys)

@require_torch
def test_torch_sentiment_analysis(self):
mandatory_keys = {"label", "score"}
Expand Down Expand Up @@ -393,6 +366,100 @@ def test_tf_question_answering(self):
self._test_qa_pipeline(nlp)


class NerPipelineTests(unittest.TestCase):
def _test_ner_pipeline(
self, nlp: Pipeline, output_keys: Iterable[str],
):

ungrouped_ner_inputs = [
[
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "word": "Cons"},
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "word": "##uelo"},
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "word": "Ara"},
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "word": "##új"},
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "word": "##o"},
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "word": "No"},
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "word": "##guera"},
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "word": "Andrés"},
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "word": "Pas"},
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "word": "##tran"},
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "word": "##a"},
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "word": "Far"},
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "word": "##c"},
],
[
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "word": "En"},
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "word": "##zo"},
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"},
],
]
expected_grouped_ner_results = [
[
{"entity_group": "B-PER", "score": 0.9710702640669686, "word": "Consuelo Araújo Noguera"},
{"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
{"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"},
],
[
{"entity_group": "I-PER", "score": 0.9962901175022125, "word": "Enzo"},
{"entity_group": "I-ORG", "score": 0.9986497163772583, "word": "UN"},
],
]

self.assertIsNotNone(nlp)

mono_result = nlp(VALID_INPUTS[0])
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], (dict, list))

if isinstance(mono_result[0], list):
mono_result = mono_result[0]

for key in output_keys:
self.assertIn(key, mono_result[0])

multi_result = [nlp(input) for input in VALID_INPUTS]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))

if isinstance(multi_result[0], list):
multi_result = multi_result[0]

for result in multi_result:
for key in output_keys:
self.assertIn(key, result)

for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)

@require_torch
def test_torch_ner(self):
mandatory_keys = {"entity", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
self._test_ner_pipeline(nlp, mandatory_keys)

@require_torch
def test_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True)
self._test_ner_pipeline(nlp, mandatory_keys)

@require_tf
def test_tf_ner(self):
mandatory_keys = {"entity", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf")
self._test_ner_pipeline(nlp, mandatory_keys)

@require_tf
def test_tf_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
self._test_ner_pipeline(nlp, mandatory_keys)


class PipelineCommonTests(unittest.TestCase):

pipelines = SUPPORTED_TASKS.keys()
Expand Down