You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
At line no.285 , it is given dx = thread_reverse_data[i].y , but according to my calculation of gradient of x it should be dx = dout . C should be there , also it seems to be as in code
dx in the code might actually be gradient wrt to the hidden states, so maybe dh is a better var name.
At some point in the process of writing up the paper we changed the notation.
Sorry, I was not clear. I understand that dx is gradient wrt hidden states. My question is whether thread_reverse_data[I].y contains dy * C after the Reverse Scan op? It is initialized to dy * C in these lines:
At line no.285 , it is given dx = thread_reverse_data[i].y , but according to my calculation of gradient of x it should be dx = dout . C should be there , also it seems to be as in code
mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh
Line 285 in bc84fb1
but according to my calculation of gradients dx = dout . C should be there , also it seems to be because it is given in file at line no. 260-263
https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L260C8-L264C8
but there is reverse_scan Operation on thread_reverse_data after that , so Does thread_reverse_data after the Reverse_Scan op contains dy .C ?
Thank you very much for your help
The text was updated successfully, but these errors were encountered: