Skip to content

Commit

Permalink
metal : fix GELU kernel numerical stability by using precise::tanh
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 23, 2023
1 parent b693000 commit 0a85ae7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ void ggml_metal_graph_compute(

id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];

const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);

for (int ind = node_start; ind < node_end; ++ind) {
const int i = has_concur ? ctx->concur_list[ind] : ind;
Expand Down
7 changes: 6 additions & 1 deletion ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ kernel void kernel_gelu(
device float * dst,
uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));

// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
// This was observed with Falcon 7B and 40B models
//
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}

kernel void kernel_soft_max(
Expand Down

0 comments on commit 0a85ae7

Please sign in to comment.