Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] No torch.sqrt support in Hidet ? #386

Closed
xxzh12 opened this issue Dec 5, 2023 · 2 comments
Closed

[Feature] No torch.sqrt support in Hidet ? #386

xxzh12 opened this issue Dec 5, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@xxzh12
Copy link

xxzh12 commented Dec 5, 2023

I'm trying to optimize a SelfAttention module, but there is no support for torch.sqrt function. The code is as follows:

hidet.option.cache_dir('./outs/cache')
model = SelfAttention(num_attention_heads = 12, input_size = 768, hidden_size = 768, attention_probs_dropout_prob = 0.5, hidden_dropout_prob = 0.5).cuda().eval()
x = torch.rand(1, 128, 768).cuda()
# print(model)
model_opt = torch.compile(model, backend='hidet')  
y = model_opt(x)

where I use

x = (x - u) / torch.sqrt(s + self.variance_epsilon)

in LayNorm module
The error information is as follows:

The following modules/functions are not supported by hidet yet: torch.sqrt

I'm wondering if there is any method to support torch.sqrt function. I noticed that there is relevant abtraction in ir for sqrt function. However, the sqrt function in hidet\python\hidet\ir\primitives\math.py presents raise NotImplementedError().

@xxzh12 xxzh12 added the enhancement New feature or request label Dec 5, 2023
@xxzh12 xxzh12 changed the title [Feature] No [Feature] No torch.sqrt support in Hidet ? Dec 5, 2023
@yaoyaoding
Copy link
Member

Hi @xxzh12,

#387 adds the operator mapping for torch.sqrt. I do not have the defintion of SelfAttention thus I did not test on your use case. Feel free to open another issue if there are other operators are not mapped. Thanks.

@xxzh12
Copy link
Author

xxzh12 commented Dec 8, 2023

Hi @yaoyaoding,

Thanks for your kind reply! I will have a try.

@xxzh12 xxzh12 closed this as completed Dec 8, 2023
vadiklyutiy pushed a commit that referenced this issue Dec 19, 2024
In gpt-neo model (related issue:
CentML/hidet#338) torch.where accepts tensors
with different dtypes. Added type casting to fix the above issue.

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
vadiklyutiy pushed a commit that referenced this issue Dec 20, 2024
In gpt-neo model (related issue:
CentML/hidet#338) torch.where accepts tensors
with different dtypes. Added type casting to fix the above issue.

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
vadiklyutiy pushed a commit that referenced this issue Dec 26, 2024
In gpt-neo model (related issue:
CentML/hidet#338) torch.where accepts tensors
with different dtypes. Added type casting to fix the above issue.

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants