Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/backend/CANN.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,7 @@ Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. Whe
### GGML_CANN_PREFILL_USE_GRAPH

Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.

### GGML_CANN_HIGH_PERF_MODE

Enable high performance mode. Intermediate computation states are stored in FP16, which improves speed but may slightly reduce precision.
156 changes: 90 additions & 66 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1765,35 +1765,31 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0]; // src
ggml_tensor* src1 = dst->src[1]; // index

switch (src0->type) {
case GGML_TYPE_F32: {
if(src0->type == dst->type) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
break;
}
case GGML_TYPE_F16: {
} else if(src0->type == GGML_TYPE_F16) {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float));
ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = sizeof(float);
src_trans_nb[0] = dst->nb[0];
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
break;
}
case GGML_TYPE_Q8_0: {
// add 1 dim for bcast mul.
} else if (src0->type == GGML_TYPE_Q8_0){
// add 1 dim for bcast mul.
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
dequant_nb[GGML_MAX_DIMS + 1];
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
Expand Down Expand Up @@ -1854,11 +1850,8 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
src1, dst->type);

ggml_cann_release_resources(ctx, dequant_tensor);
break;
}
default:
} else {
GGML_ABORT("Unsupported tensor type for GGML_OP_GET_ROWS");
break;
}
}

Expand Down Expand Up @@ -3178,7 +3171,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclTensor* acl_src0_f16_tensor = nullptr;
aclTensor* acl_src1_f16_tensor = nullptr;
aclTensor* acl_src2_f16_tensor = nullptr;
aclTensor* acl_dst_f16_tensor = nullptr;

// Step 1: cast the src0 (Query) to fp16 if needed
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
Expand Down Expand Up @@ -3216,22 +3208,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
src2_bsnd_nb, GGML_MAX_DIMS);

ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);

int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}

acl_dst_f16_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);

// Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp
aclTensor* bcast_pse_tensor = nullptr;
Expand Down Expand Up @@ -3336,40 +3312,88 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){

// Step 5: launch the FusedInferAttentionScoreV2 kernel.
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
if (dst->type == GGML_TYPE_F16) {
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);

GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
bcast_pse_tensor, nullptr, // pse, mask
nullptr, nullptr, // actSeqLen, actSeqLenkv
nullptr, nullptr, // deqScale1, quantScale1
nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
nullptr, nullptr, // antiquantScale, antiquantOffset
nullptr, // blockTable
nullptr, nullptr, // qPadSize, kvPadSize
nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
numHeads, scaleValue, // heads, scaleValue
preTokens, nextTokens, // preTokens, nextTokens
layout, // inputLayout
numKeyValueHeads, // numKVHeads
sparseMode, innerPrecise, // sparseMode, innerPrecise
blockSize, antiquantMode, // blockSize, antiquantMode
softmaxLseFlag, // softmaxLseFlag
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
acl_dst_tensor, // attentionOut
nullptr // softmaxLse
);

ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
acl_dst_tensor);
} else {
aclTensor* acl_dst_f16_tensor = nullptr;
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);

int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}

acl_dst_f16_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
bcast_pse_tensor, nullptr, // pse, mask
nullptr, nullptr, // actSeqLen, actSeqLenkv
nullptr, nullptr, // deqScale1, quantScale1
nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
nullptr, nullptr, // antiquantScale, antiquantOffset
nullptr, // blockTable
nullptr, nullptr, // qPadSize, kvPadSize
nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
numHeads, scaleValue, // heads, scaleValue
preTokens, nextTokens, // preTokens, nextTokens
layout, // inputLayout
numKeyValueHeads, // numKVHeads
sparseMode, innerPrecise, // sparseMode, innerPrecise
blockSize, antiquantMode, // blockSize, antiquantMode
softmaxLseFlag, // softmaxLseFlag
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
acl_dst_f16_tensor, // attentionOut
nullptr // softmaxLse
);
// Step 6: post-processing, permute and cast to f32
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
// TODO: when dst is fp16, don't need cast
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
acl_dst_f16_tensor,
acl_dst_tensor);
}

GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
bcast_pse_tensor, nullptr, // pse, mask
nullptr, nullptr, // actSeqLen, actSeqLenkv
nullptr, nullptr, // deqScale1, quantScale1
nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
nullptr, nullptr, // antiquantScale, antiquantOffset
nullptr, // blockTable
nullptr, nullptr, // qPadSize, kvPadSize
nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
numHeads, scaleValue, // heads, scaleValue
preTokens, nextTokens, // preTokens, nextTokens
layout, // inputLayout
numKeyValueHeads, // numKVHeads
sparseMode, innerPrecise, // sparseMode, innerPrecise
blockSize, antiquantMode, // blockSize, antiquantMode
softmaxLseFlag, // softmaxLseFlag
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
acl_dst_f16_tensor, // attentionOut
nullptr // softmaxLse
);

