Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 87 additions & 19 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ struct vk_device_struct {

ggml_backend_buffer_type buffer_type;

bool disable_fusion;

#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger;
#endif
Expand Down Expand Up @@ -1091,8 +1093,8 @@ static size_t vk_skip_checks;
static size_t vk_output_tensor;

static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
static void ggml_vk_check_results_0(ggml_tensor * tensor);
static void ggml_vk_check_results_1(ggml_tensor * tensor);
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
#endif

typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
Expand Down Expand Up @@ -3507,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) {

device->idx = idx;

device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;

return device;
}

Expand Down Expand Up @@ -7654,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}

static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
Expand Down Expand Up @@ -8885,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
}
}

static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);

// Returns true if node has enqueued work into the queue, false otherwise
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
Expand Down Expand Up @@ -9146,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// fused rms_norm + mul
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
} else {
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
}
break;
case GGML_OP_RMS_NORM_BACK:
Expand Down Expand Up @@ -9308,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr

ctx->compute_ctx.reset();

bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
if (!ok) {
if (node->op == GGML_OP_UNARY) {
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
Expand All @@ -9323,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
return true;
}

static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
GGML_UNUSED(cgraph);
ggml_backend_buffer * buf = nullptr;

switch (tensor->op) {
Expand Down Expand Up @@ -9433,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
// Only run if ctx hasn't been submitted yet
if (!subctx->seqs.empty()) {
#ifdef GGML_VULKAN_CHECK_RESULTS
ggml_vk_check_results_0(tensor);
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
use_fence = true;
#endif

Expand All @@ -9453,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
ggml_vk_wait_for_fence(ctx);
}
#ifdef GGML_VULKAN_CHECK_RESULTS
ggml_vk_check_results_1(tensor);
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
#endif
}

Expand Down Expand Up @@ -9900,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
}

static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}

if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
// additional constraints specific to this fusion
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];

GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
// rms_norm only supports f32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
return false;
}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
mul->src[0]->ne[1] != rms_norm->ne[1]) {
return false;
}
// rms_norm shader assumes contiguous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
}
return true;
}

static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
Expand All @@ -9913,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg

uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
Expand Down Expand Up @@ -9983,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}

if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}

Expand Down Expand Up @@ -10760,11 +10795,21 @@ void * comp_result;
size_t comp_size;
size_t comp_nb[GGML_MAX_DIMS];
size_t check_counter = 0;
static void ggml_vk_check_results_0(ggml_tensor * tensor) {
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
if (tensor->op == GGML_OP_TRANSPOSE) {
return;
}

bool fused_rms_norm_mul = false;
int rms_norm_idx = -1;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}

check_counter++;
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return;
Expand Down Expand Up @@ -10792,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {

for (int i = 0; i < 6; i++) {
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
switch (i) {
case 0: srci = rms_norm->src[0]; break;
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
default: continue;
}
}
if (srci == nullptr) {
continue;
}
Expand Down Expand Up @@ -10849,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_SUB) {
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL) {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
if (fused_rms_norm_mul) {
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
} else {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
}
} else if (tensor->op == GGML_OP_DIV) {
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_CONCAT) {
Expand Down Expand Up @@ -11040,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
GGML_ABORT("fatal error");
}

ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph, tensor_clone);
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph_cpu, tensor_clone);

ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);

if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
Expand All @@ -11066,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
}

static void ggml_vk_check_results_1(ggml_tensor * tensor) {
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
if (tensor->op == GGML_OP_TRANSPOSE) {
return;
}
bool fused_rms_norm_mul = false;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}

if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return;
}
Expand Down
6 changes: 1 addition & 5 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2583,10 +2583,6 @@ struct test_rms_norm_mul : public test_case {
}
}

double max_nmse_err() override {
return 1e-6;
}

float grad_eps() override {
return 1.0f;
}
Expand Down Expand Up @@ -5058,7 +5054,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}

Expand Down
Loading