Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Will hidet launch all cuda kernel on the same cudaStream? #388

Closed
VincentXWD opened this issue Dec 7, 2023 · 2 comments
Closed

Will hidet launch all cuda kernel on the same cudaStream? #388

VincentXWD opened this issue Dec 7, 2023 · 2 comments

Comments

@VincentXWD
Copy link

Hi, I noticed in generated cuda kernel has launch methods like this:

void launch(float * __restrict__ b, float * __restrict__ y, float * __restrict__ data, float * __restrict__ c) {
  batch_matmul_kernel<<<dim3(196, 1, 1), dim3(64, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(b, data, y, c);
  {cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << cudaGetErrorString(err) << "\n";}
}

And the cudaStream paramater is set to (cudaStream_t)get_cuda_stream(). This is returned by a singleton class CudaContext::global(). Is this for all kernels without changing(or arranging)?

Thanks.

@yaoyaoding
Copy link
Member

Hi @VincentXWD,

Yes, it is for all cuda kernels. Hidet comes with two libraries: libhidet.so and libhidet_runtime.so (see also here). The former is used to implement the compilation part (if we find anything that is efficient to implement in python, we can implement in C++). The latter one implements the runtime to support the hidet's compiled model (e.g., such as the current cuda stream the next cuda kernel should be launched on: get_cuda_stream()).

We can use this api to change the current cuda stream to launch the kernels.

@VincentXWD
Copy link
Author

I see. It seems that it is determined by users. Thanks @yaoyaoding .

vadiklyutiy pushed a commit that referenced this issue Dec 19, 2024
Cast return value of `get_parallel_num_workers` to float. Fix #388
vadiklyutiy pushed a commit that referenced this issue Dec 20, 2024
Cast return value of `get_parallel_num_workers` to float. Fix #388
vadiklyutiy pushed a commit that referenced this issue Dec 26, 2024
Cast return value of `get_parallel_num_workers` to float. Fix #388
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants