Skip to content

CUDA: add fused rms norm #14800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include <cstddef>
#include <cstdint>
#include <float.h>
#include <initializer_list>
#include <limits>
#include <map>
#include <memory>
Expand Down Expand Up @@ -2765,6 +2766,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
}
#endif

static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}

if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];

GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);

//rms norm only supports F32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
return false;
}

//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
}

//rms_norm kernel assumes contigous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
}

return true;
}

static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu
Expand All @@ -2774,13 +2808,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {

for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];

if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}

static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
Expand Down
97 changes: 92 additions & 5 deletions ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
}

template <int block_size>
template <int block_size, bool do_multiply = false>
static __global__ void rms_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;

Expand All @@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;

if constexpr (do_multiply) {
const int mul_row = row % mul_nrows;
const int mul_channel = channel % mul_nchannels;
const int mul_sample = sample % mul_nsamples;
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
}

float tmp = 0.0f; // partial sum for thread in warp

for (int col = tid; col < ncols; col += block_size) {
Expand All @@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
const float scale = rsqrtf(mean + eps);

for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
} else {
dst[col] = scale * x[col];
}
}
}

Expand Down Expand Up @@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}

static void rms_norm_mul_f32_cuda(
const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
}
}

Expand Down Expand Up @@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}

void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
float eps = 0.0f;

memcpy(&eps, dst->op_params, sizeof(float));

const float * src0_d = (const float *) rms_norm_src->data;
const float * mul_d = nullptr;
const ggml_tensor * mul_src = nullptr;

if (mul_tensor->src[0] == dst) {
mul_d = (float *) mul_tensor->src[1]->data;
mul_src = mul_tensor->src[1];
} else if(mul_tensor->src[1] == dst) {
mul_d = (float *) mul_tensor->src[0]->data;
mul_src = mul_tensor->src[0];
} else {
GGML_ASSERT(false);
}

float * dst_d = (float *) mul_tensor->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(eps >= 0.0f);

const int64_t ne00 = rms_norm_src->ne[0];
const int64_t ne01 = rms_norm_src->ne[1];
const int64_t ne02 = rms_norm_src->ne[2];
const int64_t ne03 = rms_norm_src->ne[3];

const size_t ts0 = ggml_type_size(rms_norm_src->type);
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
const int64_t s01 = rms_norm_src->nb[1] / ts0;
const int64_t s02 = rms_norm_src->nb[2] / ts0;
const int64_t s03 = rms_norm_src->nb[3] / ts0;

const size_t ts_mul = ggml_type_size(mul_src->type);
GGML_ASSERT(mul_src->nb[0] == ts_mul);
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;

const int mul_ncols = mul_src->ne[0];
const int mul_nrows = mul_src->ne[1];
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];

rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
}

void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0]; // gradients
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)

void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);

void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
13 changes: 9 additions & 4 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2641,6 +2641,7 @@ struct test_rms_norm_mul_add : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float eps;
const bool broadcast;

std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
Expand All @@ -2650,18 +2651,21 @@ struct test_rms_norm_mul_add : public test_case {
bool run_whole_graph() override { return true; }

std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
return VARS_TO_STR4(type, ne, eps, broadcast);
}

test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
float eps = 1e-6f, bool broadcast = false)
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};

ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());

ggml_set_param(a);
ggml_set_name(a, "a");
ggml_set_param(b);
Expand Down Expand Up @@ -5354,6 +5358,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
}

test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
Expand Down
Loading