Skip to content

ggml : ggml_set_rows bcast, quantization and metal support #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i];
} else if (type == GGML_TYPE_I64) {
v = (float) *(int64_t *) &data[i];
} else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) {
Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml-cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ extern "C" {

GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);

GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
Expand Down
14 changes: 14 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,9 @@ extern "C" {
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);

// true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);

GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);

Expand Down Expand Up @@ -1376,6 +1379,17 @@ extern "C" {
struct ggml_tensor * b, // row indices
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape

// a TD [n_embd, ne1, ne2, ne3]
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
//
// undefined behavior if destination rows overlap
//
// broadcast:
// ne2 % ne11 == 0
// ne3 % ne12 == 0
//
// return view(a)
GGML_API struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a, // destination
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ typedef pthread_t ggml_thread_t;

static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
Expand Down Expand Up @@ -3126,6 +3127,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
return ggml_graph_compute(cgraph, &cplan);
}

void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
memcpy(y, x, n * sizeof(float));
}

void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
int64_t i = 0;
#if defined(__F16C__)
Expand Down
85 changes: 72 additions & 13 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2282,6 +2282,52 @@ static void ggml_compute_forward_repeat_f16(
}
}

static void ggml_compute_forward_repeat_i64(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

if (params->ith != 0) {
return;
}

GGML_ASSERT(ggml_can_repeat(src0, dst));

GGML_TENSOR_UNARY_OP_LOCALS

// guaranteed to be an integer due to the check in ggml_can_repeat
const int nr0 = (int)(ne0/ne00);
const int nr1 = (int)(ne1/ne01);
const int nr2 = (int)(ne2/ne02);
const int nr3 = (int)(ne3/ne03);

// TODO: support for transposed / permuted tensors
GGML_ASSERT(nb0 == sizeof(int64_t));
GGML_ASSERT(nb00 == sizeof(int64_t));

// TODO: maybe this is not optimal?
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne03; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne02; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne01; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
for (int i = 0; i < ne00; ++i) {
y[i] = x[i];
}
}
}
}
}
}
}
}
}

void ggml_compute_forward_repeat(
const ggml_compute_params * params,
ggml_tensor * dst) {
Expand All @@ -2300,6 +2346,10 @@ void ggml_compute_forward_repeat(
{
ggml_compute_forward_repeat_f32(params, dst);
} break;
case GGML_TYPE_I64:
{
ggml_compute_forward_repeat_i64(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -4480,12 +4530,14 @@ static void ggml_compute_forward_set_rows_f32(
GGML_TENSOR_BINARY_OP_LOCALS

const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
const int64_t nr = ne01;

assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(float));
assert(ggml_nrows(src0) == nr);
assert(ne2 == ne02);
assert(ne3 == ne03);
assert(src0->type == GGML_TYPE_F32);
assert(ne02 % ne11 == 0);
assert(ne03 % ne12 == 0);

const int ith = params->ith;
const int nth = params->nth;
Expand All @@ -4497,17 +4549,24 @@ static void ggml_compute_forward_set_rows_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i/(ne11*ne10);
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;

for (int64_t i03 = 0; i03 < ne03; ++i03) {
for (int64_t i02 = 0; i02 < ne02; ++i02) {
for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i03%ne12;
const int64_t i11 = i02%ne11;
const int64_t i10 = i;

GGML_ASSERT(i01 >= 0 && i01 < ne1);
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);

ggml_cpu_fp32_to_fp16(
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
GGML_ASSERT(i1 >= 0 && i1 < ne1);

from_float(
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
}
}
}
}

Expand Down
16 changes: 16 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,22 @@ typedef struct {
uint64_t nb2;
} ggml_metal_kargs_get_rows;

typedef struct {
int32_t nk0;
int32_t ne01;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_set_rows;

typedef struct {
int64_t ne00;
int64_t ne01;
Expand Down
110 changes: 105 additions & 5 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
GGML_METAL_KERNEL_TYPE_L2_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
Expand Down Expand Up @@ -1166,6 +1175,15 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
Expand Down Expand Up @@ -1630,7 +1648,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex

if (!use_bfloat) {
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
if (op->src[i] != NULL && (op->src[i]->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_BF16)) {
return false;
}
}
Expand Down Expand Up @@ -1798,6 +1816,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
{
return op->ne[3] == 1;
}
case GGML_OP_SET_ROWS:
{
if (op->src[0]->type != GGML_TYPE_F32) {
return false;
}

switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
return true;
default:
return false;
};
}
default:
return false;
}
Expand Down Expand Up @@ -3757,13 +3796,74 @@ static bool ggml_metal_encode_node(
};

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];

[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_SET_ROWS:
{
id<MTLComputePipelineState> pipeline = nil;

switch (dst->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
default: GGML_ABORT("not implemented");
}

const int32_t nk0 = ne0/ggml_blck_size(dst->type);

int nth = 32; // SIMD width

while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}

int nrptg = 1;
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;

if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
nrptg--;
}
}

nth = MIN(nth, nk0);

ggml_metal_kargs_set_rows args = {
/*.nk0 =*/ nk0,
/*.ne01 =*/ ne01,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};

[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];

[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
} break;
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
Expand Down
Loading
Loading