-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add is_floating_point()
test and better type support in verify_model_vm()
#7134
Conversation
verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets) | ||
else: | ||
verify_model_vm(script_module, ishapes, targets=targets) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this if/else, just verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets)
even if idtype
is None should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I've cleaned up verify_script_model()
and changed the handling of default arguments in verify_model_vm()
so that passing in None
will result in a torch.float
being used. Let me know if the new approach looks good.
Thanks @TylerADavis |
…el_vm()` (apache#7134) * Add div_ and is_floating_point operators * Add handling of exprs to op, update tests * add test + supporting functions * Revert whitespace changes * Properly assign dtype to random integers * Reformat with black * Switched default dtype logic, removed extra line
…el_vm()` (apache#7134) * Add div_ and is_floating_point operators * Add handling of exprs to op, update tests * add test + supporting functions * Revert whitespace changes * Properly assign dtype to random integers * Reformat with black * Switched default dtype logic, removed extra line
…el_vm()` (apache#7134) * Add div_ and is_floating_point operators * Add handling of exprs to op, update tests * add test + supporting functions * Revert whitespace changes * Properly assign dtype to random integers * Reformat with black * Switched default dtype logic, removed extra line
…el_vm()` (apache#7134) * Add div_ and is_floating_point operators * Add handling of exprs to op, update tests * add test + supporting functions * Revert whitespace changes * Properly assign dtype to random integers * Reformat with black * Switched default dtype logic, removed extra line
…el_vm()` (apache#7134) * Add div_ and is_floating_point operators * Add handling of exprs to op, update tests * add test + supporting functions * Revert whitespace changes * Properly assign dtype to random integers * Reformat with black * Switched default dtype logic, removed extra line
This PR builds upon #7128 , adding a test for
is_floating_point()
following masahi's recommendation to look atverify_script_model()
.In addition to the test, this PR makes the following changes:
verify_script_model()
can now pass a dtype toverify_model_vm()
verify_model_vm()
can now generate random inputs for bool and int dtypes.verify_model_vm()
includes TVM dtype information ininput_shapes
, providing additional type information to operators such asis_floating_point()
.These additional changes were made because
verify_model_vm()
does not currently pass type information torelay.frontend.from_pytorch()
, preventingis_floating_point()
from working correctly. The changes to random tensor generation were required astorch.randn()
does not support bool or int dtypes, and my test requires inputs with these dtypes.If it would be helpful, I can split the addition of the test and the changes to testing infrastructure out into two separate PRs.