Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
upgrade cutlass to 3.5

upd

upd

upd

upd

upd

upd

upd

upd

upd

upd

upd
  • Loading branch information
yzh119 authored and xslingcn committed Oct 7, 2024
1 parent 3613a5b commit 1c53a42
Show file tree
Hide file tree
Showing 7 changed files with 483 additions and 28 deletions.
2 changes: 1 addition & 1 deletion include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe

} // namespace flashinfer

#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_
#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_
57 changes: 45 additions & 12 deletions include/flashinfer/gemm/group_gemm_cutlass.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/util/packed_stride.hpp"

namespace flashinfer {

Expand All @@ -41,21 +46,49 @@ struct cutlass_dtype<nv_bfloat16> {
using type = cutlass::bfloat16_t;
};

template <typename T>
__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x,
T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w,
int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr,
int64_t* w_indices, size_t d_in, size_t d_out,
bool w_column_major) {
template <>
struct cutlass_dtype<__nv_fp8_e4m3> {
using type = cutlass::float_e4m3_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e5m2> {
using type = cutlass::float_e5m2_t;
};

template <typename DTypeIn, typename DTypeOut>
__global__ void compute_sm80_cutlass_group_gemm_args(
cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
int64_t* x_ld, int64_t* w_ld, int64_t* y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y,
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out;
ptr_x[i] = x + xy_indptr[i] * d_in;
ptr_y[i] = y + xy_indptr[i] * d_out;
ld_x[i] = k; // m * k
ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
ld_y[i] = n; // m * n
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
x_ptr[i] = x + xy_indptr[i] * k;
y_ptr[i] = y + xy_indptr[i] * n;
x_ld[i] = k; // m * k
w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
y_ld[i] = n; // m * n
}

template <typename DTypeIn, typename DTypeOut, typename ProblemShape, typename StrideA,
typename StrideB, typename StrideCD>
__global__ void compute_sm90_cutlass_group_gemm_args(
ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y,
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = ProblemShape(m, n, k);
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
x_ptr[i] = x + xy_indptr[i] * k;
y_ptr[i] = y + xy_indptr[i] * n;

x_stride[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1})
: cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
y_stride[i] = cutlass::make_cute_packed_stride(StrideCD{}, {m, n, 1});
}

} // namespace group_gemm
Expand Down
Loading

0 comments on commit 1c53a42

Please sign in to comment.