Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
res = "kimi-k2"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
res = "bailingmoe2"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -4461,6 +4464,103 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
name = name.removeprefix("transformer.")
return [(self.map_tensor_name(name), data_torch)]

@Model.register("BailingMoeV2ForCausalLM")
class BailingMoeV2Model(Model):
model_arch = gguf.MODEL_ARCH.BAILINGMOE2

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]

self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
else:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["moe_shared_expert_intermediate_size"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_group_count(hparams["n_group"])
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])

if hparams["score_function"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["score_function"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")

if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(nextn_layers)

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "mlp.experts" in name:
n_experts = self.hparams["num_experts"]
assert bid is not None

tensors: list[tuple[str, Tensor]] = []

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))

return tensors

if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


###### CONVERSION LOGIC ######


Expand Down
31 changes: 31 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ extern "C" {
GGML_OP_TRANSPOSE,
GGML_OP_GET_ROWS,
GGML_OP_GET_ROWS_BACK,
GGML_OP_SET_ROWS,
GGML_OP_DIAG,
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
Expand Down Expand Up @@ -1559,6 +1560,19 @@ extern "C" {
struct ggml_tensor * a,
float s);

// x = s * a + b
GGML_API struct ggml_tensor * ggml_scale_bias(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b);

GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b);

GGML_API struct ggml_tensor * ggml_softcap(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down Expand Up @@ -1781,6 +1795,23 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * c);

// a TD [n_embd, ne1, ne2, ne3]
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
//
// undefined behavior if destination rows overlap
//
// broadcast:
// ne2 % ne11 == 0
// ne3 % ne12 == 0
//
// return view(a)
GGML_API struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a, // destination
struct ggml_tensor * b, // source
struct ggml_tensor * c); // row indices

GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
struct ggml_tensor * a);
Expand Down
18 changes: 18 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
#include "ggml-cuda/topk-moe.cuh"
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/argmax.cuh"

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -3105,12 +3107,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr;

switch (dst->op) {
case GGML_OP_ARGMAX:
ggml_cuda_argmax(ctx, dst);
break;
case GGML_OP_REPEAT:
ggml_cuda_op_repeat(ctx, dst);
break;
case GGML_OP_GET_ROWS:
ggml_cuda_op_get_rows(ctx, dst);
break;
case GGML_OP_SET_ROWS:
ggml_cuda_op_set_rows(ctx, dst);
break;
case GGML_OP_DUP:
ggml_cuda_dup(ctx, dst);
break;
Expand Down Expand Up @@ -4204,6 +4212,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
return false;
}
} break;
case GGML_OP_SET_ROWS:
{
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
op->src[0]->type == GGML_TYPE_F32 &&
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
Expand Down Expand Up @@ -4260,6 +4276,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
}
return false;
} break;
case GGML_OP_ARGMAX:
return true;
case GGML_OP_DUP:
case GGML_OP_REPEAT:
case GGML_OP_CONCAT:
Expand Down
90 changes: 90 additions & 0 deletions ggml/src/ggml-cuda/argmax.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include <algorithm>
#include <cstdint>

#include "argmax.cuh"
#include "common.cuh"

static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
const int64_t row = blockIdx.x;

float maxval = -FLT_MAX;
int argmax = -1;
const float * rowx = x + row * ncols;

for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
const float val = rowx[col];
if (val > maxval) {
maxval = val;
argmax = col;
}
}

#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
}
}

const int n_warps = blockDim.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
if (n_warps > 1) {
constexpr int max_warps = 1024 / WARP_SIZE;
__shared__ float shared_maxval[max_warps];
__shared__ int shared_argmax[max_warps];
if (lane_id == 0) {
shared_maxval[warp_id] = maxval;
shared_argmax[warp_id] = argmax;
}

__syncthreads();

if (warp_id == 0) {
if (lane_id < n_warps) {
maxval = shared_maxval[lane_id];
argmax = shared_argmax[lane_id];
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
}
}
}
}

if (warp_id == 0 && lane_id == 0) {
dst[row] = argmax;
}
}

void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);

GGML_ASSERT(ggml_is_contiguous(src0));

const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

const float * src0_d = (const float *) src0->data;
int32_t * dst_d = (int32_t *) dst->data;

cudaStream_t stream = ctx.stream();

const int64_t num_blocks = nrows;
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
const dim3 blocks_dim(num_threads, 1, 1);
const dim3 blocks_num(num_blocks, 1, 1);

argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/argmax.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
54 changes: 44 additions & 10 deletions ggml/src/ggml-cuda/argsort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,20 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
b = tmp;
}

template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad,
int min_experts, float thresh_experts) {
struct store_ser {
constexpr static bool has_thresh = true;
int min_experts;
float thresh_experts;
store_ser(int min, float thresh) : min_experts(min), thresh_experts(thresh) {}
};

struct store {
constexpr static bool has_thresh = false;
};

template<ggml_sort_order order, typename Store>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, Store s) {
// int min_experts, float thresh_experts) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
Expand Down Expand Up @@ -58,19 +69,30 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
}
}

if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
if constexpr (Store::has_thresh) {
__syncthreads();
float max_val = x_row[dst_row[0]];
if (col < ncols) {
dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1;
dst[row * ncols + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1;
}
}
else {
// copy the result to dst without the padding
} else {
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
//if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
// __syncthreads();
// float max_val = x_row[dst_row[0]];
// if (col < ncols) {
// dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1;
// }
//}
//else {
// // copy the result to dst without the padding
// if (col < ncols) {
// dst[row * ncols + col] = dst_row[col];
// }
//}
}

static int next_power_of_2(int x) {
Expand All @@ -94,9 +116,21 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);

if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
{min_experts, thresh_experts});
} else {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, {});
}
//k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
{min_experts, thresh_experts});
} else {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, {});
}
//k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
} else {
GGML_ABORT("fatal error");
}
Expand Down
Loading