-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling TF
on image-classification
pipeline.
#15030
Conversation
probs = model_outputs.logits.softmax(-1)[0] | ||
scores, ids = probs.topk(top_k) | ||
|
||
if self.framework == "pt": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make more sense to make the post-processing framework-agnostic? similar to the text-classification
pipeline?
def softmax(_outputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no top_k
with a nice API on numpy
so I would stay away from native
-> np
-> native
or have the ugly topk
version in np
.
Overall this looks to me as the least worst of all codes.
Do you know a clean np
version of topk
(would indeed make things cleaner)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one can do a top_k using np.argpartition
and/or np.argsort
, do you mean that there's no np.topk
when you say there's no clean version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean np.argsort/np.argpartition
is not as clean as torch.top_k
.
Same complexity.top_k
can be done in O(n + k log(k))
where sorting is O(n + k log(k))
with numpy apparently
ind = np.argpartition(a, -top_k)[-top_k:]
indices = ind[np.argsort(a[ind])]
values = a[indices]
is also pretty unreadable.
I don't think having a single if framework == "tf"
in the decoding is that bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, but if we want to enable pipelines for jax
or re-use parts of it in optimum
i think we should try to have an aligned way of doing things and not use for one pipeline framework-specific (pt, tf)processing and for another pipeline not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok then LGTM
Thanks for working on it @Narsil
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@philschmid
@LysandreJik