Skip to content

Conversation

@gabe-l-hart
Copy link
Collaborator

DRAFT STATUS

This PR will remain in Draft until the items in the discussion section are resolved.

Description

This PR is a draft implementation of the Structured Statespace Duality described in the original mamba2 paper which reframes the SSM_SCAN op as a pseudo-attention operation. The paper describes it in great detail, but the short version is that when performing a multi-token scan, the recurrent formulation of SSM_SCAN is inefficient because it cannot parallelize over the sequence dimension the way an attention calculation can. With the SSD formulation, the logical attention matrix is decomposed into chunks and the state is updated at the chunk boundaries, allowing prefill to "jump" by the size of the chunk rather than proceed with tokens one-at-a-time.

Reference Links

Changes

  • Introduce new primitive operations in ggml:

    • ggml_cumsum / ggml_cumsum_0: Perform a cumulative sum along a give dimension
    • ggml_tri_dims / ggml_tri / ggml_tri_keep: Apply a triangular mask to the given matrix
    • ggml_softplus: Perform the unary softplus operation
  • Implement an alternate path through llm_graph_context_mamba::build_mamba2_layer when a multi-token update is detected

    • This path is the core of the SSD implementation and avoids calling SSM_SCAN in favor of the chunked pseudo-attention formulation

Discussion

There are a number of outstanding discussion points on this work that need to be resolved before moving it forward:

  1. Performance: Currently, this implementation appears to be significantly slower than simply using SSM_SCAN which roundly defeats the purpose of the change! I suspect that the performance issues are due to the number of ggml_permute / ggml_cont ops that are added to the graph, but could use assistance figuring out how to eliminate them or identifying other sources of slowness.
  2. To chunk or not to chunk: In this PR I have sub-ubatch chunking implemented. I had it mostly working before the corresponding discussion on Qwen3Next. The inter-chunk update would be needed anyway, so I didn't strip it out, but it would be fairly trivial to do so and might offer some performance improvements.
  3. Handling of repeat_interleave: Similar to the issue that came up when initially implementing NemotronH support, I believe that ggml_repeat behaves differently than mx.repeat, resulting in incorrect results for models with n_groups > 1 (tested with NemotronH).

Testing

I've tested this locally with various members of the Granite 4 family and with nvidia/NVIDIA-Nemotron-Nano-9B-v2. For the Granite 4 models with n_groups == 1, I get nearly identical results to running with purely SSM_SCAN, but NemotronH still struggles due to repeat_interleave issues (see above). I'll flesh out more testing results once we've worked through some of the above issues.

cc @compilade since I know this has been on your TODO list since the original mamba2 implementation.

ggerganov and others added 30 commits October 9, 2025 19:40
* gg/metal-mul-mat-fixes:
metal : fix mul-mm condition + fix mul-mv permuted kernels
Cherry-picked and edited from 7ec2df6

The original commit contained the DELTA_NET op as well which I've removed
in this cherry-picked version.

Co-Authored-By: Piotr Wilkin <piotr.wilkin@syndatis.com>

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
…sors

Branch: Mamba2Perf

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This should be using simd operations for better parallelism, but that will
come next.

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
* origin/master: (32 commits)
metal : FA support F32 K and V and head size = 32 (ggml-org#16531)
graph : support cacheless embeddings with FA and iSWA (ggml-org#16528)
opencl: fix build targeting CL 2 (ggml-org#16554)
CUDA: fix numerical issues in tile FA kernel (ggml-org#16540)
ggml : fix build broken with -march=armv9-a on MacOS (ggml-org#16520)
CANN: fix CPU memory leak in CANN backend (ggml-org#16549)
fix: add remark plugin to render raw HTML as literal text (ggml-org#16505)
metal: add support for opt_step_sgd (ggml-org#16539)
ggml : fix scalar path for computing norm (ggml-org#16558)
CANN: Update several operators to support FP16 data format (ggml-org#16251)
metal : add opt_step_adamw and op_sum (ggml-org#16529)
webui: remove client-side context pre-check and rely on backend for limits (ggml-org#16506)
[SYCL] fix UT fault cases: count-equal, argsort, pad OPs (ggml-org#16521)
ci : add Vulkan on Ubuntu with default packages build (ggml-org#16532)
common : handle unicode during partial json parsing (ggml-org#16526)
common : update presets (ggml-org#16504)
ggml : Fix FP16 ELU positive branch (ggml-org#16519)
hparams : add check for layer index in is_recurrent (ggml-org#16511)
ggml: Correct SVE implementation in ggml_vec_dot_f16_unroll (ggml-org#16518)
CUDA: faster tile FA, add oob checks, more HSs (ggml-org#16492)
...
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2Perf

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
* origin/master:
Add server-driven parameter defaults and syncing (ggml-org#16515)
metal: optimise `GGML_OP_SUM` (ggml-org#16559)
server : fix img token logs (ggml-org#16595)
llama-quant: add support for mmproj (ggml-org#16592)
CUDA: Changing the CUDA scheduling strategy to spin (ggml-org#16585)
server : fix mtmd checkpoints (ggml-org#16591)
metal : avoid using Metal's gpuAddress property (ggml-org#16576)
vulkan: Add ACC_TYPE_VEC2 implementation (ggml-org#16203)
CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (ggml-org#16577)
vulkan: Support FA with K/V in F32 (ggml-org#16543)
vulkan: Improve build time for MSVC (ggml-org#16545)
CUDA: enable FA for FP32 KV cache (ggml-org#16546)
CUDA: use fastdiv + ggml_cuda_mad for mmvf (ggml-org#16557)
CUDA: add fp kernel for larger batch size MoE (ggml-org#16512)
cuda : remove legacy copy-op pointer indirection code (ggml-org#16485)
server : dynamic token limit for prompt cache (ggml-org#16560)
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2Perf

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@github-actions github-actions bot added examples ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Nov 3, 2025
@pwilkin
Copy link
Collaborator

pwilkin commented Nov 3, 2025

Yeah, I had an issue with repeat_interleave too. Technically, repeat_interleave is equivalent to permute + repeat, but of course it introduces additional operations.

@pwilkin
Copy link
Collaborator

pwilkin commented Nov 3, 2025

Regarding the chunking: won't this explode the graph a lot?

In case of Delta Net attention, since you have to use triangular solve there, you don't want the chunk size over 64 or performance drops drastically. But that means that you're going to go up to 8 chunks for a typical ubatch size of 512.

The graph for Qwen3 Next already has 9000 nodes. I'm a bit afraid of doing chunking this way (and I know @ggerganov had strong objections too).

@gabe-l-hart
Copy link
Collaborator Author

chunking: won't this explode the graph a lot?

Yep, it sure will. I also suspect this as one of the reasons this is slower currently. I don't think SSD has the same need for chunking based on computational complexity, so I think it's mostly there for memory overhead management.

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This is experimantal!

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This is used to zero-out the state in build_rs, so it's required to support
F16 cache states for recurrent models. The bias route does not get hit in
that case, but would need to be implemented if used elsewhere.

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This will be needed until F16 support is added for SSM_SCAN

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@github-actions github-actions bot added the model Model specific label Nov 5, 2025
@gabe-l-hart
Copy link
Collaborator Author

I've pulled the changes to llama-gguf (#17025), llama-eval-callback (#17028), and test-backend-ops (#17029) into separate PRs and will plan to update this PR once they're reviewed.

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Unlike Qwen3Next, we don't hit big commplexity scaling issues here, so
removing all of the batching gives a big reduction in complexity and a big
boost to performance!

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
* origin/master: (21 commits)
vulkan: Fix GGML_VULKAN_CHECK_RESULTS to better handle fusion (ggml-org#16919)
examples(gguf): GGUF example outputs (ggml-org#17025)
mtmd: allow QwenVL to process larger image by default (ggml-org#17020)
server : do not default to multiple slots with speculative decoding (ggml-org#17017)
mtmd: improve struct initialization (ggml-org#16981)
docs: Clarify the endpoint that webui uses (ggml-org#17001)
model : add openPangu-Embedded (ggml-org#16941)
ggml webgpu: minor set rows optimization (ggml-org#16810)
sync : ggml
ggml : fix conv2d_dw SVE path (ggml/1380)
CUDA: update ops.md (ggml-org#17005)
opencl: update doc (ggml-org#17011)
refactor: replace sprintf with snprintf for safer string handling in dump functions (ggml-org#16913)
vulkan: remove the need for the dryrun (ggml-org#16826)
server : do context shift only while generating (ggml-org#17000)
readme : update hot topics (ggml-org#17002)
ggml-cpu : bicubic interpolation (ggml-org#16891)
ci : apply model label to models (ggml-org#16994)
chore : fix models indent after refactor (ggml-org#16992)
Fix garbled output with REPACK at high thread counts (ggml-org#16956)
...
@gabe-l-hart
Copy link
Collaborator Author

I've been further experimenting with a few tweaks to get more performance out of this.

  • Add F16 and BF16 support to SSM_CONV so that we can reduce the precision of ssm_conv1d.weight when converting to GGUF
    • This one seems to have a small but noticeable perf boost both with and without the SSD version of SSM_SCAN
  • Support F16 cache types for both r and s recurrent caches
    • This requires a ggml_cast before calling the single-token ggml_ssm_scan and a cast back since the output is the merge of x and next_state
    • For the SSD formulation, doing the cast only at the end gives a noticeable performance improvement
  • Remove sub-ubatch batching
    • This one gives a fairly significant performance improvement getting the SSD version close to raw SSM_CONV.

Local notes using variants of the following command:

./bin/llama-batched-bench -m ~/models/ibm-granite/granite-4.0-h-1b/granite-4.0-h-1B-BF16-exp.gguf -c 2048 -b 2048 -ub 512 -npp 128,256 -ntg 128 -npl 1,2,4 -ngl 99

NOTE: granite-4.0-h-1B-BF16-exp.gguf includes my changes to allow ssm_conv1d.weight to be quantized as low as F16 during llama-quantize, so in this model, it's BF16 instead of FP32.


Baseline SSM_SCAN

F16 cache w/ F32 conv

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.110 1167.11 2.777 46.09 2.887 88.67
128 128 2 512 0.193 1328.97 3.187 80.34 3.379 151.51
128 128 4 1024 0.369 1386.92 4.149 123.41 4.518 226.65
256 128 1 384 0.191 1343.18 2.781 46.03 2.971 129.23
256 128 2 768 0.365 1402.74 3.149 81.29 3.514 218.54
256 128 4 1536 0.737 1388.79 4.150 123.37 4.887 314.27

F32 cache / F32 conv

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.109 1173.59 2.767 46.26 2.876 89.01
128 128 2 512 0.192 1331.47 3.193 80.18 3.385 151.25
128 128 4 1024 0.370 1384.20 4.204 121.78 4.574 223.87
256 128 1 384 0.190 1345.33 2.734 46.81 2.925 131.30
256 128 2 768 0.365 1401.15 3.206 79.85 3.572 215.03
256 128 4 1536 0.737 1389.54 4.245 120.60 4.982 308.30

F16 cache w/ BF16 conv

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.109 1176.62 2.786 45.94 2.895 88.42
128 128 2 512 0.191 1342.10 3.106 82.41 3.297 155.28
128 128 4 1024 0.364 1406.48 4.080 125.48 4.444 230.40
256 128 1 384 0.190 1349.40 2.796 45.77 2.986 128.60
256 128 2 768 0.363 1412.04 3.198 80.06 3.560 215.71
256 128 4 1536 0.731 1401.63 4.093 125.09 4.824 318.43

With SSD

F16 cache w/ F32 conv

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.134 957.38 2.852 44.89 2.985 85.76
128 128 2 512 0.238 1073.98 3.241 79.00 3.479 147.17
128 128 4 1024 0.449 1141.23 4.149 123.41 4.598 222.73
256 128 1 384 0.256 999.61 2.820 45.40 3.076 124.85
256 128 2 768 0.512 1000.77 3.223 79.43 3.735 205.64
256 128 4 1536 0.897 1141.60 4.130 123.97 5.027 305.54

F32 cache / F32 conv

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.132 972.95 2.843 45.03 2.974 86.07
128 128 2 512 0.237 1079.11 3.174 80.65 3.411 150.09
128 128 4 1024 0.448 1141.80 4.211 121.58 4.660 219.76
256 128 1 384 0.255 1004.52 2.807 45.59 3.062 125.40
256 128 2 768 0.511 1001.48 3.202 79.94 3.714 206.81
256 128 4 1536 0.897 1141.03 4.201 121.87 5.099 301.26

F16 cache w/ BF16 conv

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.130 985.88 2.840 45.07 2.970 86.21
128 128 2 512 0.236 1084.81 3.195 80.12 3.431 149.22
128 128 4 1024 0.443 1155.40 4.150 123.37 4.593 222.94
256 128 1 384 0.252 1015.24 2.820 45.39 3.072 125.00
256 128 2 768 0.507 1009.49 3.219 79.52 3.726 206.10
256 128 4 1536 0.891 1149.51 4.101 124.84 4.992 307.69

F16 cache w/ BF16 conv and SSD cast at end

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.125 1027.02 2.810 45.56 2.934 87.25
128 128 2 512 0.228 1122.36 3.191 80.23 3.419 149.75
128 128 4 1024 0.492 1039.70 4.145 123.53 4.637 220.82
256 128 1 384 0.247 1037.08 2.814 45.49 3.061 125.46
256 128 2 768 0.502 1020.66 3.232 79.22 3.733 205.71
256 128 4 1536 0.981 1043.44 4.089 125.22 5.070 302.96

F16 cache w/ BF16 conv, SSD cast at end, and no sub-ubatch batching

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.117 1093.26 2.754 46.48 2.871 89.17
128 128 2 512 0.211 1213.76 3.251 78.75 3.462 147.90
128 128 4 1024 0.412 1243.73 4.161 123.04 4.573 223.92
256 128 1 384 0.230 1113.91 2.806 45.61 3.036 126.48
256 128 2 768 0.467 1097.32 3.184 80.41 3.650 210.39
256 128 4 1536 0.817 1252.68 4.110 124.59 4.927 311.75

@ggerganov
Copy link
Member

To chunk or not to chunk

Probably you have to use large ubatch and do some chunking in order to get some benefits from the SSD. But I don't have a good estimate about what the optimal sizes would be.

At the default ubatch of 512, you can do the following experiment on master:

make -j && ./bin/llama-bench -m ../models/granite-4-h-tiny/ggml-model-q8_0.gguf -fa 1 -t 1 -p 2048 -ub 512 -n 0
model size params backend threads fa test t/s
granitehybrid 1B Q8_0 6.88 GiB 6.94 B Metal,BLAS 1 1 pp2048 2116.14 ± 16.32
build: 230d116 (6962)

Now make the ssm scan a noop and run the test again:

diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 424c400f2..7881e63e0 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -2129,6 +2129,7 @@ kernel void kernel_ssm_scan_f32(
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgptg[[simdgroups_per_threadgroup]],
         uint3    tgpg[[threadgroups_per_grid]]) {
+    return;
     constexpr short NW = N_SIMDWIDTH;
 
     shared[tpitg.x] = 0.0f;
model size params backend threads fa test t/s
granitehybrid 1B Q8_0 6.88 GiB 6.94 B Metal,BLAS 1 1 pp2048 2699.35 ± 1.87
build: 230d116 (6962)

This is the upper bound that you would get at this ubatch size. I.e. any SSD implementation will not be faster than this.

Increasing the ubatch size increases the gap, so it gives more room for a good SSD implementation to outperform the ssm scan.

In any case, first step seems to be to reduce the amount of ops, permutations, conts in the SSD branch as much as possible.

@gabe-l-hart
Copy link
Collaborator Author

Now make the ssm scan a noop and run the test again:

🤦 I feel really silly for not figuring this trick out. I've been snipping out chunks of the graph and trying to coerce the input/output tensors to the same shape to simulate this upper bound part!

Probably you have to use large ubatch and do some chunking in order to get some benefits from the SSD. But I don't have a good estimate about what the optimal sizes would be.

This makes a lot of sense! I'll do some large-ubatch experiments to see if the current code may already be at a cross-over point where SSD can start offering better performance with larger batches. The speed advantages are very much supposed to be primarily felt at longer context which likely also means longer ubatches.

@gabe-l-hart
Copy link
Collaborator Author

It looks like the current code is not there yet and in fact starts to degrade further when ubatch size gets bigger (eg 1024) on my machine at least.

@ggerganov
Copy link
Member

I guess it's expected since it does not have the chunking logic.

@pwilkin
Copy link
Collaborator

pwilkin commented Nov 6, 2025

@gabe-l-hart you can look at the discussion in the Qwen3 Next thread, but basically, if you ever use the recurrent update logic aka solve_triangular, you must chunk because, unlike almost all the operations used so far, solve_triangular is of O(n^3) complexity. That means that performance will degrade rapidly with bigger chunks. For the Qwen3 Next model, I was unable to compute a single chunk in reasonable time with ubatch size 512, while with ubatch size 64 (equal to the chunk size for the reference implementation) it was reasonably fast.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants