-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
infer entailment label id on zero shot pipeline #8059
infer entailment label id on zero shot pipeline #8059
Conversation
Wouldn't it be better to:
I feel this PR is adding a feature that works around an issue when we should be fixing the root issue (also fixes the displayed labels, other future features, etc.). Wdyt? |
@julien-c Well, to be clear this PR
If I understand correctly, your issue is just with (2). I think it's a fair point. I don't think we'll be able to ensure that all NLI models have a clearly defined label mapping though. But instead of an override arg, I think it might be better to just add a warning if the entailment label ID can't be found in the config. |
@joeddav Ok yes 1/ is great.
Why not? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Do you mind running make fixup
to patch the style/quality issues?
@julien-c Just because there are almost 100 results for "NLI" on the model hub and I'd guess from a quick sampling that the majority don't have a label mapping defined. For each model we'd have to figure out which label is which, which would mean either getting the author to look it up and tell us or else running tests on the correct dataset to figure it ourselves. Do you think it'd be worthwhile to warn the user when uploading or creating configs with generic/missing label mappings (or with any other important fields missing) going forward? Defining a label2id seems like a rather obscure property that I would assume is purely cosmetic if I were uploading a model, i.e. I wouldn't expect it to actually impact code behavior for someone using my model. |
* add entailment dim argument * rename dim -> id * fix last name change, style * rm arg, auto-infer only * typo * rm superfluous import
What does this PR do?
Adds an optional argument to the zero shot pipeline constructor to specify the label id of the NLI model that corresponds to "entailment", which it needs to calculate each candidate label's score. Most models in the hub use the last last label id, but some differ (e.g. the recent ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli).
If the argument is not passed, the pipeline will attempt to look up the entailment dimension in the model config's id2label mapping. If the config does not specify the entailment dimension, the value will be set to
-1
, indicating the last dimension of the model output.With this logic in place, the arg only needs to be passed when both (1) the model's entailment label id is not the last id and (2) when the model config's
label2id
doesn't specify the entailment id.