forked from state-spaces/mamba
-
Notifications
You must be signed in to change notification settings - Fork 0
/
selective_scan_interface.py
357 lines (336 loc) · 16.4 KB
/
selective_scan_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# Copyright (c) 2023, Tri Dao, Albert Gu.
import torch
import torch.nn.functional as F
from mamba_ssm.utils.torch import custom_bwd, custom_fwd
from einops import rearrange, repeat
try:
from causal_conv1d import causal_conv1d_fn
import causal_conv1d_cuda
except ImportError:
causal_conv1d_fn = None
causal_conv1d_cuda = None
import selective_scan_cuda
class SelectiveScanFn(torch.autograd.Function):
@staticmethod
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
@staticmethod
def backward(ctx, dout, *args):
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
# Here we just pass in None and dz will be allocated in the C++ code.
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
False # option to recompute out_z, not used here
)
dz = rest[0] if ctx.has_z else None
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (du, ddelta, dA, dB, dC,
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
None,
None)
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
class MambaInnerFn(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
"""
xz: (batch, dim, seqlen)
"""
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
assert checkpoint_lvl in [0, 1]
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
if torch.is_autocast_enabled():
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
if out_proj_bias is not None else None)
if xz.stride(-1) != 1:
xz = xz.contiguous()
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
ctx.is_variable_B = B is None
ctx.is_variable_C = C is None
ctx.B_proj_bias_is_None = B_proj_bias is None
ctx.C_proj_bias_is_None = C_proj_bias is None
if B is None: # variable B
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if B.stride(-1) != 1:
B = B.contiguous()
if C is None: # variable C
C = x_dbl[:, -d_state:] # (bl dstate)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if C.stride(-1) != 1:
C = C.contiguous()
if D is not None:
D = D.contiguous()
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
ctx.delta_softplus = delta_softplus
ctx.out_proj_bias_is_None = out_proj_bias is None
ctx.checkpoint_lvl = checkpoint_lvl
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
conv1d_out, delta = None, None
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
delta_proj_weight, out_proj_weight, conv1d_out, delta,
A, B, C, D, delta_bias, scan_intermediates, out)
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
@staticmethod
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
if dout.stride(-1) != 1:
dout = dout.contiguous()
if ctx.checkpoint_lvl == 1:
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
"d (b l) -> b d l", l = L)
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
dx, dz = dxz.chunk(2, dim=1)
dout = rearrange(dout, "b l e -> e (b l)")
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
ctx.delta_softplus,
True # option to recompute out_z
)
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
dD = dD if D is not None else None
dx_dbl = torch.empty_like(x_dbl)
dB_proj_bias = None
if ctx.is_variable_B:
if not A.is_complex():
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
dB = None
dC_proj_bias = None
if ctx.is_variable_C:
if not A.is_complex():
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
dx_dbl[:, -d_state:] = dC # (bl d)
dC = None
ddelta = rearrange(ddelta, "b d l -> d (b l)")
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
dout_proj_weight, dout_proj_bias,
dA, dB, dC, dD,
ddelta_bias if delta_bias is not None else None,
dB_proj_bias, dC_proj_bias, None)
def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
def mamba_inner_ref(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
delta = rearrange(delta, "d (b l) -> b d l", l=L)
if B is None: # variable B
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
if C is None: # variable B
C = x_dbl[:, -d_state:] # (bl d)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)