11#include " ssm_conv.hpp"
22#include " common.hpp"
33
4+ #include < cstdio>
5+
46using namespace sycl ;
57
68static void kernel_ssm_conv (
@@ -18,35 +20,46 @@ static void kernel_ssm_conv(
1820 int dst_stride_token,
1921 int dst_stride_seq
2022) {
21- const size_t total_work = d_inner * n_t * n_s;
23+ const size_t total_work = static_cast < size_t >( d_inner) * static_cast < size_t >( n_t ) * static_cast < size_t >( n_s) ;
2224 const size_t work_group_size = 256 ;
2325 const size_t num_work_groups = (total_work + work_group_size - 1 ) / work_group_size;
26+
2427 const range<1 > global_range (num_work_groups * work_group_size);
2528 const range<1 > local_range (work_group_size);
2629
2730 q.submit ([&](handler &h) {
28- h.parallel_for (nd_range<1 >(global_range, local_range), [=](nd_item<1 > item) {
29- const size_t idx = item.get_global_id (0 );
30-
31- if (idx >= total_work) return ;
32-
33- const int channel = idx % d_inner;
34- const int token = (idx / d_inner) % n_t ;
35- const int seq = idx / (d_inner * n_t );
36-
37- const float *s = src_data + seq * src_stride_seq + channel * src_stride_inner + token;
38- const float *c = weights + channel * d_conv;
39-
40- float sumf = 0 .0f ;
41- for (int i0 = 0 ; i0 < d_conv; ++i0) {
42- sumf += s[i0] * c[i0];
31+ h.parallel_for (
32+ nd_range<1 >(global_range, local_range),
33+ [=](nd_item<1 > item) {
34+ const size_t idx = item.get_global_id (0 );
35+ if (idx >= total_work) {
36+ return ;
37+ }
38+
39+ const int channel = static_cast <int >(idx % d_inner);
40+ const int token = static_cast <int >((idx / d_inner) % n_t );
41+ const int seq = static_cast <int >(idx / (static_cast <size_t >(d_inner) * static_cast <size_t >(n_t )));
42+
43+ const float *s = src_data
44+ + static_cast <size_t >(seq) * static_cast <size_t >(src_stride_seq)
45+ + static_cast <size_t >(channel) * static_cast <size_t >(src_stride_inner)
46+ + static_cast <size_t >(token);
47+
48+ const float *c = weights + static_cast <size_t >(channel) * static_cast <size_t >(d_conv);
49+
50+ float sumf = 0 .0f ;
51+ for (int i0 = 0 ; i0 < d_conv; ++i0) {
52+ sumf += s[i0] * c[i0];
53+ }
54+
55+ const size_t dst_idx =
56+ static_cast <size_t >(seq) * static_cast <size_t >(dst_stride_seq) +
57+ static_cast <size_t >(token) * static_cast <size_t >(dst_stride_token) +
58+ static_cast <size_t >(channel);
59+
60+ dst_data[dst_idx] = sumf;
4361 }
44-
45- const size_t dst_idx = seq * dst_stride_seq +
46- token * dst_stride_token +
47- channel;
48- dst_data[dst_idx] = sumf;
49- });
62+ );
5063 });
5164}
5265
@@ -56,45 +69,57 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
5669
5770 GGML_ASSERT (src0->type == GGML_TYPE_F32);
5871 GGML_ASSERT (src1->type == GGML_TYPE_F32);
59- GGML_ASSERT ( dst->type == GGML_TYPE_F32);
60-
61- const int d_conv = src1->ne [0 ];
62- const int ncs = src0->ne [0 ];
63- const int d_inner = src0->ne [1 ];
64- const int n_t = dst->ne [1 ];
65- const int n_s = dst->ne [2 ];
66-
72+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
73+
74+ const int d_conv = src1->ne [0 ];
75+ const int ncs = src0->ne [0 ];
76+ const int d_inner = src0->ne [1 ];
77+ const int n_t = dst->ne [1 ];
78+ const int n_s = dst->ne [2 ];
79+
6780 GGML_ASSERT (src0->ne [0 ] == d_conv - 1 + n_t );
6881 GGML_ASSERT (src0->ne [1 ] == d_inner);
6982 GGML_ASSERT (src1->ne [1 ] == d_inner);
83+
7084 GGML_ASSERT (dst->ne [0 ] == d_inner);
7185 GGML_ASSERT (dst->ne [1 ] == n_t );
7286 GGML_ASSERT (dst->ne [2 ] == n_s);
73-
87+
7488 GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
7589 GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
76- GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ] * sizeof (float ));
77-
90+
91+ GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ] * static_cast <int >(sizeof (float )));
92+
7893 const int src_stride_inner = ncs;
79- const int src_stride_seq = ncs * d_inner;
94+ const int src_stride_seq = ncs * d_inner;
8095 const int dst_stride_token = d_inner;
81- const int dst_stride_seq = d_inner * n_t ;
96+ const int dst_stride_seq = d_inner * n_t ;
8297
8398 try {
8499 queue *q = ctx.stream ();
85100
86- const float *src_data = ( const float *) src0->data ;
87- const float *weights = ( const float *) src1->data ;
88- float *dst_data = ( float *) dst->data ;
89-
101+ const float *src_data = static_cast < const float *>( src0->data ) ;
102+ const float *weights = static_cast < const float *>( src1->data ) ;
103+ float *dst_data = static_cast < float *>( dst->data ) ;
104+
90105 GGML_ASSERT (src_data && weights && dst_data);
106+
91107 kernel_ssm_conv (
92- *q, src_data, weights, dst_data,
93- d_conv, d_inner, n_t , n_s, ncs,
94- src_stride_inner, src_stride_seq,
95- dst_stride_token, dst_stride_seq
108+ *q,
109+ src_data,
110+ weights,
111+ dst_data,
112+ d_conv,
113+ d_inner,
114+ n_t ,
115+ n_s,
116+ ncs,
117+ src_stride_inner,
118+ src_stride_seq,
119+ dst_stride_token,
120+ dst_stride_seq
96121 );
97-
122+
98123 } catch (const std::exception &e) {
99124 std::fprintf (stderr, " [SYCL-SSM_CONV] ERROR: %s\n " , e.what ());
100125 throw ;
0 commit comments