Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2964,6 +2964,36 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
}
#endif

static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
const ggml_tensor * view,
const ggml_tensor * set_rows) {
// ne3 not tested
if (rope->src[0]->ne[3] != 1) {
return false;
}

if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
return false;
}

if (set_rows->src[1]->type != GGML_TYPE_I64) {
return false;
}

// The view should flatten two dims of rope into one dim
if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
return false;
}

// Only norm/neox shaders have the fusion code
const int mode = ((const int32_t *) rope->op_params)[2];
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
return false;
}

return true;
}

static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
#ifndef NDEBUG
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
Expand Down Expand Up @@ -3039,6 +3069,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}

if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * rope = cgraph->nodes[node_idx];
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];

if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
return true;
}
}

if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
Expand Down Expand Up @@ -3170,6 +3210,15 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
continue;
}

if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
ggml_tensor * rope = cgraph->nodes[i];
ggml_tensor * set_rows = cgraph->nodes[i + 2];

ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
i += 2;
continue;
}

if (node->op == GGML_OP_MUL) {
int current_node = i + 1;
int num_views = 0;
Expand Down
213 changes: 154 additions & 59 deletions ggml/src/ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "ggml-cuda/common.cuh"
#include "ggml.h"
#include "rope.cuh"

struct rope_corr_dims {
Expand Down Expand Up @@ -37,11 +39,23 @@ static __device__ void rope_yarn(
}
}

template<bool forward, bool has_ff, typename T>
static __global__ void rope_norm(
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_norm(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (i0 >= ne0) {
Expand All @@ -53,12 +67,19 @@ static __global__ void rope_norm(
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;

const int idst = row_dst*ne0 + i0;
int idst = row_dst * ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;

// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0;
idst += row_indices[channel_x] * set_rows_stride;
}

if (i0 >= n_dims) {
dst[idst + 0] = x[ix + 0];
dst[idst + 1] = x[ix + 1];
dst[idst + 0] = D(x[ix + 0]);
dst[idst + 1] = D(x[ix + 1]);
Comment on lines +81 to +82
Copy link
Collaborator

@JohannesGaessler JohannesGaessler Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest you use ggml_cuda_cast defined in convert.cuh. Otherwise there will potentially be issues with FP16 <-> BF16 conversions.


return;
}
Expand All @@ -75,15 +96,27 @@ static __global__ void rope_norm(
const float x0 = x[ix + 0];
const float x1 = x[ix + 1];

dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta);
dst[idst + 1] = D(x0 * sin_theta + x1 * cos_theta);
Comment on lines +99 to +100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you're already working on optimizing RoPE: I think the memory access pattern here is suboptimal because there are gaps between each thread and I don't know whether the compiler is smart enough to combine the first and second write into a single one. I would suggest grouping the values as float2/half2 and either casting dst to that type or using ggml_cuda_memcpy_1.

}

template<bool forward, bool has_ff, typename T>
static __global__ void rope_neox(
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_neox(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (i0 >= ne0) {
Expand All @@ -95,12 +128,19 @@ static __global__ void rope_neox(
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;

const int idst = row_dst*ne0 + i0/2;
int idst = row_dst * ne0 + i0 / 2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;

// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0 / 2;
idst += row_indices[channel_x] * set_rows_stride;
}

if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
dst[idst + i0 / 2 + 0] = D(x[ix + i0 / 2 + 0]);
dst[idst + i0 / 2 + 1] = D(x[ix + i0 / 2 + 1]);

return;
}
Expand All @@ -117,8 +157,8 @@ static __global__ void rope_neox(
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2];

dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta);
dst[idst + n_dims / 2] = D(x0 * sin_theta + x1 * cos_theta);
}

template<bool forward, bool has_ff, typename T>
Expand Down Expand Up @@ -238,11 +278,25 @@ static __global__ void rope_vision(
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
}

template<bool forward, typename T>
static void rope_norm_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
template <bool forward, typename T, typename D>
static void rope_norm_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
Expand All @@ -252,20 +306,34 @@ static void rope_norm_cuda(

if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors);
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
} else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors);
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
}
}

template<bool forward, typename T>
static void rope_neox_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
template <bool forward, typename T, typename D>
static void rope_neox_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
Expand All @@ -274,13 +342,13 @@ static void rope_neox_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims);

if (freq_factors == nullptr) {
rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors);
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
} else {
rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors);
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
}
}

Expand Down Expand Up @@ -333,20 +401,35 @@ static void rope_vision_cuda(
}

template <bool forward>
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
const ggml_tensor * set_rows = nullptr) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;

float * dst_d = (float *)dst->data;
void * dst_d = dst->data;
const int64_t * row_indices = nullptr;
ggml_type dst_type = dst->type;
int set_rows_stride = 0;

if (set_rows != nullptr) {
GGML_ASSERT(forward);
dst_d = set_rows->data;
row_indices = (const int64_t *) set_rows->src[1]->data;
dst_type = set_rows->type;
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
}
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
// When not fused, src0 and dst types must match
// When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
GGML_ASSERT(src0->type == dst->type || dst->type == GGML_TYPE_F16);

const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
Expand Down Expand Up @@ -404,14 +487,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)

// compute
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
}
Expand Down Expand Up @@ -440,14 +527,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_norm_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
}
Expand All @@ -461,3 +552,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<false>(ctx, dst);
}

void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/rope.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
Loading
Loading