@@ -67,25 +67,36 @@ def _mamba_chunk_scan_combined_fwd(x,
6767 D = D .contiguous ()
6868 if initial_states is not None :
6969 assert initial_states .shape == (batch , nheads , headdim , dstate )
70- # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
71- # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
72- # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
73- # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
70+
71+ # This function executes 5 sub-functions for computing mamba
72+ # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
73+ # which has a minimal implementation to understand the below operations
74+ # - as explained by the blog, mamba is a special case of causal attention
75+ # - the idea is to chunk the attention matrix and compute each
76+ # submatrix separately using different optimizations.
77+ # - see the blog and paper for a visualization of the submatrices
78+ # which we refer to in the comments below
79+
80+ # 1. Compute chunked cumsum of A * dt
81+ # - here dt may go through a softplus activation
7482 dA_cumsum , dt = _chunk_cumsum_fwd (dt ,
7583 A ,
7684 chunk_size ,
7785 dt_bias = dt_bias ,
7886 dt_softplus = dt_softplus ,
7987 dt_limit = dt_limit )
88+
89+ # 2. Compute the state for each intra-chunk
90+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
8091 states = _chunk_state_fwd (B ,
8192 x ,
8293 dt ,
8394 dA_cumsum ,
8495 seq_idx = seq_idx ,
8596 states_in_fp32 = True )
86- # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
87- # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
88- # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True )
97+
98+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
99+ # (middle term of factorization of off-diag blocks; A terms )
89100 states , final_states = _state_passing_fwd (
90101 rearrange (states , "... p n -> ... (p n)" ),
91102 dA_cumsum [:, :, :, - 1 ],
@@ -96,13 +107,16 @@ def _mamba_chunk_scan_combined_fwd(x,
96107 out_dtype = C .dtype )
97108 states , final_states = (rearrange (t , "... (p n) -> ... p n" , n = dstate )
98109 for t in [states , final_states ])
99- # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
100- # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
110+
111+ # 4. Compute batched matrix multiply for C_j^T B_i terms
101112 CB = _bmm_chunk_fwd (C ,
102113 B ,
103114 chunk_size ,
104115 seq_idx = seq_idx ,
105116 output_dtype = torch .float32 )
117+
118+ # 5. Scan and compute the diagonal blocks, taking into
119+ # account past causal states.
106120 out , out_x = _chunk_scan_fwd (CB ,
107121 x ,
108122 dt ,
0 commit comments