-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dtype support for SegmentAnythingModel #2207
Fix dtype support for SegmentAnythingModel #2207
Conversation
Thanks for the PR! LGTM! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
def test_end_to_end_model_predict(self, dtype_policy): | ||
import threading | ||
|
||
with threading.Lock(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's with this? are we running our cv testing multi-processed ever?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be multi-processed with the -n <num_threads>
argumment in pytest
. PyTest uses multi-processing and not multi-threading so locking should not be necessary here. I just added this as a safeguard if anyone ever tries to run these tests using Python threads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long term, we could move towards Model(dtype=policy)
support, so that these tests can run effectively without mutating global state.
# Check the number of parameters | ||
num_parameters = np.sum([np.prod(x.shape) for x in model.weights]) | ||
self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340) | ||
@parameterized.named_parameters( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this test marked as large? (just for my own learning)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model initialized here is a ViT Base model with 130M parameters. Creating and evaluating it takes about 15-20 seconds which is significantly more than small unit tests in KerasCV.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, thanks! No need on this PR, but in general it will be good to separate small checks (like dtype stuff) into fast running tests, and keep the large test only for the things that must inherently by large parameter count and slow (like preset tests).
Did a big rewrite of KerasNLP backbones to this effect a bit ago. e.g. https://github.com/keras-team/keras-nlp/blob/a05f411a27eab437e71a1651c97e9addf26298ef/keras_nlp/models/bert/bert_backbone_test.py#L38-L80
@divyashreepathihalli This is ready from my side. Feel free to merge if everything looks good to you now! |
* Fix dtype support for SAM * Update keras_cv/models/segmentation/segment_anything/sam_test.py * Fix Keras 2 failures * Fix F401 lint error; remove unused import
* Fix Keras 3 version check (#2191) * Fix Keras 3 version check * Fix Keras 3 version check * Fix Keras 3 version check * Raise error if Keras is not compatible with TF * Fix bug when upranking passthrough inputs to RandAugment (#2194) - RandAugment sometimes will choose a "no augmentation" option and passthrough inputs unaltered. - Preprocessing normalization routines were not making copies of inputs and sometimes mutating layer input directly (mutating the input dict to cast dtypes and uprank tensors). - RandAugment under the passthrough option would return these inputs directly. The net effect was sometimes attempting to uprank during a passthrough call, breaking tf.map_fn * fix stable diffusion rank error (#2208) * Simplify running KerasCV with Keras 3 (#2179) * remove keras_core dependency * update init * update readme * fix model None error (#2176) (#2177) * Update pycoco_callback.py * Update waymo_evaluation_callback.py * fix model None error (#2176) (#2178) * Update pycoco_callback.py * Update waymo_evaluation_callback.py * update readme and conftest * update readme * update citation list * fix mix transformer tests * fix lint error * fix all failing tests * Fix dtype support for SegmentAnythingModel (#2207) * Fix dtype support for SAM * Update keras_cv/models/segmentation/segment_anything/sam_test.py * Fix Keras 2 failures * Fix F401 lint error; remove unused import * Version bump to r0.7.2.dev0 --------- Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Co-authored-by: Divyashree Sreepathihalli <divyashreepathihalli@gmail.com> Co-authored-by: Tirth Patel <tirthasheshpatel@gmail.com>
* Fix dtype support for SAM * Update keras_cv/models/segmentation/segment_anything/sam_test.py * Fix Keras 2 failures * Fix F401 lint error; remove unused import
What does this PR do?
Segment Anything model can use
keras.mixed_precision.set_dype_policy
for quick optimizations. This PR fixed the model so that it can be run with any dtype policy set. Also added a test forfloat32
(default in keras),mixed_float16
, andbfloat16
.Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.