Skip to content

Commit

Permalink
MoE updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Dec 13, 2023
1 parent 82bc8a3 commit e6554d5
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
url = https://github.com/rogersce/cnpy.git
[submodule "3rdparty/cutlass_fpA_intB_gemm"]
path = 3rdparty/cutlass_fpA_intB_gemm
url = https://github.com/tlc-pack/cutlass_fpA_intB_gemm
url = https://github.com/vinx13/cutlass_fpA_intB_gemm
[submodule "3rdparty/libflash_attn"]
path = 3rdparty/libflash_attn
url = https://github.com/tlc-pack/libflash_attn
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/cutlass_fpA_intB_gemm
Submodule cutlass_fpA_intB_gemm updated 27 files
+86 −0 cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h
+532 −0 cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h
+358 −0 cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h
+16 −1 cutlass_kernels/CMakeLists.txt
+235 −93 cutlass_kernels/cutlass_preprocessors.cc
+10 −2 cutlass_kernels/cutlass_preprocessors.h
+54 −11 cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_impl.h
+71 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels.h
+27 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
+27 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
+27 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
+500 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
+14 −0 utils/activation_types.h
+84 −0 weightOnlyBatchedGemv/common.h
+91 −0 weightOnlyBatchedGemv/enabled.h
+435 −0 weightOnlyBatchedGemv/kernel.h
+224 −0 weightOnlyBatchedGemv/kernelLauncher.cu
+27 −0 weightOnlyBatchedGemv/kernelLauncher.h
+99 −0 weightOnlyBatchedGemv/utility.h
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu
+97 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu
+97 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu
+97 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu
5 changes: 5 additions & 0 deletions cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ if(USE_CUDA AND USE_CUTLASS)
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)

include_directories(3rdparty/cutlass_fpA_intB_gemm
3rdparty/cutlass_fpA_intB_gemm/cutlass/include) # FIXME
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/moe_gemm.cc)
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/moe_compute_rows.cu)

message(STATUS "Build with CUTLASS")
endif()
2 changes: 1 addition & 1 deletion python/tvm/dlight/gpu/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
{"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop)

if not s_loops:
s_loops.append(sch.add_unit_loop(block.block))
s_loops.append(sch.add_unit_loop(block))
sch.reorder(*s_loops, *r_loops, *o_loops)
bx, tx = sch.split( # pylint: disable=invalid-name
sch.fuse(*s_loops),
Expand Down
45 changes: 45 additions & 0 deletions src/runtime/contrib/cutlass/moe_compute_rows.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/moe_kernels.cu
__device__ inline int find_total_elts_leq_target(const int* sorted_indices, const int arr_length, const int target)
{
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high) {
int64_t mid = (low + high) / 2;

if (sorted_indices[mid] > target) {
high = mid - 1;
}
else {
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
__global__ void compute_total_rows_before_expert_kernel(const int* sorted_experts,
const int sorted_experts_len,
const int64_t num_experts,
int64_t* total_rows_before_expert)
{

// First, compute the global tid. We only need 1 thread per expert.
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts)
return;

// This should construct the last index where each expert occurs.
total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
}

void compute_total_rows_before_expert(const int* sorted_indices,
const int total_indices,
const int num_experts,
int64_t* total_rows_before_expert,
cudaStream_t stream)
{

const int threads = std::min(1024, num_experts);
const int blocks = (num_experts + threads - 1) / threads;

compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}
101 changes: 101 additions & 0 deletions src/runtime/contrib/cutlass/moe_gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <optional>
#include <string>

#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/half.h"
// clang-format off
// theses headers can't be reordered
#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/numeric_types.h"
#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/integer_subbyte.h"
// clang-format on
#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"

void compute_total_rows_before_expert(const int* sorted_indices, const int total_indices,
const int num_experts, int64_t* total_rows_before_expert,
cudaStream_t stream);

namespace fastertransformer {

template <typename T, typename WeightType>
void moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases,
T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n,
int64_t gemm_k, int num_experts, std::optional<std::string> activation,
cudaStream_t stream);
}

namespace tvm {
namespace runtime {

TVM_REGISTER_GLOBAL("cutlass.moe_gemm_f16f16")
.set_body_typed([](NDArray x, NDArray weight, NDArray total_rows_before_expert,
int64_t total_rows, int64_t n, int64_t k, int64_t num_experts, NDArray out) {
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

fastertransformer::moe_gemm_bias_act<half, half>(
reinterpret_cast<half*>(x->data), reinterpret_cast<half*>(weight->data), nullptr, nullptr,
reinterpret_cast<half*>(out->data),
reinterpret_cast<int64_t*>(total_rows_before_expert->data), total_rows, n, k, num_experts,
std::nullopt, stream);
});

TVM_REGISTER_GLOBAL("cutlass.moe_gemm_s4f16")
.set_body_typed([](NDArray x, NDArray weight, NDArray scales, NDArray total_rows_before_expert,
int64_t total_rows, int64_t n, int64_t k, int64_t num_experts, NDArray out) {
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

fastertransformer::moe_gemm_bias_act<half, cutlass::uint4b_t>(
reinterpret_cast<half*>(x->data), reinterpret_cast<cutlass::uint4b_t*>(weight->data),
reinterpret_cast<half*>(scales->data), nullptr, reinterpret_cast<half*>(out->data),
reinterpret_cast<int64_t*>(total_rows_before_expert->data), total_rows, n, k, num_experts,
std::nullopt, stream);
});

TVM_REGISTER_GLOBAL("moe_compute_rows_before")
.set_body_typed([](NDArray sorted_indices, NDArray total_rows_before_expert) {
CHECK(sorted_indices->dtype.code == kDLInt && sorted_indices->dtype.bits == 32);
CHECK(total_rows_before_expert->dtype.code == kDLInt &&
total_rows_before_expert->dtype.bits == 64);
CHECK(sorted_indices->ndim == 1);
CHECK(total_rows_before_expert->ndim == 1);

auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

int num_experts = total_rows_before_expert->shape[0];
compute_total_rows_before_expert(
reinterpret_cast<int*>(sorted_indices->data), sorted_indices->shape[0], num_experts,
reinterpret_cast<int64_t*>(total_rows_before_expert->data), stream);
});

} // namespace runtime
} // namespace tvm
41 changes: 35 additions & 6 deletions src/runtime/contrib/cutlass/weight_preprocess.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/

#include <cuda_fp16.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand All @@ -37,22 +38,50 @@ namespace runtime {
// The preprocessing functions are defined in C++, so we need to copy the input weight to CPU.
TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight")
.set_body_typed([](NDArray packed_weight, int sm, bool is_int4) {
int rows = packed_weight->shape[0];
int cols = packed_weight->shape[1];
std::vector<int8_t> input_cpu(rows * cols);
std::vector<int8_t> output_cpu(rows * cols);
bool is_2d = packed_weight->ndim == 2;
int num_experts = is_2d ? 1 : packed_weight->shape[0];
int rows = packed_weight->shape[is_2d ? 0 : 1];
int cols = packed_weight->shape[is_2d ? 1 : 2];

std::vector<int8_t> input_cpu(num_experts * rows * cols);
std::vector<int8_t> output_cpu(num_experts * rows * cols);
packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size());
// multiply cols by 2 since the "col" params in preprocess_weights refers to the column of
// the unpacked weight.
if (is_int4) {
cols *= 2;
}
fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), rows, cols,
is_int4, sm);
fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(),
is_2d ? -1 : num_experts, rows, cols, is_int4, sm);
auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, packed_weight->device);
out.CopyFromBytes(output_cpu.data(), output_cpu.size());
return out;
});

