-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[add] support for half (GPU only) #6
base: espnet_v1.1
Are you sure you want to change the base?
[add] support for half (GPU only) #6
Conversation
[fix] pytorch binding case torch::ScalarType::Half
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking that work!
As we talked off-the-list, the following changes are needed:
- torch bindings are WIP/borked
- remove
__host__
fromHOSTDEVICE
macro and re-use it. - Subsequently,
log_sum_exp
fromrnnt_helper
can be re-used from that.
Also, more a general note: This PR drop support for Kepler and Maxwell archs. I'm not against it but I have to confirm it won't affect too many people (unlikely though).
Thank you for your feedback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay! Compilation is OK but import will fail with the following error:
ImportError: /home/b-flo/warp-transducer/pytorch_binding/warprnnt_pytorch/warp_rnnt.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK2at10TensorBase8data_ptrI6__halfEEPT_vI
I'll let you figure it out but feel free to ask off the list for any help/hints on resolving this error!
P.S.: Postponing the full review until it's in a stable state.
CMakeLists.txt
Outdated
# Drop support for old GPU to use CUDA math library functions | ||
IF(NOT (CUDA_VERSION GREATER 10.2)) | ||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30 -O2") | ||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35") | ||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50") | ||
# set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30 -O2") | ||
# set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35") | ||
# set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50") | ||
ENDIF() | ||
|
||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") | ||
#set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, you can remove the condition + commented lines
a1bb689
to
77ae1f6
Compare
77ae1f6
to
7f48e85
Compare
Support added only for GPU use.
The half type used is from Cuda Math library and some of the functions are only usable for GPU.
Most of the work consists to specialize templates in order to avoid ambiguous calls.