-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizers.py
40 lines (33 loc) · 1.28 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import mlx.core as mx
from metrics import rmse
def sgd(X, T, W, forward, gradient_f, eta, epochs):
error_trace = []
W_trace = []
for _ in range(epochs):
W_trace.append(mx.flatten(W))
W -= eta * gradient_f(X, T, W)
error_trace.append(rmse(T, forward(X, W)))
return W, error_trace, W_trace
def adam(X, T, W, forward, gradient_f, eta, epochs, beta1=0.9, beta2=0.999, epsilon=1e-8):
m, v = 0, 0
error_trace, W_trace = [], []
for step in range(epochs):
W_trace.append(W.flatten())
g = gradient_f(X, T, W)
m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g * g
mhat = m / (1 - beta1 ** (step+1))
vhat = v / (1 - beta2 ** (step+1))
W -= eta * mhat / (mx.sqrt(vhat) + epsilon)
error_trace.append(rmse(T, forward(X, W)))
return W, error_trace, W_trace
def ols(X, T, pseudoinverse=True):
if pseudoinverse: # Moore-Penrose pseudoinverse
U, S, Vt = mx.linalg.svd(X, stream=mx.cpu)
K = min(X.shape)
U, Vt = U[:, :K], Vt[:K, :]
s_inv = mx.where(S > 1e-15, 1 / S, 0)
pinv = Vt.T @ (mx.diag(s_inv) @ U.T)
return pinv @ T
else: # direct inverse (must be full rank)
return mx.linalg.inv(X.T @ X, stream=mx.cpu) @ X.T @ T