You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a working context parallel implementation forked from this repo for forward/backward passes which required two modifications
padding conv layer input chunks on each GPU with the last N_padding tokens of the previous GPU and then discarding padding token output indices
transferring final states in state passing point-to-point between GPUs sequentially
And then vice-a-versa for the backward pass. I believe I've also worked out a way to do this without sequential point-to-point.
Would this be useful to contribute? If so, would like to know best way to do so since it requires modification of the core wrapper of the mamba 2 triton code.
The text was updated successfully, but these errors were encountered:
I have a working context parallel implementation forked from this repo for forward/backward passes which required two modifications
And then vice-a-versa for the backward pass. I believe I've also worked out a way to do this without sequential point-to-point.
Would this be useful to contribute? If so, would like to know best way to do so since it requires modification of the core wrapper of the mamba 2 triton code.
The text was updated successfully, but these errors were encountered: