Skip to content

Commit 99cfbfe

Browse files
committed
Convert to scala-infer 0.3, add ARD notebook
1 parent 6307997 commit 99cfbfe

File tree

4 files changed

+454
-104
lines changed

4 files changed

+454
-104
lines changed

Diff for: AutoDiff.ipynb

+38-47
Original file line numberDiff line numberDiff line change
@@ -9,72 +9,63 @@
99
},
1010
{
1111
"cell_type": "code",
12-
"execution_count": 9,
12+
"execution_count": 1,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"interp.repositories() ++= Seq(\n",
17+
" coursier.MavenRepository(\"https://dl.bintray.com/scala-infer/maven\")\n",
18+
")"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": 2,
1324
"metadata": {
1425
"collapsed": true
1526
},
1627
"outputs": [
17-
{
18-
"name": "stderr",
19-
"output_type": "stream",
20-
"text": [
21-
"Downloading https://repo1.maven.org/maven2/org/bytedeco/javacpp-presets/mkl-platform/2019.0-1.4.3/mkl-platform-2019.0-1.4.3-sources.jar\n",
22-
"Downloading https://repo1.maven.org/maven2/org/bytedeco/javacpp-presets/mkl-platform/2019.0-1.4.3/mkl-platform-2019.0-1.4.3-sources.jar.sha1\n",
23-
"Downloading https://repo1.maven.org/maven2/org/bytedeco/javacpp-presets/mkl-dnn-platform/0.16-1.4.3/mkl-dnn-platform-0.16-1.4.3-sources.jar\n",
24-
"Downloading https://repo1.maven.org/maven2/org/bytedeco/javacpp-presets/mkl-dnn-platform/0.16-1.4.3/mkl-dnn-platform-0.16-1.4.3-sources.jar.sha1\n",
25-
"Downloading https://repo1.maven.org/maven2/org/nd4j/nd4j-backend-impls/1.0.0-beta3/nd4j-backend-impls-1.0.0-beta3-sources.jar\n",
26-
"Downloading https://repo1.maven.org/maven2/org/nd4j/nd4j-backend-impls/1.0.0-beta3/nd4j-backend-impls-1.0.0-beta3-sources.jar.sha1\n",
27-
"Downloading https://repo1.maven.org/maven2/org/bytedeco/javacpp-presets/openblas-platform/0.3.3-1.4.3/openblas-platform-0.3.3-1.4.3-sources.jar\n",
28-
"Downloading https://repo1.maven.org/maven2/org/bytedeco/javacpp-presets/openblas-platform/0.3.3-1.4.3/openblas-platform-0.3.3-1.4.3-sources.jar.sha1\n"
29-
]
30-
},
3128
{
3229
"data": {
3330
"text/plain": [
34-
"\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \u001b[39m"
31+
"\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \u001b[39m"
3532
]
3633
},
37-
"execution_count": 9,
34+
"execution_count": 2,
3835
"metadata": {},
3936
"output_type": "execute_result"
4037
}
4138
],
4239
"source": [
43-
"interp.repositories() ++= Seq(\n",
44-
" coursier.MavenRepository(\"https://dl.bintray.com/fvlankvelt/maven\")\n",
45-
")\n",
46-
"import $ivy.`fvlankvelt::scala-infer:0.1`"
40+
"import $ivy.`scala-infer::scala-infer:0.3`"
4741
]
4842
},
4943
{
5044
"cell_type": "code",
51-
"execution_count": 10,
45+
"execution_count": 4,
5246
"metadata": {},
5347
"outputs": [
5448
{
5549
"data": {
5650
"text/plain": [
5751
"\u001b[32mimport \u001b[39m\u001b[36mscappla._\n",
5852
"\u001b[39m\n",
59-
"\u001b[32mimport \u001b[39m\u001b[36mscappla.Real._\n",
60-
"\u001b[39m\n",
6153
"\u001b[32mimport \u001b[39m\u001b[36mscappla.Functions._\n",
6254
"\n",
6355
"\u001b[39m\n",
64-
"defined \u001b[32mclass\u001b[39m \u001b[36mParam\u001b[39m"
56+
"defined \u001b[32mclass\u001b[39m \u001b[36mVar\u001b[39m"
6557
]
6658
},
67-
"execution_count": 10,
59+
"execution_count": 4,
6860
"metadata": {},
6961
"output_type": "execute_result"
7062
}
7163
],
7264
"source": [
7365
"import scappla._\n",
74-
"import scappla.Real._\n",
7566
"import scappla.Functions._\n",
7667
"\n",
77-
"case class Param(name: String, v: Double = 0.0) extends Real {\n",
68+
"case class Var(name: String, v: Double = 0.0) extends AbstractReal {\n",
7869
" \n",
7970
" override def dv(d: Double): Unit = {\n",
8071
" println(s\"grad $name = $d\")\n",
@@ -94,33 +85,33 @@
9485
},
9586
{
9687
"cell_type": "code",
97-
"execution_count": 11,
88+
"execution_count": 5,
9889
"metadata": {},
9990
"outputs": [
10091
{
10192
"data": {
10293
"text/plain": [
103-
"\u001b[36mx\u001b[39m: \u001b[32mParam\u001b[39m = \u001b[33mParam\u001b[39m(\u001b[32m\"x\"\u001b[39m, \u001b[32m2.0\u001b[39m)\n",
104-
"\u001b[36my\u001b[39m: \u001b[32mParam\u001b[39m = \u001b[33mParam\u001b[39m(\u001b[32m\"y\"\u001b[39m, \u001b[32m3.0\u001b[39m)\n",
105-
"\u001b[36mz\u001b[39m: \u001b[32mExpr\u001b[39m[\u001b[32mDouble\u001b[39m] = \u001b[33mDAdd\u001b[39m(\u001b[33mParam\u001b[39m(\u001b[32m\"x\"\u001b[39m, \u001b[32m2.0\u001b[39m), \u001b[33mParam\u001b[39m(\u001b[32m\"y\"\u001b[39m, \u001b[32m3.0\u001b[39m))\n",
106-
"\u001b[36mres10_3\u001b[39m: \u001b[32mDouble\u001b[39m = \u001b[32m5.0\u001b[39m"
94+
"\u001b[36mx\u001b[39m: \u001b[32mVar\u001b[39m = \u001b[33mVar\u001b[39m(\u001b[32m\"x\"\u001b[39m, \u001b[32m2.0\u001b[39m)\n",
95+
"\u001b[36my\u001b[39m: \u001b[32mVar\u001b[39m = \u001b[33mVar\u001b[39m(\u001b[32m\"y\"\u001b[39m, \u001b[32m3.0\u001b[39m)\n",
96+
"\u001b[36mz\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mVPlus\u001b[39m(\u001b[33mVar\u001b[39m(\u001b[32m\"x\"\u001b[39m, \u001b[32m2.0\u001b[39m), \u001b[33mVar\u001b[39m(\u001b[32m\"y\"\u001b[39m, \u001b[32m3.0\u001b[39m))\n",
97+
"\u001b[36mres4_3\u001b[39m: \u001b[32mDouble\u001b[39m = \u001b[32m5.0\u001b[39m"
10798
]
10899
},
109-
"execution_count": 11,
100+
"execution_count": 5,
110101
"metadata": {},
111102
"output_type": "execute_result"
112103
}
113104
],
114105
"source": [
115-
"val x = Param(\"x\", 2.0)\n",
116-
"val y = Param(\"y\", 3.0)\n",
106+
"val x = Var(\"x\", 2.0)\n",
107+
"val y = Var(\"y\", 3.0)\n",
117108
"val z = x + y\n",
118109
"z.v"
119110
]
120111
},
121112
{
122113
"cell_type": "code",
123-
"execution_count": 12,
114+
"execution_count": 6,
124115
"metadata": {},
125116
"outputs": [
126117
{
@@ -147,17 +138,17 @@
147138
},
148139
{
149140
"cell_type": "code",
150-
"execution_count": 13,
141+
"execution_count": 7,
151142
"metadata": {},
152143
"outputs": [
153144
{
154145
"data": {
155146
"text/plain": [
156-
"\u001b[36mw\u001b[39m: \u001b[32mReal\u001b[39m = Log(x)\n",
157-
"\u001b[36mres12_1\u001b[39m: \u001b[32mDouble\u001b[39m = \u001b[32m0.6931471805599453\u001b[39m"
147+
"\u001b[36mw\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mVLog\u001b[39m(\u001b[33mVar\u001b[39m(\u001b[32m\"x\"\u001b[39m, \u001b[32m2.0\u001b[39m))\n",
148+
"\u001b[36mres6_1\u001b[39m: \u001b[32mDouble\u001b[39m = \u001b[32m0.6931471805599453\u001b[39m"
158149
]
159150
},
160-
"execution_count": 13,
151+
"execution_count": 7,
161152
"metadata": {},
162153
"output_type": "execute_result"
163154
}
@@ -176,7 +167,7 @@
176167
},
177168
{
178169
"cell_type": "code",
179-
"execution_count": 14,
170+
"execution_count": 8,
180171
"metadata": {},
181172
"outputs": [
182173
{
@@ -201,17 +192,17 @@
201192
},
202193
{
203194
"cell_type": "code",
204-
"execution_count": 15,
195+
"execution_count": 9,
205196
"metadata": {},
206197
"outputs": [
207198
{
208199
"data": {
209200
"text/plain": [
210-
"\u001b[36mu\u001b[39m: \u001b[32mReal\u001b[39m = Log(Const(1.0000) / (Exp(-x) + Const(1.0000)))\n",
211-
"\u001b[36mres14_1\u001b[39m: \u001b[32mDouble\u001b[39m = \u001b[32m-0.12692801104297263\u001b[39m"
201+
"\u001b[36mu\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mVLog\u001b[39m(\u001b[33mVSigmoid\u001b[39m(\u001b[33mVar\u001b[39m(\u001b[32m\"x\"\u001b[39m, \u001b[32m2.0\u001b[39m)))\n",
202+
"\u001b[36mres8_1\u001b[39m: \u001b[32mDouble\u001b[39m = \u001b[32m-0.12692801104297263\u001b[39m"
212203
]
213204
},
214-
"execution_count": 15,
205+
"execution_count": 9,
215206
"metadata": {},
216207
"output_type": "execute_result"
217208
}
@@ -223,7 +214,7 @@
223214
},
224215
{
225216
"cell_type": "code",
226-
"execution_count": 16,
217+
"execution_count": 10,
227218
"metadata": {},
228219
"outputs": [
229220
{

0 commit comments

Comments
 (0)