-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Mamba2 SSD #16982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Mamba2 SSD #16982
Conversation
This reverts commit 00f115f.
* gg/metal-mul-mat-fixes: metal : fix mul-mm condition + fix mul-mv permuted kernels
Cherry-picked and edited from 7ec2df6 The original commit contained the DELTA_NET op as well which I've removed in this cherry-picked version. Co-Authored-By: Piotr Wilkin <piotr.wilkin@syndatis.com> Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
…sors Branch: Mamba2Perf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This should be using simd operations for better parallelism, but that will come next. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
* origin/master: (32 commits) metal : FA support F32 K and V and head size = 32 (ggml-org#16531) graph : support cacheless embeddings with FA and iSWA (ggml-org#16528) opencl: fix build targeting CL 2 (ggml-org#16554) CUDA: fix numerical issues in tile FA kernel (ggml-org#16540) ggml : fix build broken with -march=armv9-a on MacOS (ggml-org#16520) CANN: fix CPU memory leak in CANN backend (ggml-org#16549) fix: add remark plugin to render raw HTML as literal text (ggml-org#16505) metal: add support for opt_step_sgd (ggml-org#16539) ggml : fix scalar path for computing norm (ggml-org#16558) CANN: Update several operators to support FP16 data format (ggml-org#16251) metal : add opt_step_adamw and op_sum (ggml-org#16529) webui: remove client-side context pre-check and rely on backend for limits (ggml-org#16506) [SYCL] fix UT fault cases: count-equal, argsort, pad OPs (ggml-org#16521) ci : add Vulkan on Ubuntu with default packages build (ggml-org#16532) common : handle unicode during partial json parsing (ggml-org#16526) common : update presets (ggml-org#16504) ggml : Fix FP16 ELU positive branch (ggml-org#16519) hparams : add check for layer index in is_recurrent (ggml-org#16511) ggml: Correct SVE implementation in ggml_vec_dot_f16_unroll (ggml-org#16518) CUDA: faster tile FA, add oob checks, more HSs (ggml-org#16492) ...
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2Perf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
* origin/master: Add server-driven parameter defaults and syncing (ggml-org#16515) metal: optimise `GGML_OP_SUM` (ggml-org#16559) server : fix img token logs (ggml-org#16595) llama-quant: add support for mmproj (ggml-org#16592) CUDA: Changing the CUDA scheduling strategy to spin (ggml-org#16585) server : fix mtmd checkpoints (ggml-org#16591) metal : avoid using Metal's gpuAddress property (ggml-org#16576) vulkan: Add ACC_TYPE_VEC2 implementation (ggml-org#16203) CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (ggml-org#16577) vulkan: Support FA with K/V in F32 (ggml-org#16543) vulkan: Improve build time for MSVC (ggml-org#16545) CUDA: enable FA for FP32 KV cache (ggml-org#16546) CUDA: use fastdiv + ggml_cuda_mad for mmvf (ggml-org#16557) CUDA: add fp kernel for larger batch size MoE (ggml-org#16512) cuda : remove legacy copy-op pointer indirection code (ggml-org#16485) server : dynamic token limit for prompt cache (ggml-org#16560)
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2Perf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
|
Yeah, I had an issue with |
|
Regarding the chunking: won't this explode the graph a lot? In case of Delta Net attention, since you have to use triangular solve there, you don't want the chunk size over 64 or performance drops drastically. But that means that you're going to go up to 8 chunks for a typical ubatch size of 512. The graph for Qwen3 Next already has 9000 nodes. I'm a bit afraid of doing chunking this way (and I know @ggerganov had strong objections too). |
Yep, it sure will. I also suspect this as one of the reasons this is slower currently. I don't think SSD has the same need for chunking based on computational complexity, so I think it's mostly there for memory overhead management. |
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This is experimantal! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This is used to zero-out the state in build_rs, so it's required to support F16 cache states for recurrent models. The bias route does not get hit in that case, but would need to be implemented if used elsewhere. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This will be needed until F16 support is added for SSM_SCAN Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Unlike Qwen3Next, we don't hit big commplexity scaling issues here, so removing all of the batching gives a big reduction in complexity and a big boost to performance! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
* origin/master: (21 commits) vulkan: Fix GGML_VULKAN_CHECK_RESULTS to better handle fusion (ggml-org#16919) examples(gguf): GGUF example outputs (ggml-org#17025) mtmd: allow QwenVL to process larger image by default (ggml-org#17020) server : do not default to multiple slots with speculative decoding (ggml-org#17017) mtmd: improve struct initialization (ggml-org#16981) docs: Clarify the endpoint that webui uses (ggml-org#17001) model : add openPangu-Embedded (ggml-org#16941) ggml webgpu: minor set rows optimization (ggml-org#16810) sync : ggml ggml : fix conv2d_dw SVE path (ggml/1380) CUDA: update ops.md (ggml-org#17005) opencl: update doc (ggml-org#17011) refactor: replace sprintf with snprintf for safer string handling in dump functions (ggml-org#16913) vulkan: remove the need for the dryrun (ggml-org#16826) server : do context shift only while generating (ggml-org#17000) readme : update hot topics (ggml-org#17002) ggml-cpu : bicubic interpolation (ggml-org#16891) ci : apply model label to models (ggml-org#16994) chore : fix models indent after refactor (ggml-org#16992) Fix garbled output with REPACK at high thread counts (ggml-org#16956) ...
|
I've been further experimenting with a few tweaks to get more performance out of this.
Local notes using variants of the following command: ./bin/llama-batched-bench -m ~/models/ibm-granite/granite-4.0-h-1b/granite-4.0-h-1B-BF16-exp.gguf -c 2048 -b 2048 -ub 512 -npp 128,256 -ntg 128 -npl 1,2,4 -ngl 99NOTE: Baseline SSM_SCANF16 cache w/ F32 conv
F32 cache / F32 conv
F16 cache w/ BF16 conv
With SSDF16 cache w/ F32 conv
F32 cache / F32 conv
F16 cache w/ BF16 conv
F16 cache w/ BF16 conv and SSD cast at end
F16 cache w/ BF16 conv, SSD cast at end, and no sub-ubatch batching
|
Probably you have to use large ubatch and do some chunking in order to get some benefits from the SSD. But I don't have a good estimate about what the optimal sizes would be. At the default ubatch of 512, you can do the following experiment on make -j && ./bin/llama-bench -m ../models/granite-4-h-tiny/ggml-model-q8_0.gguf -fa 1 -t 1 -p 2048 -ub 512 -n 0
Now make the ssm scan a noop and run the test again: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 424c400f2..7881e63e0 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -2129,6 +2129,7 @@ kernel void kernel_ssm_scan_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
+ return;
constexpr short NW = N_SIMDWIDTH;
shared[tpitg.x] = 0.0f;
This is the upper bound that you would get at this ubatch size. I.e. any SSD implementation will not be faster than this. Increasing the ubatch size increases the gap, so it gives more room for a good SSD implementation to outperform the ssm scan. In any case, first step seems to be to reduce the amount of ops, permutations, conts in the SSD branch as much as possible. |
🤦 I feel really silly for not figuring this trick out. I've been snipping out chunks of the graph and trying to coerce the input/output tensors to the same shape to simulate this upper bound part!
This makes a lot of sense! I'll do some large-ubatch experiments to see if the current code may already be at a cross-over point where SSD can start offering better performance with larger batches. The speed advantages are very much supposed to be primarily felt at longer context which likely also means longer ubatches. |
|
It looks like the current code is not there yet and in fact starts to degrade further when |
|
I guess it's expected since it does not have the chunking logic. |
|
@gabe-l-hart you can look at the discussion in the Qwen3 Next thread, but basically, if you ever use the recurrent update logic aka |
DRAFT STATUS
This PR will remain in
Draftuntil the items in the discussion section are resolved.Description
This PR is a draft implementation of the Structured Statespace Duality described in the original mamba2 paper which reframes the
SSM_SCANop as a pseudo-attention operation. The paper describes it in great detail, but the short version is that when performing a multi-token scan, the recurrent formulation ofSSM_SCANis inefficient because it cannot parallelize over the sequence dimension the way an attention calculation can. With the SSD formulation, the logical attention matrix is decomposed into chunks and the state is updated at the chunk boundaries, allowing prefill to "jump" by the size of the chunk rather than proceed with tokens one-at-a-time.Reference Links
mlx-lm: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/ssm.pyChanges
Introduce new primitive operations in
ggml:ggml_cumsum/ggml_cumsum_0: Perform a cumulative sum along a give dimensionggml_tri_dims/ggml_tri/ggml_tri_keep: Apply a triangular mask to the given matrixggml_tri_dimsggml_softplus: Perform the unarysoftplusoperationImplement an alternate path through
llm_graph_context_mamba::build_mamba2_layerwhen a multi-token update is detectedSSM_SCANin favor of the chunked pseudo-attention formulationDiscussion
There are a number of outstanding discussion points on this work that need to be resolved before moving it forward:
SSM_SCANwhich roundly defeats the purpose of the change! I suspect that the performance issues are due to the number ofggml_permute/ggml_contops that are added to the graph, but could use assistance figuring out how to eliminate them or identifying other sources of slowness.ubatchchunking implemented. I had it mostly working before the corresponding discussion on Qwen3Next. The inter-chunk update would be needed anyway, so I didn't strip it out, but it would be fairly trivial to do so and might offer some performance improvements.repeat_interleave: Similar to the issue that came up when initially implementingNemotronHsupport, I believe thatggml_repeatbehaves differently thanmx.repeat, resulting in incorrect results for models withn_groups > 1(tested withNemotronH).Testing
I've tested this locally with various members of the Granite 4 family and with
nvidia/NVIDIA-Nemotron-Nano-9B-v2. For the Granite 4 models withn_groups == 1, I get nearly identical results to running with purelySSM_SCAN, butNemotronHstill struggles due torepeat_interleaveissues (see above). I'll flesh out more testing results once we've worked through some of the above issues.cc @compilade since I know this has been on your TODO list since the original
mamba2implementation.