Skip to content

Commit b353068

Browse files
committed
vulkan: Fuse rope+set_rows
This pattern appears in a lot of models, the rope operation is applied right before storing into the KV cache (usually on the K tensor). Add a path to some of the rope shaders that computes the destination address based on the set_rows tensor. Compile variants of the shader with D_TYPE of f16 (the usual KV cache type). Add a src3 operand to ggml_vk_op_f32 - sometimes rope uses three srcs and needs the fourth for the row indices. Add fused_ops_write_mask to indicate which intermediate tensors need to write their results to memory. Skipping writing the roped K value helps to allow more nodes to run concurrently. Add logic to ggml_vk_graph_optimize to make ROPE+VIEW+SET_ROWS consecutive. It rarely starts out that way in the graph. Add new backend tests.
1 parent 55945d2 commit b353068

File tree

7 files changed

+386
-117
lines changed

7 files changed

+386
-117
lines changed

ggml/src/ggml-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
682682
#endif
683683

684684
#ifdef __cplusplus
685+
#include <array>
685686
#include <initializer_list>
686687
#include <vector>
687688

@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
697698
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
698699
}
699700

701+
// Return true if the edges in the graph match expectations.
702+
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
703+
int start_idx,
704+
std::initializer_list<std::array<int, 3>> edges) {
705+
for (const auto &edge : edges) {
706+
int dst_node = edge[0];
707+
int src_idx = edge[1];
708+
int src_node = edge[2];
709+
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
710+
return false;
711+
}
712+
}
713+
return true;
714+
}
715+
700716
// expose GGUF internals for test code
701717
GGML_API size_t gguf_type_size(enum gguf_type type);
702718
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 246 additions & 86 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
1010
layout (binding = 1) readonly buffer Y {int data_pos[];};
1111
layout (binding = 2) readonly buffer Z {float data_ff[];};
1212
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
13+
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
1314

