-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TokenClassification] Label realignment for subword aggregation #11680
[TokenClassification] Label realignment for subword aggregation #11680
Conversation
Tentative to replace https://github.com/huggingface/transformers/pull/11622/files - Added `AggregationStrategy` - `ignore_subwords` and `grouped_entities` arguments are now fused into `aggregation_strategy`. It makes more sense anyway because `ignore_subwords=True` with `grouped_entities=False` did not have a meaning anyway. - Added 2 new ways to aggregate which are MAX, and AVERAGE - AVERAGE requires a bit more information than the others, for now this case is slightly specific, we should keep that in mind for future changes. - Testing has been modified to reflect new argument, and to check the correct deprecation and the new aggregation_strategy. - Put the testing argument and testing results for aggregation_strategy, close together, so that readers can understand what is supposed to happen. - `aggregate` is now only tested on a small model as it does not mean anything to test it globally for all models. - Previous tests are unchanged in desired output. - Added a new test case that showcases better the difference between the FIRST, MAX and AVERAGE strategies.
model_name = self.small_models[0] | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") | ||
example3 = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@francescorubbo Can you confirm that this respects what was in your PR ?
I tried to obtain exactly the same results, but might have misread some stuff.
I simplified the example as much as possible to make it more self contained and easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this implementation support a scenario where we need to aggregate subwords into words, but not words into entities (e.g. in case the model labels do not follow the BIO convention we don't want to merge multiple entities with the same class)?
I think it doesn't, but I might have missed some details of the implementation logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean
Micro ##soft Com ##pany
TAG, TAG, TAG, TAG (without B-, I- ?)
->
Microsoft Company
TAG TAG
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, you're right that this currently does not support that, is that correct that in the absence of '-' within the Tag name we can assume everything is 'B-' ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's more robust to check explicitly in the label dict if they match the BIO pattern. Using a conditional on absence of '-' seems brittle as it could yield a false negative if a label is e.g. 'over-the-counter' (imagine a model classifying medications in prescription notes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've chnaged the code to be more restrictive.
Thank you @Narsil, I'll take a look! @francescorubbo, @elk-cloner, this PR originated from yours and is based on the same approach. If you approve of this PR, can I add you as co-authors, as you've greatly contributed to its current shape? |
Sure! |
Sure! |
Great, I will! Pinging also @joshdevins and @cceyda for feedback |
scores = np.stack([entity["scores"] for entity in entities]) | ||
average_scores = np.nanmean(scores, axis=0) | ||
entity_idx = average_scores.argmax() | ||
entity = self.model.config.id2label[entity_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
umm average_scores.argmax()
would just return the index where max score occurs. which !=entity_idx on self.model.config.id2label
am I missing something? shouldn't it be:
Edit: yes I was missing something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also not sure what exactly is happening in the average strategy based on the code😵 In the code scores
looks like a list [0.1, 0.5 , 0.1, 0.4] which makes averaging then argmaxing make no sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't scores
a 2D array of dimensions (N_subwords, N_softmax)? We are averaging along subwords and then taking the overall argmax label. This makes sense if we assume each subword carries some amount of uncorrelated signal about the entity label.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aha just realized these are pre argmax 2d scores, it all makes sense now!
Nevermind, morning brain 🤦
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the avg strategy,
It all looks good! much cleaner than before too
Good job @elk-cloner @Narsil @francescorubbo 🥳
wait a second I'm seeing something:
Edit: reached the right conclusion by the wrong means 🤣
The above error is due to that subword getting predicted as a "O"(ignored default) and being ignored early on.
this bit of code line
filtered_labels_idx = [
(idx, label_idx)
for idx, label_idx in enumerate(labels_idx)
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx]
]
we should probably do ignore_labels
check later after we merge the subwords, or we will end up wrongfully ignoring some subwords' contributions in max|avg strategies.
bi, tag = entity["entity"].split("-") | ||
last_bi, last_tag = entity_group_disagg[-1]["entity"].split("-") | ||
# Index might not be available if we aggregate words first. | ||
index_agree = entity["index"] == entity_group_disagg[-1]["index"] + 1 if "index" in entity else True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just sayin True
when index
is not available feels dangerous 🧐
non consecutive entities having matching B-I-tags it could happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might have to keep track of first subword index and last subword index of the entity during aggregation.
maybe index could be a tuple 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are perfectly correct, and doing the filtering later should actually remove that check anyway (as there can't be any more non-consecutive entities anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that's right~ turns out it's a - kill two 🐛 with one 🥌- situation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It led to a bit more consequential change, but should be good now.
We are still filtering special_tokens early on, but ignore_labels
later, so all contributions should be there.
1- Tags might not follow B-, I- convention, so any tag should work now (assumed as B-TAG) 2- Fixed an issue with average that leads to a substantial code change. 3- The testing suite was not checking for the "index" key for "none" strategy. This is now fixed. The issue is that "O" could not be chosen by AVERAGE strategy because those tokens were filtered out beforehand, so their relative scores were not counted in the average. Now filtering on ignore_labels will happen at the very end of the pipeline fixing that issue. It's a bit hard to make sure this stays like that because we do not have a end-to-end test for that behavior
Looking good, thank you for all the work! I'm wondering if you can include test cases for the original 3 examples provided in #10263 (comment) ? The new test examples here look correct but I'm not sure they cover the scope of the first examples. Maybe just stub out a real model and test with the labels for each sub-word token as provided in the example. This will exercise just the aggregation logic then as in theory a model could output any of these example labels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for this refactor! Let me know when satisfied and I'll operate the merge
while adding the co-authors as mentioned above.
Whether or not to group the tokens corresponding to the same entity together in the predictions or not. | ||
DEPRECATED, use :obj:`aggregation_strategy` instead. Whether or not to group the tokens corresponding to | ||
the same entity together in the predictions or not. | ||
aggregation_strategy (:obj:`str`, `optional`, defaults to :obj:`"none"`): The strategy to fuse (or not) tokens based on the model prediction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice docstring
if grouped_entities is not None: | ||
warnings.warn( | ||
f'`grouped_entities` is deprecated, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' | ||
) | ||
if ignore_subwords is not None: | ||
warnings.warn( | ||
f'`ignore_subwords` is deprecated, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice warnings. We can also add is deprecated and will be removed in version v5.0.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if offset_mapping is not None: | ||
start_ind, end_ind = offset_mapping[idx] | ||
word_ref = sentence[start_ind:end_ind] | ||
is_subword = len(word_ref) != len(word) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True) | ||
sentence = """Consuelo Araújo Noguera, ministra de cultura del presidente Andrés Pastrana (1998.2002) fue asesinada por las Farc luego de haber permanecido secuestrada por algunos meses.""" | ||
|
||
nlp_ner = pipeline("ner", model=model, tokenizer=tokenizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're in the process of renaming all nlp
variables to more precise names (for example here token_classifier
), if you want to take the opportunity of that refactor to rename those, that would be great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I've done it !
nlp.model.config.id2label, | ||
{0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"}, | ||
) | ||
example = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example could be put on a single line to improve readability, for example like it's done here:
transformers/tests/test_tokenization_tapas.py
Lines 1188 to 1190 in 113eaa7
# fmt: off | |
expected_results = {'input_ids':[101,2054,2020,1996,6853,3415,1029,102,13433,2015,2053,4062,2136,10876,2051,1013,3394,8370,2685,1015,3590,4754,29267,4765,3771,2136,2447,1005,1055,6584,1015,1024,4466,1024,2340,1012,6185,2509,1015,2570,1016,1015,10391,12022,4226,7895,10625,1013,22996,3868,6584,1009,1014,1012,1022,10819,2015,1016,2459,1017,1017,2703,10555,2136,2447,1005,1055,6584,1009,2654,1012,1020,10819,2015,1017,2403,1018,1023,8709,8183,3126,21351,2078,1010,3781,1012,2136,10958,8865,6584,1009,2871,1012,1022,10819,2015,2410,2260,1019,4090,7986,5292,5677,8151,2771,1011,2990,9187,3868,6584,1009,4413,1012,1015,10819,2015,1020,2184,1020,2322,2030,20282,14262,9035,4754,3868,6584,1009,1015,1024,4002,1012,1016,2184,1022,1021,4868,7918,12023,12023,3868,6584,1009,1015,1024,5890,1012,1018,1019,1020,1022,2260,5261,12436,18116,2137,4382,2136,26447,6584,1009,1015,1024,5890,1012,1022,1022,1019,1023,1021,27339,3995,10125,9711,4906,25101,24657,1011,22033,2386,3868,6564,1009,1015,5001,2321,1018,2184,4583,7986,14383,2075,29488,14906,9351,2971,6564,1009,1015,5001,2340,1017,2340,2676,8527,2014,2696,1052,2243,3868,6564,1009,1015,5001,2260,1016,2260,2861,4575,4477,1011,2128,4710,2137,4382,2136,26447,6564,1009,1015,5001,2459,1015,2410,2539,8963,11503,25457,3022,8512,2522,9654,3868,5594,1009,1016,10876,2324,1014,2403,3943,4074,6415,15204,2072,12496,25378,3868,5594,1009,1016,10876,2403,1014,2321,1018,10704,17921,14906,9351,2971,5594,1009,1016,10876,1023,1014,2385,2340,14915,5795,8512,2522,9654,3868,6640,6228,2539,1014,2459,1016,28328,8945,3126,21351,2015,10625,1013,22996,3868,6255,6228,1018,1014,2324,2321,12270,11956,5232,3868,2260,6228,1021,1014,2539,1019,8473,28027,2080,2474,6371,5232,3868,2184,6228,2385,1014,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],'column_ids':[0,0,0,0,0,0,0,0,1,1,2,3,4,5,6,6,6,7,8,1,2,3,3,3,3,4,4,4,4,5,6,6,6,6,6,6,6,6,7,8,1,2,3,3,3,3,4,4,4,4,5,6,6,6,6,6,6,7,8,1,2,3,3,4,4,4,4,5,6,6,6,6,6,6,7,8,1,2,3,3,3,3,3,3,3,3,4,4,4,5,6,6,6,6,6,6,7,8,1,2,3,3,3,3,4,4,4,4,4,5,6,6,6,6,6,6,7,8,1,2,3,3,3,3,4,4,5,6,6,6,6,6,6,7,8,1,2,3,3,4,4,5,6,6,6,6,6,6,7,8,1,2,3,3,3,4,4,4,4,5,6,6,6,6,6,6,7,8,1,2,3,3,3,3,4,4,4,4,4,4,4,5,6,6,6,7,8,1,2,3,3,3,3,4,4,4,5,6,6,6,7,8,1,2,3,3,3,4,4,4,5,6,6,6,7,8,1,2,3,3,3,3,3,4,4,4,4,5,6,6,6,7,8,1,2,3,3,3,3,4,4,4,4,5,6,6,6,7,8,1,2,3,3,3,3,4,4,4,5,6,6,6,7,8,1,2,3,3,4,4,4,5,6,6,6,7,8,1,2,3,3,4,4,4,4,5,6,7,8,1,2,3,3,3,3,3,4,4,4,4,5,6,7,8,1,2,3,3,4,4,5,6,7,8,1,2,3,3,3,3,3,4,4,5,6,7,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],'row_ids':[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,11,11,11,11,11,11,11,11,11,11,11,11,11,11,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,13,13,13,13,13,13,13,13,13,13,13,13,13,13,13,13,14,14,14,14,14,14,14,14,14,14,14,14,14,14,14,15,15,15,15,15,15,15,15,15,15,15,15,15,16,16,16,16,16,16,16,16,16,16,16,16,17,17,17,17,17,17,17,17,17,17,17,17,17,17,17,18,18,18,18,18,18,18,18,18,18,19,19,19,19,19,19,19,19,19,19,19,19,19,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],'segment_ids':[0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]} # noqa: E231 | |
# fmt: on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect.
) | ||
self._test_pipeline(nlp) | ||
|
||
@require_torch | ||
def test_pt_ignore_subwords_slow_tokenizer_raises(self): | ||
for model_name in self.small_models: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | ||
model_name = self.small_models[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Do you mind giving an example displaying the issue ? I tried this, but I don't think it exhibits what you mention in the original issue: ( #10263 (comment)) NER_MODEL = "elastic/distilbert-base-cased-finetuned-conll03-english"
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
sentence = """Accenture is a company. Max Mustermann is someone, Elasticsearch is something."""
nlp_ner = pipeline("ner", model=model, tokenizer=tokenizer)
output = nlp_ner(sentence)
print(output)
self.assertEqual(
nested_simplify(output),
[
{"entity": "B-PER", "score": 0.9953969, "index": 9, "word": "Max", "start": 24, "end": 27},
{"entity": "I-PER", "score": 0.9773876, "index": 10, "word": "Must", "start": 28, "end": 32},
{"entity": "I-PER", "score": 0.9924896, "index": 11, "word": "##erman", "start": 32, "end": 37},
{"entity": "I-PER", "score": 0.9860034, "index": 12, "word": "##n", "start": 37, "end": 38},
{"entity": "B-ORG", "score": 0.99201995, "index": 16, "word": "El", "start": 51, "end": 53},
{"entity": "B-ORG", "score": 0.99391395, "index": 17, "word": "##astic", "start": 53, "end": 58},
{"entity": "B-ORG", "score": 0.9962443, "index": 18, "word": "##sea", "start": 58, "end": 61},
{"entity": "B-ORG", "score": 0.9924281, "index": 19, "word": "##rch", "start": 61, "end": 64},
],
) |
I think we should just wait for the test with |
@Narsil I've since retrained that model (by labelling all sub-word tokens instead of padding) and it appears to work better with new domain data like "Elasticsearch". The point was more to decouple the fix from a specific model and to be robust to possible outputs of a model particularly for words/sub-words that are out-of-domain for a model and relying on sub-word token classification which can be less predictable (in my experience). This was why I suggested stubbing out the model and just putting in the sub-word labels directly to the aggregator to see if the expected behaviour matches the actual new behaviour. |
@joshdevins Yes, then I think I already added those tests here: def test_aggregation_strategy_example2(self):
model_name = self.small_models[0]
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
# Just to understand scores indexes in this test
self.assertEqual(
nlp.model.config.id2label,
{0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"},
)
example = [
{
# Necessary for AVERAGE
"scores": np.array([0, 0.55, 0, 0.45, 0, 0, 0, 0, 0, 0]),
"is_subword": False,
"index": 1,
"word": "Ra",
"start": 0,
"end": 2,
},
{
"scores": np.array([0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0]),
"is_subword": True,
"word": "##ma",
"start": 2,
"end": 4,
"index": 2,
},
{
# 4th score will have the higher average
# 4th score is B-PER for this model
# It's does not correspond to any of the subtokens.
"scores": np.array([0, 0, 0, 0.4, 0, 0, 0.6, 0, 0, 0]),
"is_subword": True,
"word": "##zotti",
"start": 11,
"end": 13,
"index": 3,
},
]
self.assertEqual(
nlp.aggregate(example, AggregationStrategy.NONE),
[
{"end": 2, "entity": "B-MISC", "score": 0.55, "start": 0, "word": "Ra", "index": 1},
{"end": 4, "entity": "B-LOC", "score": 0.8, "start": 2, "word": "##ma", "index": 2},
{"end": 13, "entity": "I-ORG", "score": 0.6, "start": 11, "word": "##zotti", "index": 3},
],
)
self.assertEqual(
nlp.aggregate(example, AggregationStrategy.FIRST),
[{"entity_group": "MISC", "score": 0.55, "word": "Ramazotti", "start": 0, "end": 13}],
)
self.assertEqual(
nlp.aggregate(example, AggregationStrategy.MAX),
[{"entity_group": "LOC", "score": 0.8, "word": "Ramazotti", "start": 0, "end": 13}],
)
self.assertEqual(
nested_simplify(nlp.aggregate(example, AggregationStrategy.AVERAGE)),
[{"entity_group": "PER", "score": 0.35, "word": "Ramazotti", "start": 0, "end": 13}],
) |
@Narsil Ah cool, I missed those examples in my read-through. LGTM 🎉 |
Co-authored-by: Francesco Rubbo <rubbo.francesco@gmail.com> Co-authored-by: elk-cloner <rezakakhki.rk@gmail.com>
@LysandreJik I changed the co-authors, I'll merge after you check I've done it correctly. |
Looks good to me, feel free to merge! |
…ingface#11680) * [TokenClassification] Label realignment for subword aggregation Tentative to replace https://github.com/huggingface/transformers/pull/11622/files - Added `AggregationStrategy` - `ignore_subwords` and `grouped_entities` arguments are now fused into `aggregation_strategy`. It makes more sense anyway because `ignore_subwords=True` with `grouped_entities=False` did not have a meaning anyway. - Added 2 new ways to aggregate which are MAX, and AVERAGE - AVERAGE requires a bit more information than the others, for now this case is slightly specific, we should keep that in mind for future changes. - Testing has been modified to reflect new argument, and to check the correct deprecation and the new aggregation_strategy. - Put the testing argument and testing results for aggregation_strategy, close together, so that readers can understand what is supposed to happen. - `aggregate` is now only tested on a small model as it does not mean anything to test it globally for all models. - Previous tests are unchanged in desired output. - Added a new test case that showcases better the difference between the FIRST, MAX and AVERAGE strategies. * Wrong framework. * Addressing three issues. 1- Tags might not follow B-, I- convention, so any tag should work now (assumed as B-TAG) 2- Fixed an issue with average that leads to a substantial code change. 3- The testing suite was not checking for the "index" key for "none" strategy. This is now fixed. The issue is that "O" could not be chosen by AVERAGE strategy because those tokens were filtered out beforehand, so their relative scores were not counted in the average. Now filtering on ignore_labels will happen at the very end of the pipeline fixing that issue. It's a bit hard to make sure this stays like that because we do not have a end-to-end test for that behavior * Formatting. * Adding formatting to code + cleaner handling of B-, I- tags. Co-authored-by: Francesco Rubbo <rubbo.francesco@gmail.com> Co-authored-by: elk-cloner <rezakakhki.rk@gmail.com> * Typo. Co-authored-by: Francesco Rubbo <rubbo.francesco@gmail.com> Co-authored-by: elk-cloner <rezakakhki.rk@gmail.com>
What does this PR do?
Tentative to replace #11622
AggregationStrategy
ignore_subwords
andgrouped_entities
arguments are now fusedinto
aggregation_strategy
. It makes more sense anyway becauseignore_subwords=True
withgrouped_entities=False
did not have ameaning anyway.
case is slightly specific, we should keep that in mind for future
changes.
correct deprecation and the new aggregation_strategy.
close together, so that readers can understand what is supposed to
happen.
aggregate
is now only tested on a small model as it does not meananything to test it globally for all models.
FIRST, MAX and AVERAGE strategies.
Fixes #10263, #10763
See also #10568
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@LysandreJik