-
Notifications
You must be signed in to change notification settings - Fork 658
/
flax_basics.ipynb
969 lines (969 loc) · 37.3 KB
/
flax_basics.ipynb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)\n",
"[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)\n",
"\n",
"# Flax Basics\n",
"\n",
"This notebook will walk you through the following workflow:\n",
"\n",
"* Instantiating a model from Flax built-in layers or third-party models.\n",
"* Initializing parameters of the model and manually written training.\n",
"* Using optimizers provided by Flax to ease training.\n",
"* Serialization of parameters and other objects.\n",
"* Creating your own models and managing state."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting up our environment\n",
"\n",
"Here we provide the code needed to set up the environment for our notebook."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"outputId": "e30aa464-fa52-4f35-df96-716c68a4b3ee",
"tags": [
"skip-execution"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n",
"\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n"
]
}
],
"source": [
"# Install the latest JAXlib version.\n",
"!pip install --upgrade -q pip jax jaxlib\n",
"# Install Flax at head:\n",
"!pip install --upgrade -q git+https://github.com/google/flax.git"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"from typing import Any, Callable, Sequence\n",
"from jax import random, numpy as jnp\n",
"import flax\n",
"from flax import linen as nn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Linear regression with Flax\n",
"\n",
"In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done.\n",
"\n",
"A dense layer is a layer that has a kernel parameter $W\\in\\mathcal{M}_{m,n}(\\mathbb{R})$ where $m$ is the number of features as an output of the model, and $n$ the dimensionality of the input, and a bias parameter $b\\in\\mathbb{R}^m$. The dense layers returns $Wx+b$ from an input $x\\in\\mathbb{R}^n$.\n",
"\n",
"This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# We create one dense layer instance (taking 'features' parameter as input)\n",
"model = nn.Dense(features=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class.\n",
"\n",
"### Model parameters & initialization\n",
"\n",
"Parameters are not stored with the models themselves. You need to initialize parameters by calling the `init` function, using a PRNGKey and dummy input data."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"outputId": "06feb9d2-db50-4f41-c169-6df4336f43a5"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"data": {
"text/plain": [
"FrozenDict({\n",
" params: {\n",
" bias: (5,),\n",
" kernel: (10, 5),\n",
" },\n",
"})"
]
},
"execution_count": 4,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"key1, key2 = random.split(random.key(0))\n",
"x = random.normal(key1, (10,)) # Dummy input data\n",
"params = model.init(key2, x) # Initialization call\n",
"jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.*\n",
"\n",
"The result is what we expect: bias and kernel parameters of the correct size. Under the hood:\n",
"\n",
"* The dummy input data `x` is used to trigger shape inference: we only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the kernel.\n",
"* The random PRNG key is used to trigger the initialization functions (those have default values provided by the module here).\n",
"* Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`.\n",
"* The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"outputId": "7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([-0.7358944, 1.3583755, -0.7976872, 0.8168598, 0.6297793], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"model.apply(params, x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gradient descent\n",
"\n",
"If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\\{(x_i,y_i), i\\in \\{1,\\ldots, k\\}, x_i\\in\\mathbb{R}^n,y_i\\in\\mathbb{R}^m\\}$, we try to find a set of parameters $W\\in \\mathcal{M}_{m,n}(\\mathbb{R}), b\\in\\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:\n",
"\n",
"$$\\mathcal{L}(W,b)\\rightarrow\\frac{1}{k}\\sum_{i=1}^{k} \\frac{1}{2}\\|y_i-f_{W,b}(x_i)\\|^2_2$$\n",
"\n",
"Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"outputId": "6eae59dc-0632-4f53-eac8-c22a7c646a52"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x shape: (20, 10) ; y shape: (20, 5)\n"
]
}
],
"source": [
"# Set problem dimensions.\n",
"n_samples = 20\n",
"x_dim = 10\n",
"y_dim = 5\n",
"\n",
"# Generate random ground truth W and b.\n",
"key = random.key(0)\n",
"k1, k2 = random.split(key)\n",
"W = random.normal(k1, (x_dim, y_dim))\n",
"b = random.normal(k2, (y_dim,))\n",
"# Store the parameters in a FrozenDict pytree.\n",
"true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n",
"\n",
"# Generate samples with additional noise.\n",
"key_sample, key_noise = random.split(k1)\n",
"x_samples = random.normal(key_sample, (n_samples, x_dim))\n",
"y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))\n",
"print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees))."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Same as JAX version but using model.apply().\n",
"@jax.jit\n",
"def mse(params, x_batched, y_batched):\n",
" # Define the squared loss for a single pair (x,y)\n",
" def squared_error(x, y):\n",
" pred = model.apply(params, x)\n",
" return jnp.inner(y-pred, y-pred) / 2.0\n",
" # Vectorize the previous to compute the average of the loss on all samples.\n",
" return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And finally perform the gradient descent."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"outputId": "50d975b3-4706-4d8a-c4b8-2629ab8e3ac4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss for \"true\" W,b: 0.023639778\n",
"Loss step 0: 38.094772\n",
"Loss step 10: 0.44692168\n",
"Loss step 20: 0.10053458\n",
"Loss step 30: 0.035822745\n",
"Loss step 40: 0.018846875\n",
"Loss step 50: 0.013864839\n",
"Loss step 60: 0.012312559\n",
"Loss step 70: 0.011812928\n",
"Loss step 80: 0.011649306\n",
"Loss step 90: 0.011595251\n",
"Loss step 100: 0.0115773035\n"
]
}
],
"source": [
"learning_rate = 0.3 # Gradient step size.\n",
"print('Loss for \"true\" W,b: ', mse(true_params, x_samples, y_samples))\n",
"loss_grad_fn = jax.value_and_grad(mse)\n",
"\n",
"@jax.jit\n",
"def update_params(params, learning_rate, grads):\n",
" params = jax.tree_util.tree_map(\n",
" lambda p, g: p - learning_rate * g, params, grads)\n",
" return params\n",
"\n",
"for i in range(101):\n",
" # Perform one gradient update.\n",
" loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n",
" params = update_params(params, learning_rate, grads)\n",
" if i % 10 == 0:\n",
" print(f'Loss step {i}: ', loss_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Optimizing with Optax\n",
"\n",
"Flax used to use its own `flax.optim` package for optimization, but with\n",
"[FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md)\n",
"this was deprecated in favor of\n",
"[Optax](https://github.com/deepmind/optax).\n",
"\n",
"Basic usage of Optax is straightforward:\n",
"\n",
"1. Choose an optimization method (e.g. `optax.adam`).\n",
"2. Create optimizer state from parameters (for the Adam optimizer, this state will contain the [momentum values](https://optax.readthedocs.io/en/latest/api.html#optax.adam)).\n",
"3. Compute the gradients of your loss with `jax.value_and_grad()`.\n",
"4. At every iteration, call the Optax `update` function to update the internal\n",
" optimizer state and create an update to the parameters. Then add the update\n",
" to the parameters with Optax's `apply_updates` method.\n",
"\n",
"Note that Optax can do a lot more: it's designed for composing simple gradient\n",
"transformations into more complex transformations that allows to implement a\n",
"wide range of optimizers. There is also support for changing optimizer\n",
"hyperparameters over time (\"schedules\"), applying different updates to different\n",
"parts of the parameter tree (\"masking\") and much more. For details please refer\n",
"to the\n",
"[official documentation](https://optax.readthedocs.io/en/latest/)."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import optax\n",
"tx = optax.adam(learning_rate=learning_rate)\n",
"opt_state = tx.init(params)\n",
"loss_grad_fn = jax.value_and_grad(mse)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"outputId": "eec0c096-1d9e-4b3c-f8e5-942ee63828ec"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss step 0: 0.011576377\n",
"Loss step 10: 0.0115710115\n",
"Loss step 20: 0.011569244\n",
"Loss step 30: 0.011568661\n",
"Loss step 40: 0.011568454\n",
"Loss step 50: 0.011568379\n",
"Loss step 60: 0.011568358\n",
"Loss step 70: 0.01156836\n",
"Loss step 80: 0.01156835\n",
"Loss step 90: 0.011568353\n",
"Loss step 100: 0.011568348\n"
]
}
],
"source": [
"for i in range(101):\n",
" loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n",
" updates, opt_state = tx.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" if i % 10 == 0:\n",
" print('Loss step {}: '.format(i), loss_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Serializing the result\n",
"\n",
"Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"outputId": "b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dict output\n",
"{'params': {'bias': DeviceArray([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547], dtype=float32), 'kernel': DeviceArray([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285,\n",
" 0.34720102],\n",
" [ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 ,\n",
" -0.10574618],\n",
" [-1.2009128 , 0.28837162, 1.4176372 , 0.12073109,\n",
" -1.3132601 ],\n",
" [-1.1944956 , -0.18993308, 0.03379077, 1.3165942 ,\n",
" 0.07996067],\n",
" [ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774,\n",
" -2.239638 ],\n",
" [ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 ,\n",
" 0.90352124],\n",
" [-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 ,\n",
" 0.9286919 ],\n",
" [ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117,\n",
" -1.202457 ],\n",
" [ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 ,\n",
" -0.45790705],\n",
" [-0.64384323, 0.4564892 , -1.1331053 , -0.68556863,\n",
" 0.17010891]], dtype=float32)}}\n",
"Bytes output\n",
"b'\\x81\\xa6params\\x82\\xa4bias\\xc7!\\x01\\x93\\x91\\x05\\xa7float32\\xc4\\x14\\x1d\\x1d\\xba\\xbf\\xc4\\xad\\x01\\xc0\\x81)\\x05@\\xdd.\\x9c?\\xa8\\x17\\x7f\\xbf\\xa6kernel\\xc7\\xd6\\x01\\x93\\x92\\n\\x05\\xa7float32\\xc4\\xc8\\x84]\\x81?\\xf0\\xb5B>`\\xb59=z^m\\xbfU\\xc4\\xb1>\\x00\\xb3\\xdd?\\xb8x}?\\xc7F\\x95?2(\\x8d?t\\x91\\xd8\\xbd\\x83\\xb7\\x99\\xbfr\\xa5\\x93>#u\\xb5?\\xdcA\\xf7=\\xe8\\x18\\xa8\\xbf;\\xe5\\x98\\xbf\\xd1}B\\xbe0h\\n=)\\x86\\xa8?k\\xc2\\xa3=\\xaaj\\x10>\\x91\\xd8\\xaf?\\xa9y\\xa8\\xbfc\\xb5\\x08?;V\\x0f\\xc0Av\\x10?ZHP?wD\\xa3>\\x022\\t?+Mg?\\xa0K\\xc2\\xbe\\xb1\\xd3\\xde?)\\x16\\x8a?\\x04\\x13\\x01\\xbf\\xc1\\xbem?\\xfdZx?Wn\\xa8\\xbf\\x940\\xac>\\x925O?\\x1c\\xea\\x99\\xbf\\x9e\\x89\\x82?\\x07\\xad\\x1e\\xbf\\xe2\\x87\\x8a?\\xdfU\\xeb\\xbf\\xcbr\\xea\\xbe\\xe9\\xd2$\\xbf\\xf4\\xb8\\xe9>\\x98\\t\\x91\\xbfm\\x81/\\xbf\\x081.>'\n"
]
}
],
"source": [
"from flax import serialization\n",
"bytes_output = serialization.to_bytes(params)\n",
"dict_output = serialization.to_state_dict(params)\n",
"print('Dict output')\n",
"print(dict_output)\n",
"print('Bytes output')\n",
"print(bytes_output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place.\n",
"\n",
"*The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.*"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"outputId": "13acc4e1-8757-4554-e2c8-d594ba6e67dc"
},
"outputs": [
{
"data": {
"text/plain": [
"FrozenDict({\n",
" params: {\n",
" bias: array([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547],\n",
" dtype=float32),\n",
" kernel: array([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285, 0.34720102],\n",
" [ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 , -0.10574618],\n",
" [-1.2009128 , 0.28837162, 1.4176372 , 0.12073109, -1.3132601 ],\n",
" [-1.1944956 , -0.18993308, 0.03379077, 1.3165942 , 0.07996067],\n",
" [ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774, -2.239638 ],\n",
" [ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 , 0.90352124],\n",
" [-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 , 0.9286919 ],\n",
" [ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117, -1.202457 ],\n",
" [ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 , -0.45790705],\n",
" [-0.64384323, 0.4564892 , -1.1331053 , -0.68556863, 0.17010891]],\n",
" dtype=float32),\n",
" },\n",
"})"
]
},
"execution_count": 14,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"serialization.from_bytes(params, bytes_output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining your own models\n",
"\n",
"Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class.\n",
"\n",
"*Keep in mind that we imported* `linen as nn` *and this only works with the new linen API*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Module basics\n",
"\n",
"The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"outputId": "b59c679c-d164-4fd6-92db-b50f0d310ec3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"initialized parameter shapes:\n",
" {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n",
"output:\n",
" [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n",
" -1.7147182e-02]\n",
" [ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02\n",
" -4.5417298e-02]\n",
" [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
" 0.0000000e+00]\n",
" [ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04\n",
" -1.0110770e-03]]\n"
]
}
],
"source": [
"class ExplicitMLP(nn.Module):\n",
" features: Sequence[int]\n",
"\n",
" def setup(self):\n",
" # we automatically know what to do with lists, dicts of submodules\n",
" self.layers = [nn.Dense(feat) for feat in self.features]\n",
" # for single submodules, we would just write:\n",
" # self.layer1 = nn.Dense(feat1)\n",
"\n",
" def __call__(self, inputs):\n",
" x = inputs\n",
" for i, lyr in enumerate(self.layers):\n",
" x = lyr(x)\n",
" if i != len(self.layers) - 1:\n",
" x = nn.relu(x)\n",
" return x\n",
"\n",
"key1, key2 = random.split(random.key(0), 2)\n",
"x = random.uniform(key1, (4,4))\n",
"\n",
"model = ExplicitMLP(features=[3,4,5])\n",
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, a `nn.Module` subclass is made of:\n",
"\n",
"* A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`.\n",
"* A `setup()` method that is being called at the end of the `__postinit__` where you can register submodules, variables, parameters you will need in your model.\n",
"* A `__call__` function that returns the output of the model from a given input.\n",
"* The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one `layers_n` sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit.\n",
"\n",
"*Note: lists are mostly managed as you would expect (WIP), there are corner cases you should be aware of as pointed out* [here](https://github.com/google/flax/issues/524)\n",
"\n",
"Since the module structure and its parameters are not tied to each other, you can't directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one, which is the one to call on an input:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"outputId": "4af16ec5-b52a-43b0-fc47-1f8ab25e7058"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"ExplicitMLP\" object has no attribute \"layers\"\n"
]
}
],
"source": [
"try:\n",
" y = model(x) # Returns an error\n",
"except AttributeError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"outputId": "183a74ef-f54e-4848-99bf-fee4c174ba6d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"initialized parameter shapes:\n",
" {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n",
"output:\n",
" [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n",
" -1.7147182e-02]\n",
" [ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02\n",
" -4.5417298e-02]\n",
" [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
" 0.0000000e+00]\n",
" [ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04\n",
" -1.0110770e-03]]\n"
]
}
],
"source": [
"class SimpleMLP(nn.Module):\n",
" features: Sequence[int]\n",
"\n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
" x = inputs\n",
" for i, feat in enumerate(self.features):\n",
" x = nn.Dense(feat, name=f'layers_{i}')(x)\n",
" if i != len(self.features) - 1:\n",
" x = nn.relu(x)\n",
" # providing a name is optional though!\n",
" # the default autonames would be \"Dense_0\", \"Dense_1\", ...\n",
" return x\n",
"\n",
"key1, key2 = random.split(random.key(0), 2)\n",
"x = random.uniform(key1, (4,4))\n",
"\n",
"model = SimpleMLP(features=[3,4,5])\n",
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are, however, a few differences you should be aware of between the two declaration modes:\n",
"\n",
"* In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders).\n",
"* If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated.\n",
"* The last initialization will be handled differently. See these notes for more details (TODO: add notes link)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Module parameters\n",
"\n",
"In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"outputId": "83b5fea4-071e-4ea0-8fa8-610e69fb5fd5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"initialized parameters:\n",
" FrozenDict({\n",
" params: {\n",
" kernel: DeviceArray([[ 0.6503669 , 0.86789787, 0.4604268 ],\n",
" [ 0.05673932, 0.9909285 , -0.63536596],\n",
" [ 0.76134115, -0.3250529 , -0.65221626],\n",
" [-0.82430327, 0.4150194 , 0.19405058]], dtype=float32),\n",
" bias: DeviceArray([0., 0., 0.], dtype=float32),\n",
" },\n",
"})\n",
"output:\n",
" [[ 0.5035518 1.8548558 -0.4270195 ]\n",
" [ 0.0279097 0.5589246 -0.43061772]\n",
" [ 0.3547128 1.5740999 -0.32865518]\n",
" [ 0.5264864 1.2928858 0.10089308]]\n"
]
}
],
"source": [
"class SimpleDense(nn.Module):\n",
" features: int\n",
" kernel_init: Callable = nn.initializers.lecun_normal()\n",
" bias_init: Callable = nn.initializers.zeros_init()\n",
"\n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
" kernel = self.param('kernel',\n",
" self.kernel_init, # Initialization function\n",
" (inputs.shape[-1], self.features)) # shape info.\n",
" y = jnp.dot(inputs, kernel)\n",
" bias = self.param('bias', self.bias_init, (self.features,))\n",
" y = y + bias\n",
" return y\n",
"\n",
"key1, key2 = random.split(random.key(0), 2)\n",
"x = random.uniform(key1, (4,4))\n",
"\n",
"model = SimpleDense(features=3)\n",
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameters:\\n', params)\n",
"print('output:\\n', y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args, **init_kwargs)` :\n",
"\n",
"* `name` is simply the name of the parameter that will end up in the parameter structure.\n",
"* `init_fn` is a function with input `(PRNGKey, *init_args, **init_kwargs)` returning an Array, with `init_args` and `init_kwargs` being the arguments needed to call the initialisation function.\n",
"* `init_args` and `init_kwargs` are the arguments to provide to the initialization function.\n",
"\n",
"Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Variables and collections of variables\n",
"\n",
"As we've seen so far, working with models means working with:\n",
"\n",
"* A subclass of `nn.Module`;\n",
"* A pytree of parameters for the model (typically from `model.init()`);\n",
"\n",
"However this is not enough to cover everything that we would need for machine learning, especially neural networks. In some cases, you might want your neural network to keep track of some internal state while it runs (e.g. batch normalization layers). There is a way to declare variables beyond the parameters of the model with the `variable` method.\n",
"\n",
"For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py)."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"outputId": "75465fd6-cdc8-497c-a3ec-7f709b5dde7a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"initialized variables:\n",
" FrozenDict({\n",
" batch_stats: {\n",
" mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n",
" },\n",
" params: {\n",
" bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n",
" },\n",
"})\n",
"updated state:\n",
" FrozenDict({\n",
" batch_stats: {\n",
" mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n",
" },\n",
"})\n"
]
}
],
"source": [
"class BiasAdderWithRunningMean(nn.Module):\n",
" decay: float = 0.99\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" # easy pattern to detect if we're initializing via empty variable tree\n",
" is_initialized = self.has_variable('batch_stats', 'mean')\n",
" ra_mean = self.variable('batch_stats', 'mean',\n",
" lambda s: jnp.zeros(s),\n",
" x.shape[1:])\n",
" bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])\n",
" if is_initialized:\n",
" ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)\n",
"\n",
" return x - ra_mean.value + bias\n",
"\n",
"\n",
"key1, key2 = random.split(random.key(0), 2)\n",
"x = jnp.ones((10,5))\n",
"model = BiasAdderWithRunningMean()\n",
"variables = model.init(key1, x)\n",
"print('initialized variables:\\n', variables)\n",
"y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n",
"print('updated state:\\n', updated_state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"outputId": "09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"updated state:\n",
" FrozenDict({\n",
" batch_stats: {\n",
" mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n",
" },\n",
"})\n",
"updated state:\n",
" FrozenDict({\n",
" batch_stats: {\n",
" mean: DeviceArray([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),\n",
" },\n",
"})\n",
"updated state:\n",
" FrozenDict({\n",
" batch_stats: {\n",
" mean: DeviceArray([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),\n",
" },\n",
"})\n"
]
}
],
"source": [
"for val in [1.0, 2.0, 3.0]:\n",
" x = val * jnp.ones((10,5))\n",
" y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n",
" old_state, params = flax.core.pop(variables, 'params')\n",
" variables = flax.core.freeze({'params': params, **updated_state})\n",
" print('updated state:\\n', updated_state) # Shows only the mutable part"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables.\n",
"\n",
"*This example isn't doing anything and is only for demonstration purposes.*"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"outputId": "0906fbab-b866-4956-d231-b1374415d448"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Updated state: FrozenDict({\n",
" batch_stats: {\n",
" mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n",
" },\n",
"})\n",
"Updated state: FrozenDict({\n",
" batch_stats: {\n",
" mean: DeviceArray([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),\n",
" },\n",
"})\n",
"Updated state: FrozenDict({\n",
" batch_stats: {\n",
" mean: DeviceArray([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),\n",
" },\n",
"})\n"
]
}
],
"source": [
"from functools import partial\n",
"\n",
"@partial(jax.jit, static_argnums=(0, 1))\n",
"def update_step(tx, apply_fn, x, opt_state, params, state):\n",
"\n",
" def loss(params):\n",
" y, updated_state = apply_fn({'params': params, **state},\n",
" x, mutable=list(state.keys()))\n",
" l = ((x - y) ** 2).sum()\n",
" return l, updated_state\n",
"\n",
" (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)\n",
" updates, opt_state = tx.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" return opt_state, params, state\n",
"\n",
"x = jnp.ones((10,5))\n",
"variables = model.init(random.key(0), x)\n",
"state, params = flax.core.pop(variables, 'params')\n",
"del variables\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
"\n",
"for _ in range(3):\n",
" opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)\n",
" print('Updated state: ', state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that the above function has a quite verbose signature and it would not actually\n",
"work with `jax.jit()` because the function arguments are not \"valid JAX types\".\n",
"\n",
"Flax provides a handy wrapper - `TrainState` - that simplifies the above code. Check out [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to learn more."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exporting to Tensorflow's SavedModel with jax2tf\n",
"\n",
"JAX released an experimental converter called [jax2tf](https://github.com/jax-ml/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax."
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"language_info": {
"name": "python",
"version": "3.8.15"
}
},
"nbformat": 4,
"nbformat_minor": 0
}