@@ -245,6 +245,7 @@ struct vk_device_struct {
245245 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
246246 vk_pipeline pipeline_timestep_embedding_f32;
247247 vk_pipeline pipeline_pool2d_f32;
248+ vk_pipeline pipeline_rwkv_wkv6_f32;
248249
249250 // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
250251 vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2 ][2 ][2 ];
@@ -528,6 +529,13 @@ struct vk_op_pool2d_push_constants {
528529 int32_t p0; int32_t p1;
529530};
530531
532+ struct vk_op_rwkv_wkv6_push_constants {
533+ uint32_t B;
534+ uint32_t T;
535+ uint32_t C;
536+ uint32_t H;
537+ };
538+
531539// Allow pre-recording command buffers
532540struct vk_staging_memcpy {
533541 vk_staging_memcpy (void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -2014,6 +2022,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20142022
20152023 ggml_vk_create_pipeline (device, device->pipeline_pool2d_f32 , " pool2d_f32" , pool2d_f32_len, pool2d_f32_data, " main" , 2 , sizeof (vk_op_pool2d_push_constants), {512 , 1 , 1 }, {}, 1 );
20162024
2025+ ggml_vk_create_pipeline (device, device->pipeline_rwkv_wkv6_f32 , " rwkv_wkv6_f32" , rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, " main" , 7 , sizeof (vk_op_rwkv_wkv6_push_constants), {1 , 1 , 1 }, {device->subgroup_size }, 1 );
2026+
20172027 for (auto &c : compiles) {
20182028 c.wait ();
20192029 }
@@ -5022,6 +5032,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
50225032 return ctx->device ->pipeline_pool2d_f32 ;
50235033 }
50245034 return nullptr ;
5035+ case GGML_OP_RWKV_WKV6:
5036+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5037+ return ctx->device ->pipeline_rwkv_wkv6_f32 ;
5038+ }
5039+ return nullptr ;
50255040 case GGML_OP_LEAKY_RELU:
50265041 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
50275042 return ctx->device ->pipeline_leaky_relu_f32 ;
@@ -5424,6 +5439,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
54245439 }, dryrun);
54255440}
54265441
5442+ static void ggml_vk_op_f32_rwkv6 (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false ) {
5443+ const ggml_tensor * k = dst->src [0 ];
5444+ const ggml_tensor * v = dst->src [1 ];
5445+ const ggml_tensor * r = dst->src [2 ];
5446+ const ggml_tensor * tf = dst->src [3 ];
5447+ const ggml_tensor * td = dst->src [4 ];
5448+ const ggml_tensor * state = dst->src [5 ];
5449+
5450+ GGML_ASSERT (!ggml_is_quantized (k->type ));
5451+ GGML_ASSERT (!ggml_is_quantized (v->type ));
5452+ GGML_ASSERT (!ggml_is_quantized (r->type ));
5453+ GGML_ASSERT (!ggml_is_quantized (tf->type ));
5454+ GGML_ASSERT (!ggml_is_quantized (td->type ));
5455+ GGML_ASSERT (!ggml_is_quantized (state->type ));
5456+ GGML_ASSERT (dst->buffer != nullptr );
5457+
5458+ vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
5459+ GGML_ASSERT (pipeline != nullptr );
5460+
5461+ if (dryrun) {
5462+ ggml_pipeline_request_descriptor_sets (ctx->device , pipeline, 1 );
5463+ return ;
5464+ }
5465+
5466+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer ->context ;
5467+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer ->context ;
5468+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer ->context ;
5469+ ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer ->context ;
5470+ ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer ->context ;
5471+ ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer ->context ;
5472+ ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer ->context ;
5473+
5474+ ggml_vk_sync_buffers (subctx);
5475+
5476+ vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5477+ uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5478+ bool K_uma = false , V_uma = false , R_uma = false , TF_uma = false , TD_uma = false , STATE_uma = false , DST_uma = false ;
5479+
5480+ if (ctx->device ->uma ) {
5481+ ggml_vk_host_get (ctx->device , k->data , d_K, k_offset);
5482+ ggml_vk_host_get (ctx->device , v->data , d_V, v_offset);
5483+ ggml_vk_host_get (ctx->device , r->data , d_R, r_offset);
5484+ ggml_vk_host_get (ctx->device , tf->data , d_TF, tf_offset);
5485+ ggml_vk_host_get (ctx->device , td->data , d_TD, td_offset);
5486+ ggml_vk_host_get (ctx->device , state->data , d_State, state_offset);
5487+ ggml_vk_host_get (ctx->device , dst->data , d_D, dst_offset);
5488+
5489+ K_uma = d_K != nullptr ;
5490+ V_uma = d_V != nullptr ;
5491+ R_uma = d_R != nullptr ;
5492+ TF_uma = d_TF != nullptr ;
5493+ TD_uma = d_TD != nullptr ;
5494+ STATE_uma = d_State != nullptr ;
5495+ DST_uma = d_D != nullptr ;
5496+ }
5497+
5498+ if (!K_uma) {
5499+ d_K = k_buf_ctx->dev_buffer ;
5500+ k_offset = vk_tensor_offset (k) + k->view_offs ;
5501+ }
5502+ if (!V_uma) {
5503+ d_V = v_buf_ctx->dev_buffer ;
5504+ v_offset = vk_tensor_offset (v) + v->view_offs ;
5505+ }
5506+ if (!R_uma) {
5507+ d_R = r_buf_ctx->dev_buffer ;
5508+ r_offset = vk_tensor_offset (r) + r->view_offs ;
5509+ }
5510+ if (!TF_uma) {
5511+ d_TF = tf_buf_ctx->dev_buffer ;
5512+ tf_offset = vk_tensor_offset (tf) + tf->view_offs ;
5513+ }
5514+ if (!TD_uma) {
5515+ d_TD = td_buf_ctx->dev_buffer ;
5516+ td_offset = vk_tensor_offset (td) + td->view_offs ;
5517+ }
5518+ if (!STATE_uma) {
5519+ d_State = state_buf_ctx->dev_buffer ;
5520+ state_offset = vk_tensor_offset (state) + state->view_offs ;
5521+ }
5522+ if (!DST_uma) {
5523+ d_D = dst_buf_ctx->dev_buffer ;
5524+ dst_offset = vk_tensor_offset (dst) + dst->view_offs ;
5525+ }
5526+
5527+ const uint64_t k_size = ggml_nbytes (k);
5528+ const uint64_t v_size = ggml_nbytes (v);
5529+ const uint64_t r_size = ggml_nbytes (r);
5530+ const uint64_t tf_size = ggml_nbytes (tf);
5531+ const uint64_t td_size = ggml_nbytes (td);
5532+ const uint64_t state_size = ggml_nbytes (state);
5533+ const uint64_t dst_size = ggml_nbytes (dst);
5534+
5535+ std::array<uint32_t , 3 > elements = {
5536+ (uint32_t )(pc.B * pc.H ),
5537+ 1 ,
5538+ 1
5539+ };
5540+
5541+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, {
5542+ vk_subbuffer{ d_K, k_offset, k_size },
5543+ vk_subbuffer{ d_V, v_offset, v_size },
5544+ vk_subbuffer{ d_R, r_offset, r_size },
5545+ vk_subbuffer{ d_TF, tf_offset, tf_size },
5546+ vk_subbuffer{ d_TD, td_offset, td_size },
5547+ vk_subbuffer{ d_State, state_offset, state_size },
5548+ vk_subbuffer{ d_D, dst_offset, dst_size }
5549+ }, sizeof (vk_op_rwkv_wkv6_push_constants), &pc, elements);
5550+ }
5551+
5552+ static void ggml_vk_rwkv_wkv6 (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false ) {
5553+ const size_t seq_length = dst->src [0 ]->ne [3 ];
5554+ const size_t n_embed = dst->ne [0 ];
5555+ const size_t n_heads = dst->src [0 ]->ne [2 ];
5556+ const size_t n_seqs = dst->src [5 ]->ne [1 ];
5557+
5558+ ggml_vk_op_f32_rwkv6 (
5559+ ctx, subctx, dst,
5560+ {
5561+ (uint32_t )n_seqs,
5562+ (uint32_t )seq_length,
5563+ (uint32_t )n_embed,
5564+ (uint32_t )n_heads,
5565+ },
5566+ dryrun
5567+ );
5568+ }
5569+
54275570static void ggml_vk_concat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
54285571 int * op_params = (int *)dst->op_params ;
54295572
@@ -6569,6 +6712,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
65696712 case GGML_OP_IM2COL:
65706713 case GGML_OP_TIMESTEP_EMBEDDING:
65716714 case GGML_OP_POOL_2D:
6715+ case GGML_OP_RWKV_WKV6:
65726716 case GGML_OP_LEAKY_RELU:
65736717 case GGML_OP_FLASH_ATTN_EXT:
65746718 break ;
@@ -6768,6 +6912,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
67686912 case GGML_OP_FLASH_ATTN_EXT:
67696913 ggml_vk_flash_attn (ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
67706914
6915+ break ;
6916+
6917+ case GGML_OP_RWKV_WKV6:
6918+ ggml_vk_rwkv_wkv6 (ctx, compute_ctx, node, dryrun);
6919+
67716920 break ;
67726921 default :
67736922 return false ;
@@ -6848,6 +6997,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
68486997 case GGML_OP_IM2COL:
68496998 case GGML_OP_TIMESTEP_EMBEDDING:
68506999 case GGML_OP_POOL_2D:
7000+ case GGML_OP_RWKV_WKV6:
68517001 case GGML_OP_LEAKY_RELU:
68527002 case GGML_OP_REPEAT:
68537003 buf = tensor->buffer ;
@@ -7724,6 +7874,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
77247874 case GGML_OP_IM2COL:
77257875 case GGML_OP_TIMESTEP_EMBEDDING:
77267876 case GGML_OP_POOL_2D:
7877+ case GGML_OP_RWKV_WKV6:
77277878 case GGML_OP_LEAKY_RELU:
77287879 return true ;
77297880 default :
@@ -8300,7 +8451,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
83008451 } else if (tensor->op == GGML_OP_LEAKY_RELU) {
83018452 const float * op_params = (const float *)tensor->op_params ;
83028453 tensor_clone = ggml_leaky_relu (ggml_ctx, src0_clone, op_params[0 ], false );
8303- } else {
8454+ } else if (tensor->op == GGML_OP_RWKV_WKV6) {
8455+ tensor_clone = ggml_rwkv_wkv6 (ggml_ctx, tensor->src [0 ], tensor->src [1 ], tensor->src [2 ], tensor->src [3 ],
8456+ tensor->src [4 ], tensor->src [5 ]);
8457+ }
8458+ else {
83048459 std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
83058460 GGML_ABORT (" fatal error" );
83068461 }
0 commit comments