forked from automata/mojograd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mojograd.🔥
188 lines (152 loc) · 5.82 KB
/
mojograd.🔥
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from random import random_float64
from math import tanh
@value
@register_passable("trivial")
struct Value:
var r: Pointer[Int]
var l: Pointer[Int]
var data: Float64
var grad: Float64
var op: StringLiteral
var _id: Float64
fn __init__(data: Float64) -> Value:
return Value(Pointer[Int].get_null(), Pointer[Int].get_null(), data, 0.0, "", random_float64())
fn __eq__(self, other : Value) -> Bool:
# For now using a random_float64 value :-)
if self._id == other._id:
return True
return False
# Add
fn __add__(self, other: Value) -> Value:
return self.new(self.data + other.data, other, "+")
fn __radd__(self, other:Value) -> Value:
return self + other
fn __add__(self, other: Float64) -> Value:
return self + Value(other)
fn __radd__(self, other: Float64) -> Value:
return self + Value(other)
@staticmethod
fn backward_add(inout node: Value):
var l = node.l.bitcast[Value]().load(0)
var r = node.r.bitcast[Value]().load(0)
l.grad += node.grad
r.grad += node.grad
node.l.bitcast[Value]().store(0, l)
node.l.bitcast[Value]().store(0, r)
Value._backward(l)
Value._backward(r)
# Mul
fn __mul__(self, other: Value) -> Value:
return self.new(self.data * other.data, other, "*")
fn __rmul__(self, other: Value) -> Value:
return self * other
fn __mul__(self, other: Float64) -> Value:
return self * Value(other)
fn __rmul__(self, other: Float64) -> Value:
return self * Value(other)
@staticmethod
fn backward_mul(inout node: Value):
var left = node.l.bitcast[Value]().load(0)
var right = node.r.bitcast[Value]().load(0)
left.grad += right.data * node.grad
right.grad += left.data * node.grad
node.l.bitcast[Value]().store(0, left)
node.r.bitcast[Value]().store(0, right)
Value._backward(left)
Value._backward(right)
# Neg
fn __neg__(self) -> Value:
return self * -1
# Sub
fn __sub__(self, other: Value) -> Value:
return self + (-other)
fn __sub__(self, other: Float64) -> Value:
return self + (-Value(other))
# Tanh
fn tanh(self) -> Value:
return self.new(tanh(self.data), "tanh")
fn backward_tanh(inout node: Value):
var left = node.l.bitcast[Value]().load(0)
left.grad += (1 - tanh(left.data)**2) * node.grad
node.l.bitcast[Value]().store(0, left)
Value._backward(left)
# Value alloc
fn new(self, data: Float64, op: StringLiteral) -> Value:
let l = Pointer[Value].alloc(1)
l.store(self)
return Value(l.bitcast[Int](), Pointer[Int].get_null(), data, 0.0, op, random_float64())
fn new(self, data: Float64, right: Value, op: StringLiteral) -> Value:
let l = Pointer[Value].alloc(1)
l.store(self)
let r = Pointer[Value].alloc(1)
r.store(right)
return Value(l.bitcast[Int](), r.bitcast[Int](), data, 0.0, op, random_float64())
# Autograd
@staticmethod
fn _backward(inout node: Value):
if node.op == "":
return
if node.op == "+":
Value.backward_add(node)
if node.op == "*":
Value.backward_mul(node)
if node.op == "tanh":
Value.backward_tanh(node)
fn backward(inout self):
# Topological sort
var topo : DynamicVector[Value] = DynamicVector[Value]()
var visited : DynamicVector[Value] = DynamicVector[Value]()
self.build_topo(self, visited, topo)
self.grad = 1.0
var reversed = Value.reverse(topo)
for i in range(len(reversed)):
self._backward(reversed[i])
fn build_topo(inout self, v : Value, inout visited : DynamicVector[Value], inout topo : DynamicVector[Value]):
var is_in_visited = False
let size = len(visited)
for i in range(size):
if v == visited[i]:
is_in_visited = True
if not is_in_visited:
visited.push_back(v)
# It's pushing back, so visit in reverse, first right then left
if v.r.bitcast[Int]() != Pointer[Int].get_null():
self.build_topo(v.r.bitcast[Value]().load(0), visited, topo)
if v.l.bitcast[Int]() != Pointer[Int].get_null():
self.build_topo(v.l.bitcast[Value]().load(0), visited, topo)
topo.push_back(v)
@staticmethod
fn reverse(vec : DynamicVector[Value]) -> DynamicVector[Value]:
var reversed : DynamicVector[Value] = DynamicVector[Value](len(vec))
for i in range(len(vec)-1, -1, -1):
reversed.push_back(vec[i])
return reversed
fn show(self, label : StringLiteral):
print("<Value", label, "::", "data:", self.data, "grad:", self.grad, "op:", self.op, ">")
@staticmethod
fn print_backward(node: Value):
if node.l and node.r:
let left = node.l.bitcast[Value]().load(0)
let right = node.r.bitcast[Value]().load(0)
print(left.data, "(", left.grad, ")", node.op, right.data, "(", right.grad, ")", "=", node.data)
elif node.l:
let left = node.l.bitcast[Value]().load(0)
print(left.data, "(", left.grad, ")", node.op, "=", node.data)
if node.l:
let left = node.l.bitcast[Value]().load(0)
Value.print_backward(left)
if node.r:
let right = node.r.bitcast[Value]().load(0)
Value.print_backward(right)
fn main():
let a = Value(1)
let b = Value(2)
let c = Value(7)
let s1 = a + b
var s2 = s1 * c
s2.backward()
a.show("a")
a.show("b")
a.show("c")
a.show("s1")
a.show("s2")