-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support of float dtypes for draw_segmentation_masks #8150
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8150
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 3dce586 with merge base c35d385 (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for submitting this PR @GsnMithra
This is promising, but allowing support for float will require a little bit more work than just converting the dtype using .to()
, since the convention for float images is that their valuae range is in [0, 1] instead of [0, 255].
Perhaps the easiest way to make progress here will be to focus on draw_segmentation_masks
first, and to add a small unit test in https://github.com/pytorch/vision/blob/main/test/test_utils.py, that checks for equality of result for a given image with different dtypes. Something roughly like this:
from torchvision.transforms.v2.functional import to_dtype
img_uint8 = torch.randing(0, 256, (3, 100, 100), dtype=torch.uint8)
img_float = to_dtype(img_uint8, torch.float32, scale=True)
out_uint8 = draw_segmentation_mask(img_uint8, ...)
out_float = draw_segmentation_mask(img_float, ...)
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True))
Does that make sense?
torchvision/utils.py
Outdated
elif image.dtype not in {torch.uint8, torch.float32}: | ||
raise ValueError(f"The image dtype must be uint8 or float32, got {image.dtype}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and in the other places, let's just check for image.is_floating_point()
instead of checking specifically for float32
. This way we can also support float64, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for replying, this will be added with the upcoming commits.
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Hey @NicolasHug Just dropping a note to mention that I've added some assertion checks for floating-point data types in the unit test method for draw_segmentation_mask() with the recent commit titled "test_draw_sementation_mask". I would love to get your thoughts on it. Thank you for your time. |
Thanks a lot @GsnMithra . It looks like there are a few things missing for now: the import of Also, I think it would be best not to modify the current tests, and instead just create a separate new test for the one suggested in #8150 (review). Thank you! |
Hey @NicolasHug I would like to sincerely apologize for the mistakes I made in my previous contributions. I am still in the learning process and appreciate your guidance. In the latest commit, I have included a new unit test called test_draw_segmentation_masks_dtypes, which has been implemented according to the suggestions provided earlier. I would greatly appreciate your thoughts and feedback on this addition. Once again, I apologize for any inconvenience caused and thank you for taking the time to review my work. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the follow-up @GsnMithra, and no problems for the first few hicups.
I can see the new test is passing now, good job! In order to move forward with this PR, I would suggest the following:
- Address the comments I made below (it should be pretty easy)
- Remove the changes that were made to
draw_bounding_boxes()
anddraw_keypoints()
, and only keep the changes and tests relating todraw_segmentation_masks()
. This way, we'll be able to merge this PR straight away, and then you can send 2 separate PR (one for boxes, one for keypoints), should you wish to.
Does that sound good to you?
test/test_utils.py
Outdated
import torchvision.transforms.functional as F | ||
from torchvision.transforms.v2.functional import to_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to re-order the imports, check out https://github.com/pytorch/vision/actions/runs/7214131414/job/19676695616?pr=8150, or run the pre-commit hooks locally (check our contributing instructions)
torchvision/utils.py
Outdated
if image.is_floating_point(): | ||
image = (image * 255).to(torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use to_dtype
here instead, like you did in the test.
torchvision/utils.py
Outdated
if original_dtype in {torch.float16, torch.float32, torch.float64}: | ||
out = out.float() / 255.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same: Let's use to_dtype
here instead, like you did in the test.
torchvision/utils.py
Outdated
@@ -315,7 +318,10 @@ def draw_segmentation_masks( | |||
img_to_draw[:, mask] = color[:, None] | |||
|
|||
out = image * (1 - alpha) + img_to_draw * alpha | |||
return out.to(out_dtype) | |||
if original_dtype in {torch.float16, torch.float32, torch.float64}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check for is_floating_point()
instead.
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Hello @NicolasHug I hope this message finds you well. I wanted to bring up a couple of points regarding the usage of Firstly, when importing Secondly, I've encountered a situation where, when a float dtype image is passed, converting it back to float using the following code results in mismatched elements for the unit test:
It seems to work only when dividing by 255.0:
I appreciate your insights and look forward to any further guidance or suggestions you may have. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the follow-up @GsnMithra !
You are right about the need for /= 255
. It was because out
was a float32 image in [0, 255] when the input was a float image. So even doing dtype(scale=True)
would not work as expected, as no conversion was done (since the image was already float32).
I found a simpler way to handle all this by simply converting the color's dtype and scale, instead of converting the input image. I pushed the changes and also reverted some minor things from the other functions.
I'll merge this PR when green, thanks a ton for your work on this!
(Feel free to submit follow-up PRs for the other functions if you wish to. Just giving you a heads-up that I will only be able to review them starting next year in Jan, as I'll be on leaves from tomorrow.)
Thanks again @GsnMithra !
Hey @NicolasHug Wishing you a wonderful break ;) |
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Summary: Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Reviewed By: vmoens Differential Revision: D52539003 fbshipit-source-id: e7b9412a496e88749dc6e9c5afdd1b5cf85b4aa0
Fixes: #8138
Hey there!
I've implemented support for draw_* methods to seamlessly handle both uint8 and float32 image types. The processed image will now be returned with the same data type as the input. Your feedback is always appreciated.
Thank you!