// Step 6: post-processing, permute and cast to f32
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
// TODO: when dst is fp16, don't need cast
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
acl_dst_f16_tensor,
acl_dst_tensor);
if(src3 != nullptr){
ggml_cann_release_resources(ctx, bcast_pse_tensor);
}
Expand Down
11 changes: 8 additions & 3 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5357,9 +5357,14 @@ static void ggml_compute_forward_get_rows_f16(

GGML_ASSERT(i01 >= 0 && i01 < ne01);

ggml_cpu_fp16_to_fp32(
(const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
if (dst->type == GGML_TYPE_F16)
ggml_vec_cpy_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
(ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
else
ggml_cpu_fp16_to_fp32(
(const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
}
}

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp
inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
inline static void ggml_vec_cpy_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(-GGML_CPU_FP16_TO_FP32(x[i]));
Expand Down
47 changes: 44 additions & 3 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <float.h>
#include <limits.h>
#include <stdarg.h>
#include <ctype.h>
#include <signal.h>
#if defined(__gnu_linux__)
#include <syscall.h>
Expand Down Expand Up @@ -3006,6 +3007,32 @@ struct ggml_tensor * ggml_l2_norm_inplace(
return ggml_l2_norm_impl(ctx, a, eps, true);
}

static int get_env_as_bool(const char *name) {
const char *val = getenv(name);
if (val == NULL) {
return 0;
}

char buf[64];
size_t len = strlen(val);
if (len >= sizeof(buf)) {
len = sizeof(buf) - 1;
}
for (size_t i = 0; i < len; i++) {
buf[i] = (char)tolower((unsigned char)val[i]);
}
buf[len] = '\0';

const char *truthy[] = {"on", "1", "yes", "y", "enable", "true"};
for (size_t i = 0; i < sizeof(truthy) / sizeof(truthy[0]); i++) {
if (strcmp(buf, truthy[i]) == 0) {
return 1; // true
}
}

return 0; // false
}

// ggml_mul_mat

static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
Expand All @@ -3024,7 +3051,12 @@ struct ggml_tensor * ggml_mul_mat(
GGML_ASSERT(!ggml_is_transposed(a));

const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
struct ggml_tensor * result;
if(get_env_as_bool("GGML_CANN_HIGH_PERF_MODE") && b->type == GGML_TYPE_F16){
result = ggml_new_tensor(ctx, b->type, 4, ne);
} else {
result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
}

result->op = GGML_OP_MUL_MAT;
result->src[0] = a;
Expand Down Expand Up @@ -3629,6 +3661,9 @@ struct ggml_tensor * ggml_get_rows(

// TODO: implement non F32 return
enum ggml_type type = GGML_TYPE_F32;
if(get_env_as_bool("GGML_CANN_HIGH_PERF_MODE") && a->type == GGML_TYPE_F16){
type = a->type;
}
if (a->type == GGML_TYPE_I32) {
type = a->type;
}
Expand Down Expand Up @@ -3676,7 +3711,7 @@ struct ggml_tensor * ggml_set_rows(
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
GGML_ASSERT(c->ne[3] == 1);
GGML_ASSERT(b->type == GGML_TYPE_F32);
// GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(c->type == GGML_TYPE_I64);

GGML_ASSERT(ggml_is_contiguous_rows(a));
Expand Down Expand Up @@ -5003,7 +5038,13 @@ struct ggml_tensor * ggml_flash_attn_ext(

// permute(0, 2, 1, 3)
int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);

struct ggml_tensor * result;
if(get_env_as_bool("GGML_CANN_HIGH_PERF_MODE") && q->type == GGML_TYPE_F16){
result = ggml_new_tensor(ctx, q->type, 4, ne);
} else {
result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
}

float params[] = { scale, max_bias, logit_softcap };
ggml_set_op_params(result, params, sizeof(params));
Expand Down
3 changes: 3 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8530,6 +8530,9 @@ struct llm_build_qwen2 : public llm_graph_context {

// lm_head
cur = build_lora_mm(model.output, cur);
if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx0 ,cur, GGML_TYPE_F32);
}

if (model.output_b != nullptr) {
cur = ggml_add(ctx0, cur, model.output_b);
Expand Down