Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the parameter dt(delta) and its initialization #5

Open
zzzack66 opened this issue Nov 17, 2024 · 7 comments
Open

Question about the parameter dt(delta) and its initialization #5

zzzack66 opened this issue Nov 17, 2024 · 7 comments

Comments

@zzzack66
Copy link

Thanks for your awesome work.
Looking through the code of Mamba and Mamba2. I'm really confused about the dimension of the parameter dt. I understand that delta is used to discretize A and B in SSM. However, I don't understand why dt is first projected into (b,l,dt_rank) and then projected into (b,l,d_inner) as in algorithm 2 of the Mamba paper. As in the code:

self.x_proj = nn.Linear( self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs )

What is the purpose of 'dt_rank'?

As for the initialization of dt, I found the code in your project:
`# Initialize log dt bias
dt = torch.exp(
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)

    dt = torch.clamp(dt, min=dt_init_floor)
    # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
    inv_dt = dt + torch.log(-torch.expm1(-dt))
    self.dt_bias = nn.Parameter(inv_dt)
    # Just to be explicit. Without this, we already don't put wd on dt_bias because of the check
    # name.endswith("bias") in param_grouping.py
    self.dt_bias._no_weight_decay = True`

I don't understand the process of this initialization. I found some explanations in the Mamba paper:
image

I found a similar problem here. But I still don't understand the time step of the discretization process. Could you help me with this problem? Could you give me papers or tutorials to check? I'm looking forward to your reply.

@Hprairie
Copy link
Owner

For the first question. The low rank can just be thought of parameterizing delta as a low rank matrix projection, rather than a full linear layer. The dimension of delta in Mamba2, will essentially be a scaler for each head. Meaning that if we let $i$ indicate head, then for a given timestep $t$ we will have $\Delta_{t,i} \in \mathbb{R}$. Now since we are parameterizing the projection matrix in a low rank fashion, then if dt_rank=1 we will first take our input $u_t$ and project it into a scalar $\hat \Delta_{t} = w_{\Delta}^\top u_t$, where $u_t, w_\Delta \in \mathbb{R}^n$ . We will then now up project it to the number of heads in the following way $\Delta_{t,i} = \hat w_i \hat \Delta_{t}$. I hope that makes sense.

Having it be low rank is just parameter efficient. Meaning that rather than having a matrix $w \in \mathbb{R}^{i \times n}$, where again $i$ is the number of heads and $n$ is the input dimension, we parameterize it by two low rank matrices $w = w_1 w_2^\top$ where $w_1 \in \mathbb{R}^{i \times k}$ and $w_2 \in \mathbb{R}^{k \times n}$, with k = dt_rank.

The initialization of delta in the way above is just convenient for training dynamics. As $\Delta$ essential controls the amount of information entering our hidden state, we want it initializated in a "good way", meaning that it overriding the hidden state at every timestep at the beginning of training or that it isn't adding no information to the hidden state at every time step at the very beginning of training.

Finally, if you are wanting to understand what discretization actually is, think back to calculus when you were learning about the derivative. We can use euler's approximation for discretization as an example:

$$\begin{align*} \frac{x(t + \Delta) - x(t)}{\Delta} &= A x(t) + B u(t) \\ x(t + \Delta) &= \Delta A x(t) + \Delta B u(t) + x(t) \\ x_{t+1} &= (I + \Delta A) x_t + \Delta B u_t \\ x_{t+1} &= \bar A x_t + \bar B u_t \end{align*}$$

This isn't the discretization method used Mamba, however, it is along the same lines and it paints a picture of what discretization is doing. At the end of the day these thing our weight matrices. $\Delta$ could be absorbed into the $A$ and $B$ parameters, however, there are some nice theoretical justifications for not doing this, i.e. gating. Since we have a closed form solution for discretization, i.e. $ x_{t+1} &= (I + \Delta A) x_t + \Delta B u_t$, we can just plug in $\Delta$, $A$, and $B$ at every timestep to get our hidden state at the next time step.

I hope this makes sense. There are a lot of fascinating little tricks about Mamba parameterization, which IMO is the actual reason why these things work so well. If you want to get deeper into the math of parameterization I would recommend reading Albert Gu's paper called S4D, which is only on the parameterization of SSMs.

@zzzack66
Copy link
Author

Thanks for your reply. I will try to understand it and read S4D.

@oggyfaker
Copy link

@Hprairie how about the Nan loss in some experiments Issue. Would you have a plan to update it ? I am really looking forward to your progress !!

@Hprairie
Copy link
Owner

Hey yes it's on my bucket list! I'm currently swamped with grads apps, classes, and my research, however, I was planning on fixing the bugs during winter break. Sorry for being slow on this :(

Also might try to play around with ThunderKittens and create some optimized causal and bidirectional kernels!

@oggyfaker
Copy link

oh it's really great to hear that !! Hope to see your update soon @Hprairie

@Hprairie
Copy link
Owner

@oggyfaker Starting to take a look back into this and have been trying to reproduce the error and struggling. I wrote a small training pipeline and haven't had any NaN's appear. Do you have a small reproducible script of NaNs? I will keep trying to recreate them but would greatly appreciate this if you do. Thanks!

@oggyfaker
Copy link

@Hprairie Actually i have intention to implement your lib to replace original Hydra block. But i see some previous comment in another issue when apply for ViT to make sure your repo will work before i start my project. So if you are in process, i will implement project and give you feedback as soon as possible, may be before the end of this year :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants