-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] operator.gt got an object with type <class 'int'> #320
Comments
Hi @GisellWu, Any minimal reproducible example to reproduce the error? |
|
Furthermore, I added some functions to register_functions.py in order to run through T5Model:
|
Hi @GisellWu, I added the missing operators and fixed some bugs in #322 for T5 model. Could you give a try again? |
Thanks for your help! I successfully ran it. Close the issue :) |
Hi @yaoyaoding , Sorry to bother you again, I’m trying T5Model with float16. There are some new unsupported functions. Could you please help me fix it? The error is:
And the example code is:
|
Hi @GisellWu, I added the missing operators in #343, could you give it a try? Thanks! |
Hi @yaoyaoding, That's ok. Appreciate your help again! |
…immutable_list` (#320) As discussed on our Slack channel, the compilation of the model `dalle2_pytorch` caused a Hidet exception triggered by `expand`: > RuntimeError: 'immutable_list' object does not support mutation. If you are attempting to modify the kwargs or args of a torch.fx.Node object, > instead create a new copy of it and assign the copy to the node: > new_args = ... # copy and mutate args > node.args = new_args > , occurred when interpreting expand with > tensor_expand(tensor(...), [-1, 2, -1]) This PR fixes this bug.
…immutable_list` (#320) As discussed on our Slack channel, the compilation of the model `dalle2_pytorch` caused a Hidet exception triggered by `expand`: > RuntimeError: 'immutable_list' object does not support mutation. If you are attempting to modify the kwargs or args of a torch.fx.Node object, > instead create a new copy of it and assign the copy to the node: > new_args = ... # copy and mutate args > node.args = new_args > , occurred when interpreting expand with > tensor_expand(tensor(...), [-1, 2, -1]) This PR fixes this bug.
…immutable_list` (#320) As discussed on our Slack channel, the compilation of the model `dalle2_pytorch` caused a Hidet exception triggered by `expand`: > RuntimeError: 'immutable_list' object does not support mutation. If you are attempting to modify the kwargs or args of a torch.fx.Node object, > instead create a new copy of it and assign the copy to the node: > new_args = ... # copy and mutate args > node.args = new_args > , occurred when interpreting expand with > tensor_expand(tensor(...), [-1, 2, -1]) This PR fixes this bug.
Hi @yaoyaoding, I encounter a bug when I run T5Model with hidet:
Could you please help me fix it?
The text was updated successfully, but these errors were encountered: