Skip to content

[Enhancement] Refactor gemm templates #207

@zqh-wz

Description

@zqh-wz

Currently, T.gemm is lowered to gemm_ss/gemm_rs templates in tl_templates/cuda/gemm_smXX.h to leverage cutlass templates of tensor core ptx instructions. The gemm_ss template functions are implemented separately in different header files and dispatched in gemm.h using CUDA_ARCH environment variable.

The current implementation has the following drawbacks:

  • We need to write different files for devices with different compute capabilities.
  • GPUs with higher compute capabilities can naturally execute ptx instructions of lower compute capabilities, which means most code between different files can be reused.
  • The current implementation, with three gemm_smX0.h files, has different implementation logics, making it difficult to maintain and reducing readability.

For example, when attempting to support fp8 tensor core on RTX 4090 (sm89), it is challenging to determine where to add the code because the relevant cutlass templates for fp8 inputs are provided in a separate header file: cutlass/include/cutlass/arch/mma_sm89.h. This file is written in cutlass 2.x style, making it easier to be supported in gemm_sm70.h. However, we still need templates from gemm_sm80.h for fp16 input gemm on sm89 devices.

Considering the above dilemma, I recommend refactoring the gemm templates to align with the structure and design of the cutlass library. Specifically, we should consider reorganizing the current three gemm_smX0.h files into gemm_cute.h, gemm_cutlass.h (or an additional gemm_cute_wgmma.h for Hopper support, as it may require special considerations). These would correspond to the cutlass cute templates and cutlass 2.x templates, respectively. Since most MMA instructions are already supported by cute, this approach could unify the implementation in most cases. For special cases where MMA instructions are not supported by cute, we could implement them in gemm_cutlass.h using the cutlass 2.x style (e.g., fp8xfp8 on sm89). Finally, we need a dispatch mechanism for fallback based on different compute capabilities.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions