diff --git a/ggml.c b/ggml.c index 9dd2faca119b61..e9fee20f102548 100644 --- a/ggml.c +++ b/ggml.c @@ -5352,6 +5352,8 @@ static void ggml_compute_forward_add_q_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + float * wdata = (float*) params->wdata + ne00 * ith; + for (int ir = ir0; ir < ir1; ++ir) { // src0 indices const int i03 = ir/(ne02*ne01); @@ -5374,12 +5376,11 @@ static void ggml_compute_forward_add_q_f32( assert(ne00 % 32 == 0); // unquantize row from src0 to temp buffer - float tmp[ne00]; - dequantize_row_q(src0_row, tmp, ne00); + dequantize_row_q(src0_row, wdata, ne00); // add src1 - ggml_vec_acc_f32(ne00, tmp, src1_row); + ggml_vec_acc_f32(ne00, wdata, src1_row); // quantize row to dst - quantize_row_q(tmp, dst_row, ne00); + quantize_row_q(wdata, dst_row, ne00); } } @@ -9568,6 +9569,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_ADD: { node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) { + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads; + } + + work_size = MAX(work_size, cur); } break; case GGML_OP_SUB: case GGML_OP_MUL: