-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor label re-alignment in NER pipeline and add tests #2
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,7 +111,7 @@ def __init__( | |
self.subword_label_re_alignment = subword_label_re_alignment | ||
self.ignore_subwords = ignore_subwords | ||
|
||
if self.ignore_subwords and not self.tokenizer.is_fast: | ||
if (self.ignore_subwords or self.subword_label_re_alignment) and not self.tokenizer.is_fast: | ||
raise ValueError( | ||
"Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option" | ||
"to `False` or use a fast tokenizer." | ||
|
@@ -208,13 +208,13 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
"end": end_ind, | ||
} | ||
|
||
if self.grouped_entities and (self.subword_label_re_alignment or self.ignore_subwords): | ||
if (self.grouped_entities and self.ignore_subwords) or self.subword_label_re_alignment: | ||
entity["is_subword"] = is_subword | ||
|
||
entities += [entity] | ||
|
||
if self.subword_label_re_alignment: | ||
self.set_subwords_label(entities, self.subword_label_re_alignment) | ||
self.set_subwords_label(entities) | ||
else: | ||
for entity in entities: | ||
label_idx = entity["score"].argmax() | ||
|
@@ -234,56 +234,68 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
return answers[0] | ||
return answers | ||
|
||
def set_subwords_label(self, entities: List[dict], strategy: str) -> dict: | ||
def sub_words_label(sub_words: List[dict]) -> dict: | ||
def set_subwords_label(self, entities: List[dict]) -> List[dict]: | ||
strategy = self.subword_label_re_alignment | ||
|
||
def set_labels(sub_words: List[dict]) -> dict: | ||
score = np.stack([sub["score"] for sub in sub_words]) | ||
if strategy == "default": | ||
if strategy == "default" or strategy is True: | ||
label_idx = score[0].argmax() | ||
label = self.model.config.id2label[label_idx] | ||
sub_words[0]["entity"] = label | ||
sub_words[0]["score"] = score[0][label_idx] | ||
for sub in sub_words[1:]: | ||
sub["entity"] = "O" | ||
return sub_words | ||
sub["score"] = 0.0 # what score should we assign here? | ||
else: | ||
if strategy == "first": | ||
# get label of first sub-word | ||
max_label_idx = score[0].argmax() | ||
label = self.model.config.id2label[max_label_idx] | ||
elif strategy == "max": | ||
max_label_idx = np.unravel_index(np.argmax(score, axis=None), score.shape)[1] | ||
label = self.model.config.id2label[max_label_idx] | ||
elif strategy == "average": | ||
max_label_idx = np.mean(score, axis=0).argmax() | ||
label = self.model.config.id2label[max_label_idx] | ||
else: | ||
raise ValueError(f"Invalid value {strategy} for option `subword_label_re_alignment`") | ||
|
||
if strategy == "first": | ||
# get label of first sub-word | ||
label_idx = score[0].argmax() | ||
label = self.model.config.id2label[label_idx] | ||
elif strategy == "max": | ||
max_label_idx = np.unravel_index(np.argmax(score, axis=None), score.shape)[1] | ||
label = self.model.config.id2label[max_label_idx] | ||
elif strategy == "average": | ||
max_label_idx = np.mean(score, axis=0).argmax() | ||
label = self.model.config.id2label[max_label_idx] | ||
for idx, sub in enumerate(sub_words): | ||
sub["entity"] = label | ||
sub["score"] = score[idx][max_label_idx].item() | ||
|
||
for idx, sub in enumerate(sub_words): | ||
sub["entity"] = label | ||
sub["score"] = score[idx][max_label_idx].item() | ||
if self.ignore_subwords: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that |
||
sub_words[0]["word"] += "".join([sub["word"].split("##")[1] for sub in sub_words[1:]]) | ||
return [sub_words[0]] | ||
|
||
return sub_words | ||
|
||
word_group_disagg = [] | ||
entities_with_label = [] | ||
|
||
for entity in reversed(entities): | ||
is_subword = entity["is_subword"] | ||
word_group_disagg += [entity] | ||
if not is_subword: | ||
begin_sub = word_group_disagg.pop(-1) | ||
if len(word_group_disagg) and word_group_disagg[0]["is_subword"]: | ||
word_group_disagg.reverse() | ||
merged_entity = sub_words_label(sub_words=word_group_disagg) | ||
entities_with_label.extend(merged_entity) | ||
|
||
label_idx = begin_sub["score"].argmax() | ||
label = self.model.config.id2label[label_idx] | ||
begin_sub["entity"] = label | ||
begin_sub["score"] = begin_sub["score"][label_idx].item() | ||
subword_indices = np.where([entity["is_subword"] for entity in entities])[0] | ||
if subword_indices.size == 0: | ||
adjacent_subwords = [] | ||
else: | ||
# find non-consecutive indices to identify separate clusters of subwords | ||
cluster_edges = np.where(np.diff(subword_indices) != 1)[0] | ||
# Sets of adjacent subwords indices, e.g. | ||
# ['Sir', 'Test', '##y', 'M', '##c', '##T', '##est', 'is', 'test', '##iful'] | ||
# --> [[2],[4,5,6],[9]] | ||
adjacent_subwords = np.split(subword_indices, cluster_edges + 1) # shift edge by 1 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. love your solution to find adjacent subwords 👍👍. I knew that we can probably do it with numpy but didn't figure it out |
||
word_indices = [] | ||
start = 0 | ||
for subwords in adjacent_subwords: | ||
root_word = subwords[0] - 1 | ||
word_indices += [[idx] for idx in range(start, root_word)] | ||
word_indices += [[root_word] + list(subwords)] | ||
start = subwords[-1] + 1 | ||
word_indices += [[idx] for idx in range(start, len(entities))] | ||
|
||
entities_with_label.append(begin_sub) | ||
word_group_disagg = [] | ||
entities_with_label = [] | ||
for word_idx in word_indices: | ||
subwords = [entities[idx] for idx in word_idx] | ||
entities_with_label += set_labels(subwords) | ||
|
||
entities_with_label.reverse() | ||
return entities_with_label | ||
|
||
def group_sub_entities(self, entities: List[dict]) -> dict: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the score is set to 0, because we are assigning the arbitrary label "O". Alternatively, we could use the corresponding score.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know which one is better, set to 0 or leave it as is??!
but I think if we don't set the score to 0.0 then we should take for example
sub["score"].max()
or something like this here, becausesub["score"]
is a numpy array with scores for each type of entity.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think taking the max would be confusing, because it wouldn't match the label. How about we set it to -1? So that it's clear it's not a meaningful score?