-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ner label re alignment #10568
Ner label re alignment #10568
Changes from 3 commits
009409c
ee7aeb0
96f68e4
1017651
62a526a
9052a33
096c433
df641ab
482f325
e728ce1
04925f2
399d713
4338be8
cdfe3ac
324f641
7fcfc4e
031f3ef
909e5c8
30f0658
04ab2ca
32dbb2d
4b72cfd
c1625b3
4bd6b54
38a716c
d7633a4
e3e70f9
b24ead8
ab2cabb
6715e3b
a753caf
ce11318
b03b2a6
7959d83
1d30ec9
0661abc
bc2571e
741d48f
88ac60f
7ceff67
8d43c71
2d27900
c0eb218
3f6add8
9114e51
7d518b8
e108da1
8879e12
5662477
92e6cee
b0074c7
af865e3
cdd1db2
69da7cc
8b64c28
2f3b8b0
1e48070
c27f9eb
a251184
45e1919
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,6 +85,7 @@ def __init__( | |
ignore_labels=["O"], | ||
task: str = "", | ||
grouped_entities: bool = False, | ||
subword_label_re_alignment: Union[bool, str] = False, | ||
ignore_subwords: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so, we should remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I would think the test can be repurposed for the new flag. It would also be good to assert correctness, beside execution (it looks like the current test doesn't check the resulting output?). I'm happy to contribute directly to your branch, if it helps. Let me know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to keep it for backwards compatibility purposes. Can the capabilities enabled by that flag be achieved with the new flag introduced in this PR? |
||
): | ||
super().__init__( | ||
|
@@ -107,6 +108,7 @@ def __init__( | |
self._args_parser = args_parser | ||
self.ignore_labels = ignore_labels | ||
self.grouped_entities = grouped_entities | ||
self.subword_label_re_alignment = subword_label_re_alignment | ||
self.ignore_subwords = ignore_subwords | ||
|
||
if self.ignore_subwords and not self.tokenizer.is_fast: | ||
|
@@ -177,18 +179,12 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
input_ids = tokens["input_ids"].cpu().numpy()[0] | ||
|
||
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True) | ||
labels_idx = score.argmax(axis=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we are going to set the labels according to the strategy we need the scores for all labels, specially when using “average” strategy. |
||
|
||
entities = [] | ||
# Filter to labels not in `self.ignore_labels` | ||
# Filter special_tokens | ||
filtered_labels_idx = [ | ||
(idx, label_idx) | ||
for idx, label_idx in enumerate(labels_idx) | ||
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx] | ||
] | ||
filtered_labels_idx = [idx for idx in range(score.shape[0]) if not special_tokens_mask[idx]] | ||
|
||
for idx, label_idx in filtered_labels_idx: | ||
for idx in filtered_labels_idx: | ||
if offset_mapping is not None: | ||
start_ind, end_ind = offset_mapping[idx] | ||
word_ref = sentence[start_ind:end_ind] | ||
|
@@ -206,18 +202,28 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
|
||
entity = { | ||
"word": word, | ||
"score": score[idx][label_idx].item(), | ||
"entity": self.model.config.id2label[label_idx], | ||
"score": score[idx], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need score for all labels |
||
"index": idx, | ||
"start": start_ind, | ||
"end": end_ind, | ||
} | ||
|
||
if self.grouped_entities and self.ignore_subwords: | ||
if self.grouped_entities or self.subword_label_re_alignment: | ||
entity["is_subword"] = is_subword | ||
|
||
entities += [entity] | ||
|
||
if self.subword_label_re_alignment: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are going to set the labels according to the strategy, if subword_label_re_alignment == false we will leave the labels as they were predicted |
||
self.set_subwords_label(entities, self.subword_label_re_alignment) | ||
else: | ||
for entity in entities: | ||
label_idx = entity["score"].argmax() | ||
label = self.model.config.id2label[label_idx] | ||
entity["entity"] = label | ||
entity["score"] = entity["score"][label_idx] | ||
|
||
# I think we should check self.subword_label_re_alignment here too | ||
# because we can't use self.grouped_entities if self.subword_label_re_alignment is false | ||
if self.grouped_entities: | ||
answers += [self.group_entities(entities)] | ||
# Append ungrouped entities | ||
|
@@ -228,6 +234,58 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
return answers[0] | ||
return answers | ||
|
||
def set_subwords_label(self, entities: List[dict], strategy: str) -> dict: | ||
def sub_words_label(sub_words: List[dict]) -> dict: | ||
score = np.stack([sub["score"] for sub in sub_words]) | ||
if strategy == "default": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as @joshdevins said: "If training with padded sub-words/label for first sub-word only, e.g. Max Mustermann → Max Must ##erman ##n → B-PER I-PER X X There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if strategy is set to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sry this is my bad. you're right. When |
||
label_idx = score[0].argmax() | ||
label = self.model.config.id2label[label_idx] | ||
sub_words[0]["entity"] = label | ||
for sub in sub_words[1:]: | ||
sub["entity"] = "O" | ||
return sub_words | ||
|
||
if strategy == "first": | ||
# get label of first sub-word | ||
label_idx = score[0].argmax() | ||
label = self.model.config.id2label[label_idx] | ||
elif strategy == "max": | ||
max_label_idx = np.unravel_index(np.argmax(score, axis=None), score.shape)[1] | ||
label = self.model.config.id2label[max_label_idx] | ||
elif strategy == "average": | ||
max_label_idx = np.mean(score, axis=0).argmax() | ||
label = self.model.config.id2label[max_label_idx] | ||
|
||
for idx, sub in enumerate(sub_words): | ||
sub["entity"] = label | ||
sub["score"] = score[idx][max_label_idx] | ||
|
||
return sub_words | ||
|
||
word_group_disagg = [] | ||
entities_with_label = [] | ||
|
||
for entity in reversed(entities): | ||
is_subword = entity["is_subword"] | ||
word_group_disagg += [entity] | ||
if not is_subword: | ||
begin_sub = word_group_disagg.pop(-1) | ||
if len(word_group_disagg) and word_group_disagg[0]["is_subword"]: | ||
word_group_disagg.reverse() | ||
merged_entity = sub_words_label(sub_words=word_group_disagg) | ||
entities_with_label.extend(merged_entity) | ||
|
||
label_idx = begin_sub["score"].argmax() | ||
label = self.model.config.id2label[label_idx] | ||
begin_sub["entity"] = label | ||
begin_sub["score"] = begin_sub["score"][label_idx] | ||
|
||
entities_with_label.append(begin_sub) | ||
word_group_disagg = [] | ||
|
||
entities_with_label.reverse() | ||
return entities_with_label | ||
|
||
def group_sub_entities(self, entities: List[dict]) -> dict: | ||
""" | ||
Group together the adjacent tokens with the same entity predicted. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if
aggregate_subwords
would be a more suitable name?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would understand
aggregate_subwords
better thansubword_label_re_alignment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or would
aggregate_strategy
be even better, as we're actually prompting for a strategy?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having this accept enum parameters as value would be great, similar to what we do with
PaddingStrategy
:transformers/src/transformers/file_utils.py
Lines 1657 to 1665 in 9f4e0c2