TVM_REGISTER_GLOBAL("cutlass.symmetric_quantize").set_body_typed([](NDArray weight, bool is_int4) {
CHECK(is_int4);
CHECK(weight->dtype.code == kDLFloat && weight->dtype.bits == 16);
CHECK(weight->ndim == 3);
CHECK(weight->device.device_type == kDLCPU);
int64_t num_experts = weight->shape[0];
int64_t rows = weight->shape[1];
int64_t cols = weight->shape[2];

ShapeTuple out_weight_shape{num_experts, rows, cols / 2};
ShapeTuple out_scale_shape{num_experts, cols};
auto out_weight = NDArray::Empty(
out_weight_shape, DLDataType{.code = kDLInt, .bits = 8, .lanes = 1}, weight->device);
auto out_scale = NDArray::Empty(
out_scale_shape, DLDataType{.code = kDLFloat, .bits = 16, .lanes = 1}, weight->device);

fastertransformer::symmetric_quantize<half, half>(
reinterpret_cast<int8_t*>(out_weight->data), reinterpret_cast<half*>(out_scale->data),
reinterpret_cast<const half*>(weight->data),
std::vector<size_t>{static_cast<size_t>(num_experts), static_cast<size_t>(rows),
static_cast<size_t>(cols)},
true);
return Array<NDArray>{out_weight, out_scale};
});

} // namespace runtime
} // namespace tvm
29 changes: 29 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file Use external Thrust library call
*/

#include <cuda_fp16.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
Expand Down Expand Up @@ -140,6 +141,18 @@ void thrust_sort_common(DLTensor* input,
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float16") {
if (out_dtype == "int32") {
thrust_sort<half, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<half, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<half, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<half, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
Expand Down Expand Up @@ -185,6 +198,22 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
data_dtype, out_dtype);
});

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_dps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[2];
DLTensor* indices_out = args[3];
bool is_ascend = args[1];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

int n_values = input->shape[input->ndim - 1];
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
data_dtype, out_dtype);
});

template<typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in,
DLTensor* values_in,
Expand Down

0 comments on commit e6554d5

Please sign in to comment.