-
Notifications
You must be signed in to change notification settings - Fork 204
/
Copy pathalias.py
2081 lines (1813 loc) · 80.9 KB
/
alias.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Aliases for popular optimizers."""
import functools
from typing import Any, Callable, Optional, Union
import jax.numpy as jnp
from optax._src import base
from optax._src import clipping
from optax._src import combine
from optax._src import factorized
from optax._src import linesearch as _linesearch
from optax._src import transform
from optax._src import wrappers
MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]
def adabelief(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-16,
eps_root: float = 1e-16) -> base.GradientTransformation:
r"""The AdaBelief optimizer.
AdaBelief is an adaptive learning rate optimizer that focuses on fast
convergence, generalization, and stability. It adapts the step size depending
on its "belief" in the gradient direction — the optimizer adaptively scales
the step size by the difference between the predicted and observed gradients.
AdaBelief is a modified version of :func:`optax.adam` and contains the same
number of parameters.
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments
``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function.
The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0, s_0) = (0, 0)`, representing initial estimates for the
first and second moments. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t` and optimizer state :math:`S_t`
and computes updates :math:`u_t` and new state :math:`S_{t+1}`. Thus, for
:math:`t > 0`, we have,
.. math::
\begin{align*}
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
s_t &\leftarrow \beta_2 \cdot s_{t-1} + (1-\beta_2) \cdot (g_t - m_t)^2
+ \bar{\varepsilon} \\
\hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\
\hat{s}_t &\leftarrow s_t / {(1-\beta_2^t)} \\
u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left(\sqrt{\hat{s}_t}
+ \varepsilon \right) \\
S_t &\leftarrow (m_t, s_t).
\end{align*}
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.adabelief(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
References:
Zhuang et al, 2020: https://arxiv.org/abs/2010.07468
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the second moment of the prediction error to
improve numerical stability. If backpropagating gradients through the
gradient transformation (e.g. for meta-learning), this must be non-zero.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_belief(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
transform.scale_by_learning_rate(learning_rate),
)
def adadelta(
learning_rate: Optional[base.ScalarOrSchedule] = None,
rho: float = 0.9,
eps: float = 1e-6,
weight_decay: float = 0.0,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformation:
"""The Adadelta optimizer.
Adadelta is a stochastic gradient descent method that adapts learning rates
based on a moving window of gradient updates. Adadelta is a modification of
Adagrad.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> f = lambda x: jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.adadelta(learning_rate=10.)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.36E+01
Objective function: 1.32E+01
Objective function: 1.29E+01
Objective function: 1.25E+01
Objective function: 1.21E+01
References:
[Matthew D. Zeiler, 2012](https://arxiv.org/pdf/1212.5701.pdf)
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
rho: A coefficient used for computing a running average of squared
gradients.
eps: Term added to the denominator to improve numerical stability.
weight_decay: Optional rate at which to decay weights.
weight_decay_mask: A tree with same structure as (or a prefix of) the params
PyTree, or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the transformation to, and `False` for those you want to skip.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_adadelta(rho=rho, eps=eps),
transform.scale_by_learning_rate(learning_rate),
)
def adafactor(
learning_rate: Optional[base.ScalarOrSchedule] = None,
min_dim_size_to_factor: int = 128,
decay_rate: float = 0.8,
decay_offset: int = 0,
multiply_by_parameter_scale: float = True,
clipping_threshold: Optional[float] = 1.0,
momentum: Optional[float] = None,
dtype_momentum: Any = jnp.float32,
weight_decay_rate: Optional[float] = None,
eps: float = 1e-30,
factored: bool = True,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformation:
"""The Adafactor optimizer.
Adafactor is an adaptive learning rate optimizer that focuses on fast
training of large scale neural networks. It saves memory by using a factored
estimate of the second order moments used to scale gradients.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.adafactor(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.36E+01
References:
Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
Note that the natural scale for Adafactor's LR is markedly different
from Adam, one doesn't use the 1/sqrt(hidden) correction for this optim
with attention-based models.
min_dim_size_to_factor: Only factor the statistics if two array dimensions
have at least this size.
decay_rate: Controls second-moment exponential decay schedule.
decay_offset: For fine-tuning, one may set this to the starting step
number of the fine-tuning phase.
multiply_by_parameter_scale: If True, then scale learning_rate by
parameter norm. If False, provided learning_rate is absolute step size.
clipping_threshold: Optional clipping threshold. Must be >= 1. If None,
clipping is disabled.
momentum: Optional value between 0 and 1, enables momentum and uses extra
memory if non-None! None by default.
dtype_momentum: Data type of momentum buffers.
weight_decay_rate: Optional rate at which to decay weights.
eps: Regularization constant for root mean squared gradient.
factored: Whether to use factored second-moment estimates.
weight_decay_mask: A tree with same structure as (or a prefix of)
the params PyTree, or a Callable that returns such a pytree given
the params/updates. The leaves should be booleans, `True`
for leaves/subtrees you want to apply the transformation to,
and `False` for those you want to skip.
Returns:
The corresponding `GradientTransformation`.
"""
# The core of the algorithm is a procedure for rescaling gradients
# by a factored estimate of the root mean squared gradients.
# This reduces memory compared to algorithms such as Adam or RmsProp,
# by not having to hold a separate estimate for each weight.
tx = [
factorized.scale_by_factored_rms(
factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)]
# This basic rescaling is typically combined with one or more of the following
# transformation (all can be disabled via adafactor's constructor args).
if clipping_threshold is not None:
tx.append(clipping.clip_by_block_rms(clipping_threshold))
if learning_rate is not None:
tx.append(transform.scale_by_learning_rate(learning_rate, flip_sign=False))
if multiply_by_parameter_scale:
tx.append(transform.scale_by_param_block_rms())
if momentum is not None:
tx.append(
transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum))
if weight_decay_rate is not None:
tx.append(transform.add_decayed_weights(
weight_decay_rate, mask=weight_decay_mask))
# In gradient "descent" we follow the negative gradient.
tx.append(transform.scale(-1))
return combine.chain(*tx)
def adagrad(
learning_rate: base.ScalarOrSchedule,
initial_accumulator_value: float = 0.1,
eps: float = 1e-7
) -> base.GradientTransformation:
r"""The Adagrad optimizer.
AdaGrad is a sub-gradient algorithm for stochastic optimization that adapts
the learning rate individually for each feature based on its gradient history.
The updated parameters adopt the form:
.. math::
w_{t+1}^{(i)} = w_{t}^{(i)} - \eta \frac{g_{t}^{(i)}}
{\sqrt{\sum_{\tau=1}^{t} (g_{\tau}^{(i)})^2 + \epsilon}}
where:
- :math:`w_t^{(i)}` is the parameter :math:`i` at time step :math:`t`,
- :math:`\eta` is the learning rate,
- :math:`g_t^{(i)}` is the gradient of parameter :math:`i` at time step
:math:`t`,
- :math:`\epsilon` is a small constant to ensure numerical stability.
Defining :math:`G = \sum_{t=1}^\tau g_t g_t^\top`, the update can be
written as
.. math::
w_{t+1} = w_{t} - \eta \cdot \text{diag}(G + \epsilon I)^{-1/2}
\cdot g_t
where :math:`\text{diag} (G) = (G_{ii})_{i=1}^p` is the vector of diagonal
entries of :math:`G \in \mathbb{R}^p` and :math:`I` is the identity matrix
in :math:`\mathbb{R}^p`.
.. warning::
Adagrad's main limit is the monotonic accumulation of squared
gradients in the denominator: since all terms are >0, the sum keeps growing
during training and the learning rate eventually becomes vanishingly small.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.adagrad(learning_rate=1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 5.01E+00
Objective function: 2.40E+00
Objective function: 1.25E+00
Objective function: 6.86E-01
Objective function: 3.85E-01
References:
Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
initial_accumulator_value: Initial value for the accumulator.
eps: A small constant applied to denominator inside of the square root
(as in RMSProp) to avoid dividing by zero when rescaling.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_rss(
initial_accumulator_value=initial_accumulator_value, eps=eps),
transform.scale_by_learning_rate(learning_rate),
)
def adam(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
*,
nesterov: bool = False
) -> base.GradientTransformation:
r"""The Adam optimizer.
Adam is an SGD variant with gradient scaling adaptation. The scaling
used for each parameter is computed from estimates of first and second-order
moments of the gradients (using suitable exponential moving averages).
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments
``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function.
The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the
first and second moments. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t` and optimizer state :math:`S_t`
and computes updates :math:`u_t` and new state :math:`S_{t+1}`. Thus, for
:math:`t > 0`, we have,
.. math::
\begin{align*}
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\
\hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\
\hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\
u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t +
\bar{\varepsilon}} + \varepsilon} \right)\\
S_t &\leftarrow (m_t, v_t).
\end{align*}
With the keyword argument `nesterov=True`, the optimizer uses Nesterov
momentum, replacing the above :math:`\hat{m}_t` with
.. math::
\hat{m}_t \leftarrow
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.adam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
References:
Kingma et al, `Adam: A Method for Stochastic Optimization
<https://arxiv.org/abs/1412.6980>`_, 2014
Dozat, `Incorporating Nesterov Momentum into Adam
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_, 2016
.. warning::
PyTorch and optax's implementation follow Algorithm 1 of [Kingma et al.
2014]. Note that TensorFlow used instead the formulation just before Section
2.1 of the paper. See https://github.com/deepmind/optax/issues/571 for more
detail.
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
example when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
nesterov: Whether to use Nesterov momentum. The solver with
nesterov=True is equivalent to the :func:`optax.nadam` optimizer, and
described in [Dozat 2016].
Returns:
The corresponding `GradientTransformation`.
.. seealso:: :func:`optax.nadam`, :func:`optax.adamw`.
"""
return combine.chain(
transform.scale_by_adam(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype,
nesterov=nesterov,
),
transform.scale_by_learning_rate(learning_rate),
)
nadam = functools.partial(adam, nesterov=True)
nadam.__doc__ = (
r"""The NAdam optimizer.
Nadam is a variant of :func:`optax.adam` with Nesterov's momentum. The update
rule of this solver is as follows:
.. math::
\begin{align*}
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\
\hat{m}_t &\leftarrow
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}\\
\hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\
u_t &\leftarrow \alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t +
\bar{\varepsilon}} + \varepsilon} \right)\\
S_t &\leftarrow (m_t, v_t).
\end{align*}
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.nadam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
References:
Dozat, `Incorporating Nesterov Momentum into Adam
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_, 2016
.. versionadded:: 0.1.9
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
example when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
The corresponding `GradientTransformation`.
.. seealso:: :func:`optax.adam`, :func:`optax.nadamw`.
"""
)
def adamw(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: float = 1e-4,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
*,
nesterov: bool = False,
) -> base.GradientTransformation:
r"""Adam with weight decay regularization.
AdamW uses weight decay to regularize learning towards small weights, as
this leads to better generalization. In SGD you can also use L2 regularization
to implement this as an additive loss term, however L2 regularization
does not behave as intended for adaptive gradient algorithms such as Adam,
see [Loshchilov et al, 2019].
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments
``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function. Let :math:`\lambda` be the weight decay and
:math:`\theta_t` the parameter vector at time :math:`t`.
The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the
first and second moments. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t`
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have,
.. math::
\begin{align*}
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\
\hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\
\hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\
u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t
+ \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\
S_t &\leftarrow (m_t, v_t).
\end{align*}
This implementation can incorporate a momentum a la Nesterov introduced by
[Dozat 2016]. The resulting optimizer is then often referred as NAdamW.
With the keyword argument `nesterov=True`, the optimizer uses Nesterov
momentum, replacing the above :math:`\hat{m}_t` with
.. math::
\hat{m}_t \leftarrow
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.adamw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
References:
Loshchilov et al, `Decoupled Weight Decay
Regularization <https://arxiv.org/abs/1711.05101>`_, 2019
Dozat, `Incorporating Nesterov Momentum into Adam
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_, 2016
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
nesterov: Whether to use Nesterov momentum. The solver with
nesterov=True is equivalent to the :func:`optax.nadamw` optimizer. This
modification is described in [Dozat 2016].
Returns:
The corresponding `GradientTransformation`.
.. seealso:: :func:`optax.adam`, :func:`optax.nadamw`.
"""
return combine.chain(
transform.scale_by_adam(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype,
nesterov=nesterov,
),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)
nadamw = functools.partial(adamw, nesterov=True)
nadamw.__doc__ = (
r"""NAdamW optimizer, implemented as part of the AdamW optimizer.
NadamW is variant of :func:`optax.adamw` with Nesterov's momentum. Compared
to AdamW, this optimizer replaces the assignment
.. math::
\hat{m}_t \leftarrow m_t / {(1-\beta_1^t)}
with
.. math::
\hat{m}_t \leftarrow
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.nadamw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
References:
Loshchilov et al, `Decoupled Weight Decay
Regularization <https://arxiv.org/abs/1711.05101>`_, 2019
Dozat, `Incorporating Nesterov Momentum into Adam
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_, 2016
.. versionadded:: 0.1.9
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
Returns:
The corresponding `GradientTransformation`.
.. seealso:: :func:`optax.adam`, :func:`optax.adamw`.
"""
)
def lion(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.99,
mu_dtype: Optional[Any] = None,
weight_decay: float = 1e-3,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
"""The Lion optimizer.
Lion is discovered by symbolic program search. Unlike most adaptive optimizers
such as AdamW, Lion only tracks momentum, making it more memory-efficient.
The update of Lion is produced through the sign operation, resulting in a
larger norm compared to updates produced by other optimizers such as SGD and
AdamW. A suitable learning rate for Lion is typically 3-10x smaller than that
for AdamW, the weight decay for Lion should be in turn 3-10x larger than that
for AdamW to maintain a similar strength (lr * wd).
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.lion(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
References:
Chen et al, 2023: https://arxiv.org/abs/2302.06675
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Rate to combine the momentum and the current gradient.
b2: Exponential decay rate to track the momentum of past gradients.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_lion(b1=b1, b2=b2, mu_dtype=mu_dtype),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)
def amsgrad(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
) -> base.GradientTransformation:
"""The AMSGrad optimiser.
The original Adam can fail to converge to the optimal solution in some cases.
AMSGrad guarantees convergence by using a long-term memory of past gradients.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.amsgrad(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
References:
Reddi et al, 2018: https://openreview.net/forum?id=ryQu7f-RZ
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_amsgrad(
b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype),
transform.scale_by_learning_rate(learning_rate),
)
def fromage(
learning_rate: float,
min_norm: float = 1e-6
) -> base.GradientTransformation:
"""The Frobenius matched gradient descent (Fromage) optimizer.
Fromage is a learning algorithm that does not require learning rate tuning.
The optimizer is based on modeling neural network gradients via deep relative
trust (a distance function on deep neural networks). Fromage is similar to the
LARS optimizer and can work on a range of standard neural network benchmarks,
such as natural language Transformers and generative adversarial networks.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.fromage(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.37E+01
Objective function: 1.36E+01
References:
Bernstein et al, 2020: https://arxiv.org/abs/2002.03432
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
min_norm: A minimum value that the norm of the gradient updates and the norm
of the layer parameters can be clipped to to avoid dividing by zero when
computing the trust ratio (as in the LARS paper).
Returns:
The corresponding `GradientTransformation`.
"""
mult = 1 / jnp.sqrt(1 + learning_rate ** 2)
return combine.chain(
transform.scale_by_trust_ratio(min_norm),
transform.scale_by_learning_rate(learning_rate * mult),
transform.add_decayed_weights((mult - 1)),
)
def lars(
learning_rate: base.ScalarOrSchedule,
weight_decay: float = 0.,
weight_decay_mask: MaskOrFn = True,
trust_coefficient: float = 0.001,
eps: float = 0.,
trust_ratio_mask: MaskOrFn = True,
momentum: float = 0.9,
nesterov: bool = False,
) -> base.GradientTransformation:
"""The LARS optimizer.
LARS is a layer-wise adaptive optimizer introduced to help scale SGD to
larger batch sizes. LARS later inspired the LAMB optimizer.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.lars(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
References:
You et al, 2017: https://arxiv.org/abs/1708.03888
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
weight_decay: Strength of the weight decay regularization.
weight_decay_mask: A tree with same structure as (or a prefix of) the params
PyTree, or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the transformation to, and `False` for those you want to skip.
trust_coefficient: A multiplier for the trust ratio.
eps: Optional additive constant in the trust ratio denominator.
trust_ratio_mask: A tree with same structure as (or a prefix of) the params
PyTree, or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the transformation to, and `False` for those you want to skip.
momentum: Decay rate for momentum.
nesterov: Whether to use Nesterov momentum.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(