-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Contrib] Support fp16 input in cpu sort #8672
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just a nit.
@@ -25,15 +25,14 @@ | |||
|
|||
@tvm.testing.uses_gpu | |||
def test_sort(): | |||
def verify_sort(shape, axis, is_ascend, is_dyn=False): | |||
|
|||
def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also do a small refactoring for this file:
- Use
@pytest.mark.parametrize("in_dtype", ["float32", "float16"])
in each unit test. - Let pytest collect tests in the main function:
if __name__ == "__main__":
pytest.main([__file__])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't that make the testing time double? This file is fairly slow to test (more than 3 min according to https://ci.tlcpack.ai/job/tvm/job/main/1384/testReport/ctypes.tests.python.relay/test_op_level6/) and I don't think fp16 tests need to run as often as fp32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm that's a fair concern.
Thanks @masahi |
A solution for the issue discussed in #8296 (comment)
With this change, PT Faster RCNN can be converted to fp16. Converting MaskRCNN requires support for ADT (tensor array) in the mixed precision pass.
@comaniac @AndrewZhaoLuo