Skip to content

Commit

Permalink
fix fallback gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Jul 26, 2024
1 parent 4e9a61c commit b251d58
Showing 1 changed file with 6 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_int8_fallback_gemm {
// Shared mem requirement : 61440
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
static int32_t const MainLoopStages = 5;

using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5,
FP8MathOperator>;
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};

struct sm89_int8_config_default {
Expand Down

0 comments on commit b251d58

Please sign in to comment.