@@ -396,6 +396,51 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM
396396 GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
397397 GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
398398
399+ //node #963 ( SOFT_MAX): ffn_moe_probs-15 ( 64K) [Vulka ] use=2: ffn_moe_logits-15 ( 64K) [Vulka ]
400+ //node #964 ( RESHAPE): ffn_moe_probs-15 (re ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
401+ //node #965 ( ARGSORT): ffn_moe_argsort-15 ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
402+ //node #966 ( VIEW): ffn_moe_topk-15 ( 63K) [Vulka ] use=4: ffn_moe_argsort-15 ( 64K) [Vulka ]
403+ //node #967 ( GET_ROWS): ffn_moe_weights-15 ( 4K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 64K) [Vulka ] ffn_moe_topk-15 ( 63K) [Vulka ]
404+ //node #968 ( RESHAPE): ffn_moe_weights-15 ( ( 4K) [Vulka ] use=2: ffn_moe_weights-15 ( 4K) [Vulka ]
405+ //node #969 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ]
406+ //node #970 ( DIV): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ] ffn_moe_weights_sum- ( 0K) [Vulka ]
407+ //node #971 ( RESHAPE): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights_norm ( 4K) [Vulka ]
408+ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
409+ { 1, 0, 0 }, // reshape->src[0] == softmax
410+ { 2, 0, 0 }, // argsort->src[0] == softmax
411+ { 3, 0, 2 }, // view->src[0] == argsort
412+ { 4, 0, 1 }, // get_rows->src[0] == reshape
413+ { 4, 1, 3 }, // get_rows->src[1] == view
414+ { 5, 0, 4 }, // reshape->src[0] == get_rows
415+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
416+ { 7, 0, 5 }, // div->src[0] == reshape
417+ { 7, 1, 6 }, // div->src[1] == sum_rows
418+ { 8, 0, 7 }, // reshape->src[0] == div
419+ };
420+
421+ // same as early_softmax_norm but ending after the get_rows
422+ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
423+ { 1, 0, 0 }, // reshape->src[0] == softmax
424+ { 2, 0, 0 }, // argsort->src[0] == softmax
425+ { 3, 0, 2 }, // view->src[0] == argsort
426+ { 4, 0, 1 }, // get_rows->src[0] == reshape
427+ { 4, 1, 3 }, // get_rows->src[1] == view
428+ };
429+
430+ //node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
431+ //node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
432+ //node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
433+ //node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
434+ //node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
435+ //node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
436+ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
437+ { 1, 0, 0 }, // view->src[0] == argsort
438+ { 2, 1, 1 }, // get_rows->src[1] == view
439+ { 3, 0, 2 }, // reshape->src[0] == get_rows
440+ { 4, 0, 3 }, // soft_max->src[0] == reshape
441+ { 5, 0, 4 }, // reshape->src[0] == soft_max
442+ };
443+
399444enum topk_moe_mode {
400445 TOPK_MOE_EARLY_SOFTMAX,
401446 TOPK_MOE_EARLY_SOFTMAX_NORM,
@@ -12291,38 +12336,14 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1229112336
1229212337 switch (mode) {
1229312338 case TOPK_MOE_EARLY_SOFTMAX_NORM:
12294- if (node_idx + (int)topk_moe_early_softmax_norm.size() > cgraph->n_nodes) {
12295- return false;
12296- }
12297- for (size_t i = 0; i < topk_moe_early_softmax_norm.size(); ++i) {
12298- if (cgraph->nodes[node_idx + i]->op != topk_moe_early_softmax_norm.begin()[i]) {
12299- return false;
12300- }
12301- }
1230212339 softmax = cgraph->nodes[node_idx + 0];
1230312340 weights = cgraph->nodes[node_idx + 8];
1230412341 break;
1230512342 case TOPK_MOE_EARLY_SOFTMAX:
12306- if (node_idx + (int)topk_moe_early_softmax.size() > cgraph->n_nodes) {
12307- return false;
12308- }
12309- for (size_t i = 0; i < topk_moe_early_softmax.size(); ++i) {
12310- if (cgraph->nodes[node_idx + i]->op != topk_moe_early_softmax.begin()[i]) {
12311- return false;
12312- }
12313- }
1231412343 softmax = cgraph->nodes[node_idx + 0];
1231512344 weights = cgraph->nodes[node_idx + 4];
1231612345 break;
1231712346 case TOPK_MOE_LATE_SOFTMAX:
12318- if (node_idx + (int)topk_moe_late_softmax.size() > cgraph->n_nodes) {
12319- return false;
12320- }
12321- for (size_t i = 0; i < topk_moe_late_softmax.size(); ++i) {
12322- if (cgraph->nodes[node_idx + i]->op != topk_moe_late_softmax.begin()[i]) {
12323- return false;
12324- }
12325- }
1232612347 softmax = cgraph->nodes[node_idx + 4];
1232712348 weights = cgraph->nodes[node_idx + 5];
1232812349 break;
@@ -12354,95 +12375,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1235412375 return false;
1235512376 }
1235612377
12357- // Check that the nodes don't have any unexpected uses
12358- if (mode == TOPK_MOE_LATE_SOFTMAX) {
12359- const ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
12360- const ggml_tensor * view = cgraph->nodes[node_idx + 1];
12361- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
12362- const ggml_tensor * reshape3 = cgraph->nodes[node_idx + 3];
12363- // softmax is 4
12364- const ggml_tensor * reshape5 = cgraph->nodes[node_idx + 5];
12365-
12366- // argsort is used by view
12367- if (ggml_node_get_use_count(cgraph, node_idx + 0) != 1 ||
12368- view->src[0] != argsort) {
12369- return false;
12370- }
12371- // view is written, we can skip checking it
12372-
12373- // get_rows is used by reshape3
12374- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
12375- reshape3->src[0] != get_rows) {
12376- return false;
12377- }
12378-
12379- // reshape3 is used by softmax
12380- if (ggml_node_get_use_count(cgraph, node_idx + 3) != 1 ||
12381- softmax->src[0] != reshape3) {
12382- return false;
12383- }
12384-
12385- // softmax is used by reshape5
12386- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
12387- reshape5->src[0] != softmax) {
12388- return false;
12389- }
12390- } else {
12391- bool with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM;
12392- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
12393- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
12394- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
12395- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
12396- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
12397- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
12398- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
12399- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
12400-
12401- // softmax is used by reshape and argsort
12402- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
12403- reshape1->src[0] != softmax ||
12404- argsort->src[0] != softmax) {
12405- return false;
12406- }
12407- // reshape is used by get_rows
12408- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
12409- get_rows->src[0] != reshape1) {
12410- return false;
12411- }
12412- // argsort is used by view
12413- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
12414- view->src[0] != argsort) {
12415- return false;
12416- }
12417- // view is written (via argsort), we can skip checking it
12418-
12419- if (with_norm) {
12420- // get_rows is used by reshape
12421- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
12422- reshape5->src[0] != get_rows) {
12423- return false;
12424- }
12425-
12426- // reshape is used by sum_rows and div
12427- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
12428- sum_rows->src[0] != reshape5 ||
12429- div->src[0] != reshape5) {
12430- return false;
12431- }
12432-
12433- // sum_rows is used by div
12434- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
12435- div->src[1] != sum_rows) {
12436- return false;
12437- }
12438-
12439- // div/reshape are written
12440- if (reshape8->src[0] != div) {
12441- return false;
12442- }
12443- }
12444- }
12445-
1244612378 if (!ctx->device->subgroup_arithmetic ||
1244712379 !ctx->device->subgroup_shuffle ||
1244812380 !ctx->device->subgroup_require_full_support ||
@@ -12528,11 +12460,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1252812460 ctx->num_additional_fused_ops = num_adds - 1;
1252912461 } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1253012462 ctx->num_additional_fused_ops = 1;
12531- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12463+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12464+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
12465+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1253212466 ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12533- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
12467+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
12468+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
12469+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
1253412470 ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12535- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
12471+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
12472+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
12473+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
1253612474 ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
1253712475 }
1253812476 }
@@ -12631,11 +12569,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1263112569 ctx->num_additional_fused_ops = num_adds - 1;
1263212570 } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1263312571 ctx->num_additional_fused_ops = 1;
12634- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12572+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12573+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
12574+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1263512575 ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12636- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
12576+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
12577+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
12578+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
1263712579 ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12638- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
12580+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
12581+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
12582+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
1263912583 ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
1264012584 }
1264112585 }
0 commit comments