Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[add] support for half (GPU only) #6

Open
wants to merge 11 commits into
base: espnet_v1.1
Choose a base branch
from

Conversation

mangel9742
Copy link

Support added only for GPU use.
The half type used is from Cuda Math library and some of the functions are only usable for GPU.
Most of the work consists to specialize templates in order to avoid ambiguous calls.

Copy link
Owner

@b-flo b-flo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking that work!

As we talked off-the-list, the following changes are needed:

  • torch bindings are WIP/borked
  • remove __host__ from HOSTDEVICE macro and re-use it.
  • Subsequently, log_sum_exp from rnnt_helper can be re-used from that.

Also, more a general note: This PR drop support for Kepler and Maxwell archs. I'm not against it but I have to confirm it won't affect too many people (unlikely though).

@mangel9742
Copy link
Author

Thank you for your feedback.
I've proceeded to the changes needed, and corrected the torch bindings.
Do not hesitate if you have further remarks.

Copy link
Owner

@b-flo b-flo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay! Compilation is OK but import will fail with the following error:

ImportError: /home/b-flo/warp-transducer/pytorch_binding/warprnnt_pytorch/warp_rnnt.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK2at10TensorBase8data_ptrI6__halfEEPT_vI

I'll let you figure it out but feel free to ask off the list for any help/hints on resolving this error!

P.S.: Postponing the full review until it's in a stable state.

CMakeLists.txt Outdated
Comment on lines 47 to 54
# Drop support for old GPU to use CUDA math library functions
IF(NOT (CUDA_VERSION GREATER 10.2))
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30 -O2")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50")
# set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30 -O2")
# set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35")
# set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50")
ENDIF()

set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52")
#set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, you can remove the condition + commented lines

@mangel9742 mangel9742 marked this pull request as draft July 7, 2023 07:38
@mangel9742 mangel9742 force-pushed the espnet_v1.1_half_support branch from a1bb689 to 77ae1f6 Compare July 7, 2023 13:39
@mangel9742 mangel9742 force-pushed the espnet_v1.1_half_support branch from 77ae1f6 to 7f48e85 Compare July 7, 2023 13:48
@mangel9742 mangel9742 marked this pull request as ready for review July 7, 2023 13:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants