Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add Mixtral support #4406

Merged
merged 47 commits into from
Dec 13, 2023
Merged
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
dff8cbe
convert : support Mixtral as LLAMA arch
ggerganov Dec 9, 2023
d38e41e
convert : fix n_ff typo
ggerganov Dec 9, 2023
a3eefe9
llama : model loading
ggerganov Dec 9, 2023
861cd67
ggml : sync latest ggml_mul_mat_id
ggerganov Dec 9, 2023
aedfad1
llama : update graph to support MoE
ggerganov Dec 9, 2023
af1a096
llama : fix cur -> cur_expert
ggerganov Dec 9, 2023
7ea3695
llama : first working version
ggerganov Dec 9, 2023
8b185b7
llama : fix expert weighting in the FFN
ggerganov Dec 9, 2023
7372b62
ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu o…
ggerganov Dec 9, 2023
ee8fb39
ggml : add n_as argument to ggml_mul_mat_id
slaren Dec 9, 2023
9064b1c
ggml : fix ggml_get_rows to take into account ne02 / ne11
ggerganov Dec 9, 2023
2cbcba8
metal : add more general support for ggml_get_rows + tests
ggerganov Dec 9, 2023
06dfde3
llama : add basic support for offloading moe with CUDA
slaren Dec 9, 2023
7e2006b
metal : add/mul/div use general kernel when src1 not cont
ggerganov Dec 9, 2023
8c5b66e
metal : reduce the kernel launches for ggml_mul_mat_id
ggerganov Dec 9, 2023
ac3f7d8
ggml : get_rows : support non-contiguos tensors with gaps, generalize…
slaren Dec 9, 2023
2e4db48
ggml : update get_rows f16 and q
slaren Dec 9, 2023
62b95f9
cuda : support non-contiguous src1 in get_rows
slaren Dec 9, 2023
0710b0f
llama : offload missing ffn_moe_silu
slaren Dec 9, 2023
016f9bb
metal : fix ggml_get_rows to work with non-cont src1
ggerganov Dec 10, 2023
6cfb31f
metal : add indirect mat-vec kernels for all quantization types
ggerganov Dec 10, 2023
d1259b7
llama : do not quantize expert gating tensors
ggerganov Dec 10, 2023
e640cbe
llama : add n_expert and n_expert_used to hparams + change quants
ggerganov Dec 10, 2023
cefebb3
test-backend-ops : add moe test
slaren Dec 10, 2023
8614aa7
cuda : fix get_rows when ncols is odd
slaren Dec 10, 2023
65923a8
convert : determine n_ctx correctly
ggerganov Dec 10, 2023
b0b83dd
metal : fix ggml_mul_mat_id for F32
ggerganov Dec 10, 2023
54ba263
test-backend-ops : make experts more evenly probable (test_moe)
ggerganov Dec 10, 2023
54d254b
test-backend-ops : cleanup, add moe test for batches
slaren Dec 10, 2023
f1380d7
test-backend-ops : add cpy from f32 -> all types test
slaren Dec 10, 2023
b002981
test-backend-ops : fix dequantize block offset
slaren Dec 11, 2023
8cbaed1
llama : fix hard-coded number of experts
ggerganov Dec 11, 2023
ffda94c
test-backend-ops : simplify and disable slow tests to avoid CI timeout
slaren Dec 11, 2023
33e50f1
test-backend-ops : disable MOE test with thread sanitizer
slaren Dec 11, 2023
296c945
cuda : fix mul_mat_id with multi gpu
slaren Dec 11, 2023
7dc75e3
convert : use 1e6 rope_freq_base for mixtral
slaren Dec 11, 2023
f1cbfab
convert : fix style
slaren Dec 11, 2023
6a419f4
convert : support safetensors format
ggerganov Dec 12, 2023
a742d9f
gguf-py : bump version
slaren Dec 12, 2023
08eb991
metal : add cpy f16 -> f32 kernel
ggerganov Dec 12, 2023
a51bc0c
metal : fix binary ops for ne10 % 4 != 0
ggerganov Dec 12, 2023
ea4402b
test-backend-ops : add one more sum_rows test
ggerganov Dec 12, 2023
90c12e6
ggml : do not use BLAS with ggml_mul_mat_id
ggerganov Dec 12, 2023
82e4f64
convert-hf : support for mixtral-instruct (#4428)
Mrkvak Dec 12, 2023
ab558ac
metal : fix soft_max kernels
ggerganov Dec 13, 2023
109e7aa
metal : limit kernels to not use more than the allowed threads
ggerganov Dec 13, 2023
e1241d9
metal : switch to execution barriers + fix one of the barriers
ggerganov Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
metal : add cpy f16 -> f32 kernel
ggerganov committed Dec 12, 2023
commit 08eb99179a301850ed7aaaf1143e0e20ca50c234
10 changes: 5 additions & 5 deletions convert.py
Original file line number Diff line number Diff line change
@@ -63,10 +63,10 @@ class UnquantizedDataType(DataType):
pass


DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])


@dataclass(frozen=True)
@@ -996,7 +996,7 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM


def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type

if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
return GGMLFileType.AllF32
36 changes: 32 additions & 4 deletions ggml-metal.m
Original file line number Diff line number Diff line change
@@ -155,6 +155,7 @@
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr);
GGML_METAL_DECL_KERNEL(sum_rows);
@@ -424,6 +425,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr);
GGML_METAL_ADD_KERNEL(sum_rows);
@@ -539,6 +541,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr);
GGML_METAL_DEL_KERNEL(sum_rows);
@@ -867,12 +870,37 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
case GGML_OP_ARGSORT:
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return true;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
{
switch (op->src[0]->type) {
case GGML_TYPE_F32:
switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return true;
default:
return false;
}
case GGML_TYPE_F16:
switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
default:
return false;
};
}
case GGML_OP_DIAG_MASK_INF:
{
return op->ne[0] % 4 == 0;
@@ -2021,7 +2049,7 @@ void ggml_metal_graph_compute(
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
45 changes: 43 additions & 2 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
@@ -1698,8 +1698,8 @@ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_ar
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;

kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,
device const half * src0,
device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -1738,6 +1738,47 @@ kernel void kernel_cpy_f16_f16(
}
}

kernel void kernel_cpy_f16_f32(
device const half * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];

const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);

device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}

kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
8 changes: 4 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
@@ -4277,23 +4277,23 @@ struct llm_build_context {
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
cb(logits, "ffn_moe_logits", il);

ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
cb(probs, "ffn_moe_probs", il);

// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
cb(selected_experts->src[0], "ffn_moe_argsort", il);

ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
cb(weights, "ffn_moe_weights", il);

weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]

ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
cb(weights_sum, "ffn_moe_weights_sum", il);

weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
cb(weights, "ffn_moe_weights_norm", il);

// compute expert outputs