-
Notifications
You must be signed in to change notification settings - Fork 333
Description
Currently, T.gemm is lowered to gemm_ss/gemm_rs templates in tl_templates/cuda/gemm_smXX.h to leverage cutlass templates of tensor core ptx instructions. The gemm_ss template functions are implemented separately in different header files and dispatched in gemm.h using CUDA_ARCH environment variable.
The current implementation has the following drawbacks:
- We need to write different files for devices with different compute capabilities.
- GPUs with higher compute capabilities can naturally execute ptx instructions of lower compute capabilities, which means most code between different files can be reused.
- The current implementation, with three
gemm_smX0.hfiles, has different implementation logics, making it difficult to maintain and reducing readability.
For example, when attempting to support fp8 tensor core on RTX 4090 (sm89), it is challenging to determine where to add the code because the relevant cutlass templates for fp8 inputs are provided in a separate header file: cutlass/include/cutlass/arch/mma_sm89.h. This file is written in cutlass 2.x style, making it easier to be supported in gemm_sm70.h. However, we still need templates from gemm_sm80.h for fp16 input gemm on sm89 devices.
Considering the above dilemma, I recommend refactoring the gemm templates to align with the structure and design of the cutlass library. Specifically, we should consider reorganizing the current three gemm_smX0.h files into gemm_cute.h, gemm_cutlass.h (or an additional gemm_cute_wgmma.h for Hopper support, as it may require special considerations). These would correspond to the cutlass cute templates and cutlass 2.x templates, respectively. Since most MMA instructions are already supported by cute, this approach could unify the implementation in most cases. For special cases where MMA instructions are not supported by cute, we could implement them in gemm_cutlass.h using the cutlass 2.x style (e.g., fp8xfp8 on sm89). Finally, we need a dispatch mechanism for fallback based on different compute capabilities.