-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorFlow MobileViT #18555
TensorFlow MobileViT #18555
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Co-authored-by: Amy <aeroberts4444@gmail.com>
Coo-authored-by: Yih <2521628+ydshieh@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the TF implementation of MobileViT 🙌 The TFLite demo is great, especially because it should be covered by our doctests 🚀
In addition to the comments throughout the code, I have the following notes:
- I will share the instructions to open the Hub PR when this PR is approved by all (everyone has permission to do it now 🎉 )
- The
training
argument is missing in the layer'scall
(and in places like thedropout
calls)
Didn't realize that re-requesting a review from @gante would result in removing @amyeroberts and @sgugger from the reviewer list. Please know that it was completely unintentional. |
@sayakpaul no worries :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the changes 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 📱
patch_width, patch_height = self.patch_width, self.patch_height | ||
patch_area = tf.cast(patch_width * patch_height, "int32") | ||
|
||
batch_size, orig_height, orig_width, channels = shape_list(features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_size, orig_height, orig_width, channels = shape_list(features) | |
batch_size, orig_height, orig_width, channels = tf.shape(features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having it in one line leads to:
OperatorNotAllowedInGraphError: Iterating over a symbolic
tf.Tensor
is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
That's why I separated it.
def test_attention_outputs(self): | ||
pass | ||
|
||
@unittest.skip("Test was written for TF 1.x and isn't really relevant here") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is the case - should it even be in the test suite? cc @gante
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, me and @Rocketknight1 talked about it a few weeks ago. We should remove this test, it's heavy and the only new thing it tests is that we can build a functional TF model with the model class (which it's kinda obvious we can)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume it will be phased out in a separate PR from the main TF testing suite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ofc, as a separate PR 👍 Leave it be as it is in this PR :)
Thanks for another great model addition @sayakpaul ! |
@sayakpaul assuming it is passing the slow tests, it is ready for the TF weights. The super complex instructions to do it are as follows:
|
Super simple (complex?) question: What is the format of |
The same as the model name on the hub, e.g. this model would be P.S.: I edited the comment above with a 3rd step :D |
This might be due to the |
Very well! If we need to deal with the inconsistencies between |
@gante WDYT? |
@sayakpaul regarding the PR, all good on my end, but we still need approval from @sgugger :D As for the |
On it, sir! |
Please take note of the changes in 32cfd30. Initially, when I tested TFLite conversion it didn't require any spec for SELECT operations but now they're failing with a specification for the SELECT ops. What is more surprising is that the TFLite interpreter is treating |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your PR! Left a couple of nits then we can merge this.
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
(retriggered failing job, seems like a spurious failure) |
Yeah probably nothing related to the PR? |
The build doc job failure is not spurious. There seems to be a problem with an example bloc introduced by this PR. |
Let me see if removing comments from the example block does the trick. Because when the job wasn't failing the example block didn't have any comments. |
No, it didn't help :( Any suggestions to try out? |
You can use the following code to convert a MobileViT checkpoint (be it image classification or semantic segmentation) to generate a | ||
TensorFlow Lite model: | ||
|
||
```py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could it be because the example is indented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it, for some reason. The failure seems to disappear locally when I remove it. In any case its place is probably closer to the TF models doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the culprit it seems :3
My bad :D read the failure bottom to top, so I didn't notice the |
* initial implementation. * add: working model till image classification. * add: initial implementation that passes intg tests. Co-authored-by: Amy <aeroberts4444@gmail.com> * chore: formatting. * add: tests (still breaking because of config mismatch). Coo-authored-by: Yih <2521628+ydshieh@users.noreply.github.com> * add: corrected tests and remaning changes. * fix code style and repo consistency. * address PR comments. * address Amy's comments. * chore: remove from_pt argument. * chore: add full-stop. * fix: TFLite model conversion in the doc. * Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply formatting. * chore: remove comments from the example block. * remove identation in the example. Co-authored-by: Amy <aeroberts4444@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This PR implements the MobileViT model in TensorFlow.
Interesting points
TODOs
from_pt
wherever needed.@amyeroberts @gante @sgugger up for review!