@@ -252,10 +252,10 @@ struct webgpu_context_struct {
252252 webgpu_pipeline get_rows_pipeline[30 ];
253253 webgpu_pipeline get_rows_f32_no_vec_pipeline;
254254 webgpu_pipeline cpy_pipeline[2 ][2 ]; // src type, dst type
255- webgpu_pipeline add_pipeline[2 ][2 ]; // type, inplace
256- webgpu_pipeline sub_pipeline[2 ][2 ]; // type, inplace
257- webgpu_pipeline mul_pipeline[2 ][2 ]; // type, inplace
258- webgpu_pipeline div_pipeline[2 ][2 ]; // type, inplace
255+ webgpu_pipeline add_pipeline[2 ][2 ][ 2 ] ; // type, inplace, overlap
256+ webgpu_pipeline sub_pipeline[2 ][2 ][ 2 ] ; // type, inplace, overlap
257+ webgpu_pipeline mul_pipeline[2 ][2 ][ 2 ] ; // type, inplace, overlap
258+ webgpu_pipeline div_pipeline[2 ][2 ][ 2 ] ; // type, inplace, overlap
259259 webgpu_pipeline rms_norm_pipeline[2 ]; // inplace
260260 webgpu_pipeline rope_pipeline[2 ][2 ][2 ]; // type, ff, inplace
261261 webgpu_pipeline glu_pipeline[7 ][2 ][2 ]; // glu-op, type, split
@@ -677,9 +677,12 @@ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor
677677 return offset & ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
678678}
679679
680+ static size_t ggml_webgpu_tensor_align_binding_size (size_t size) {
681+ return (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1 );
682+ }
683+
680684static size_t ggml_webgpu_tensor_binding_size (webgpu_context & ctx, ggml_tensor * t) {
681- return (ggml_nbytes (t) + ggml_webgpu_tensor_misalignment (ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) &
682- ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1 );
685+ return ggml_webgpu_tensor_align_binding_size (ggml_nbytes (t) + ggml_webgpu_tensor_misalignment (ctx, t));
683686}
684687
685688// Used to determine if two tensors are the same for in-place operations
@@ -688,6 +691,12 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
688691 (ggml_webgpu_tensor_offset (a) == ggml_webgpu_tensor_offset (b));
689692}
690693
694+ static bool ggml_webgpu_tensor_overlap (ggml_tensor * a, ggml_tensor * b) {
695+ return (ggml_webgpu_tensor_buf (a).Get () == ggml_webgpu_tensor_buf (b).Get ()) &&
696+ ggml_webgpu_tensor_offset (a) < (ggml_webgpu_tensor_offset (b) + ggml_nbytes (b)) &&
697+ ggml_webgpu_tensor_offset (b) < (ggml_webgpu_tensor_offset (a) + ggml_nbytes (a));
698+ }
699+
691700static webgpu_command ggml_webgpu_cpy (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
692701 uint32_t ne = (uint32_t ) ggml_nelements (dst);
693702
@@ -870,16 +879,27 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
870879 return ggml_backend_webgpu_build (ctx, ctx->mul_mat_pipeline [src0->type ][src1->type ], params, entries, wg_x);
871880}
872881
873- static webgpu_command ggml_webgpu_binary_op (webgpu_context & ctx,
874- ggml_tensor * src0,
875- ggml_tensor * src1,
876- ggml_tensor * dst,
877- webgpu_pipeline & pipeline,
878- bool inplace) {
882+ template <size_t a, size_t b, size_t c>
883+ static webgpu_command ggml_webgpu_binary_op (webgpu_context & ctx,
884+ ggml_tensor * src0,
885+ ggml_tensor * src1,
886+ ggml_tensor * dst,
887+ webgpu_pipeline (&pipelines)[a][b][c]) {
888+ int inplace = ggml_webgpu_tensor_equal (src0, dst);
889+ int overlap = ggml_webgpu_tensor_overlap (src0, src1);
890+ webgpu_pipeline pipeline = pipelines[dst->type ][inplace][overlap];
891+
892+ uint32_t src1_offset = ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1->type );
893+ if (overlap) {
894+ // when overlapped, bind a single buffer covering both src0 and src1
895+ // TODO: Do other operations need this?
896+ src1_offset = (uint32_t ) ((ggml_webgpu_tensor_offset (src1) - ggml_webgpu_tensor_align_offset (ctx, src0)) /
897+ ggml_type_size (src1->type ));
898+ }
879899 std::vector<uint32_t > params = {
880900 (uint32_t ) ggml_nelements (dst),
881901 (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src0) / ggml_type_size (src0->type )),
882- ( uint32_t ) ( ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1-> type )) ,
902+ src1_offset ,
883903 (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
884904 (uint32_t ) (src1->nb [0 ] / ggml_type_size (src1->type )),
885905 (uint32_t ) (src1->nb [1 ] / ggml_type_size (src1->type )),
@@ -894,25 +914,36 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
894914 (uint32_t ) src1->ne [3 ],
895915 };
896916
917+ size_t src0_binding_size = ggml_webgpu_tensor_binding_size (ctx, src0);
918+ if (overlap) {
919+ const uint64_t base_align = ggml_webgpu_tensor_align_offset (ctx, src0);
920+ // assume end of src1 is >= end of src0
921+ const uint64_t max_end = ggml_webgpu_tensor_offset (src1) + ggml_nbytes (src1);
922+ src0_binding_size = ggml_webgpu_tensor_align_binding_size (max_end - base_align);
923+ }
897924 std::vector<wgpu::BindGroupEntry> entries = {
898925 { .binding = 0 ,
899926 .buffer = ggml_webgpu_tensor_buf (src0),
900927 .offset = ggml_webgpu_tensor_align_offset (ctx, src0),
901- .size = ggml_webgpu_tensor_binding_size (ctx, src0) },
902- { .binding = 1 ,
903- .buffer = ggml_webgpu_tensor_buf (src1),
904- .offset = ggml_webgpu_tensor_align_offset (ctx, src1),
905- .size = ggml_webgpu_tensor_binding_size (ctx, src1) }
928+ .size = src0_binding_size }
906929 };
930+ uint32_t binding_num = 1 ;
931+ if (!overlap) {
932+ entries.push_back ({ .binding = binding_num,
933+ .buffer = ggml_webgpu_tensor_buf (src1),
934+ .offset = ggml_webgpu_tensor_align_offset (ctx, src1),
935+ .size = ggml_webgpu_tensor_binding_size (ctx, src1) });
936+ binding_num++;
937+ }
907938 if (!inplace) {
908- entries.push_back ({ .binding = 2 ,
939+ entries.push_back ({ .binding = binding_num ,
909940 .buffer = ggml_webgpu_tensor_buf (dst),
910941 .offset = ggml_webgpu_tensor_align_offset (ctx, dst),
911942 .size = ggml_webgpu_tensor_binding_size (ctx, dst) });
912943 }
913944
914- size_t max_wg_size = ctx->max_wg_size_x ;
915- uint32_t wg_x = (ggml_nelements (dst) + max_wg_size - 1 ) / max_wg_size;
945+ size_t max_wg_size = ctx->max_wg_size_x ;
946+ uint32_t wg_x = (ggml_nelements (dst) + max_wg_size - 1 ) / max_wg_size;
916947 return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x);
917948}
918949
@@ -1232,25 +1263,13 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
12321263 case GGML_OP_MUL_MAT:
12331264 return ggml_webgpu_mul_mat (ctx, src0, src1, node);
12341265 case GGML_OP_ADD:
1235- {
1236- int inplace = ggml_webgpu_tensor_equal (src0, node);
1237- return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->add_pipeline [node->type ][inplace], inplace);
1238- }
1266+ return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->add_pipeline );
12391267 case GGML_OP_SUB:
1240- {
1241- int inplace = ggml_webgpu_tensor_equal (src0, node);
1242- return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->sub_pipeline [node->type ][inplace], inplace);
1243- }
1268+ return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->sub_pipeline );
12441269 case GGML_OP_MUL:
1245- {
1246- int inplace = ggml_webgpu_tensor_equal (src0, node);
1247- return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->mul_pipeline [node->type ][inplace], inplace);
1248- }
1270+ return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->mul_pipeline );
12491271 case GGML_OP_DIV:
1250- {
1251- int inplace = ggml_webgpu_tensor_equal (src0, node);
1252- return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->div_pipeline [node->type ][inplace], inplace);
1253- }
1272+ return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->div_pipeline );
12541273 case GGML_OP_RMS_NORM:
12551274 return ggml_webgpu_rms_norm (ctx, src0, node);
12561275 case GGML_OP_ROPE:
@@ -1700,50 +1719,82 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
17001719
17011720static void ggml_webgpu_init_add_pipeline (webgpu_context & webgpu_ctx) {
17021721 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x );
1703- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][0 ], wgsl_add_f32, " add_f32 " ,
1704- constants);
1705- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][0 ], wgsl_add_f16, " add_f16 " ,
1706- constants);
1707- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][1 ], wgsl_add_f32_inplace,
1722+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][0 ][ 0 ] , wgsl_add_f32,
1723+ " add_f32 " , constants);
1724+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][0 ][ 0 ] , wgsl_add_f16,
1725+ " add_f16 " , constants);
1726+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][1 ][ 0 ] , wgsl_add_f32_inplace,
17081727 " add_f32_inplace" , constants);
1709- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][1 ], wgsl_add_f16_inplace,
1728+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][1 ][ 0 ] , wgsl_add_f16_inplace,
17101729 " add_f16_inplace" , constants);
1730+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][0 ][1 ], wgsl_add_f32_overlap,
1731+ " add_f32_overlap" , constants);
1732+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][1 ][1 ],
1733+ wgsl_add_f32_inplace_overlap, " add_f32_inplace_overlap" , constants);
1734+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][0 ][1 ], wgsl_add_f16_overlap,
1735+ " add_f16_overlap" , constants);
1736+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][1 ][1 ],
1737+ wgsl_add_f16_inplace_overlap, " add_f16_inplace_overlap" , constants);
17111738}
17121739
17131740static void ggml_webgpu_init_sub_pipeline (webgpu_context & webgpu_ctx) {
17141741 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x );
1715- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][0 ], wgsl_sub_f32, " sub_f32 " ,
1716- constants);
1717- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][0 ], wgsl_sub_f16, " sub_f16 " ,
1718- constants);
1719- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][1 ], wgsl_sub_f32_inplace,
1742+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][0 ][ 0 ] , wgsl_sub_f32,
1743+ " sub_f32 " , constants);
1744+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][0 ][ 0 ] , wgsl_sub_f16,
1745+ " sub_f16 " , constants);
1746+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][1 ][ 0 ] , wgsl_sub_f32_inplace,
17201747 " sub_f32_inplace" , constants);
1721- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][1 ], wgsl_sub_f16_inplace,
1748+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][1 ][ 0 ] , wgsl_sub_f16_inplace,
17221749 " sub_f16_inplace" , constants);
1750+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][0 ][1 ], wgsl_sub_f32_overlap,
1751+ " sub_f32_overlap" , constants);
1752+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][1 ][1 ],
1753+ wgsl_sub_f32_inplace_overlap, " sub_f32_inplace_overlap" , constants);
1754+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][0 ][1 ], wgsl_sub_f16_overlap,
1755+ " sub_f16_overlap" , constants);
1756+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][1 ][1 ],
1757+ wgsl_sub_f16_inplace_overlap, " sub_f16_inplace_overlap" , constants);
17231758}
17241759
17251760static void ggml_webgpu_init_mul_pipeline (webgpu_context & webgpu_ctx) {
17261761 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x );
1727- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][0 ], wgsl_mul_f32, " mul_f32 " ,
1728- constants);
1729- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][0 ], wgsl_mul_f16, " mul_f16 " ,
1730- constants);
1731- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][1 ], wgsl_mul_f32_inplace,
1762+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][0 ][ 0 ] , wgsl_mul_f32,
1763+ " mul_f32 " , constants);
1764+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][0 ][ 0 ] , wgsl_mul_f16,
1765+ " mul_f16 " , constants);
1766+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][1 ][ 0 ] , wgsl_mul_f32_inplace,
17321767 " mul_f32_inplace" , constants);
1733- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][1 ], wgsl_mul_f16_inplace,
1768+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][1 ][ 0 ] , wgsl_mul_f16_inplace,
17341769 " mul_f16_inplace" , constants);
1770+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][0 ][1 ], wgsl_mul_f32_overlap,
1771+ " mul_f32_overlap" , constants);
1772+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][1 ][1 ],
1773+ wgsl_mul_f32_inplace_overlap, " mul_f32_inplace_overlap" , constants);
1774+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][0 ][1 ], wgsl_mul_f16_overlap,
1775+ " mul_f16_overlap" , constants);
1776+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][1 ][1 ],
1777+ wgsl_mul_f16_inplace_overlap, " mul_f16_inplace_overlap" , constants);
17351778}
17361779
17371780static void ggml_webgpu_init_div_pipeline (webgpu_context & webgpu_ctx) {
17381781 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x );
1739- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][0 ], wgsl_div_f32, " div_f32 " ,
1740- constants);
1741- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][0 ], wgsl_div_f16, " div_f16 " ,
1742- constants);
1743- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][1 ], wgsl_div_f32_inplace,
1782+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][0 ][ 0 ] , wgsl_div_f32,
1783+ " div_f32 " , constants);
1784+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][0 ][ 0 ] , wgsl_div_f16,
1785+ " div_f16 " , constants);
1786+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][1 ][ 0 ] , wgsl_div_f32_inplace,
17441787 " div_f32_inplace" , constants);
1745- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][1 ], wgsl_div_f16_inplace,
1788+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][1 ][ 0 ] , wgsl_div_f16_inplace,
17461789 " div_f16_inplace" , constants);
1790+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][0 ][1 ], wgsl_div_f32_overlap,
1791+ " div_f32_overlap" , constants);
1792+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][1 ][1 ],
1793+ wgsl_div_f32_inplace_overlap, " div_f32_inplace_overlap" , constants);
1794+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][0 ][1 ], wgsl_div_f16_overlap,
1795+ " div_f16_overlap" , constants);
1796+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][1 ][1 ],
1797+ wgsl_div_f16_inplace_overlap, " div_f16_inplace_overlap" , constants);
17471798}
17481799
17491800static void ggml_webgpu_init_rms_norm_pipeline (webgpu_context & webgpu_ctx) {
@@ -2152,9 +2203,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
21522203 // TODO: Don't enable for WASM builds, they won't have an effect anyways
21532204 // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
21542205 // only for native performance?
2155- const char * const deviceEnabledToggles[] = { " skip_validation " , " disable_robustness" , " disable_workgroup_init" ,
2156- " disable_polyfills_on_integer_div_and_mod" };
2157- const char * const deviceDisabledToggles[] = { " timestamp_quantization" };
2206+ const char * const deviceEnabledToggles[] = { " disable_robustness" , " disable_workgroup_init" ,
2207+ " disable_polyfills_on_integer_div_and_mod" };
2208+ const char * const deviceDisabledToggles[] = { " timestamp_quantization" };
21582209 wgpu::DawnTogglesDescriptor deviceTogglesDesc;
21592210 deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
21602211 deviceTogglesDesc.enabledToggleCount = 4 ;
0 commit comments