diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh index 7c164414..07e60d97 100644 --- a/include/flashinfer/gemm/group_gemm_sm90.cuh +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -16,8 +16,6 @@ #ifndef FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ #define FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ -#include - #include "../allocator.h" #include "../cutlass_utils.cuh" #include "../utils.cuh" diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 046fe36f..2069ccd4 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -24,7 +24,7 @@ from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops from .utils import ( _get_cache_buf, - get_compute_capability, + determine_gemm_backend, get_cuda_stream, get_indptr, register_custom_op, @@ -480,7 +480,9 @@ class SegmentGEMMWrapper: True """ - def __init__(self, float_workspace_buffer: torch.Tensor) -> None: + def __init__( + self, float_workspace_buffer: torch.Tensor, backend: str = "auto" + ) -> None: r"""Initialize the wrapper. Parameters @@ -493,6 +495,7 @@ def __init__(self, float_workspace_buffer: torch.Tensor) -> None: (1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.device ) self._float_workspace_buffer = float_workspace_buffer + self.backend = backend def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor @@ -584,75 +587,82 @@ def run( if weight_indices is None: # create an empty CPU tensor as placeholder weight_indices = torch.empty(0, dtype=torch.int64) - major, _ = get_compute_capability(x.device) cumulative_batch_size = x.size(0) d_out = weights.size(1) if weight_column_major else weights.size(2) y = torch.zeros((cumulative_batch_size, d_out), dtype=x.dtype, device=x.device) empty_x_data = torch.empty(0, dtype=x.dtype, device=x.device) - if major >= 9: - ( - all_problems, - x_data, - w_data, - y_data, - x_stride_data, - w_stride_data, - y_stride_data, - ) = launch_compute_sm90_group_gemm_args( - x, - weights, - y, - weight_column_major, - batch_size, - seg_indptr, - weight_indices, - ) - get_gemm_sm90_module().cutlass_segment_gemm_sm90( - self._float_workspace_buffer, - self._int_workspace_buffer, - all_problems, - x_data, - w_data, - y_data, - x_stride_data, - w_stride_data, - y_stride_data, - y, # for torch compile mutates_args - empty_x_data, # for kernel type dispatch - weight_column_major, - ) + if self.backend == "auto": + backend = determine_gemm_backend(x.device) else: - ( - all_problems, - x_data, - w_data, - y_data, - x_ld_data, - w_ld_data, - y_ld_data, - ) = launch_compute_sm80_group_gemm_args( - x, - weights, - y, - weight_column_major, - batch_size, - seg_indptr, - weight_indices, - ) - get_gemm_module().cutlass_segment_gemm( - self._int_workspace_buffer, - all_problems, - x_data, - w_data, - y_data, - x_ld_data, - w_ld_data, - y_ld_data, - y, - empty_x_data, - weight_column_major, - ) + backend = self.backend + + match backend: + case "sm90": + ( + all_problems, + x_data, + w_data, + y_data, + x_stride_data, + w_stride_data, + y_stride_data, + ) = launch_compute_sm90_group_gemm_args( + x, + weights, + y, + weight_column_major, + batch_size, + seg_indptr, + weight_indices, + ) + get_gemm_sm90_module().cutlass_segment_gemm_sm90( + self._float_workspace_buffer, + self._int_workspace_buffer, + all_problems, + x_data, + w_data, + y_data, + x_stride_data, + w_stride_data, + y_stride_data, + y, # for torch compile mutates_args + empty_x_data, # for kernel type dispatch + weight_column_major, + ) + case "sm80": + ( + all_problems, + x_data, + w_data, + y_data, + x_ld_data, + w_ld_data, + y_ld_data, + ) = launch_compute_sm80_group_gemm_args( + x, + weights, + y, + weight_column_major, + batch_size, + seg_indptr, + weight_indices, + ) + get_gemm_module().cutlass_segment_gemm( + self._int_workspace_buffer, + all_problems, + x_data, + w_data, + y_data, + x_ld_data, + w_ld_data, + y_ld_data, + y, + empty_x_data, + weight_column_major, + ) + case _: + raise ValueError(f"Unsupported gemm backend: {backend}") return y forward = run diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 69ef8921..4abce374 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -252,3 +252,11 @@ def register_fake_op( def get_cuda_stream(device: torch.device) -> int: return torch.cuda.current_stream(device).cuda_stream + + +def determine_gemm_backend(device: torch.device) -> str: + major, _ = get_compute_capability(device) + if major >= 9: + return "sm90" + else: + return "sm80" diff --git a/tests/test_group_gemm.py b/tests/test_group_gemm.py index 5c67c306..777c0adf 100644 --- a/tests/test_group_gemm.py +++ b/tests/test_group_gemm.py @@ -18,6 +18,7 @@ import torch import flashinfer +from flashinfer.utils import determine_gemm_backend DTYPES = [torch.float16] CUDA_DEVICES = ["cuda:0"] @@ -31,6 +32,7 @@ @pytest.mark.parametrize("column_major", [False, True]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"]) def test_segment_gemm( batch_size, num_rows_per_batch, @@ -40,12 +42,16 @@ def test_segment_gemm( column_major, dtype, device, + backend, ): if batch_size * num_rows_per_batch > 8192: pytest.skip("batch_size * num_rows_per_batch too large for test.") + latest_supported_backend = determine_gemm_backend(torch.device(device)) + if backend == "sm90" and latest_supported_backend == "sm80": + pytest.skip("sm90 backend not supported on this device.") torch.manual_seed(42) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(device) - segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer) + segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend=backend) x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype).to(device) if use_weight_indices: num_weights = 1024 @@ -99,7 +105,7 @@ def test_segment_gemm( if __name__ == "__main__": - test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0") - test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0") - test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0") - test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0") + test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0", "auto") + test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0", "auto") + test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0", "auto") + test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0", "auto")