Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor label re-alignment in NER pipeline and add tests #2

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 52 additions & 40 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
self.subword_label_re_alignment = subword_label_re_alignment
self.ignore_subwords = ignore_subwords

if self.ignore_subwords and not self.tokenizer.is_fast:
if (self.ignore_subwords or self.subword_label_re_alignment) and not self.tokenizer.is_fast:
raise ValueError(
"Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option"
"to `False` or use a fast tokenizer."
Expand Down Expand Up @@ -208,13 +208,13 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
"end": end_ind,
}

if self.grouped_entities and (self.subword_label_re_alignment or self.ignore_subwords):
if (self.grouped_entities and self.ignore_subwords) or self.subword_label_re_alignment:
entity["is_subword"] = is_subword

entities += [entity]

if self.subword_label_re_alignment:
self.set_subwords_label(entities, self.subword_label_re_alignment)
self.set_subwords_label(entities)
else:
for entity in entities:
label_idx = entity["score"].argmax()
Expand All @@ -234,56 +234,68 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
return answers[0]
return answers

def set_subwords_label(self, entities: List[dict], strategy: str) -> dict:
def sub_words_label(sub_words: List[dict]) -> dict:
def set_subwords_label(self, entities: List[dict]) -> List[dict]:
strategy = self.subword_label_re_alignment

def set_labels(sub_words: List[dict]) -> dict:
score = np.stack([sub["score"] for sub in sub_words])
if strategy == "default":
if strategy == "default" or strategy is True:
label_idx = score[0].argmax()
label = self.model.config.id2label[label_idx]
sub_words[0]["entity"] = label
sub_words[0]["score"] = score[0][label_idx]
for sub in sub_words[1:]:
sub["entity"] = "O"
return sub_words
sub["score"] = 0.0 # what score should we assign here?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the score is set to 0, because we are assigning the arbitrary label "O". Alternatively, we could use the corresponding score.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know which one is better, set to 0 or leave it as is??!
but I think if we don't set the score to 0.0 then we should take for example sub["score"].max() or something like this here, because sub["score"] is a numpy array with scores for each type of entity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think taking the max would be confusing, because it wouldn't match the label. How about we set it to -1? So that it's clear it's not a meaningful score?

else:
if strategy == "first":
# get label of first sub-word
max_label_idx = score[0].argmax()
label = self.model.config.id2label[max_label_idx]
elif strategy == "max":
max_label_idx = np.unravel_index(np.argmax(score, axis=None), score.shape)[1]
label = self.model.config.id2label[max_label_idx]
elif strategy == "average":
max_label_idx = np.mean(score, axis=0).argmax()
label = self.model.config.id2label[max_label_idx]
else:
raise ValueError(f"Invalid value {strategy} for option `subword_label_re_alignment`")

if strategy == "first":
# get label of first sub-word
label_idx = score[0].argmax()
label = self.model.config.id2label[label_idx]
elif strategy == "max":
max_label_idx = np.unravel_index(np.argmax(score, axis=None), score.shape)[1]
label = self.model.config.id2label[max_label_idx]
elif strategy == "average":
max_label_idx = np.mean(score, axis=0).argmax()
label = self.model.config.id2label[max_label_idx]
for idx, sub in enumerate(sub_words):
sub["entity"] = label
sub["score"] = score[idx][max_label_idx].item()

for idx, sub in enumerate(sub_words):
sub["entity"] = label
sub["score"] = score[idx][max_label_idx].item()
if self.ignore_subwords:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to check self.ignore_subwords here, we just want to set labels for subwords in set_labels function. later according to self.grouped_entities we'll decide to merge subwords or not. take a look at this as joshdevins said, I also don't get what is the purpose of self.ignore_subwords in the first place. what are your reasons to check self.ignore_subwords here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that grouped_entities merges all the consecutive tokens with the same label, not just subwords. In my use case, I need to merge subwords without aggregating multiple words in the same entity.
Considering the example from that thread:
Max Mustermann → Max Must ##erman ##n
with grouped_entities=True and ignore_subwords=True we get a single entity "Max Mustermann"
with ignore_subwords=True and label_re_alignment=True we can keep the two words as separate entities ["Max","Mustermann"].
This is useful when you need to relabel with additional attributes (in this case it could be FIRST_NAME, LAST_NAME, for example), or more generally perform subsequent tasks at the word level.

sub_words[0]["word"] += "".join([sub["word"].split("##")[1] for sub in sub_words[1:]])
return [sub_words[0]]

return sub_words

word_group_disagg = []
entities_with_label = []

for entity in reversed(entities):
is_subword = entity["is_subword"]
word_group_disagg += [entity]
if not is_subword:
begin_sub = word_group_disagg.pop(-1)
if len(word_group_disagg) and word_group_disagg[0]["is_subword"]:
word_group_disagg.reverse()
merged_entity = sub_words_label(sub_words=word_group_disagg)
entities_with_label.extend(merged_entity)

label_idx = begin_sub["score"].argmax()
label = self.model.config.id2label[label_idx]
begin_sub["entity"] = label
begin_sub["score"] = begin_sub["score"][label_idx].item()
subword_indices = np.where([entity["is_subword"] for entity in entities])[0]
if subword_indices.size == 0:
adjacent_subwords = []
else:
# find non-consecutive indices to identify separate clusters of subwords
cluster_edges = np.where(np.diff(subword_indices) != 1)[0]
# Sets of adjacent subwords indices, e.g.
# ['Sir', 'Test', '##y', 'M', '##c', '##T', '##est', 'is', 'test', '##iful']
# --> [[2],[4,5,6],[9]]
adjacent_subwords = np.split(subword_indices, cluster_edges + 1) # shift edge by 1

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love your solution to find adjacent subwords 👍👍. I knew that we can probably do it with numpy but didn't figure it out

word_indices = []
start = 0
for subwords in adjacent_subwords:
root_word = subwords[0] - 1
word_indices += [[idx] for idx in range(start, root_word)]
word_indices += [[root_word] + list(subwords)]
start = subwords[-1] + 1
word_indices += [[idx] for idx in range(start, len(entities))]

entities_with_label.append(begin_sub)
word_group_disagg = []
entities_with_label = []
for word_idx in word_indices:
subwords = [entities[idx] for idx in word_idx]
entities_with_label += set_labels(subwords)

entities_with_label.reverse()
return entities_with_label

def group_sub_entities(self, entities: List[dict]) -> dict:
Expand Down
Loading