1415
layout (push_constant) uniform parameter {
1516
uint ncols;
@@ -27,6 +28,7 @@ layout (push_constant) uniform parameter {
2728
uint s2;
2829
int sections[4];
2930
uint is_back;
31+
uint set_rows_stride;
3032
} p;
3133

3234
float rope_yarn_ramp(const float low, const float high, const uint i0) {

ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,19 @@ void main() {
1616
const uint row_x = row_dst % ne1;
1717
const uint channel_x = row_dst / ne1;
1818

19-
const uint idst = row_dst*ne0 + i0/2;
19+
uint idst = row_dst*ne0 + i0/2;
2020
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
2121

22+
// Fusion optimization: ROPE + VIEW + SET_ROWS..
23+
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
24+
if (p.set_rows_stride != 0) {
25+
idst = row_x*ne0 + i0/2;
26+
idst += data_i[channel_x].x * p.set_rows_stride;
27+
}
28+
2229
if (i0 >= p.n_dims) {
23-
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
24-
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
30+
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
31+
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
2532

2633
return;
2734
}

ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,19 @@ void main() {
1616
const uint row_x = row_dst % ne1;
1717
const uint channel_x = row_dst / ne1;
1818

19-
const uint idst = row_dst*ne0 + i0;
19+
uint idst = row_dst*ne0 + i0;
2020
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
2121

22+
// Fusion optimization: ROPE + VIEW + SET_ROWS..
23+
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
24+
if (p.set_rows_stride != 0) {
25+
idst = row_x*ne0 + i0;
26+
idst += data_i[channel_x].x * p.set_rows_stride;
27+
}
28+
2229
if (i0 >= p.n_dims) {
23-
data_d[idst + 0] = data_a[ix + 0];
24-
data_d[idst + 1] = data_a[ix + 1];
30+
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
31+
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
2532

2633
return;
2734
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,10 +841,14 @@ void process_shaders() {
841841
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
842842
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
843843
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
844+
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
845+
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
844846

845847
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
846848
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
847849
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
850+
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
851+
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
848852

849853
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
850854
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

tests/test-backend-ops.cpp

Lines changed: 98 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,6 +2105,35 @@ struct test_get_rows_back : public test_case {
21052105
}
21062106
};
21072107

2108+
static void init_set_rows_row_ids(ggml_tensor * t, int num_rows)
2109+
{
2110+
std::random_device rd;
2111+
std::default_random_engine rng(rd());
2112+
for (int i2 = 0; i2 < t->ne[2]; i2++) {
2113+
for (int i1 = 0; i1 < t->ne[1]; i1++) {
2114+
// generate a shuffled subset of row indices
2115+
std::vector<int64_t> data(num_rows);
2116+
for (int i = 0; i < num_rows; i++) {
2117+
data[i] = i;
2118+
}
2119+
std::shuffle(data.begin(), data.end(), rng);
2120+
data.resize(t->ne[0]);
2121+
2122+
const size_t offs = i1*t->nb[1] + i2*t->nb[2];
2123+
if (t->type == GGML_TYPE_I32) {
2124+
// TODO: Make a template or something
2125+
std::vector<int32_t> data_i32(t->ne[0]);
2126+
for (int i = 0; i < t->ne[0]; i++) {
2127+
data_i32[i] = static_cast<int32_t>(data[i]);
2128+
}
2129+
ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
2130+
} else {
2131+
ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
2132+
}
2133+
}
2134+
}
2135+
}
2136+
21082137
// GGML_OP_SET_ROWS
21092138
struct test_set_rows : public test_case {
21102139
const ggml_type type;
@@ -2148,37 +2177,13 @@ struct test_set_rows : public test_case {
21482177
}
21492178

21502179
void initialize_tensors(ggml_context * ctx) override {
2151-
std::random_device rd;
2152-
std::default_random_engine rng(rd());
21532180
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
21542181
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
21552182
if (ggml_is_view_op(t->op)) {
21562183
continue;
21572184
}
21582185

2159-
for (int i2 = 0; i2 < t->ne[2]; i2++) {
2160-
for (int i1 = 0; i1 < t->ne[1]; i1++) {
2161-
// generate a shuffled subset of row indices
2162-
std::vector<int64_t> data(ne[1]);
2163-
for (int i = 0; i < ne[1]; i++) {
2164-
data[i] = i;
2165-
}
2166-
std::shuffle(data.begin(), data.end(), rng);
2167-
data.resize(t->ne[0]);
2168-
2169-
const size_t offs = i1*t->nb[1] + i2*t->nb[2];
2170-
if (t->type == GGML_TYPE_I32) {
2171-
// TODO: Make a template or something
2172-
std::vector<int32_t> data_i32(t->ne[0]);
2173-
for (int i = 0; i < t->ne[0]; i++) {
2174-
data_i32[i] = static_cast<int32_t>(data[i]);
2175-
}
2176-
ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
2177-
} else {
2178-
ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
2179-
}
2180-
}
2181-
}
2186+
init_set_rows_row_ids(t, ne[1]);
21822187
} else {
21832188
init_tensor_uniform(t);
21842189
}
@@ -2207,6 +2212,67 @@ struct test_set_rows : public test_case {
22072212
}
22082213
};
22092214

2215+
// GGML_OP_ROPE + GGML_OP_VIEW + GGML_OP_SET_ROWS
2216+
struct test_rope_set_rows : public test_case {
2217+
const ggml_type type;
2218+
const ggml_type type_idx;
2219+
const std::array<int64_t, 4> ne;
2220+
int mode;
2221+
2222+
std::string vars() override {
2223+
return VARS_TO_STR4(type, type_idx, ne, mode);
2224+
}
2225+
2226+
std::string op_desc(ggml_tensor * t) override {
2227+
GGML_UNUSED(t);
2228+
return "ROPE_SET_ROWS";
2229+
}
2230+
2231+
bool run_whole_graph() override { return true; }
2232+
2233+
test_rope_set_rows(ggml_type type,
2234+
ggml_type type_idx,
2235+
std::array<int64_t, 4> ne,
2236+
int mode)
2237+
: type(type), type_idx(type_idx), ne(ne), mode(mode) {}
2238+
2239+
ggml_tensor * build_graph(ggml_context * ctx) override {
2240+
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
2241+
ggml_set_name(src, "src");
2242+
2243+
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
2244+
2245+
ggml_tensor * rope = ggml_rope(ctx, src, pos, ne[0], mode);
2246+
2247+
ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
2248+
2249+
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
2250+
ggml_set_name(dst, "dst");
2251+
2252+
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne[2], 1, 1);
2253+
ggml_set_name(row_idxs, "row_idxs");
2254+
2255+
ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs);
2256+
ggml_set_name(out, "out");
2257+
2258+
return out;
2259+
}
2260+
2261+
void initialize_tensors(ggml_context * ctx) override {
2262+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2263+
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
2264+
if (ggml_is_view_op(t->op)) {
2265+
continue;
2266+
}
2267+
2268+
init_set_rows_row_ids(t, ne[2]);
2269+
} else {
2270+
init_tensor_uniform(t);
2271+
}
2272+
}
2273+
}
2274+
};
2275+
22102276
// GGML_OP_ARGMAX
22112277
struct test_argmax : public test_case {
22122278
const ggml_type type;
@@ -6008,6 +6074,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
60086074
}
60096075
}
60106076

6077+
for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX }) {
6078+
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
6079+
test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 1, 100 }, mode));
6080+
test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 512, 1 }, mode));
6081+
}
6082+
}
6083+
60116084
for (ggml_type type_input : {GGML_TYPE_F32}) {
60126085
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
60136086
for (int k0 : {1, 3}) {

0 commit comments

Comments
 (0)