Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ner label re alignment #10568

Closed
Closed
Changes from 3 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
009409c
add subword_label_re_alignment strategies
elk-cloner Mar 6, 2021
ee7aeb0
remove redundent import
elk-cloner Mar 6, 2021
96f68e4
pass local test
elk-cloner Mar 13, 2021
1017651
Restore compatibility with existing NER pipeline tests
francescorubbo Mar 20, 2021
62a526a
Merge pull request #1 from francescorubbo/ner_label_re_alignment
elk-cloner Mar 21, 2021
9052a33
Refactor label re-alignment in NER pipeline and add tests
francescorubbo Mar 22, 2021
096c433
Merge pull request #2 from francescorubbo/ner_label_re_alignment
elk-cloner Mar 24, 2021
df641ab
Import numpy, used for arrays in test input
francescorubbo Mar 26, 2021
482f325
Bugfix: ensure entities are updated with aligned labels.
francescorubbo Apr 4, 2021
e728ce1
Update tests to probe more scenarios and with bugfix.
francescorubbo Apr 4, 2021
04925f2
Define and use AggregationStrategy enum as argument
francescorubbo Apr 25, 2021
399d713
Use AggregationStrategy.FIRST as default
francescorubbo Apr 25, 2021
4338be8
Updated expected test results and move to fixtures
francescorubbo Apr 25, 2021
cdfe3ac
Use score corresponding to chosen label.
francescorubbo Apr 25, 2021
324f641
Fill entity attributes only if they exist.
francescorubbo Apr 25, 2021
7fcfc4e
Style fixes
francescorubbo Apr 25, 2021
031f3ef
Merge branch 'master' of https://github.com/huggingface/transformers …
francescorubbo Apr 25, 2021
909e5c8
Cleanup leftover from solving conflicts
francescorubbo Apr 25, 2021
30f0658
updating the checkpoint for GPT2ForSequence Classification to one wit…
abiolaTresor Apr 26, 2021
04ab2ca
add pooling layer support (#11439)
thevasudevgupta Apr 26, 2021
32dbb2d
make style (#11442)
patrickvonplaten Apr 26, 2021
4b72cfd
Pin black to 20.8.b1
sgugger Apr 26, 2021
c1625b3
With style
sgugger Apr 26, 2021
4bd6b54
Pin black to 21.4b0
sgugger Apr 26, 2021
38a716c
TF BART models - Add `cross_attentions` to model output and fix cross…
stancld Apr 26, 2021
d7633a4
Add basic support for FP16 in SageMaker model parallelism (#11407)
sgugger Apr 26, 2021
e3e70f9
docs(examples): fix link to TPU launcher script (#11427)
Apr 26, 2021
b24ead8
fix some typos in docs, comments, logging/errors (#11432)
LSinev Apr 26, 2021
ab2cabb
Pass along seed to DistributedSampler (#11406)
sgugger Apr 26, 2021
6715e3b
Clarify description of the is_split_into_words argument (#11449)
kstathou Apr 26, 2021
a753caf
[docs] fix invalid class name (#11438)
stas00 Apr 26, 2021
ce11318
make sure to test against the local checkout (#11437)
stas00 Apr 26, 2021
b03b2a6
Style
sgugger Apr 26, 2021
7959d83
Give each test a different repo name (#11453)
sgugger Apr 26, 2021
1d30ec9
[Examples] Fixes inconsistency around eval vs val and predict vs test…
bhadreshpsavani Apr 26, 2021
0661abc
Variable Correction for Consistency in Distillation Example (#11444)
jaimeenahn Apr 26, 2021
bc2571e
[Deepspeed] ZeRO-Infinity integration plus config revamp (#11418)
stas00 Apr 26, 2021
741d48f
Remove max length beam scorer (#11378)
GeetDsa Apr 26, 2021
88ac60f
update QuickTour docs to reflect model output object (#11462)
Apr 27, 2021
7ceff67
Finish Making Quick Tour respect the model object (#11467)
Apr 27, 2021
8d43c71
fix docs for decoder_input_ids (#11466)
patil-suraj Apr 27, 2021
2d27900
Update min versions in README and add Flax (#11472)
sgugger Apr 28, 2021
c0eb218
Update `PreTrainedTokenizerBase` to check/handle batch length for `te…
hamelsmu Apr 28, 2021
3f6add8
fix #1149 (#11493)
hamelsmu Apr 28, 2021
9114e51
add subword_label_re_alignment strategies
elk-cloner Mar 6, 2021
7d518b8
remove redundent import
elk-cloner Mar 6, 2021
e108da1
pass local test
elk-cloner Mar 13, 2021
8879e12
Restore compatibility with existing NER pipeline tests
francescorubbo Mar 20, 2021
5662477
Refactor label re-alignment in NER pipeline and add tests
francescorubbo Mar 22, 2021
92e6cee
Import numpy, used for arrays in test input
francescorubbo Mar 26, 2021
b0074c7
Bugfix: ensure entities are updated with aligned labels.
francescorubbo Apr 4, 2021
af865e3
Update tests to probe more scenarios and with bugfix.
francescorubbo Apr 4, 2021
cdd1db2
Define and use AggregationStrategy enum as argument
francescorubbo Apr 25, 2021
69da7cc
Use AggregationStrategy.FIRST as default
francescorubbo Apr 25, 2021
8b64c28
Updated expected test results and move to fixtures
francescorubbo Apr 25, 2021
2f3b8b0
Use score corresponding to chosen label.
francescorubbo Apr 25, 2021
1e48070
Fill entity attributes only if they exist.
francescorubbo Apr 25, 2021
c27f9eb
Style fixes
francescorubbo Apr 25, 2021
a251184
Merge branch 'ner_label_re_alignment' of https://github.com/elk-clone…
francescorubbo Apr 29, 2021
45e1919
Remove duplicated definition caused by rebasing after merging
francescorubbo Apr 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
ignore_labels=["O"],
task: str = "",
grouped_entities: bool = False,
subword_label_re_alignment: Union[bool, str] = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if aggregate_subwords would be a more suitable name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would understand aggregate_subwords better than subword_label_re_alignment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or would aggregate_strategy be even better, as we're actually prompting for a strategy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this accept enum parameters as value would be great, similar to what we do with PaddingStrategy:

class PaddingStrategy(ExplicitEnum):
"""
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
in an IDE.
"""
LONGEST = "longest"
MAX_LENGTH = "max_length"
DO_NOT_PAD = "do_not_pad"

ignore_subwords: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the ignore_subwords flag be removed then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, we should remove ignore_subwords flag. @LysandreJik, @Narsil i leave some of the old codes just to pass the tests(😅 i'm new to tests). Can you help me about the tests? ( for example should I remove this test or somehow change it ? )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I would think the test can be repurposed for the new flag. It would also be good to assert correctness, beside execution (it looks like the current test doesn't check the resulting output?). I'm happy to contribute directly to your branch, if it helps. Let me know.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to keep it for backwards compatibility purposes. Can the capabilities enabled by that flag be achieved with the new flag introduced in this PR?

):
super().__init__(
Expand All @@ -107,6 +108,7 @@ def __init__(
self._args_parser = args_parser
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
self.subword_label_re_alignment = subword_label_re_alignment
self.ignore_subwords = ignore_subwords

if self.ignore_subwords and not self.tokenizer.is_fast:
Expand Down Expand Up @@ -177,18 +179,12 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
input_ids = tokens["input_ids"].cpu().numpy()[0]

score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
labels_idx = score.argmax(axis=-1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we are going to set the labels according to the strategy we need the scores for all labels, specially when using “average” strategy.


entities = []
# Filter to labels not in `self.ignore_labels`
# Filter special_tokens
filtered_labels_idx = [
(idx, label_idx)
for idx, label_idx in enumerate(labels_idx)
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx]
]
filtered_labels_idx = [idx for idx in range(score.shape[0]) if not special_tokens_mask[idx]]

for idx, label_idx in filtered_labels_idx:
for idx in filtered_labels_idx:
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
Expand All @@ -206,18 +202,28 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):

entity = {
"word": word,
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"score": score[idx],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need score for all labels

"index": idx,
"start": start_ind,
"end": end_ind,
}

if self.grouped_entities and self.ignore_subwords:
if self.grouped_entities or self.subword_label_re_alignment:
entity["is_subword"] = is_subword

entities += [entity]

if self.subword_label_re_alignment:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are going to set the labels according to the strategy, if subword_label_re_alignment == false we will leave the labels as they were predicted

self.set_subwords_label(entities, self.subword_label_re_alignment)
else:
for entity in entities:
label_idx = entity["score"].argmax()
label = self.model.config.id2label[label_idx]
entity["entity"] = label
entity["score"] = entity["score"][label_idx]

# I think we should check self.subword_label_re_alignment here too
# because we can't use self.grouped_entities if self.subword_label_re_alignment is false
if self.grouped_entities:
answers += [self.group_entities(entities)]
# Append ungrouped entities
Expand All @@ -228,6 +234,58 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
return answers[0]
return answers

def set_subwords_label(self, entities: List[dict], strategy: str) -> dict:
def sub_words_label(sub_words: List[dict]) -> dict:
score = np.stack([sub["score"] for sub in sub_words])
if strategy == "default":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as @joshdevins said: "If training with padded sub-words/label for first sub-word only, e.g. Max Mustermann → Max Must ##erman ##n → B-PER I-PER X X
Use the label from the first sub-word (default)"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if strategy is set to True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sry this is my bad. you're right. When strategy is True we should use "default" strategy. I'll fix this

label_idx = score[0].argmax()
label = self.model.config.id2label[label_idx]
sub_words[0]["entity"] = label
for sub in sub_words[1:]:
sub["entity"] = "O"
return sub_words

if strategy == "first":
# get label of first sub-word
label_idx = score[0].argmax()
label = self.model.config.id2label[label_idx]
elif strategy == "max":
max_label_idx = np.unravel_index(np.argmax(score, axis=None), score.shape)[1]
label = self.model.config.id2label[max_label_idx]
elif strategy == "average":
max_label_idx = np.mean(score, axis=0).argmax()
label = self.model.config.id2label[max_label_idx]

for idx, sub in enumerate(sub_words):
sub["entity"] = label
sub["score"] = score[idx][max_label_idx]

return sub_words

word_group_disagg = []
entities_with_label = []

for entity in reversed(entities):
is_subword = entity["is_subword"]
word_group_disagg += [entity]
if not is_subword:
begin_sub = word_group_disagg.pop(-1)
if len(word_group_disagg) and word_group_disagg[0]["is_subword"]:
word_group_disagg.reverse()
merged_entity = sub_words_label(sub_words=word_group_disagg)
entities_with_label.extend(merged_entity)

label_idx = begin_sub["score"].argmax()
label = self.model.config.id2label[label_idx]
begin_sub["entity"] = label
begin_sub["score"] = begin_sub["score"][label_idx]

entities_with_label.append(begin_sub)
word_group_disagg = []

entities_with_label.reverse()
return entities_with_label

def group_sub_entities(self, entities: List[dict]) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
Expand Down