|
28 | 28 | },
|
29 | 29 | {
|
30 | 30 | "cell_type": "code",
|
31 |
| - "execution_count": 2, |
32 |
| - "metadata": { |
33 |
| - "collapsed": true |
34 |
| - }, |
35 |
| - "outputs": [ |
36 |
| - { |
37 |
| - "data": { |
38 |
| - "text/plain": [ |
39 |
| - "\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \n", |
40 |
| - "\u001b[39m\n", |
41 |
| - "\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \u001b[39m" |
42 |
| - ] |
43 |
| - }, |
44 |
| - "execution_count": 2, |
45 |
| - "metadata": {}, |
46 |
| - "output_type": "execute_result" |
47 |
| - } |
48 |
| - ], |
| 31 | + "execution_count": null, |
| 32 | + "metadata": {}, |
| 33 | + "outputs": [], |
49 | 34 | "source": [
|
50 | 35 | "import $ivy.`scala-infer::scala-infer:0.3`\n",
|
51 | 36 | "import $ivy.`org.jupyter-scala::kernel-api:0.4.1`"
|
52 | 37 | ]
|
53 | 38 | },
|
54 | 39 | {
|
55 | 40 | "cell_type": "code",
|
56 |
| - "execution_count": 3, |
57 |
| - "metadata": { |
58 |
| - "collapsed": true |
59 |
| - }, |
60 |
| - "outputs": [ |
61 |
| - { |
62 |
| - "data": { |
63 |
| - "text/plain": [ |
64 |
| - "\u001b[32mimport \u001b[39m\u001b[36mscappla._\n", |
65 |
| - "\u001b[39m\n", |
66 |
| - "\u001b[32mimport \u001b[39m\u001b[36mscappla.Functions._\n", |
67 |
| - "\u001b[39m\n", |
68 |
| - "\u001b[32mimport \u001b[39m\u001b[36mscappla.distributions._\n", |
69 |
| - "\u001b[39m\n", |
70 |
| - "\u001b[32mimport \u001b[39m\u001b[36mscappla.guides._\n", |
71 |
| - "\u001b[39m\n", |
72 |
| - "\u001b[32mimport \u001b[39m\u001b[36mscappla.optimization._\n", |
73 |
| - "\u001b[39m\n", |
74 |
| - "\u001b[32mimport \u001b[39m\u001b[36mscappla.tensor.Tensor._\n", |
75 |
| - "\u001b[39m\n", |
76 |
| - "\u001b[32mimport \u001b[39m\u001b[36mscappla.tensor._\u001b[39m" |
77 |
| - ] |
78 |
| - }, |
79 |
| - "execution_count": 3, |
80 |
| - "metadata": {}, |
81 |
| - "output_type": "execute_result" |
82 |
| - } |
83 |
| - ], |
| 41 | + "execution_count": null, |
| 42 | + "metadata": {}, |
| 43 | + "outputs": [], |
84 | 44 | "source": [
|
85 | 45 | "import scappla._\n",
|
86 | 46 | "import scappla.Functions._\n",
|
87 | 47 | "import scappla.distributions._\n",
|
88 | 48 | "import scappla.guides._\n",
|
89 | 49 | "import scappla.optimization._\n",
|
90 | 50 | "import scappla.tensor.Tensor._\n",
|
91 |
| - "import scappla.tensor._" |
92 |
| - ] |
93 |
| - }, |
94 |
| - { |
95 |
| - "cell_type": "code", |
96 |
| - "execution_count": 4, |
97 |
| - "metadata": { |
98 |
| - "collapsed": true |
99 |
| - }, |
100 |
| - "outputs": [ |
101 |
| - { |
102 |
| - "data": { |
103 |
| - "text/plain": [ |
104 |
| - "\u001b[32mimport \u001b[39m\u001b[36mscala.util.Random\u001b[39m" |
105 |
| - ] |
106 |
| - }, |
107 |
| - "execution_count": 4, |
108 |
| - "metadata": {}, |
109 |
| - "output_type": "execute_result" |
110 |
| - } |
111 |
| - ], |
112 |
| - "source": [ |
| 51 | + "import scappla.tensor._\n", |
| 52 | + "\n", |
113 | 53 | "import scala.util.Random"
|
114 | 54 | ]
|
115 | 55 | },
|
116 | 56 | {
|
117 | 57 | "cell_type": "code",
|
118 |
| - "execution_count": 5, |
119 |
| - "metadata": { |
120 |
| - "collapsed": true |
121 |
| - }, |
122 |
| - "outputs": [ |
123 |
| - { |
124 |
| - "data": { |
125 |
| - "text/plain": [ |
126 |
| - "defined \u001b[32mclass\u001b[39m \u001b[36mRecord\u001b[39m\n", |
127 |
| - "defined \u001b[32mclass\u001b[39m \u001b[36mBatch\u001b[39m\n", |
128 |
| - "\u001b[36mbatch\u001b[39m: \u001b[32mBatch\u001b[39m = \u001b[33mBatch\u001b[39m(\u001b[32m1000\u001b[39m)\n", |
129 |
| - "\u001b[36ma_vals\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mArrayTensor\u001b[39m, \u001b[32mBatch\u001b[39m] = scappla.Constant@12f3f093\n", |
130 |
| - "\u001b[36mb_vals\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mArrayTensor\u001b[39m, \u001b[32mBatch\u001b[39m] = scappla.Constant@31e80b51\n", |
131 |
| - "\u001b[36my_vals\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mArrayTensor\u001b[39m, \u001b[32mBatch\u001b[39m] = scappla.Constant@5f6ae152" |
132 |
| - ] |
133 |
| - }, |
134 |
| - "execution_count": 5, |
135 |
| - "metadata": {}, |
136 |
| - "output_type": "execute_result" |
137 |
| - } |
138 |
| - ], |
| 58 | + "execution_count": null, |
| 59 | + "metadata": {}, |
| 60 | + "outputs": [], |
139 | 61 | "source": [
|
140 | 62 | "case class Record(a: Float, b: Float, y: Float)\n",
|
141 | 63 | "\n",
|
|
165 | 87 | },
|
166 | 88 | {
|
167 | 89 | "cell_type": "code",
|
168 |
| - "execution_count": 6, |
169 |
| - "metadata": { |
170 |
| - "collapsed": true |
171 |
| - }, |
172 |
| - "outputs": [ |
173 |
| - { |
174 |
| - "data": { |
175 |
| - "text/plain": [ |
176 |
| - "\u001b[36ma_prior_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@7e13f371\n", |
177 |
| - "\u001b[36mb_prior_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@314da4f9\n", |
178 |
| - "\u001b[36ma_post_mu\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@33764907\n", |
179 |
| - "\u001b[36ma_post_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@3a272583\n", |
180 |
| - "\u001b[36ma_guide\u001b[39m: \u001b[32mReparamGuide\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mReparamGuide\u001b[39m(\n", |
181 |
| - " \u001b[33mNormal\u001b[39m(scappla.Param@33764907, \u001b[33mApply1\u001b[39m(scappla.Param@3a272583, <function1>))\n", |
182 |
| - ")\n", |
183 |
| - "\u001b[36mb_post_mu\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@5cbb302d\n", |
184 |
| - "\u001b[36mb_post_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@5d9b0110\n", |
185 |
| - "\u001b[36mb_guide\u001b[39m: \u001b[32mReparamGuide\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mReparamGuide\u001b[39m(\n", |
186 |
| - " \u001b[33mNormal\u001b[39m(scappla.Param@5cbb302d, \u001b[33mApply1\u001b[39m(scappla.Param@5d9b0110, <function1>))\n", |
187 |
| - ")\n", |
188 |
| - "\u001b[36mnoise_mu\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@33f65cad\n", |
189 |
| - "\u001b[36mnoise_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@58927bd0\n", |
190 |
| - "\u001b[36mnoise_guide\u001b[39m: \u001b[32mReparamGuide\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mReparamGuide\u001b[39m(\n", |
191 |
| - " \u001b[33mNormal\u001b[39m(scappla.Param@33f65cad, \u001b[33mApply1\u001b[39m(scappla.Param@58927bd0, <function1>))\n", |
192 |
| - ")\n", |
193 |
| - "\u001b[36mmodel\u001b[39m: \u001b[32mModel\u001b[39m[\u001b[32mUnit\u001b[39m] = ammonite.$sess.cmd5$Helper$$anon$1@67449597" |
194 |
| - ] |
195 |
| - }, |
196 |
| - "execution_count": 6, |
197 |
| - "metadata": {}, |
198 |
| - "output_type": "execute_result" |
199 |
| - } |
200 |
| - ], |
| 90 | + "execution_count": null, |
| 91 | + "metadata": {}, |
| 92 | + "outputs": [], |
201 | 93 | "source": [
|
202 | 94 | "val a_prior_s = Param(0.0)\n",
|
203 | 95 | "val b_prior_s = Param(0.0)\n",
|
|
229 | 121 | },
|
230 | 122 | {
|
231 | 123 | "cell_type": "code",
|
232 |
| - "execution_count": 7, |
233 |
| - "metadata": { |
234 |
| - "collapsed": true |
235 |
| - }, |
236 |
| - "outputs": [ |
237 |
| - { |
238 |
| - "data": { |
239 |
| - "text/plain": [ |
240 |
| - "\u001b[36mopt\u001b[39m: \u001b[32mAdam\u001b[39m = scappla.optimization.Adam@7bc3e74f\n", |
241 |
| - "\u001b[36minterpreter\u001b[39m: \u001b[32mOptimizingInterpreter\u001b[39m = scappla.OptimizingInterpreter@115f4b3" |
242 |
| - ] |
243 |
| - }, |
244 |
| - "execution_count": 7, |
245 |
| - "metadata": {}, |
246 |
| - "output_type": "execute_result" |
247 |
| - } |
248 |
| - ], |
249 |
| - "source": [ |
250 |
| - "val opt = new Adam(0.1)\n", |
251 |
| - "val interpreter = new OptimizingInterpreter(opt)" |
252 |
| - ] |
253 |
| - }, |
254 |
| - { |
255 |
| - "cell_type": "code", |
256 |
| - "execution_count": 15, |
| 124 | + "execution_count": null, |
257 | 125 | "metadata": {},
|
258 | 126 | "outputs": [],
|
259 | 127 | "source": [
|
| 128 | + "val opt = new Adam(0.1)\n", |
| 129 | + "val interpreter = new OptimizingInterpreter(opt)\n", |
| 130 | + "\n", |
260 | 131 | "for { _ <- 0 until 10000 } {\n",
|
261 | 132 | " interpreter.reset()\n",
|
262 | 133 | " model.sample(interpreter)\n",
|
|
265 | 136 | },
|
266 | 137 | {
|
267 | 138 | "cell_type": "code",
|
268 |
| - "execution_count": 17, |
269 |
| - "metadata": { |
270 |
| - "collapsed": true |
271 |
| - }, |
272 |
| - "outputs": [ |
273 |
| - { |
274 |
| - "data": { |
275 |
| - "text/plain": [ |
276 |
| - "\u001b[36mparams\u001b[39m: \u001b[32mSeq\u001b[39m[(\u001b[32mString\u001b[39m, \u001b[32mExpr\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m])] = \u001b[33mList\u001b[39m(\n", |
277 |
| - " (\u001b[32m\"a_prior\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@7e13f371, <function1>)),\n", |
278 |
| - " (\u001b[32m\"b_prior\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@314da4f9, <function1>)),\n", |
279 |
| - " (\u001b[32m\"a_post_mu\"\u001b[39m, scappla.Param@33764907),\n", |
280 |
| - " (\u001b[32m\"a_post_s\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@3a272583, <function1>)),\n", |
281 |
| - " (\u001b[32m\"b_post_mu\"\u001b[39m, scappla.Param@5cbb302d),\n", |
282 |
| - " (\u001b[32m\"b_post_s\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@5d9b0110, <function1>)),\n", |
283 |
| - " (\u001b[32m\"noise_mu\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@33f65cad, <function1>)),\n", |
284 |
| - " (\u001b[32m\"noise_s\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@58927bd0, <function1>))\n", |
285 |
| - ")" |
286 |
| - ] |
287 |
| - }, |
288 |
| - "execution_count": 17, |
289 |
| - "metadata": {}, |
290 |
| - "output_type": "execute_result" |
291 |
| - } |
292 |
| - ], |
| 139 | + "execution_count": null, |
| 140 | + "metadata": {}, |
| 141 | + "outputs": [], |
293 | 142 | "source": [
|
294 | 143 | "val params = Seq(\n",
|
295 | 144 | " \"a_prior\" -> exp(a_prior_s),\n",
|
|
0 commit comments