Skip to content

Commit bee8468

Browse files
committed
Add ggml_check_edges
1 parent b2d689a commit bee8468

File tree

2 files changed

+79
-119
lines changed

2 files changed

+79
-119
lines changed

ggml/src/ggml-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
682682
#endif
683683

684684
#ifdef __cplusplus
685+
#include <array>
685686
#include <initializer_list>
686687
#include <vector>
687688

@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
697698
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
698699
}
699700

701+
// Return true if the edges in the graph match expectations.
702+
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
703+
int start_idx,
704+
std::initializer_list<std::array<int, 3>> edges) {
705+
for (const auto &edge : edges) {
706+
int dst_node = edge[0];
707+
int src_idx = edge[1];
708+
int src_node = edge[2];
709+
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
710+
return false;
711+
}
712+
}
713+
return true;
714+
}
715+
700716
// expose GGUF internals for test code
701717
GGML_API size_t gguf_type_size(enum gguf_type type);
702718
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 63 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,51 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM
396396
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
397397
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
398398

399+
//node #963 ( SOFT_MAX): ffn_moe_probs-15 ( 64K) [Vulka ] use=2: ffn_moe_logits-15 ( 64K) [Vulka ]
400+
//node #964 ( RESHAPE): ffn_moe_probs-15 (re ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
401+
//node #965 ( ARGSORT): ffn_moe_argsort-15 ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
402+
//node #966 ( VIEW): ffn_moe_topk-15 ( 63K) [Vulka ] use=4: ffn_moe_argsort-15 ( 64K) [Vulka ]
403+
//node #967 ( GET_ROWS): ffn_moe_weights-15 ( 4K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 64K) [Vulka ] ffn_moe_topk-15 ( 63K) [Vulka ]
404+
//node #968 ( RESHAPE): ffn_moe_weights-15 ( ( 4K) [Vulka ] use=2: ffn_moe_weights-15 ( 4K) [Vulka ]
405+
//node #969 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ]
406+
//node #970 ( DIV): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ] ffn_moe_weights_sum- ( 0K) [Vulka ]
407+
//node #971 ( RESHAPE): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights_norm ( 4K) [Vulka ]
408+
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
409+
{ 1, 0, 0 }, // reshape->src[0] == softmax
410+
{ 2, 0, 0 }, // argsort->src[0] == softmax
411+
{ 3, 0, 2 }, // view->src[0] == argsort
412+
{ 4, 0, 1 }, // get_rows->src[0] == reshape
413+
{ 4, 1, 3 }, // get_rows->src[1] == view
414+
{ 5, 0, 4 }, // reshape->src[0] == get_rows
415+
{ 6, 0, 5 }, // sum_rows->src[0] == reshape
416+
{ 7, 0, 5 }, // div->src[0] == reshape
417+
{ 7, 1, 6 }, // div->src[1] == sum_rows
418+
{ 8, 0, 7 }, // reshape->src[0] == div
419+
};
420+
421+
// same as early_softmax_norm but ending after the get_rows
422+
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
423+
{ 1, 0, 0 }, // reshape->src[0] == softmax
424+
{ 2, 0, 0 }, // argsort->src[0] == softmax
425+
{ 3, 0, 2 }, // view->src[0] == argsort
426+
{ 4, 0, 1 }, // get_rows->src[0] == reshape
427+
{ 4, 1, 3 }, // get_rows->src[1] == view
428+
};
429+
430+
//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
431+
//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
432+
//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
433+
//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
434+
//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
435+
//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
436+
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
437+
{ 1, 0, 0 }, // view->src[0] == argsort
438+
{ 2, 1, 1 }, // get_rows->src[1] == view
439+
{ 3, 0, 2 }, // reshape->src[0] == get_rows
440+
{ 4, 0, 3 }, // soft_max->src[0] == reshape
441+
{ 5, 0, 4 }, // reshape->src[0] == soft_max
442+
};
443+
399444
enum topk_moe_mode {
400445
TOPK_MOE_EARLY_SOFTMAX,
401446
TOPK_MOE_EARLY_SOFTMAX_NORM,
@@ -12291,38 +12336,14 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1229112336

1229212337
switch (mode) {
1229312338
case TOPK_MOE_EARLY_SOFTMAX_NORM:
12294-
if (node_idx + (int)topk_moe_early_softmax_norm.size() > cgraph->n_nodes) {
12295-
return false;
12296-
}
12297-
for (size_t i = 0; i < topk_moe_early_softmax_norm.size(); ++i) {
12298-
if (cgraph->nodes[node_idx + i]->op != topk_moe_early_softmax_norm.begin()[i]) {
12299-
return false;
12300-
}
12301-
}
1230212339
softmax = cgraph->nodes[node_idx + 0];
1230312340
weights = cgraph->nodes[node_idx + 8];
1230412341
break;
1230512342
case TOPK_MOE_EARLY_SOFTMAX:
12306-
if (node_idx + (int)topk_moe_early_softmax.size() > cgraph->n_nodes) {
12307-
return false;
12308-
}
12309-
for (size_t i = 0; i < topk_moe_early_softmax.size(); ++i) {
12310-
if (cgraph->nodes[node_idx + i]->op != topk_moe_early_softmax.begin()[i]) {
12311-
return false;
12312-
}
12313-
}
1231412343
softmax = cgraph->nodes[node_idx + 0];
1231512344
weights = cgraph->nodes[node_idx + 4];
1231612345
break;
1231712346
case TOPK_MOE_LATE_SOFTMAX:
12318-
if (node_idx + (int)topk_moe_late_softmax.size() > cgraph->n_nodes) {
12319-
return false;
12320-
}
12321-
for (size_t i = 0; i < topk_moe_late_softmax.size(); ++i) {
12322-
if (cgraph->nodes[node_idx + i]->op != topk_moe_late_softmax.begin()[i]) {
12323-
return false;
12324-
}
12325-
}
1232612347
softmax = cgraph->nodes[node_idx + 4];
1232712348
weights = cgraph->nodes[node_idx + 5];
1232812349
break;
@@ -12354,95 +12375,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1235412375
return false;
1235512376
}
1235612377

12357-
// Check that the nodes don't have any unexpected uses
12358-
if (mode == TOPK_MOE_LATE_SOFTMAX) {
12359-
const ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
12360-
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
12361-
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
12362-
const ggml_tensor * reshape3 = cgraph->nodes[node_idx + 3];
12363-
// softmax is 4
12364-
const ggml_tensor * reshape5 = cgraph->nodes[node_idx + 5];
12365-
12366-
// argsort is used by view
12367-
if (ggml_node_get_use_count(cgraph, node_idx + 0) != 1 ||
12368-
view->src[0] != argsort) {
12369-
return false;
12370-
}
12371-
// view is written, we can skip checking it
12372-
12373-
// get_rows is used by reshape3
12374-
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
12375-
reshape3->src[0] != get_rows) {
12376-
return false;
12377-
}
12378-
12379-
// reshape3 is used by softmax
12380-
if (ggml_node_get_use_count(cgraph, node_idx + 3) != 1 ||
12381-
softmax->src[0] != reshape3) {
12382-
return false;
12383-
}
12384-
12385-
// softmax is used by reshape5
12386-
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
12387-
reshape5->src[0] != softmax) {
12388-
return false;
12389-
}
12390-
} else {
12391-
bool with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM;
12392-
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
12393-
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
12394-
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
12395-
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
12396-
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
12397-
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
12398-
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
12399-
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
12400-
12401-
// softmax is used by reshape and argsort
12402-
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
12403-
reshape1->src[0] != softmax ||
12404-
argsort->src[0] != softmax) {
12405-
return false;
12406-
}
12407-
// reshape is used by get_rows
12408-
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
12409-
get_rows->src[0] != reshape1) {
12410-
return false;
12411-
}
12412-
// argsort is used by view
12413-
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
12414-
view->src[0] != argsort) {
12415-
return false;
12416-
}
12417-
// view is written (via argsort), we can skip checking it
12418-
12419-
if (with_norm) {
12420-
// get_rows is used by reshape
12421-
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
12422-
reshape5->src[0] != get_rows) {
12423-
return false;
12424-
}
12425-
12426-
// reshape is used by sum_rows and div
12427-
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
12428-
sum_rows->src[0] != reshape5 ||
12429-
div->src[0] != reshape5) {
12430-
return false;
12431-
}
12432-
12433-
// sum_rows is used by div
12434-
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
12435-
div->src[1] != sum_rows) {
12436-
return false;
12437-
}
12438-
12439-
// div/reshape are written
12440-
if (reshape8->src[0] != div) {
12441-
return false;
12442-
}
12443-
}
12444-
}
12445-
1244612378
if (!ctx->device->subgroup_arithmetic ||
1244712379
!ctx->device->subgroup_shuffle ||
1244812380
!ctx->device->subgroup_require_full_support ||
@@ -12528,11 +12460,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1252812460
ctx->num_additional_fused_ops = num_adds - 1;
1252912461
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1253012462
ctx->num_additional_fused_ops = 1;
12531-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12463+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12464+
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
12465+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1253212466
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12533-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
12467+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
12468+
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
12469+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
1253412470
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12535-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
12471+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
12472+
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
12473+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
1253612474
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
1253712475
}
1253812476
}
@@ -12631,11 +12569,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1263112569
ctx->num_additional_fused_ops = num_adds - 1;
1263212570
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1263312571
ctx->num_additional_fused_ops = 1;
12634-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12572+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12573+
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
12574+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1263512575
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12636-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
12576+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
12577+
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
12578+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
1263712579
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12638-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
12580+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
12581+
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
12582+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
1263912583
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
1264012584
}
1264112585
}

0 commit comments

Comments
 (0)