diff --git a/common/arg.cpp b/common/arg.cpp
index 78cf6ab3058b4..205177d4695bd 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1163,14 +1163,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params, int value) {
             params.grp_attn_n = value;
         }
-    ).set_env("LLAMA_ARG_GRP_ATTN_N"));
+    ).set_env("LLAMA_ARG_GRP_ATTN_N").set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_PASSKEY}));
     add_opt(common_arg(
         {"-gaw", "--grp-attn-w"}, "N",
-        string_format("group-attention width (default: %.1f)", (double)params.grp_attn_w),
+        string_format("group-attention width (default: %d)", params.grp_attn_w),
         [](common_params & params, int value) {
             params.grp_attn_w = value;
         }
-    ).set_env("LLAMA_ARG_GRP_ATTN_W"));
+    ).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN}));
     add_opt(common_arg(
         {"-dkvc", "--dump-kv-cache"},
         "verbose print of the KV cache",
diff --git a/examples/server/README.md b/examples/server/README.md
index 52ccd9f5ee0ab..caffbac527306 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -60,8 +60,6 @@ The project is under active development, and we are [looking for feedback and co
 | `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)<br/>(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
 | `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
 | `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
-| `-gan, --grp-attn-n N` | group-attention factor (default: 1)<br/>(env: LLAMA_ARG_GRP_ATTN_N) |
-| `-gaw, --grp-attn-w N` | group-attention width (default: 512.0)<br/>(env: LLAMA_ARG_GRP_ATTN_W) |
 | `-dkvc, --dump-kv-cache` | verbose print of the KV cache |
 | `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
 | `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 42b57d9c4c4dd..0dd2fc8b204cc 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -193,21 +193,15 @@ struct server_slot {
 
     llama_token sampled;
 
-    int32_t ga_i = 0;   // group-attention state
-    int32_t ga_n = 1;   // group-attention factor
-    int32_t ga_w = 512; // group-attention width
-
-    int32_t n_past_se = 0; // self-extend
-
     // stats
-    size_t n_sent_text = 0; // number of sent text character
+    size_t n_sent_text        = 0; // number of sent text character
     size_t n_sent_token_probs = 0;
 
     int64_t t_start_process_prompt;
     int64_t t_start_generation;
 
     double t_prompt_processing; // ms
-    double t_token_generation; // ms
+    double t_token_generation;  // ms
 
     std::function<void(int)> callback_on_release;
 
@@ -225,8 +219,6 @@ struct server_slot {
         n_sent_text        = 0;
         n_sent_token_probs = 0;
         cmpl_type          = SERVER_TASK_CMPL_TYPE_NORMAL;
-        ga_i               = 0;
-        n_past_se          = 0;
 
         generated_token_probs.clear();
     }
@@ -705,22 +697,6 @@ struct server_context {
 
             SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
 
-            const int ga_n = params.grp_attn_n;
-            const int ga_w = params.grp_attn_w;
-
-            if (ga_n != 1) {
-                GGML_ASSERT(ga_n > 0                    && "ga_n must be positive");                       // NOLINT
-                GGML_ASSERT(ga_w % ga_n == 0            && "ga_w must be a multiple of ga_n");             // NOLINT
-                //GGML_ASSERT(n_ctx_train % ga_w == 0     && "n_ctx_train must be a multiple of ga_w");    // NOLINT
-                //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
-
-                SLT_INF(slot, "slot self-extend: ga_n = %d, ga_w = %d\n", ga_n, ga_w);
-            }
-
-            slot.ga_i = 0;
-            slot.ga_n = ga_n;
-            slot.ga_w = ga_w;
-
             slot.sparams = params.sparams;
 
             slot.callback_on_release = [this](int) {
@@ -906,19 +882,14 @@ struct server_context {
         }
         if (data.contains("json_schema") && !data.contains("grammar")) {
             try {
-                auto schema                = json_value(data, "json_schema", json::object());
-                slot.sparams.grammar       = json_schema_to_grammar(schema);
+                auto schema          = json_value(data, "json_schema", json::object());
+                slot.sparams.grammar = json_schema_to_grammar(schema);
             } catch (const std::exception & e) {
                 send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
                 return false;
             }
         } else {
-            slot.sparams.grammar       = json_value(data, "grammar",           default_sparams.grammar);
-        }
-
-        if (slot.params.cache_prompt && slot.ga_n != 1) {
-            slot.params.cache_prompt = false;
-            SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
+            slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
         }
 
         if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -1131,12 +1102,13 @@ struct server_context {
         }
 
         // if context shift is disabled, we stop when it reaches the context limit
-        if (slot.n_decoded >= slot.n_ctx) {
+        if (slot.n_past >= slot.n_ctx) {
             slot.truncated      = true;
             slot.stopped_limit  = true;
             slot.has_next_token = false;
 
-            SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
+            SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
+                    slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
         }
 
         if (llama_token_is_eog(model, result.tok)) {
@@ -1148,13 +1120,13 @@ struct server_context {
 
         const auto n_ctx_train = llama_n_ctx_train(model);
 
-        if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
+        if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
             slot.truncated      = true;
             slot.stopped_limit  = true;
             slot.has_next_token = false; // stop prediction
 
             SLT_WRN(slot,
-                    "n_predict (%d) is not set and self-context extend is disabled. "
+                    "n_predict (%d) is set for infinite generation. "
                     "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
                     slot.params.n_predict, n_ctx_train);
         }
@@ -1826,38 +1798,36 @@ struct server_context {
         // apply context-shift if needed
         // TODO: simplify and improve
         for (server_slot & slot : slots) {
-            if (slot.ga_n == 1) {
-                if (slot.is_processing() && slot.n_past >= slot.n_ctx - 1) {
-                    if (!params.ctx_shift) {
-                        // this check is redundant (for good)
-                        // we should never get here, because generation should already stopped in process_token()
-                        slot.release();
-                        send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
-                        continue;
-                    }
-
-                    // Shift context
-                    const int n_keep    = slot.params.n_keep + add_bos_token;
-                    const int n_left    = slot.n_past - n_keep;
-                    const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
+            if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
+                if (!params.ctx_shift) {
+                    // this check is redundant (for good)
+                    // we should never get here, because generation should already stopped in process_token()
+                    slot.release();
+                    send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
+                    continue;
+                }
 
-                    SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
+                // Shift context
+                const int n_keep    = slot.params.n_keep + add_bos_token;
+                const int n_left    = slot.n_past - n_keep;
+                const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
 
-                    llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep            , n_keep + n_discard);
-                    llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past,        -n_discard);
+                SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
 
-                    if (slot.params.cache_prompt) {
-                        for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
-                            slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
-                        }
+                llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep            , n_keep + n_discard);
+                llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past,        -n_discard);
 
-                        slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
+                if (slot.params.cache_prompt) {
+                    for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
+                        slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
                     }
 
-                    slot.n_past -= n_discard;
-
-                    slot.truncated = true;
+                    slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
                 }
+
+                slot.n_past -= n_discard;
+
+                slot.truncated = true;
             }
         }
 
@@ -1872,9 +1842,7 @@ struct server_context {
 
             slot.i_batch = batch.n_tokens;
 
-            const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
-
-            common_batch_add(batch, slot.sampled, slot_npast, { slot.id + 1 }, true);
+            common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);
 
             slot.n_past += 1;
 
@@ -1993,6 +1961,8 @@ struct server_context {
                         } else {
                             if (!params.ctx_shift) {
                                 // if context shift is disabled, we make sure prompt size is smaller than KV size
+                                // TODO: there should be a separate parameter that control prompt truncation
+                                //       context shift should be applied only during the generation phase
                                 if (slot.n_prompt_tokens >= slot.n_ctx) {
                                     slot.release();
                                     send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
@@ -2005,7 +1975,7 @@ struct server_context {
                             slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
 
                             // if input prompt is too big, truncate it (if group attention self-extend is disabled)
-                            if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
+                            if (slot.n_prompt_tokens >= slot.n_ctx) {
                                 const int n_left = slot.n_ctx - slot.params.n_keep;
 
                                 const int n_block_size = n_left / 2;
@@ -2032,12 +2002,7 @@ struct server_context {
 
                             common_sampler_reset(slot.smpl);
 
-                            if (!slot.params.cache_prompt) {
-                                slot.n_past_se = 0;
-                                slot.ga_i      = 0;
-                            } else {
-                                GGML_ASSERT(slot.ga_n == 1);
-
+                            if (slot.params.cache_prompt) {
                                 // reuse any previously computed tokens that are common with the new prompt
                                 slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
 
@@ -2053,9 +2018,6 @@ struct server_context {
                             SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
 
                             slot.n_past--;
-                            if (slot.ga_i > 0) {
-                                slot.n_past_se--;
-                            }
                         }
 
                         slot.n_prompt_tokens_processed = 0;
@@ -2081,52 +2043,31 @@ struct server_context {
                     }
 
                     // keep only the common part
-                    int p0 = slot.n_past;
-
-                    if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
+                    if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
                         // could not partially delete (likely using a non-Transformer model)
                         llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
 
-                        p0 = 0;
-
                         // there is no common part left
                         slot.n_past = 0;
-                        slot.n_past_se = 0;
-                        slot.ga_i = 0;
 
                         common_sampler_reset(slot.smpl);
                     }
 
+                    SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
+
                     // remove the non-common part from the cache
                     slot.cache_tokens.resize(slot.n_past);
 
-                    SLT_INF(slot, "kv cache rm [%d, end)\n", p0);
-
-                    int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
-
-                    int32_t ga_i = slot.ga_i;
-                    int32_t ga_n = slot.ga_n;
-                    int32_t ga_w = slot.ga_w;
-
                     // add prompt tokens for processing in the current batch
-                    // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
-                    for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
-                        if (slot.ga_n != 1) {
-                            while (slot_npast >= ga_i + ga_w) {
-                                const int bd = (ga_w/ga_n)*(ga_n - 1);
-                                slot_npast -= bd;
-                                ga_i += ga_w/ga_n;
-                            }
-                        }
-
-                        common_batch_add(batch, prompt_tokens[slot.n_past], slot_npast, { slot.id + 1 }, false);
+                    while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
+                        common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
 
                         if (slot.params.cache_prompt) {
                             slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
                         }
 
                         slot.n_prompt_tokens_processed++;
-                        slot_npast++;
+                        slot.n_past++;
                     }
 
                     SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
@@ -2167,34 +2108,6 @@ struct server_context {
         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
 
-            for (auto & slot : slots) {
-                if (slot.ga_n != 1) {
-                    // context extension via Self-Extend
-                    // TODO: simplify and/or abstract this
-                    while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
-                        const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
-                        const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
-                        const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
-
-                        SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
-                        SLT_DBG(slot, "div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
-                        SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
-
-                        llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
-                        llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
-                        llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
-
-                        slot.n_past_se -= bd;
-
-                        slot.ga_i += slot.ga_w / slot.ga_n;
-
-                        SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
-                    }
-
-                    slot.n_past_se += n_tokens;
-                }
-            }
-
             llama_batch batch_view = {
                 n_tokens,
                 batch.token    + i,
diff --git a/examples/server/tests/features/ctx_shift.feature b/examples/server/tests/features/ctx_shift.feature
index ba3afcf060506..ae6c6b01b0221 100644
--- a/examples/server/tests/features/ctx_shift.feature
+++ b/examples/server/tests/features/ctx_shift.feature
@@ -13,6 +13,10 @@ Feature: llama.cpp server
     And   32 as batch size
     And   2 slots
 
+    # the prompt is 301 tokens
+    # the slot context is 256/2 = 128 tokens
+    # the prompt is truncated to keep the last 109 tokens
+    # 64 tokens are generated thanks to shifting the context when it gets full
   Scenario: Inference with context shift
     And   64 server max tokens to predict
     Then  the server is starting