Skip to content

Commit f78bafd

Browse files
tamarPaltamarPal
authored andcommitted
sycl: fix trailing whitespace and minor safety casts in ssm_conv
1 parent 2c78b4b commit f78bafd

File tree

2 files changed

+76
-47
lines changed

2 files changed

+76
-47
lines changed

ggml/src/ggml-sycl/ssm_conv.cpp

Lines changed: 69 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "ssm_conv.hpp"
22
#include "common.hpp"
33

4+
#include <cstdio>
5+
46
using namespace sycl;
57

68
static void kernel_ssm_conv(
@@ -18,35 +20,46 @@ static void kernel_ssm_conv(
1820
int dst_stride_token,
1921
int dst_stride_seq
2022
) {
21-
const size_t total_work = d_inner * n_t * n_s;
23+
const size_t total_work = static_cast<size_t>(d_inner) * static_cast<size_t>(n_t) * static_cast<size_t>(n_s);
2224
const size_t work_group_size = 256;
2325
const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;
26+
2427
const range<1> global_range(num_work_groups * work_group_size);
2528
const range<1> local_range(work_group_size);
2629

2730
q.submit([&](handler &h) {
28-
h.parallel_for(nd_range<1>(global_range, local_range), [=](nd_item<1> item) {
29-
const size_t idx = item.get_global_id(0);
30-
31-
if (idx >= total_work) return;
32-
33-
const int channel = idx % d_inner;
34-
const int token = (idx / d_inner) % n_t;
35-
const int seq = idx / (d_inner * n_t);
36-
37-
const float *s = src_data + seq * src_stride_seq + channel * src_stride_inner + token;
38-
const float *c = weights + channel * d_conv;
39-
40-
float sumf = 0.0f;
41-
for (int i0 = 0; i0 < d_conv; ++i0) {
42-
sumf += s[i0] * c[i0];
31+
h.parallel_for(
32+
nd_range<1>(global_range, local_range),
33+
[=](nd_item<1> item) {
34+
const size_t idx = item.get_global_id(0);
35+
if (idx >= total_work) {
36+
return;
37+
}
38+
39+
const int channel = static_cast<int>(idx % d_inner);
40+
const int token = static_cast<int>((idx / d_inner) % n_t);
41+
const int seq = static_cast<int>(idx / (static_cast<size_t>(d_inner) * static_cast<size_t>(n_t)));
42+
43+
const float *s = src_data
44+
+ static_cast<size_t>(seq) * static_cast<size_t>(src_stride_seq)
45+
+ static_cast<size_t>(channel) * static_cast<size_t>(src_stride_inner)
46+
+ static_cast<size_t>(token);
47+
48+
const float *c = weights + static_cast<size_t>(channel) * static_cast<size_t>(d_conv);
49+
50+
float sumf = 0.0f;
51+
for (int i0 = 0; i0 < d_conv; ++i0) {
52+
sumf += s[i0] * c[i0];
53+
}
54+
55+
const size_t dst_idx =
56+
static_cast<size_t>(seq) * static_cast<size_t>(dst_stride_seq) +
57+
static_cast<size_t>(token) * static_cast<size_t>(dst_stride_token) +
58+
static_cast<size_t>(channel);
59+
60+
dst_data[dst_idx] = sumf;
4361
}
44-
45-
const size_t dst_idx = seq * dst_stride_seq +
46-
token * dst_stride_token +
47-
channel;
48-
dst_data[dst_idx] = sumf;
49-
});
62+
);
5063
});
5164
}
5265

@@ -56,45 +69,57 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
5669

5770
GGML_ASSERT(src0->type == GGML_TYPE_F32);
5871
GGML_ASSERT(src1->type == GGML_TYPE_F32);
59-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
60-
61-
const int d_conv = src1->ne[0];
62-
const int ncs = src0->ne[0];
63-
const int d_inner = src0->ne[1];
64-
const int n_t = dst->ne[1];
65-
const int n_s = dst->ne[2];
66-
72+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
73+
74+
const int d_conv = src1->ne[0];
75+
const int ncs = src0->ne[0];
76+
const int d_inner = src0->ne[1];
77+
const int n_t = dst->ne[1];
78+
const int n_s = dst->ne[2];
79+
6780
GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);
6881
GGML_ASSERT(src0->ne[1] == d_inner);
6982
GGML_ASSERT(src1->ne[1] == d_inner);
83+
7084
GGML_ASSERT(dst->ne[0] == d_inner);
7185
GGML_ASSERT(dst->ne[1] == n_t);
7286
GGML_ASSERT(dst->ne[2] == n_s);
73-
87+
7488
GGML_ASSERT(src0->nb[0] == sizeof(float));
7589
GGML_ASSERT(src1->nb[0] == sizeof(float));
76-
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
77-
90+
91+
GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast<int>(sizeof(float)));
92+
7893
const int src_stride_inner = ncs;
79-
const int src_stride_seq = ncs * d_inner;
94+
const int src_stride_seq = ncs * d_inner;
8095
const int dst_stride_token = d_inner;
81-
const int dst_stride_seq = d_inner * n_t;
96+
const int dst_stride_seq = d_inner * n_t;
8297

8398
try {
8499
queue *q = ctx.stream();
85100

86-
const float *src_data = (const float *) src0->data;
87-
const float *weights = (const float *) src1->data;
88-
float *dst_data = (float *) dst->data;
89-
101+
const float *src_data = static_cast<const float *>(src0->data);
102+
const float *weights = static_cast<const float *>(src1->data);
103+
float *dst_data = static_cast<float *>(dst->data);
104+
90105
GGML_ASSERT(src_data && weights && dst_data);
106+
91107
kernel_ssm_conv(
92-
*q, src_data, weights, dst_data,
93-
d_conv, d_inner, n_t, n_s, ncs,
94-
src_stride_inner, src_stride_seq,
95-
dst_stride_token, dst_stride_seq
108+
*q,
109+
src_data,
110+
weights,
111+
dst_data,
112+
d_conv,
113+
d_inner,
114+
n_t,
115+
n_s,
116+
ncs,
117+
src_stride_inner,
118+
src_stride_seq,
119+
dst_stride_token,
120+
dst_stride_seq
96121
);
97-
122+
98123
} catch (const std::exception &e) {
99124
std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what());
100125
throw;

ggml/src/ggml-sycl/ssm_conv.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
#pragma once
1+
#pragma once#pragma once
22

3-
#include "common.hpp"
43

5-
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
4+
5+
#include "common.hpp"#include "common.hpp"
6+
7+
8+
9+
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst);void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)