-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Some hidet tensor methods do not support symbolic tensors? #213
Comments
Hi @eric8607242, Thanks for bringing this up. We have partially fixed this issue in #214. With this PR, we can run your example: import torch
from torch import nn
import hidet
class TestMode(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Linear(10, 10)
def forward(self, x):
z = x.unsqueeze(0).expand(4, 4, 512).to(torch.device("cuda"))
return z
if __name__ == "__main__":
model = TestMode()
model = model.eval().half()
device = torch.device("cpu")
model = model.to(device)
hidet.torch.dynamo_config.search_space(2)
hidet.torch.dynamo_config.use_fp16()
model_opt = torch.compile(model, backend='hidet')
tokens = torch.zeros(4, 512).cuda()
model_opt(tokens) The limitation is: for the tensor that is dependent on the model input (e.g., See the tests for more examples of what is supported and not. |
Hi @yaoyaoding, Thanks for your kindful response and quick fix. It is very helpful. Sorry for two more silly questions. |
Hi @eric8607242, Yes, it is as what you said and you asked a good question. This is a temparary limitation of our current IR and runtime system. The direct reason is that we do not have an operator like "to_device". We currently do not have a C++ runtime, but replies on CUDA graph to get rid of the framework-level overhead. It is not trivial to track both CPU kernel and GPU kernels in the same CUDA graph. So, before we have an efficient C++ runtime, we will not support the feature to mix kernels on cpu and gpu in a single computation graph. Of course, if there are some important DNNs that reply on this feature, we would like to give it a higher priority. Currently, we are focusing on dynamic shape support. |
Hi @yaoyaoding, Thanks for the very clear answer. Thanks for this amazing work again. |
Fixing all issues mentioned in: CentML/hidet#212 [pull](CentML/hidet#209) should me merged before this pull in order to avoid issues with GPU memory Now mistral 7b can be compiled with hidet and it takes around 23 GiB of GPU memory ### Copy-pasting the issue description here for more clarity: **Describe the bug** Mistral 7b model had crashes due to missing `torch.all` (if transformers==4.37) and `torch.ones_like` (if transformers==4.41). Added those ops. Hidet's `reduction kernel` for `torch.all` with boolean tensor generates a kernel with statements: ``` rv = (rv && shfl_down_sync(mask, rv, 16, 32)) ``` However, ***nvcc*** optimizes that into something like: ``` if rv is true: rv = shfl_down_sync(mask, rv, 16, 32) ``` Above ***nvcc*** optmization causes some threads to jump over `shfl_down_sync` instruction, while some other threads of the same warp call `shfl_down_sync`. Hence GPU kernel hangs. Fixed by swapping the order as below: ``` rv = (shfl_down_sync(mask, rv, 16, 32) && rv) ``` **To Reproduce** This script is used to test mistral model: https://drive.google.com/file/d/11ovSzoiHGG2f_qWucwoRxhRjCRRuCH72/view?usp=sharing **Expected behavior** Compiled Mistral7b model with hidet **Enviroment** - OS: Ubuntu 22.04 - GPU: RTX 3090 with 24 GiB memory --------- Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
Fixing all issues mentioned in: CentML/hidet#212 [pull](CentML/hidet#209) should me merged before this pull in order to avoid issues with GPU memory Now mistral 7b can be compiled with hidet and it takes around 23 GiB of GPU memory ### Copy-pasting the issue description here for more clarity: **Describe the bug** Mistral 7b model had crashes due to missing `torch.all` (if transformers==4.37) and `torch.ones_like` (if transformers==4.41). Added those ops. Hidet's `reduction kernel` for `torch.all` with boolean tensor generates a kernel with statements: ``` rv = (rv && shfl_down_sync(mask, rv, 16, 32)) ``` However, ***nvcc*** optimizes that into something like: ``` if rv is true: rv = shfl_down_sync(mask, rv, 16, 32) ``` Above ***nvcc*** optmization causes some threads to jump over `shfl_down_sync` instruction, while some other threads of the same warp call `shfl_down_sync`. Hence GPU kernel hangs. Fixed by swapping the order as below: ``` rv = (shfl_down_sync(mask, rv, 16, 32) && rv) ``` **To Reproduce** This script is used to test mistral model: https://drive.google.com/file/d/11ovSzoiHGG2f_qWucwoRxhRjCRRuCH72/view?usp=sharing **Expected behavior** Compiled Mistral7b model with hidet **Enviroment** - OS: Ubuntu 22.04 - GPU: RTX 3090 with 24 GiB memory --------- Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
Fixing all issues mentioned in: CentML/hidet#212 [pull](CentML/hidet#209) should me merged before this pull in order to avoid issues with GPU memory Now mistral 7b can be compiled with hidet and it takes around 23 GiB of GPU memory ### Copy-pasting the issue description here for more clarity: **Describe the bug** Mistral 7b model had crashes due to missing `torch.all` (if transformers==4.37) and `torch.ones_like` (if transformers==4.41). Added those ops. Hidet's `reduction kernel` for `torch.all` with boolean tensor generates a kernel with statements: ``` rv = (rv && shfl_down_sync(mask, rv, 16, 32)) ``` However, ***nvcc*** optimizes that into something like: ``` if rv is true: rv = shfl_down_sync(mask, rv, 16, 32) ``` Above ***nvcc*** optmization causes some threads to jump over `shfl_down_sync` instruction, while some other threads of the same warp call `shfl_down_sync`. Hence GPU kernel hangs. Fixed by swapping the order as below: ``` rv = (shfl_down_sync(mask, rv, 16, 32) && rv) ``` **To Reproduce** This script is used to test mistral model: https://drive.google.com/file/d/11ovSzoiHGG2f_qWucwoRxhRjCRRuCH72/view?usp=sharing **Expected behavior** Compiled Mistral7b model with hidet **Enviroment** - OS: Ubuntu 22.04 - GPU: RTX 3090 with 24 GiB memory --------- Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
Hi, thanks for the great work!
I am wondering why some hidet tensor methods (e.g.,
to
,cuda
, andcpu
) do not support symbolic tensors.In the above test case, the exception
NotImplementedError: hidet: Tensor.to(..., device=...) is not supported for symbolic tensors., occurred when calling tensor_to(Tensor(shape=(4, 4, 512), dtype='bool', device='cuda:0'), device(type='cuda'))
is raised.I think the operation (
.to(device)
) is a common operation for deep learning models as the implementation of huggingface llamaAre there any concerns or limitations regarding these operations for symbolic trace?
Look forward to your response. Thanks!
The text was updated successfully, but these errors were encountered: