-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
layernorm.py
70 lines (59 loc) · 1.95 KB
/
layernorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
eps = 1e-5
class LayerNorm:
@staticmethod
def forward(x, w, b):
B, T, C = x.size()
mean = x.sum(-1, keepdim=True) / C # B,T,1
xshift = x - mean # B,T,C
var = (xshift**2).sum(-1, keepdim=True) / C # B,T,1
rstd = (var + eps) ** -0.5 # B,T,1
norm = xshift * rstd # B,T,C
out = norm * w + b # B,T,C
cache = (x, w, mean, rstd)
return out, cache
@staticmethod
def backward(dout, cache):
x, w, mean, rstd = cache
# recompute the norm (save memory at the cost of compute)
norm = (x - mean) * rstd
# gradients for weights, bias
db = dout.sum((0, 1))
dw = (dout * norm).sum((0, 1))
# gradients for input
dnorm = dout * w
dx = dnorm - dnorm.mean(-1, keepdim=True) - norm * (dnorm * norm).mean(-1, keepdim=True)
dx *= rstd
return dx, dw, db
# create a small dummy example and check w.r.t PyTorch backward
B = 2
T = 3
C = 4
x = torch.randn(B, T, C, requires_grad=True)
w = torch.randn(C, requires_grad=True)
b = torch.randn(C, requires_grad=True)
out, cache = LayerNorm.forward(x, w, b)
dout = torch.randn(B, T, C)
dx, dw, db = LayerNorm.backward(dout, cache)
# compare to PyTorch autograd
fakeloss = (out * dout).sum()
fakeloss.backward()
print("dx error:", (x.grad - dx).abs().max().item())
print("dw error:", (w.grad - dw).abs().max().item())
print("db error:", (b.grad - db).abs().max().item())
# for reference checking in C also
x, w, mean, rstd = cache
def write(tensor, handle):
handle.write(tensor.detach().numpy().astype("float32").tobytes())
# Write to file
with open('ln.bin', 'wb') as file:
write(x, file) # (B, T, C)
write(w, file) # (C, )
write(b, file) # (C, )
write(out, file) # (B, T, C)
write(mean, file) # (B, T)
write(rstd, file) # (B, T)
write(dout, file) # (B, T, C)
write(dx, file) # (B, T, C)
write(dw, file) # (C, )
write(db, file) # (C, )