From 3c479625db59ecf0a46369bbcc256f17c63a18bc Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Fri, 26 Jul 2024 03:09:33 +0530 Subject: [PATCH] addded engine/cr_boosters -Sun- 1. added optimizers 2. test7 : ill fix cora dataset streaming on GTransformer Later 3. cr_boosters is pkg contains optimizers written in extension to pyTorch --- engine/__pycache__/core_rec.cpython-311.pyc | Bin 736 -> 1540 bytes engine/__pycache__/models.cpython-311.pyc | Bin 5451 -> 14647 bytes engine/core_rec.py | 29 +- engine/cr_boosters/__init__.py | 38 + engine/cr_boosters/_functional.py | 84 + engine/cr_boosters/_multi_tensor/__init__.py | 30 + engine/cr_boosters/_multi_tensor/__init__.pyi | 15 + engine/cr_boosters/adadelta.py | 452 ++++ engine/cr_boosters/adagrad.py | 555 +++++ engine/cr_boosters/adam.py | 790 ++++++ engine/cr_boosters/adamax.py | 463 ++++ engine/cr_boosters/asgd.py | 454 ++++ engine/cr_boosters/lbfgs.py | 488 ++++ engine/cr_boosters/lr_scheduler.py | 2151 +++++++++++++++++ engine/cr_boosters/nadam.py | 639 +++++ engine/cr_boosters/optimizer.py | 1052 ++++++++ engine/cr_boosters/radam.py | 598 +++++ engine/cr_boosters/rmsprop.py | 510 ++++ engine/cr_boosters/sgd.py | 504 ++++ engine/cr_boosters/sparse_adam.py | 181 ++ engine/cr_boosters/swa_utils.py | 463 ++++ engine/cr_utility/dataloader.py | 1604 ++++++++++++ engine/cr_utility/dataset.py | 489 ++++ engine/models.py | 4 +- engine/timecapsule.py | 9 +- src/SANDBOX/tempCodeRunnerFile.py | 41 +- src/SANDBOX/test7.py | 52 + src/SANDBOX/testcase.py | 52 + src/USECASES/transfmodel.py | 21 +- 29 files changed, 11708 insertions(+), 60 deletions(-) create mode 100644 engine/cr_boosters/__init__.py create mode 100644 engine/cr_boosters/_functional.py create mode 100644 engine/cr_boosters/_multi_tensor/__init__.py create mode 100644 engine/cr_boosters/_multi_tensor/__init__.pyi create mode 100644 engine/cr_boosters/adadelta.py create mode 100644 engine/cr_boosters/adagrad.py create mode 100644 engine/cr_boosters/adam.py create mode 100644 engine/cr_boosters/adamax.py create mode 100644 engine/cr_boosters/asgd.py create mode 100644 engine/cr_boosters/lbfgs.py create mode 100644 engine/cr_boosters/lr_scheduler.py create mode 100644 engine/cr_boosters/nadam.py create mode 100644 engine/cr_boosters/optimizer.py create mode 100644 engine/cr_boosters/radam.py create mode 100644 engine/cr_boosters/rmsprop.py create mode 100644 engine/cr_boosters/sgd.py create mode 100644 engine/cr_boosters/sparse_adam.py create mode 100644 engine/cr_boosters/swa_utils.py create mode 100644 engine/cr_utility/dataloader.py create mode 100644 engine/cr_utility/dataset.py create mode 100644 src/SANDBOX/test7.py create mode 100644 src/SANDBOX/testcase.py diff --git a/engine/__pycache__/core_rec.cpython-311.pyc b/engine/__pycache__/core_rec.cpython-311.pyc index ad6517800d15af745a032dd11a5abcbe2be1cb01..4549cd87bc8c2b14ca9e43e3cd303d17f857d5b7 100644 GIT binary patch literal 1540 zcmZ`%O>fgc5M3u}lD0|H59l|gNKl%BVu=eE#Dz-bR3ut@@WtA!w~5QHZS6W$!>{1R zoeSbm@Ed%z_R5J1iqu&Cc7MnSGug&E^7v=j*TUmmlU4`kgGYr=*?R zzuho(FA8352ab4qv#!Zb|8n-p>Xx!EKNaJ3LRrZ+gTl-M6c81#Z zKgs3WhBA)evw9+M&_5U0lb62WR9w7*=wc|a>)G6QC~KEgEfCakNvPK7O~8P`T}Evk zY?syQ6($4=$CfU485V9hR4drQ+$A=EuwBogH>!d0l?^)w0+S(BX{5D}afl@iPuova zllYwbp6&9$7onQRa_AA;aRSw(UeEPtC+$d1mgnhORNz8tU2h@64nFrfJsjG^_pbko z6R$Ad$2z_-n@8dv9Uo|DBGpFBB{i3HNJ}BA5EQjM!)Yj7B1Ju6U<{E+rc&)riwr`S zfnXg6bYMPqs8CzQEU2V%-F1CmhE&K-qO{tao|zC{%H*<5n1sML`8Fn$AB6B1zd##1 z1u5Ym%tVOQb}`u#*ePyO_ME~dS@*Ih3)2ZM1K|hxKRD?!$+lO}f()>bG?_ZCTAELk z1?{9*1nGaGTlk{8GxKZ+@^SORX9G^9narX2NrIr}wJZBn_J@$=+C26N@Lnj*6E}Q6 z=$h|*K|fMrjk+B=0ow3DeBg$-4 zon>A|XvSjnG|KFL|5=2zjq1n%cX|l?7`={)EJml%{g!n0B0@)Fi_!BavxGUARE!b$ t+VZ>N>?kCNQ9H^kMu$;man?3uo6VEoGrdfl6DPSTWSwrD_~HP>{|5nm_c#Cm delta 254 zcmZqSdBDoIoR^o20SH_+O-#GWI+0I;F>Ru{jyZDV33}?fkj&M0fWW`RCI$uWfH6FWP4UN J)*?}$4gfHpKaKzZ diff --git a/engine/__pycache__/models.cpython-311.pyc b/engine/__pycache__/models.cpython-311.pyc index 0ea6a937608466cafe8cd8b3772603b0a024a448..7acc1874a198179078275e1dbd370e0367ec35ef 100644 GIT binary patch literal 14647 zcmds8du$s=df(+sD{A$!Bwwr5_4ylj3GqW?ZGvD|79y9u7dAXZ{^o!sAV*F>V6!q_TlS|I*#X}EG zQJ+v@YK#igVOxxj+s15h`42^Z{vC_DE%uTKxv9dTj z#>UIX%HtJd74gckN|Nu4RgG2Is6i_1dXEa1aJJiainRsm&|ma;lhRt06%ma@&HY!j5#Sjy_d&eN2yHh9ZLQSNv$%*BMx z$!l)+=tNZT!u?c?^M<+cXo3^G;soa% zb`!>E^B%bxj8B4yOkp2nTFHnE{6s2}zP!DbUUnp=;(=K38pjL0JssX~fK=?=)AKS2 z3lvRo!LZP~7fZ%DE)3FeCCG<+U)%$gqlw8WF%XW%d-rtvNH57L5wf5g(#j5)gwRAM zh-x^oV^4<{YWX%8HZrSqyqLQqnh1m@f{6qOEcDmmg}w|Kx$)*-B^ZhDz3p)>m}oCL zIF#Y_hSN7{9!THlQx73KdC%0n-*4{T@3(Y65esIsWmrM&CxqG_1fg_M#0l4Er?~8cL#V>H- zFi0w4QJfqjuNNHzwibvc!cl;G;BROvcjt3|g?iOjqPmU~ous;sfeZ(Ewca3VM-m~> zDEu*$Cv}sxKy{qrj=iqBwI+^51yLH+twy-lzje||`Ug?-U?0AnFbEK)O^cZSgms(Yb z7GI$~Ga6zHO@m^lu=R(326oKH^!v8=ou5!5T<4N#%Uw}ainCd05{x9@`D4rKUb!XjAW(?-m>Ly`^nw=eP}4gAt$0 z=cZh@VOHds3>Nbeom#+k)Ade{x*W7{WG1lIS>O*6D{^>5e<*W{*cw&pZ-hw{-p zzCWQx%VQP}&`6{?mnOx&) z8P4GqE*hB-1+@&2BzHsRKxMVJjuFt{(KS%ZISno}{mbKIt2JWYvJGU*)pAX*2C@$$ zpQ=u%7vaF54b^dv*f^#(8_}IBc1`+-f$T#1>Os+EJUsQN94w>1eFx!@y&LyT(X1W!oTVlFb zp4*;wC`?1{cXdl^X8%G&-m*{GvJXodSDW|VKl@-v(!VVV{@Ccg>DZYi;l?@U9^vw&9Wzl>7`N z=o2bMUxW-T!L=?`Gxix<%Jw9v2=%CoIwtZLp%ooXetg`BLKk%eOxesBG2<{|qOk>Y zy#eaYmhr5p$sI62@E?1HZN?MQYY@d5Fb@E!1@I?&YQdocS2ymxCYa^b*Lm5!>e zYzHM!=G77rIkhBfZBI8?#4cO%3OpkG8<6YNO2d|0u8iZ;%JdJ?^a@-4;ot{@H-=}2 z)AW~F)-NsHwo+L$o02NqSDRav=G_ZHxw%_u?w)h6w(MNkd+&hU^0LzMGB9}c5oB|q zH;3o?<}R<)dZgBuGh3C`7Z-y{>&y46<rP2%dG=)?ZDi~+z8Z}qvz-^Ya5np zw=LCf%LL`x9ZKyEl37)|T-CBv)sktEtF|gtTho1CvQ^9M)+Kgp=A_KFD{MOs8nQe~ ztY`j`%(f_O3uIMRr5Pe>J{w;R?>TzMg^-&F#<@U1bqCO0fdBzqvw^^+sbEYiA!x<- z!UJk~7S%#rEEWjdq4_SHGZe?e8n9KtCBbr1a0>1Kz;d58Td5O8eqFk`+F89utw+c! zkWNC-zMgWzZ8h1zL<5t}Moc!l$z*e?rGy*MK2#m(-h>c&atU<y_>rM#K>E0M!P3t9GzZU1jNXt6rHlA9QSB*{K||^G-bjud)PJ&Xke}b7}hc4 ze9FE7i#P6jcDIPkmra)U*3phwV0J$k@(%Pqzd|nax%pQ3;g29`L!zng5E$@+>IzOy za)~gHfW_}bf)lJdcrG@TgO4(U=c}?tS4-3KL-<%T62ifG7%^(@O*+%)=Fn^Ki0~Ic zG==?>=zOircofF7%(O2t?K0z27@x%WR_mU>dH$pGX~(R4g<)si%S_`E)A;ipnO7Fs z&-Y4ALvquQ0+bn6m|=+-Uaj6dH=XIb`G<1#cBOhd>dV)N)_jdl_kobljx4jyOKkJc zCo-wU^A8-7=eX=St^j5I3hS51ceS?R=FyLiLZ_u^_^hLDufiK_pl+#)x}pXqX`4X> z=YWMTF!E91W^ytKt>%+e$;@s|f%4$@K%F~|Jht;Akij3rTqgm49u3TVx#}S6jxs@z zUV|V6s_E+d3T4s#EMz|%1X76ZwGAJQtWmboW&WscM5+uB!c{ z^rxlsEjQVZ*mNIcMvBAV_3s({O#_E+3`9=^f?dxBf=@Pr+3u%;Lhvn}#P=Xe zmBYd_uxxoO)MVyW%wv)0Fi7Mmi-P=dEJC)_K<^E_+6eSm;1M0}0`P?J*Fdsx7n<8E zGY#N}EHf=jOpDC4DNLKhv^@sy+B3TrQun9iu7L+fq^47H(Is6X5aCB}#sF`nr67>IY-OvaO4wZrC+ z2{BMzA_6`9&G!m1%t0r!Ayi;2K@hB{v9$A9aAN&_XTkZ)JE>1#`dM_a8 zLXOCR%z3yj3NmPVt^+b^WkFosfI1rkj^D2XP7?YNehuW=0q0%0>5$TNNM;Tz%wdT+ zTm)u@3Cs*zU?#cvuIw39JcBYjq_9I0`5NuF_Dt8}0lDe0(sWp6`V^*5V*36+VGpcC z7JJY_%wtb+&>`G{B!R1jMi;DF?z3hqLCdEI*|eU$L#!bUyRhZ|Id(zZ--KPdvp4d8 zh+W1&RLWtN-y$xV3ueNSOm+>rgZ7QwL9~w&D{N@8VrMhz_*5$jkxNzuy zkwjcLo1Z@icRp5g-7dg&A(9Xyh1e_=Sn2?x+95tm{I|y#F1xW(Ifi5P&9IG47+7;1 zQLaM+!W56&ms^z6w!XjCQuF6Y}rB&DEU~GyLc*39iu(CY_;mj~x znOHOLn?^|`h%N9K03%^a0bf@OM|urzY8f(Q7Fg3zL-TSymd^)mNx!iqpsr#thJ*0FG$=+ebJB+N@L5R$5Ne|5JTwyCe9Qj~mPMj0x z`{t+R`khMsPMO`Mu)8GkU1h6hhv!1`mn61Dj}Ed`9}aymGjHNZ!S_ups_p-CPFU-1Uft_}FrB#7F2@&sv3@T51gTY+;Yq1*`7> zo3;O_3HNp3!fl1$Z;%gya|KZnW1!C%4z$Kopr$U?MnN(QFVxZn%jjDc(IuM2s@E>Q z&0xfN*a9vXy+y=#iS@7^*<=|RI;13_ULesPU%93f(7Mx<0%{nwf*c*tj~u8Kd{g%HE1rIth5fVv ziF}JH5gqzrygHwwZw5+)%0Q+yrD*f2&e)+QJYh{1C^jh`quNP1A~te(zzp&bUM{YX zND$juT-k&R#%B(6!-gA`-@qgB$r4acC7_@R)D!w^8J+u$ZQZTJqi~cd$CK5|&J(O#{_rVLzxEji9>EyTsxX*HijM&CmGYPHnn!}#iH`wMt1e7M zW8pwvtHDlo6!zLium*-AVc&f+78QMrwyVYxiOU|=f$ep}VI7spNu$c-?Ou8MkS^@u-TKi& z^}@Mdz4_Uj_xH%VdzIb2_pkol%-_yDI4K``O*!yJ@&LpZmz~X&X<2Gm4@ezD za_g|tI{Ys${_}q6%mpbtDZPG4ew9~V<)s!ux;8Dg}HgtY9YCDA@eeu@Pc83S0)eb#YTOPv7II@xS@Hmy^>7Eg=XY<^0Gt|6n88+0e!tS> z!h&C5_6ne;K2DRP3Qp(m>4e!6nTD?FxfrbjnF?&D*~teW}pMTIxB zRc6{1rhS>|UShgsrbl6VB&KJju@xN5ij{f~{II`cC^*XV9iI;RkB#;gbyhq9jK03} znz9pJ=a_Mt^er+nNfwL4m4Z_mxl-eAm?nFYhQl^yI{^+!QpGiL;juyDis~$D?87n| z+yBa_k-IF(I%<>k4{E1Te+CXr>2VxA6D}+^i|wL2R}Ke1VHNT=N%2ELVK@?N7WbJF z)12rY0Qfg0pp+c<3J!dd*+x^|lB~0hd&WMOpGW+JQv%Mgcx{sV>c1dq`ymm`uP~S@ z#5{Tla#c4unG^_y!vRf<(InD@7hjntLC!zMnvTm+?g~+Y+9@Q>UMk#(w)4p=5Ee+r zR2I%Ez~Rs!`F|JH9fA{=@NXMY@!I2Tq^2Tro{b3nBMd;O2LLbFvR-phwVoSH+MYhW zTE9i9Z(pD_A3x2mFty9f^GnS089MXA0=>|7&m}Wm3ezPqUGP|WwRT!yKBk2OEPnp) zTvS8z)3!)X9-YsbXVApvE%PIpOSg_M?D^Hc&-N|0+}VF`|NXw-3`%taa@~MZH?Ukc zyi_;*p!Gpet~;UBok+Xksq#$_l(;m%b1u2CM`Cx&>~4kKou;$HHlJnsQs06v9*-cP zjpT1|4hMmlRC*akpo@^EEf|kjk7zm^TlO`mC4oRV8G`dtjP}nNi3E>WrKty^8HiFK z`T$L2o_M!)q;ep_B_VRjM?;FUR**=a$_RZ9KYvv;VGDsR{D3WLDND#u%^xs(9+y{eYr zMa`D8+@?)Jb(7Z4+8t8sbSYs!<2Tr-QewqGnth_Y7rMjuGlJ}w&Q`m(h2X}2)8AiVAI+*&av;b zGS=81+DE|@M4<`xaWtn5a&ef0&c2UkgF{@*vAfY6PP0g?A36ZUVokmhPpNAd!PYLm z%Dx;+v4I1L)H-(2Ibi)oUv;I|pJm?~``BAXQtGH|F)=hek6y{2kq}x})&s6AA$TKj z{|#3eWUeuob^~j`t;N7gzBJ;>>$o-Hvg!Oj_NI|ka4Y?{f-!3s1|j}H0QSBbX6KF3 z40iDnx{enO)e9^)+GdC9umiiB7=CUH1e)jYIERA(MvxwYhlf|a&;?@G&vqOy;tRy8 zu6o$^f`9Qr@YBVjL+s^&v3VT`=O(}=>P2JkFTA_(&dEzBw{TBO>}mV^NL5FBiOg zCZ;NQBhK|l0gM>Ua``w1k&76T1XmRfqMW>kbc6#hl3(2rNS;DTqicjjjj2Cr+<{Ln zH8_Z9;vn4LFYhM3KY40V$4{_A@xvfbuqLs&snpI-#dFfe9e!^3L&%ohG~%=fninBf zomB0XMONt~5Ohkgf~cC7RrQmd<0dfxDz__+y+-IWLR_)Q1y8QQ%!t_t`?H}~X2%p; z&=cHErfv4k?4+M$iQStW#!v4gpVaX|QO2V@{xN`m_eRu@Ly^5ydQlUOTs30si&AmJ zOLyAv8K~AtyGc6Ku1$73cEh1OL;ozIxgtu2b0Prm`tRXm&9>ETkmGa-R9+V6Yk old += (1 - b) * (new - old) +# old_exp_avg_values = exp_avg.sparse_mask(grad)._values() +# exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) +# exp_avg.add_(make_sparse(exp_avg_update_values)) +# old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() +# exp_avg_sq_update_values = ( +# grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) +# ) +# exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) + +# # Dense addition again is intended, avoiding another sparse_mask +# numer = exp_avg_update_values.add_(old_exp_avg_values) +# exp_avg_sq_update_values.add_(old_exp_avg_sq_values) +# denom = exp_avg_sq_update_values.sqrt_().add_(eps) +# del exp_avg_update_values, exp_avg_sq_update_values + +# bias_correction1 = 1 - beta1**step +# bias_correction2 = 1 - beta2**step +# step_size = lr * math.sqrt(bias_correction2) / bias_correction1 + +# param.add_(make_sparse(-step_size * numer.div_(denom))) diff --git a/engine/cr_boosters/_multi_tensor/__init__.py b/engine/cr_boosters/_multi_tensor/__init__.py new file mode 100644 index 0000000..41a1957 --- /dev/null +++ b/engine/cr_boosters/_multi_tensor/__init__.py @@ -0,0 +1,30 @@ +""" +:mod:`torch.optim._multi_tensor` is a package implementing various optimization algorithms. + +Most commonly used methods are already supported, and the interface is general +enough, so that more sophisticated ones can be also easily integrated in the +future. +""" +from functools import partialmethod + +from torch import optim + + +def partialclass(cls, *args, **kwargs): # noqa: D103 + class NewCls(cls): + __init__ = partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +Adam = partialclass(optim.Adam, foreach=True) +AdamW = partialclass(optim.AdamW, foreach=True) +NAdam = partialclass(optim.NAdam, foreach=True) +SGD = partialclass(optim.SGD, foreach=True) +RAdam = partialclass(optim.RAdam, foreach=True) +RMSprop = partialclass(optim.RMSprop, foreach=True) +Rprop = partialclass(optim.Rprop, foreach=True) +ASGD = partialclass(optim.ASGD, foreach=True) +Adamax = partialclass(optim.Adamax, foreach=True) +Adadelta = partialclass(optim.Adadelta, foreach=True) +Adagrad = partialclass(optim.Adagrad, foreach=True) diff --git a/engine/cr_boosters/_multi_tensor/__init__.pyi b/engine/cr_boosters/_multi_tensor/__init__.pyi new file mode 100644 index 0000000..97c3e2d --- /dev/null +++ b/engine/cr_boosters/_multi_tensor/__init__.pyi @@ -0,0 +1,15 @@ +from functools import partial + +from torch import optim + +Adam = partial(optim.Adam, foreach=True) +AdamW = partial(optim.AdamW, foreach=True) +NAdam = partial(optim.NAdam, foreach=True) +SGD = partial(optim.SGD, foreach=True) +RAdam = partial(optim.RAdam, foreach=True) +RMSprop = partial(optim.RMSprop, foreach=True) +Rprop = partial(optim.Rprop, foreach=True) +ASGD = partial(optim.ASGD, foreach=True) +Adamax = partial(optim.Adamax, foreach=True) +Adadelta = partial(optim.Adadelta, foreach=True) +Adagrad = partial(optim.Adagrad, foreach=True) diff --git a/engine/cr_boosters/adadelta.py b/engine/cr_boosters/adadelta.py new file mode 100644 index 0000000..d6f19fb --- /dev/null +++ b/engine/cr_boosters/adadelta.py @@ -0,0 +1,452 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, List, Optional + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + +__all__ = ["Adadelta", "adadelta"] + + +class Adadelta(Optimizer): + def __init__( + self, + params: ParamsT, + lr: float = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0, + foreach: Optional[bool] = None, + *, + capturable: bool = False, + maximize: bool = False, + differentiable: bool = False, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= rho <= 1.0: + raise ValueError(f"Invalid rho value: {rho}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + foreach=foreach, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group: Dict[str, Any], + params_with_grad: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], + ): + has_complex = False + p: Tensor + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("Adadelta does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # Lazy state initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=_get_scalar_dtype()) + ) + + state["square_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["acc_delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + square_avgs.append(state["square_avg"]) + acc_deltas.append(state["acc_delta"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + square_avgs: List[Tensor] = [] + acc_deltas: List[Tensor] = [] + state_steps: List[Tensor] = [] + ( + lr, + rho, + eps, + weight_decay, + foreach, + maximize, + differentiable, + capturable, + ) = ( + group["lr"], + group["rho"], + group["eps"], + group["weight_decay"], + group["foreach"], + group["maximize"], + group["differentiable"], + group["capturable"], + ) + + has_complex = self._init_group( + group, params_with_grad, grads, square_avgs, acc_deltas, state_steps + ) + + adadelta( + params_with_grad, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) + + return loss + + +Adadelta.__doc__ = ( + r"""Implements Adadelta algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, + \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)}, + \: \lambda \text{ (weight decay)} \\ + &\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)}, + \: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\ + &\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} + + \epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\ + &\hspace{5mm} u_t \leftarrow u_{t-1} \rho + + \Delta x^2_t (1 - \rho) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + rho (float, optional): coefficient used for computing a running average + of squared gradients (default: 0.9). A higher value of `rho` will + result in a slower average, which can be helpful for preventing + oscillations in the learning process. + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-6). + lr (float, optional): coefficient that scale delta before it is applied + to the parameters (default: 1.0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_capturable_doc} + {_maximize_doc} + {_differentiable_doc} + + .. _ADADELTA\: An Adaptive Learning Rate Method: + https://arxiv.org/abs/1212.5701 + + """ +) + + +def _single_tensor_adadelta( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + for param, grad, square_avg, acc_delta, step in zip( + params, grads, square_avgs, acc_deltas, state_steps + ): + step += 1 + grad = grad if not maximize else -grad + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + square_avg = torch.view_as_real(square_avg) + acc_delta = torch.view_as_real(acc_delta) + grad = torch.view_as_real(grad) + + square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) + std = square_avg.add(eps).sqrt_() + delta = acc_delta.add(eps).sqrt_() + if differentiable: + delta = delta.clone() + delta.div_(std).mul_(grad) + acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) + + if torch.is_complex(param): + delta = torch.view_as_complex(delta) + param.add_(delta, alpha=-lr) + + +def _multi_tensor_adadelta( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + if len(params) == 0: + return + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, square_avgs, acc_deltas, state_steps] + ) + for ( + device_params, + device_grads, + device_square_avgs, + device_acc_deltas, + device_state_steps, + ), _ in grouped_tensors.values(): + if has_complex: + _view_as_real( + device_params, device_grads, device_square_avgs, device_acc_deltas + ) + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + torch._foreach_mul_(device_square_avgs, rho) + torch._foreach_addcmul_( + device_square_avgs, device_grads, device_grads, value=1 - rho + ) + + std = torch._foreach_add(device_square_avgs, eps) + torch._foreach_sqrt_(std) + + deltas = torch._foreach_add(device_acc_deltas, eps) + torch._foreach_sqrt_(deltas) + torch._foreach_div_(deltas, std) + torch._foreach_mul_(deltas, device_grads) + + torch._foreach_mul_(device_acc_deltas, rho) + torch._foreach_addcmul_(device_acc_deltas, deltas, deltas, value=1 - rho) + + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + torch._foreach_mul_(deltas, -lr) + torch._foreach_add_(device_params, deltas) + else: + torch._foreach_add_(device_params, deltas, alpha=-lr) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta) +def adadelta( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + capturable: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + has_complex: bool = False, + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, +): + r"""Functional API that performs Adadelta algorithm computation. + + See :class:`~torch.optim.Adadelta` for details. + """ + + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + # We still respect when the user inputs False for foreach. + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adadelta + else: + func = _single_tensor_adadelta + + func( + params, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/engine/cr_boosters/adagrad.py b/engine/cr_boosters/adagrad.py new file mode 100644 index 0000000..0b6dfe8 --- /dev/null +++ b/engine/cr_boosters/adagrad.py @@ -0,0 +1,555 @@ +# mypy: allow-untyped-defs +from typing import List, Optional + +import torch +from torch import Tensor +from torch.utils._foreach_utils import _get_fused_kernels_supported_devices +from .optimizer import ( + _default_to_fused_or_foreach, + _differentiable_doc, + _foreach_doc, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + +__all__ = ["Adagrad", "adagrad"] + + +class Adagrad(Optimizer): + def __init__( + self, + params: ParamsT, + lr: float = 1e-2, + lr_decay: float = 0, + weight_decay: float = 0, + initial_accumulator_value: float = 0, + eps: float = 1e-10, + foreach: Optional[bool] = None, + *, + maximize: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= lr_decay: + raise ValueError(f"Invalid lr_decay value: {lr_decay}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= initial_accumulator_value: + raise ValueError( + f"Invalid initial_accumulator_value value: {initial_accumulator_value}" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + + defaults = dict( + lr=lr, + lr_decay=lr_decay, + eps=eps, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + fused=fused, + ) + super().__init__(params, defaults) + + if fused: + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + self._step_supports_amp_scaling = True + fused_supported_devices = _get_fused_kernels_supported_devices() + # Not support CUDA yet + fused_supported_devices.remove("cuda") + if not all( + p.device.type in fused_supported_devices and torch.is_floating_point(p) + for pg in self.param_groups + for p in pg["params"] + ): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}." + ) + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + init_value = ( + complex(initial_accumulator_value, initial_accumulator_value) + if torch.is_complex(p) + else initial_accumulator_value + ) + state["sum"] = torch.full_like( + p, init_value, memory_format=torch.preserve_format + ) + + def __setstate__(self, state): + super().__setstate__(state) + # define "fused" for + # MYPY error: Name "fused" may be undefined + fused = None + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + fused = group.setdefault("fused", None) + + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]["step"] + ) + if not step_is_tensor: + for s in state_values: + s["step"] = torch.tensor( + float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused) + ) + + def share_memory(self): + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["sum"].share_memory_() + + def _init_group(self, group, params_with_grad, grads, state_sums, state_steps): + has_sparse_grad, has_complex = False, False + for p in group["params"]: + if p.grad is not None: + has_sparse_grad |= p.grad.is_sparse + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + grads.append(p.grad) + state = self.state[p] + state_sums.append(state["sum"]) + state_steps.append(state["step"]) + + return has_sparse_grad, has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + state_sums: List[Tensor] = [] + state_steps: List[Tensor] = [] + + has_sparse_grad, has_complex = self._init_group( + group, params_with_grad, grads, state_sums, state_steps + ) + + adagrad( + params_with_grad, + grads, + state_sums, + state_steps, + lr=group["lr"], + weight_decay=group["weight_decay"], + lr_decay=group["lr_decay"], + eps=group["eps"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + has_complex=has_complex, + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +Adagrad.__doc__ = ( + r"""Implements Adagrad algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\ + &\textbf{initialize} : state\_sum_0 \leftarrow \tau \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\ + &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\ + &\hspace{5mm}\theta_t \leftarrow + \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning + and Stochastic Optimization`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + lr_decay (float, optional): learning rate decay (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + initial_accumulator_value (float, optional): initial value of the + sum of squares of gradients (default: 0) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + fused (bool, optional): whether the fused implementation (CPU only) is used. + Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` + are supported. (default: None). Please note that the fused implementations does not + support sparse or complex gradients. + .. _Adaptive Subgradient Methods for Online Learning and Stochastic + Optimization: http://jmlr.org/papers/v12/duchi11a.html + + """ +) + + +def adagrad( + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting these as kwargs for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + has_complex: bool = False, + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs Adagrad algorithm computation. + + See :class:`~torch.optim.Adagrad` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if fused is None: + fused = False + if foreach is None: + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adagrad + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adagrad + else: + func = _single_tensor_adagrad + + func( + params, + grads, + state_sums, + state_steps, + lr=lr, + weight_decay=weight_decay, + lr_decay=lr_decay, + eps=eps, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _make_sparse(grad, grad_indices, values): + size = grad.size() + return torch.sparse_coo_tensor(grad_indices, values, size) + + +def _single_tensor_adagrad( + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +): + assert grad_scale is None and found_inf is None + for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): + # update step + step_t += 1 + step = _get_value(step_t) + grad = grad if not maximize else -grad + + if weight_decay != 0: + if grad.is_sparse: + raise RuntimeError( + "weight_decay option is not compatible with sparse gradients" + ) + grad = grad.add(param, alpha=weight_decay) + + clr = lr / (1 + (step - 1) * lr_decay) + + if grad.is_sparse: + grad = grad.coalesce() # the update is non-linear so indices must be unique + grad_indices = grad._indices() + grad_values = grad._values() + + state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) + std = state_sum.sparse_mask(grad) + std_values = std._values().sqrt_().add_(eps) + param.add_( + _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr + ) + else: + is_complex = torch.is_complex(param) + if is_complex: + grad = torch.view_as_real(grad) + state_sum = torch.view_as_real(state_sum) + param = torch.view_as_real(param) + state_sum.addcmul_(grad, grad, value=1) + if differentiable: + std = state_sum.sqrt() + eps + else: + std = state_sum.sqrt().add_(eps) + param.addcdiv_(grad, std, value=-clr) + if is_complex: + param = torch.view_as_complex(param) + state_sum = torch.view_as_complex(state_sum) + + +def _multi_tensor_adagrad( + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + assert grad_scale is None and found_inf is None + + # Foreach functions will throw errors if given empty lists + if len(params) == 0: + return + + grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, state_sums, state_steps] + ) + for ( + device_params, + device_grads, + device_state_sums, + device_state_steps, + ), _ in grouped_tensorlists.values(): + device_has_sparse_grad = has_sparse_grad and any( + grad.is_sparse for grad in device_grads + ) + + if device_has_sparse_grad: + _single_tensor_adagrad( + device_params, + device_grads, + device_state_sums, + device_state_steps, + lr=lr, + weight_decay=weight_decay, + lr_decay=lr_decay, + eps=eps, + has_sparse_grad=True, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + grad_scale=grad_scale, + found_inf=found_inf, + ) + continue + + # Handle complex parameters + if has_complex: + _view_as_real(device_params, device_grads, device_state_sums) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + minus_clr = [ + -lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps + ] + + torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1) + + std = torch._foreach_sqrt(device_state_sums) + torch._foreach_add_(std, eps) + + if weight_decay != 0 or maximize: + # Again, re-use the intermediate memory (device_grads) already allocated + torch._foreach_mul_(device_grads, minus_clr) + numerator = device_grads + else: + numerator = torch._foreach_mul(device_grads, minus_clr) # type: ignore[assignment] + + torch._foreach_addcdiv_(device_params, numerator, std) + + +def _fused_adagrad( + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +) -> None: + if not params: + return + if has_sparse_grad or has_complex: + raise RuntimeError("`fused` does not support sparse grad or complex param") + + if differentiable: + raise RuntimeError( + "adagrad with fused=True does not support differentiable=True" + ) + + grad_scale_dict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else None + ) + found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, state_sums, state_steps] + ) + for (device, _), ( + ( + device_params, + device_grads, + device_state_sums, + device_state_steps, + ), + _, + ) in grouped_tensors.items(): + device_grad_scale, device_found_inf = None, None + if grad_scale is not None and grad_scale_dict is not None: + if device not in grad_scale_dict: + grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) # type: ignore[index] + device_grad_scale = grad_scale_dict[device] # type: ignore[index] + if found_inf is not None and found_inf_dict is not None: + if found_inf not in found_inf_dict: + found_inf_dict[device] = found_inf.to(device, non_blocking=True) # type: ignore[index] + device_found_inf = found_inf_dict[device] # type: ignore[index] + torch._foreach_add_(device_state_steps, 1) + torch._fused_adagrad_( + device_params, + device_grads, + device_state_sums, + device_state_steps, + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_( + device_state_steps, [device_found_inf] * len(device_state_steps) + ) diff --git a/engine/cr_boosters/adam.py b/engine/cr_boosters/adam.py new file mode 100644 index 0000000..d5fe9a5 --- /dev/null +++ b/engine/cr_boosters/adam.py @@ -0,0 +1,790 @@ +# mypy: allow-untyped-defs +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.utils._foreach_utils import _get_fused_kernels_supported_devices +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _dispatch_sqrt, + _foreach_doc, + _fused_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _stack_if_compiling, + _use_grad_for_differentiable, + _view_as_real, + DeviceDict, + Optimizer, + ParamsT, +) + +__all__ = ["Adam", "adam"] + + +class Adam(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + amsgrad: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if isinstance(lr, Tensor) and foreach and not capturable: + raise ValueError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + fused=fused, + ) + super().__init__(params, defaults) + + if fused: + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + self._step_supports_amp_scaling = True + # TODO(crcrpar): [low prec params & their higher prec copy] + # Support AMP with FP16/BF16 model params which would need + # higher prec copy of params to do update math in higher prec to + # alleviate the loss of information. + fused_supported_devices = _get_fused_kernels_supported_devices() + if not all( + p.device.type in fused_supported_devices and torch.is_floating_point(p) + for pg in self.param_groups + for p in pg["params"] + ): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}." + ) + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, + dtype=_get_scalar_dtype(is_fused=fused), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + # note(crcrpar): [special device hosting for step] + # Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["amsgrad"]: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if group["amsgrad"]: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + if group["differentiable"] and state["step"].requires_grad: + raise RuntimeError( + "`requires_grad` is not supported for `step` in differentiable mode" + ) + + # Foreach without capturable does not support a tensor lr + if ( + group["foreach"] + and torch.is_tensor(group["lr"]) + and not group["capturable"] + ): + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + beta1, beta2 = group["betas"] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +Adam.__doc__ = ( + r"""Implements Adam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, + \:\textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR + is not yet supported for all our implementations. Please use a float + LR if you are not also specifying fused=True or capturable=True. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (bool, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + + """ +) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + # update step + step_t += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + if amsgrad: + max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) + param = torch.view_as_real(param) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + + if capturable or differentiable: + step = step_t + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + step_size_neg = step_size.neg() + + bias_correction2_sqrt = bias_correction2.sqrt() + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + if differentiable: + max_exp_avg_sq = max_exp_avg_sqs[i].clone() + else: + max_exp_avg_sq = max_exp_avg_sqs[i] + + max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) + + # Uses the max. for normalizing running avg. of gradient + # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write + # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) + denom = ( + max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) + ).add_(eps / step_size_neg) + else: + denom = ( + exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) + ).add_(eps / step_size_neg) + + param.addcdiv_(exp_avg, denom) + else: + step = _get_value(step_t) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + + # Lastly, switch back to complex view + if amsgrad and torch.is_complex(params[i]): + max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) + + +def _multi_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + if len(params) == 0: + return + + if isinstance(lr, Tensor) and not capturable: + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + assert grad_scale is None and found_inf is None + + assert not differentiable, "_foreach ops don't support autograd" + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] + ) + for ( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + device_state_steps, + ), _ in grouped_tensors.values(): + # Handle complex parameters + if has_complex: + if amsgrad: + _view_as_real( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + ) + else: + _view_as_real( + device_params, device_grads, device_exp_avgs, device_exp_avg_sqs + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) + + torch._foreach_mul_(device_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + device_exp_avg_sqs, device_grads, device_grads, 1 - beta2 + ) + + # Delete the local intermediate since it won't be used anymore to save on peak memory + del device_grads + + bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + bias_correction1 = torch._foreach_pow(beta1, device_state_steps) + bias_correction2 = torch._foreach_pow(beta2, device_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction1, 1) + torch._foreach_sub_(bias_correction2, 1) + # we do not negate bias_correction1 as it'll need to be negated later anyway + torch._foreach_neg_(bias_correction2) + + # foreach_div doesn't allow a scalar as the first arg + torch._foreach_div_(bias_correction1, lr) + torch._foreach_reciprocal_(bias_correction1) + + torch._foreach_sqrt_(bias_correction2) + + # Re-assign for clarity as we maintain minimal intermediates: we'll have + # step_size = - lr / (1 - beta1 ^ t) where t = num_steps + # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) + step_size = bias_correction1 + bias_correction2_sqrt = bias_correction2 + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment] + + # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_div_(exp_avg_sq_sqrt, step_size) + + # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr + torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) + else: + bias_correction1 = [ + 1 - beta1 ** _get_value(step) for step in device_state_steps + ] + bias_correction2 = [ + 1 - beta2 ** _get_value(step) for step in device_state_steps + ] + + step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) + + bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2] # type: ignore[arg-type] + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_addcdiv_( + device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type] + ) + + +def _fused_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, # Needed for consistency. + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, # Needed for consistency. + differentiable: bool, +) -> None: + if not params: + return + if differentiable: + raise RuntimeError("Adam with fused=True does not support differentiable=True") + + grad_scale_dict: DeviceDict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else {} + ) + found_inf_dict: DeviceDict = ( + {found_inf.device: found_inf} if found_inf is not None else {} + ) + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: Optional[DeviceDict] = ( + {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None + ) + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] + ) + for (device, _), ( + ( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + device_state_steps, + ), + _, + ) in grouped_tensors.items(): + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault( + device, grad_scale.to(device, non_blocking=True) + ) + if found_inf is not None: + device_found_inf = found_inf_dict.setdefault( + device, found_inf.to(device, non_blocking=True) + ) + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + torch._fused_adam_( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + device_state_steps, + amsgrad=amsgrad, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_( + device_state_steps, [device_found_inf] * len(device_state_steps) + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam) +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs Adam algorithm computation. + + See :class:`~torch.optim.Adam` for details. + """ + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. + if foreach and isinstance(lr, Tensor) and not capturable: + foreach = False + if fused is None: + fused = False + if foreach is None: + foreach = False + + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adam + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adam + else: + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) diff --git a/engine/cr_boosters/adamax.py b/engine/cr_boosters/adamax.py new file mode 100644 index 0000000..27caa5f --- /dev/null +++ b/engine/cr_boosters/adamax.py @@ -0,0 +1,463 @@ +# mypy: allow-untyped-defs +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + +__all__ = ["Adamax", "adamax"] + + +class Adamax(Optimizer): + def __init__( + self, + params: ParamsT, + lr: float = 2e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + foreach: Optional[bool] = None, + *, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("Adamax does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_inf"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_infs.append(state["exp_inf"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_infs: List[Tensor] = [] + state_steps: List[Tensor] = [] + + beta1, beta2 = group["betas"] + eps = group["eps"] + lr = group["lr"] + weight_decay = group["weight_decay"] + foreach = group["foreach"] + maximize = group["maximize"] + differentiable = group["differentiable"] + capturable = group["capturable"] + + has_complex = self._init_group( + group, params_with_grad, grads, exp_avgs, exp_infs, state_steps + ) + + adamax( + params_with_grad, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=eps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) + + return loss + + +Adamax.__doc__ = ( + r"""Implements Adamax algorithm (a variant of Adam based on infinity norm). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)}, + \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \epsilon \text{ (epsilon)} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + {_capturable_doc} + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + + """ +) + + +def _single_tensor_adamax( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_infs: List[Tensor], + state_steps: List[Tensor], + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + exp_avg = exp_avgs[i] + exp_inf = exp_infs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + # update step + step_t += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_inf = torch.view_as_real(exp_inf) + + # Update biased first moment estimate. + exp_avg.lerp_(grad, 1 - beta1) + # Update the exponentially weighted infinity norm. + if not differentiable: + torch.maximum( + exp_inf.mul_(beta2), + grad.abs().add_(eps), + out=exp_inf, + ) + else: + norm_buf = torch.cat( + [exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], + 0, + ) + exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False)) + + if capturable: + # why jump through extra hoops and negate bias_correction? check out #121238 + # once fixed, we should use bias_correction with addcdiv value=-1 for readability + neg_bias_correction = beta1**step_t - 1 + neg_bias_correction.div_(lr) + denom = exp_inf * neg_bias_correction + param.addcdiv_(exp_avg, denom) + else: + bias_correction = 1 - beta1 ** _get_value(step_t) + clr = lr / bias_correction + + param.addcdiv_(exp_avg, exp_inf, value=-clr) + + +def _multi_tensor_adamax( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_infs: List[Tensor], + state_steps: List[Tensor], + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + + if len(params) == 0: + return + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_infs, state_steps] + ) + for ( + grouped_params, + grouped_grads, + grouped_exp_avgs, + grouped_exp_infs, + grouped_state_steps, + ), _ in grouped_tensors.values(): + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + if maximize: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Update biased first moment estimate. + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + # Update the exponentially weighted infinity norm. + torch._foreach_mul_(grouped_exp_infs, beta2) + + # in this case, we need to introduce a copy of the grads + # since one has not been introduced previously + if not maximize and weight_decay == 0: + grouped_grads = torch._foreach_abs(grouped_grads) # type: ignore[assignment] + else: + torch._foreach_abs_(grouped_grads) + + torch._foreach_add_(grouped_grads, eps) + torch._foreach_maximum_(grouped_exp_infs, grouped_grads) + + bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + bias_corrections = torch._foreach_pow(beta1, grouped_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_corrections, 1) + torch._foreach_div_(bias_corrections, lr) + + denom = torch._foreach_mul(grouped_exp_infs, bias_corrections) + torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom) + else: + bias_corrections = [ + 1 - beta1 ** _get_value(step) for step in grouped_state_steps + ] + step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections] + torch._foreach_addcdiv_( + grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax) +def adamax( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_infs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, +): + r"""Functional API that performs adamax algorithm computation. + + See :class:`~torch.optim.Adamax` for details. + """ + + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adamax + else: + func = _single_tensor_adamax + + func( + params, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=eps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + capturable=capturable, + ) diff --git a/engine/cr_boosters/asgd.py b/engine/cr_boosters/asgd.py new file mode 100644 index 0000000..84c7602 --- /dev/null +++ b/engine/cr_boosters/asgd.py @@ -0,0 +1,454 @@ +# mypy: allow-untyped-defs +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + +__all__ = ["ASGD", "asgd"] + + +class ASGD(Optimizer): + def __init__( + self, + params: ParamsT, + lr: float = 1e-2, + lambd: float = 1e-4, + alpha: float = 0.75, + t0: float = 1e6, + weight_decay: float = 0, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + lambd=lambd, + alpha=alpha, + t0=t0, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0: + if not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if not torch.is_tensor(p_state["eta"]): + p_state["eta"] = torch.tensor( + p_state["eta"], dtype=_get_scalar_dtype(), device=p.device + ) + if not torch.is_tensor(p_state["mu"]): + p_state["mu"] = torch.tensor( + p_state["mu"], dtype=_get_scalar_dtype(), device=p.device + ) + + def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("ASGD does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # State initialization + if len(state) == 0: + state["step"] = torch.zeros( + (), device=p.device, dtype=_get_scalar_dtype() + ) + state["eta"] = ( + torch.as_tensor( + group["lr"], device=p.device, dtype=_get_scalar_dtype() + ) + .clone() + .detach() + ) + state["mu"] = torch.ones( + (), device=p.device, dtype=_get_scalar_dtype() + ) + state["ax"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + mus.append(state["mu"]) + axs.append(state["ax"]) + etas.append(state["eta"]) + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + mus: List[Tensor] = [] + axs: List[Tensor] = [] + etas: List[Tensor] = [] + state_steps: List[Tensor] = [] + + has_complex = self._init_group( + group, params_with_grad, grads, mus, axs, etas, state_steps + ) + + asgd( + params_with_grad, + grads, + axs, + mus, + etas, + state_steps, + lambd=group["lambd"], + lr=group["lr"], + t0=group["t0"], + alpha=group["alpha"], + weight_decay=group["weight_decay"], + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + capturable=group["capturable"], + has_complex=has_complex, + ) + + return loss + + +ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent. + + It has been proposed in `Acceleration of stochastic approximation by + averaging`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + lambd (float, optional): decay term (default: 1e-4) + alpha (float, optional): power for eta update (default: 0.75) + t0 (float, optional): point at which to start averaging (default: 1e6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + {_capturable_doc} + + .. _Acceleration of stochastic approximation by averaging: + https://dl.acm.org/citation.cfm?id=131098 + + """ + + +def _single_tensor_asgd( + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + mu = mus[i] + ax = axs[i] + eta = etas[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type + == mu.device.type + == eta.device.type + == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params, mus, etas, and state_steps must be " + f"on supported devices: {capturable_supported_devices}." + ) + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + param = torch.view_as_real(param) + ax = torch.view_as_real(ax) + + # update step + step_t += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if capturable: + param.mul_(1 - lambd * eta) + param.addcmul_(grad, eta, value=-1) # update parameter + else: + eta_value = _get_value(eta) + param.mul_(1 - lambd * eta_value) # decay term + param.add_(grad, alpha=-eta_value) # update parameter + + # averaging + if capturable or mu.item() != 1: + ax.add_(param.sub(ax).mul_(mu)) + else: + ax.copy_(param) + + if capturable: + eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha)) + mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) + else: + step = _get_value(step_t) + new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha)) + eta.copy_(new_eta) + new_mu = torch.as_tensor(1 / max(1, step - t0)) + mu.copy_(new_mu) + + +def _multi_tensor_asgd( + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == mu.device.type == eta.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, mu, eta, step in zip(params, mus, etas, state_steps) + ), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, axs, mus, etas, state_steps] + ) + for (device, _), ( + ( + grouped_params, + grouped_grads, + grouped_axs, + grouped_mus, + grouped_etas, + grouped_state_steps, + ), + _, + ) in grouped_tensors.items(): + if has_complex: + _view_as_real(grouped_params, grouped_grads, grouped_axs) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + # intermediate = grad + param * lambd + intermediate: Union[Tuple[Tensor, ...], List[Tensor]] + if weight_decay != 0: + if maximize: + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + intermediate = grouped_grads + else: + intermediate = torch._foreach_add( + grouped_grads, grouped_params, alpha=weight_decay + ) + + torch._foreach_add_(intermediate, grouped_params, alpha=lambd) + else: + intermediate = torch._foreach_add( + grouped_grads, grouped_params, alpha=lambd + ) + + # update param + # param * (1 - lambd * eta) - eta * grad + # => param - param * lambd * eta - eta * grad + # => param - eta * intermediate + torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1) + del intermediate + + # update grouped_axs + # averaging: ax = ax + mu * (param - ax) + # Note (mlazos): We can't use lerp here since it requires weight to be float64 + # and our grouping code requires dtypes to match for all tensors in a group (and it should, since + # we use the mus in other places) + # all dtypes need to match, so we could introduce a cast in a loop + # but since this only adds one additional kernel launch, this looks like the cleaner + # and faster solution + intermediate = torch._foreach_sub(grouped_params, grouped_axs) + torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus) + del intermediate + + new_etas: Union[Tuple[Tensor, ...], List[Tensor]] + new_mus: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + # update grouped_mus + new_mus = torch._foreach_sub(grouped_state_steps, t0) + torch._foreach_maximum_(new_mus, 1.0) + torch._foreach_reciprocal_(new_mus) + torch._foreach_copy_(grouped_mus, new_mus) + del new_mus + + # update eta = lr / ((1 + lambd * lr * step)^alpha) + new_etas = torch._foreach_mul(grouped_state_steps, lambd) + torch._foreach_mul_(new_etas, lr) + torch._foreach_add_(new_etas, 1) + torch._foreach_pow_(new_etas, alpha) + torch._foreach_reciprocal_(new_etas) + torch._foreach_mul_(new_etas, lr) + torch._foreach_copy_(grouped_etas, new_etas) + else: + new_etas = [ + torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device) + for step in grouped_state_steps + ] + new_mus = [ + torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device) + for step in grouped_state_steps + ] + torch._foreach_copy_(grouped_etas, new_etas) + torch._foreach_copy_(grouped_mus, new_mus) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd) +def asgd( + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, +): + r"""Functional API that performs asgd algorithm computation. + + See :class:`~torch.optim.ASGD` for details. + """ + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_asgd + else: + func = _single_tensor_asgd + + func( + params, + grads, + axs, + mus, + etas, + state_steps, + lambd=lambd, + lr=lr, + t0=t0, + alpha=alpha, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/engine/cr_boosters/lbfgs.py b/engine/cr_boosters/lbfgs.py new file mode 100644 index 0000000..480b45c --- /dev/null +++ b/engine/cr_boosters/lbfgs.py @@ -0,0 +1,488 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from .optimizer import Optimizer, ParamsT + +__all__ = ["LBFGS"] + + +def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): + # ported from https://github.com/torch/optim/blob/master/polyinterp.lua + # Compute bounds of interpolation area + if bounds is not None: + xmin_bound, xmax_bound = bounds + else: + xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) + + # Code for most common case: cubic interpolation of 2 points + # w/ function and derivative values for both + # Solution in this case (where x2 is the farthest point): + # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); + # d2 = sqrt(d1^2 - g1*g2); + # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); + # t_new = min(max(min_pos,xmin_bound),xmax_bound); + d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) + d2_square = d1**2 - g1 * g2 + if d2_square >= 0: + d2 = d2_square.sqrt() + if x1 <= x2: + min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) + else: + min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) + return min(max(min_pos, xmin_bound), xmax_bound) + else: + return (xmin_bound + xmax_bound) / 2.0 + + +def _strong_wolfe( + obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25 +): + # ported from https://github.com/torch/optim/blob/master/lswolfe.lua + d_norm = d.abs().max() + g = g.clone(memory_format=torch.contiguous_format) + # evaluate objective and gradient using initial step + f_new, g_new = obj_func(x, t, d) + ls_func_evals = 1 + gtd_new = g_new.dot(d) + + # bracket an interval containing a point satisfying the Wolfe criteria + t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd + done = False + ls_iter = 0 + while ls_iter < max_ls: + # check conditions + if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): + bracket = [t_prev, t] + bracket_f = [f_prev, f_new] + bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] + bracket_gtd = [gtd_prev, gtd_new] + break + + if abs(gtd_new) <= -c2 * gtd: + bracket = [t] + bracket_f = [f_new] + bracket_g = [g_new] + done = True + break + + if gtd_new >= 0: + bracket = [t_prev, t] + bracket_f = [f_prev, f_new] + bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] + bracket_gtd = [gtd_prev, gtd_new] + break + + # interpolate + min_step = t + 0.01 * (t - t_prev) + max_step = t * 10 + tmp = t + t = _cubic_interpolate( + t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step) + ) + + # next step + t_prev = tmp + f_prev = f_new + g_prev = g_new.clone(memory_format=torch.contiguous_format) + gtd_prev = gtd_new + f_new, g_new = obj_func(x, t, d) + ls_func_evals += 1 + gtd_new = g_new.dot(d) + ls_iter += 1 + + # reached max number of iterations? + if ls_iter == max_ls: + bracket = [0, t] + bracket_f = [f, f_new] + bracket_g = [g, g_new] + + # zoom phase: we now have a point satisfying the criteria, or + # a bracket around it. We refine the bracket until we find the + # exact point satisfying the criteria + insuf_progress = False + # find high and low points in bracket + low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined] + while not done and ls_iter < max_ls: + # line-search bracket is so small + if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined] + break + + # compute new trial value + t = _cubic_interpolate( + bracket[0], + bracket_f[0], + bracket_gtd[0], # type: ignore[possibly-undefined] + bracket[1], + bracket_f[1], + bracket_gtd[1], + ) + + # test that we are making sufficient progress: + # in case `t` is so close to boundary, we mark that we are making + # insufficient progress, and if + # + we have made insufficient progress in the last step, or + # + `t` is at one of the boundary, + # we will move `t` to a position which is `0.1 * len(bracket)` + # away from the nearest boundary point. + eps = 0.1 * (max(bracket) - min(bracket)) + if min(max(bracket) - t, t - min(bracket)) < eps: + # interpolation close to boundary + if insuf_progress or t >= max(bracket) or t <= min(bracket): + # evaluate at 0.1 away from boundary + if abs(t - max(bracket)) < abs(t - min(bracket)): + t = max(bracket) - eps + else: + t = min(bracket) + eps + insuf_progress = False + else: + insuf_progress = True + else: + insuf_progress = False + + # Evaluate new point + f_new, g_new = obj_func(x, t, d) + ls_func_evals += 1 + gtd_new = g_new.dot(d) + ls_iter += 1 + + if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: + # Armijo condition not satisfied or not lower than lowest point + bracket[high_pos] = t + bracket_f[high_pos] = f_new + bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] + bracket_gtd[high_pos] = gtd_new + low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) + else: + if abs(gtd_new) <= -c2 * gtd: + # Wolfe conditions satisfied + done = True + elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: + # old high becomes new low + bracket[high_pos] = bracket[low_pos] + bracket_f[high_pos] = bracket_f[low_pos] + bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined] + bracket_gtd[high_pos] = bracket_gtd[low_pos] + + # new point becomes new low + bracket[low_pos] = t + bracket_f[low_pos] = f_new + bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] + bracket_gtd[low_pos] = gtd_new + + # return stuff + t = bracket[low_pos] # type: ignore[possibly-undefined] + f_new = bracket_f[low_pos] + g_new = bracket_g[low_pos] # type: ignore[possibly-undefined] + return f_new, g_new, t, ls_func_evals + + +class LBFGS(Optimizer): + """Implements L-BFGS algorithm. + + Heavily inspired by `minFunc + `_. + + .. warning:: + This optimizer doesn't support per-parameter options and parameter + groups (there can be only one). + + .. warning:: + Right now all parameters have to be on a single device. This will be + improved in the future. + + .. note:: + This is a very memory intensive optimizer (it requires additional + ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory + try reducing the history size, or use a different algorithm. + + Args: + params (iterable): iterable of parameters to optimize. Parameters must be real. + lr (float): learning rate (default: 1) + max_iter (int): maximal number of iterations per optimization step + (default: 20) + max_eval (int): maximal number of function evaluations per optimization + step (default: max_iter * 1.25). + tolerance_grad (float): termination tolerance on first order optimality + (default: 1e-7). + tolerance_change (float): termination tolerance on function + value/parameter changes (default: 1e-9). + history_size (int): update history size (default: 100). + line_search_fn (str): either 'strong_wolfe' or None (default: None). + """ + + def __init__( + self, + params: ParamsT, + lr: float = 1, + max_iter: int = 20, + max_eval: Optional[int] = None, + tolerance_grad: float = 1e-7, + tolerance_change: float = 1e-9, + history_size: int = 100, + line_search_fn: Optional[str] = None, + ): + if max_eval is None: + max_eval = max_iter * 5 // 4 + defaults = dict( + lr=lr, + max_iter=max_iter, + max_eval=max_eval, + tolerance_grad=tolerance_grad, + tolerance_change=tolerance_change, + history_size=history_size, + line_search_fn=line_search_fn, + ) + super().__init__(params, defaults) + + if len(self.param_groups) != 1: + raise ValueError( + "LBFGS doesn't support per-parameter options " "(parameter groups)" + ) + + self._params = self.param_groups[0]["params"] + self._numel_cache = None + + def _numel(self): + if self._numel_cache is None: + self._numel_cache = sum( + 2 * p.numel() if torch.is_complex(p) else p.numel() + for p in self._params + ) + + return self._numel_cache + + def _gather_flat_grad(self): + views = [] + for p in self._params: + if p.grad is None: + view = p.new(p.numel()).zero_() + elif p.grad.is_sparse: + view = p.grad.to_dense().view(-1) + else: + view = p.grad.view(-1) + if torch.is_complex(view): + view = torch.view_as_real(view).view(-1) + views.append(view) + return torch.cat(views, 0) + + def _add_grad(self, step_size, update): + offset = 0 + for p in self._params: + if torch.is_complex(p): + p = torch.view_as_real(p) + numel = p.numel() + # view as to avoid deprecated pointwise semantics + p.add_(update[offset : offset + numel].view_as(p), alpha=step_size) + offset += numel + assert offset == self._numel() + + def _clone_param(self): + return [p.clone(memory_format=torch.contiguous_format) for p in self._params] + + def _set_param(self, params_data): + for p, pdata in zip(self._params, params_data): + p.copy_(pdata) + + def _directional_evaluate(self, closure, x, t, d): + self._add_grad(t, d) + loss = float(closure()) + flat_grad = self._gather_flat_grad() + self._set_param(x) + return loss, flat_grad + + @torch.no_grad() + def step(self, closure): + """Perform a single optimization step. + + Args: + closure (Callable): A closure that reevaluates the model + and returns the loss. + """ + assert len(self.param_groups) == 1 + + # Make sure the closure is always called with grad enabled + closure = torch.enable_grad()(closure) + + group = self.param_groups[0] + lr = group["lr"] + max_iter = group["max_iter"] + max_eval = group["max_eval"] + tolerance_grad = group["tolerance_grad"] + tolerance_change = group["tolerance_change"] + line_search_fn = group["line_search_fn"] + history_size = group["history_size"] + + # NOTE: LBFGS has only global state, but we register it as state for + # the first param, because this helps with casting in load_state_dict + state = self.state[self._params[0]] + state.setdefault("func_evals", 0) + state.setdefault("n_iter", 0) + + # evaluate initial f(x) and df/dx + orig_loss = closure() + loss = float(orig_loss) + current_evals = 1 + state["func_evals"] += 1 + + flat_grad = self._gather_flat_grad() + opt_cond = flat_grad.abs().max() <= tolerance_grad + + # optimal condition + if opt_cond: + return orig_loss + + # tensors cached in state (for tracing) + d = state.get("d") + t = state.get("t") + old_dirs = state.get("old_dirs") + old_stps = state.get("old_stps") + ro = state.get("ro") + H_diag = state.get("H_diag") + prev_flat_grad = state.get("prev_flat_grad") + prev_loss = state.get("prev_loss") + + n_iter = 0 + # optimize for a max of max_iter iterations + while n_iter < max_iter: + # keep track of nb of iterations + n_iter += 1 + state["n_iter"] += 1 + + ############################################################ + # compute gradient descent direction + ############################################################ + if state["n_iter"] == 1: + d = flat_grad.neg() + old_dirs = [] + old_stps = [] + ro = [] + H_diag = 1 + else: + # do lbfgs update (update memory) + y = flat_grad.sub(prev_flat_grad) + s = d.mul(t) + ys = y.dot(s) # y*s + if ys > 1e-10: + # updating memory + if len(old_dirs) == history_size: + # shift history by one (limited-memory) + old_dirs.pop(0) + old_stps.pop(0) + ro.pop(0) + + # store new direction/step + old_dirs.append(y) + old_stps.append(s) + ro.append(1.0 / ys) + + # update scale of initial Hessian approximation + H_diag = ys / y.dot(y) # (y*y) + + # compute the approximate (L-BFGS) inverse Hessian + # multiplied by the gradient + num_old = len(old_dirs) + + if "al" not in state: + state["al"] = [None] * history_size + al = state["al"] + + # iteration in L-BFGS loop collapsed to use just one buffer + q = flat_grad.neg() + for i in range(num_old - 1, -1, -1): + al[i] = old_stps[i].dot(q) * ro[i] + q.add_(old_dirs[i], alpha=-al[i]) + + # multiply by initial Hessian + # r/d is the final direction + d = r = torch.mul(q, H_diag) + for i in range(num_old): + be_i = old_dirs[i].dot(r) * ro[i] + r.add_(old_stps[i], alpha=al[i] - be_i) + + if prev_flat_grad is None: + prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) + else: + prev_flat_grad.copy_(flat_grad) + prev_loss = loss + + ############################################################ + # compute step length + ############################################################ + # reset initial guess for step size + if state["n_iter"] == 1: + t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr + else: + t = lr + + # directional derivative + gtd = flat_grad.dot(d) # g * d + + # directional derivative is below tolerance + if gtd > -tolerance_change: + break + + # optional line search: user function + ls_func_evals = 0 + if line_search_fn is not None: + # perform line search, using user function + if line_search_fn != "strong_wolfe": + raise RuntimeError("only 'strong_wolfe' is supported") + else: + x_init = self._clone_param() + + def obj_func(x, t, d): + return self._directional_evaluate(closure, x, t, d) + + loss, flat_grad, t, ls_func_evals = _strong_wolfe( + obj_func, x_init, t, d, loss, flat_grad, gtd + ) + self._add_grad(t, d) + opt_cond = flat_grad.abs().max() <= tolerance_grad + else: + # no line search, simply move with fixed-step + self._add_grad(t, d) + if n_iter != max_iter: + # re-evaluate function only if not in last iteration + # the reason we do this: in a stochastic setting, + # no use to re-evaluate that function here + with torch.enable_grad(): + loss = float(closure()) + flat_grad = self._gather_flat_grad() + opt_cond = flat_grad.abs().max() <= tolerance_grad + ls_func_evals = 1 + + # update func eval + current_evals += ls_func_evals + state["func_evals"] += ls_func_evals + + ############################################################ + # check conditions + ############################################################ + if n_iter == max_iter: + break + + if current_evals >= max_eval: + break + + # optimal condition + if opt_cond: + break + + # lack of progress + if d.mul(t).abs().max() <= tolerance_change: + break + + if abs(loss - prev_loss) < tolerance_change: + break + + state["d"] = d + state["t"] = t + state["old_dirs"] = old_dirs + state["old_stps"] = old_stps + state["ro"] = ro + state["H_diag"] = H_diag + state["prev_flat_grad"] = prev_flat_grad + state["prev_loss"] = prev_loss + + return orig_loss diff --git a/engine/cr_boosters/lr_scheduler.py b/engine/cr_boosters/lr_scheduler.py new file mode 100644 index 0000000..11bfff3 --- /dev/null +++ b/engine/cr_boosters/lr_scheduler.py @@ -0,0 +1,2151 @@ +# mypy: allow-untyped-defs +r"""Learning Rate Scheduler.""" +import math +import types +import warnings +from bisect import bisect_right +from collections import Counter +from functools import partial, wraps +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + SupportsFloat, + TypedDict, + Union, +) +from weakref import ref + +from torch import inf, Tensor + +from .optimizer import Optimizer + +__all__ = [ + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "MultiStepLR", + "ConstantLR", + "LinearLR", + "ExponentialLR", + "SequentialLR", + "CosineAnnealingLR", + "ChainedScheduler", + "ReduceLROnPlateau", + "CyclicLR", + "CosineAnnealingWarmRestarts", + "OneCycleLR", + "PolynomialLR", + "LRScheduler", +] + +EPOCH_DEPRECATION_WARNING = ( + "The epoch parameter in `scheduler.step()` was not necessary and is being " + "deprecated where possible. Please use `scheduler.step()` to step the " + "scheduler. During the deprecation, if epoch is different from None, the " + "closed form is used instead of the new chainable form, where available. " + "Please open an issue if you are unable to replicate your use case: " + "https://github.com/pytorch/pytorch/issues/new/choose." +) + + +def _check_verbose_deprecated_warning(verbose): + """Raise a warning when verbose is not the default value.""" + if verbose != "deprecated": + warnings.warn( + "The verbose parameter is deprecated. Please use get_last_lr() " + "to access the learning rate.", + UserWarning, + ) + return verbose + return False + + +def _format_param(name: str, optimizer: Optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + + def _copy(_param): + return _param.clone() if isinstance(_param, Tensor) else _param + + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError( + f"{name} must have the same length as optimizer.param_groups. " + f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}." + ) + else: + param = [param] * len(optimizer.param_groups) + + return list(map(_copy, param)) + + +class LRScheduler: + r"""Adjusts the learning rate during optimization.""" + + _get_lr_called_within_step: bool = False + + def __init__( + self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated" + ): # noqa: D107 + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) + else: + for i, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + raise KeyError( + "param 'initial_lr' is not specified " + f"in param_groups[{i}] when resuming an optimizer" + ) + self.base_lrs: List[float] = [ + group["initial_lr"] for group in optimizer.param_groups + ] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def patch_track_step_called(opt: Optimizer): + if hasattr(opt.step, "_wrapped_by_lr_sched"): + # we've already patched + return opt.step + + def wrap_step(step_fn): + opt_ref = ref(self.optimizer) + func = step_fn.__func__ + + @wraps(func) + def wrapper(*args, **kwargs): + opt = opt_ref() + opt._opt_called = True # type: ignore[union-attr] + return func.__get__(opt, opt.__class__)(*args, **kwargs) + + wrapper._wrapped_by_lr_sched = True # type: ignore[attr-defined] + return wrapper + + opt.step = wrap_step(opt.step) # type: ignore[method-assign] + + patch_track_step_called(self.optimizer) + self.verbose = _check_verbose_deprecated_warning(verbose) + self._initial_step() + + def _initial_step(self): + """Initialize step counts and perform a step.""" + self._step_count = 0 + self.step() + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler.""" + return self._last_lr + + def get_lr(self) -> List[float]: + """Compute learning rate using chainable form of the scheduler.""" + raise NotImplementedError + + def print_lr( + self, + is_verbose: bool, + group: Dict[str, Any], + lr: float, + epoch: Optional[int] = None, + ): + """Display the current learning rate. + + .. deprecated:: 2.4 + ``print_lr()`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + warnings.warn( + "`LRScheduler.print_lr()` is being deprecated. To fetch the learning rate, " + "please use `get_last_lr()` instead. For more details, " + "see https://github.com/pytorch/pytorch/issues/99270.", + UserWarning, + ) + if is_verbose: + if epoch is None: + print(f"Adjusting learning rate of group {group} to {lr:.4e}.") + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch + print( + f"Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}." + ) + + def step(self, epoch: Optional[int] = None): + """Perform a step.""" + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"): + warnings.warn( + "Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + UserWarning, + ) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif not getattr(self.optimizer, "_opt_called", False): + warnings.warn( + "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + UserWarning, + ) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = cast(List[float], self._get_closed_form_lr()) + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + if isinstance(param_group["lr"], Tensor): + lr_val = lr.item() if isinstance(lr, Tensor) else lr # type: ignore[attr-defined] + param_group["lr"].fill_(lr_val) + else: + param_group["lr"] = lr + + self._last_lr: List[float] = [ + group["lr"] for group in self.optimizer.param_groups + ] + + +def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler): + if not lr_scheduler._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + stacklevel=2, + ) + + +# Including _LRScheduler for backwards compatibility +# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). +class _LRScheduler(LRScheduler): + pass + + +class _enable_get_lr_call: + def __init__(self, o: LRScheduler): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + +class LambdaLR(LRScheduler): + """Sets the initial learning rate. + + The learning rate of each parameter group is set to the initial lr + times a given function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.optimizer = optimizer + + self.lr_lambdas: List[Callable[[int], float]] + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError( + f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" + ) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "lr_lambdas") + } + state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict["lr_lambdas"][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the scheduler's state. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop("lr_lambdas") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["lr_lambdas"] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + def get_lr(self): + """Compute learning rate.""" + _warn_get_lr_called_within_step(self) + + return [ + base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) + ] + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given in the specified function. + + When last_epoch=-1, set initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> lmbda = lambda epoch: 0.95 + >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.optimizer = optimizer + + self.lr_lambdas: List[Callable[[int], float]] + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError( + f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" + ) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "lr_lambdas") + } + state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict["lr_lambdas"][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop("lr_lambdas") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["lr_lambdas"] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch > 0: + return [ + group["lr"] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) + ] + else: + return [group["lr"] for group in self.optimizer.param_groups] + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every step_size epochs. + + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of learning rate decay. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + step_size: int, + gamma=0.1, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.step_size = step_size + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs + ] + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. + + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + milestones: Iterable[int], + gamma=0.1, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.milestones = Counter(milestones) + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch not in self.milestones: + return [group["lr"] for group in self.optimizer.param_groups] + return [ + group["lr"] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + milestones = sorted(self.milestones.elements()) + return [ + base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + +class ConstantLR(LRScheduler): + """Multiply the learning rate of each parameter group by a small constant factor. + + The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such multiplication of the small constant factor can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + factor=1.0 / 3, + total_iters=5, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + if factor > 1.0 or factor < 0: + raise ValueError( + "Constant multiplicative factor expected to be between 0 and 1." + ) + + self.factor = factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] * self.factor for group in self.optimizer.param_groups] + + if self.last_epoch != self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs + ] + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small multiplicative factor. + + The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + start_factor=1.0 / 3, + end_factor=1.0, + total_iters=5, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + if start_factor > 1.0 or start_factor <= 0: + raise ValueError( + "Starting multiplicative factor expected to be greater than 0 and less or equal to 1." + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "Ending multiplicative factor expected to be between 0 and 1." + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [ + group["lr"] * self.start_factor for group in self.optimizer.param_groups + ] + + if self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] + * ( + 1.0 + + (self.end_factor - self.start_factor) + / ( + self.total_iters * self.start_factor + + (self.last_epoch - 1) * (self.end_factor - self.start_factor) + ) + ) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * ( + self.start_factor + + (self.end_factor - self.start_factor) + * min(self.total_iters, self.last_epoch) + / self.total_iters + ) + for base_lr in self.base_lrs + ] + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. + + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + + def __init__( + self, optimizer: Optimizer, gamma: float, last_epoch=-1, verbose="deprecated" + ): # noqa: D107 + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] + + +class SequentialLR(LRScheduler): + """Contains a list of schedulers expected to be called sequentially during the optimization process. + + Specifically, the schedulers will be called according to the milestone points, which should provide exact + intervals by which each scheduler should be called at a given epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): Does nothing. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.1 if epoch == 0 + >>> # lr = 0.1 if epoch == 1 + >>> # lr = 0.9 if epoch == 2 + >>> # lr = 0.81 if epoch == 3 + >>> # lr = 0.729 if epoch == 4 + >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) + >>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + schedulers: List[LRScheduler], + milestones: List[int], + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler." + ) + + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + + if len(milestones) != len(schedulers) - 1: + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the " + f"number of milestones to be equal to {len(milestones)}" + ) + _check_verbose_deprecated_warning(verbose) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + self.optimizer = optimizer + + # Reset learning rates back to initial values + for group in self.optimizer.param_groups: + group["lr"] = group["initial_lr"] + + # "Undo" the step performed by other schedulers + for scheduler in self._schedulers: + scheduler.last_epoch -= 1 + + # Perform the initial step for only the first scheduler + self._schedulers[0]._initial_step() + + self._last_lr = schedulers[0].get_last_lr() + + def step(self): + """Perform a step.""" + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + scheduler = self._schedulers[idx] + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + scheduler.step(0) + else: + scheduler.step() + + self._last_lr = scheduler.get_last_lr() + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "_schedulers") + } + state_dict["_schedulers"] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict["_schedulers"][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop("_schedulers") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["_schedulers"] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function in the given total_iters. + + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (float): The power of the polynomial. Default: 1.0. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + total_iters=5, + power=1.0, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.total_iters = total_iters + self.power = power + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0 or self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + decay_factor = ( + (1.0 - self.last_epoch / self.total_iters) + / (1.0 - (self.last_epoch - 1) / self.total_iters) + ) ** self.power + return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + ( + base_lr + * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) + ** self.power + ) + for base_lr in self.base_lrs + ] + + +class CosineAnnealingLR(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing schedule. + + The :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__( + self, + optimizer: Optimizer, + T_max: int, + eta_min=0, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.T_max = T_max + self.eta_min = eta_min + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Retrieve the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + elif self._step_count == 1 and self.last_epoch > 0: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) + / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + / 2 + for base_lr in self.base_lrs + ] + + +class ChainedScheduler(LRScheduler): + """Chains a list of learning rate schedulers. + + Takes in a sequence of chainable learning rate schedulers and calls their + step() functions consecutively in just one call to step(). + + Args: + schedulers (sequence): sequence of chained schedulers. + optimizer (Optimizer, optional): Wrapped optimizer. Default: None. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.09 if epoch == 0 + >>> # lr = 0.081 if epoch == 1 + >>> # lr = 0.729 if epoch == 2 + >>> # lr = 0.6561 if epoch == 3 + >>> # lr = 0.59049 if epoch >= 4 + >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None + ): # noqa: D107 + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler." + ) + + optimizer = optimizer or schedulers[0].optimizer + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + self._schedulers = schedulers + self.optimizer = optimizer + self._last_lr = [ + group["lr"] for group in self._schedulers[-1].optimizer.param_groups + ] + + def step(self): + """Perform a step.""" + for scheduler in self._schedulers: + scheduler.step() + self._last_lr = [ + group["lr"] for group in self._schedulers[-1].optimizer.param_groups + ] + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "_schedulers") + } + state_dict["_schedulers"] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict["_schedulers"][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop("_schedulers") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["_schedulers"] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class ReduceLROnPlateau(LRScheduler): + """Reduce learning rate when a metric has stopped improving. + + Models often benefit from reducing the learning rate by a factor + of 2-10 once learning stagnates. This scheduler reads a metrics + quantity and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Args: + optimizer (Optimizer): Wrapped optimizer. + mode (str): One of `min`, `max`. In `min` mode, lr will + be reduced when the quantity monitored has stopped + decreasing; in `max` mode it will be reduced when the + quantity monitored has stopped increasing. Default: 'min'. + factor (float): Factor by which the learning rate will be + reduced. new_lr = lr * factor. Default: 0.1. + patience (int): The number of allowed epochs with no improvement after + which the learning rate will be reduced. + For example, consider the case of having no patience (`patience = 0`). + In the first epoch, a baseline is established and is always considered good as there's no previous baseline. + In the second epoch, if the performance is worse than the baseline, + we have what is considered an intolerable epoch. + Since the count of intolerable epochs (1) is greater than the patience level (0), + the learning rate is reduced at the end of this epoch. + From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch + if the performance is worse than the baseline. If the performance improves or remains the same, + the learning rate is not adjusted. + Default: 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Default: 1e-4. + threshold_mode (str): One of `rel`, `abs`. In `rel` mode, + dynamic_threshold = best * ( 1 + threshold ) in 'max' + mode or best * ( 1 - threshold ) in `min` mode. + In `abs` mode, dynamic_threshold = best + threshold in + `max` mode or best - threshold in `min` mode. Default: 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after lr has been reduced. Default: 0. + min_lr (float or list): A scalar or a list of scalars. A + lower bound on the learning rate of all param groups + or each group respectively. Default: 0. + eps (float): Minimal decay applied to lr. If the difference + between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = ReduceLROnPlateau(optimizer, 'min') + >>> for epoch in range(10): + >>> train(...) + >>> val_loss = validate(...) + >>> # Note that step should be called after validate() + >>> scheduler.step(val_loss) + """ + + def __init__( + self, + optimizer: Optimizer, + mode: Literal["min", "max"] = "min", + factor=0.1, + patience=10, + threshold=1e-4, + threshold_mode: Literal["rel", "abs"] = "rel", + cooldown=0, + min_lr: Union[List[float], float] = 0, + eps=1e-8, + verbose="deprecated", + ): # noqa: D107 + if factor >= 1.0: + raise ValueError("Factor should be < 1.0.") + self.factor = factor + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + if isinstance(min_lr, (list, tuple)): + if len(min_lr) != len(optimizer.param_groups): + raise ValueError( + f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}" + ) + self.min_lrs = list(min_lr) + else: + self.min_lrs = [min_lr] * len(optimizer.param_groups) + + self.patience = patience + + self.verbose = _check_verbose_deprecated_warning(verbose) + self.cooldown = cooldown + self.cooldown_counter = 0 + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + self.best: float + self.num_bad_epochs: int + self.mode_worse: float # the worse value for the chosen mode + self.eps = eps + self.last_epoch = 0 + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + self._init_is_better( + mode=mode, threshold=threshold, threshold_mode=threshold_mode + ) + self._reset() + + def _reset(self): + """Reset num_bad_epochs counter and cooldown counter.""" + self.best = self.mode_worse + self.cooldown_counter = 0 + self.num_bad_epochs = 0 + + def step(self, metrics: SupportsFloat, epoch=None): # type: ignore[override] + """Perform a step.""" + # convert `metrics` to float, in case it's a zero-dim Tensor + current = float(metrics) + if epoch is None: + epoch = self.last_epoch + 1 + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 # ignore any bad epochs in cooldown + + if self.num_bad_epochs > self.patience: + self._reduce_lr(epoch) + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def _reduce_lr(self, epoch): + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group["lr"]) + new_lr = max(old_lr * self.factor, self.min_lrs[i]) + if old_lr - new_lr > self.eps: + param_group["lr"] = new_lr + + @property + def in_cooldown(self): # noqa: D102 + return self.cooldown_counter > 0 + + def is_better(self, a, best): # noqa: D102 + if self.mode == "min" and self.threshold_mode == "rel": + rel_epsilon = 1.0 - self.threshold + return a < best * rel_epsilon + + elif self.mode == "min" and self.threshold_mode == "abs": + return a < best - self.threshold + + elif self.mode == "max" and self.threshold_mode == "rel": + rel_epsilon = self.threshold + 1.0 + return a > best * rel_epsilon + + else: # mode == 'max' and epsilon_mode == 'abs': + return a > best + self.threshold + + def _init_is_better(self, mode, threshold, threshold_mode): + if mode not in {"min", "max"}: + raise ValueError("mode " + mode + " is unknown!") + if threshold_mode not in {"rel", "abs"}: + raise ValueError("threshold mode " + threshold_mode + " is unknown!") + + if mode == "min": + self.mode_worse = inf + else: # mode == 'max': + self.mode_worse = -inf + + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + + def state_dict(self): # noqa: D102 + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + """Load the scheduler's state.""" + self.__dict__.update(state_dict) + self._init_is_better( + mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode + ) + + +class CyclicLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). + + The policy cycles the learning rate between two boundaries with a constant frequency, + as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_. + The distance between the two boundaries can be scaled on a per-iteration + or per-cycle basis. + + Cyclical learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This class has three built-in policies, as put forth in the paper: + + * "triangular": A basic triangular cycle without amplitude scaling. + * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. + * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` + at each cycle iteration. + + This implementation was adapted from the github repo: `bckenstler/CLR`_ + + Args: + optimizer (Optimizer): Wrapped optimizer. + base_lr (float or list): Initial learning rate which is the + lower boundary in the cycle for each parameter group. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_lr - base_lr). + The lr at any cycle is the sum of base_lr + and some scaling of the amplitude; therefore + max_lr may not actually be reached depending on + scaling function. + step_size_up (int): Number of training iterations in the + increasing half of a cycle. Default: 2000 + step_size_down (int): Number of training iterations in the + decreasing half of a cycle. If step_size_down is None, + it is set to step_size_up. Default: None + mode (str): One of {triangular, triangular2, exp_range}. + Values correspond to policies detailed above. + If scale_fn is not None, this argument is ignored. + Default: 'triangular' + gamma (float): Constant in 'exp_range' scaling function: + gamma**(cycle iterations) + Default: 1.0 + scale_fn (function): Custom scaling policy defined by a single + argument lambda function, where + 0 <= scale_fn(x) <= 1 for all x >= 0. + If specified, then 'mode' is ignored. + Default: None + scale_mode (str): {'cycle', 'iterations'}. + Defines whether scale_fn is evaluated on + cycle number or cycle iterations (training + iterations since start of cycle). + Default: 'cycle' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.8 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + The momentum at any cycle is the difference of max_momentum + and some scaling of the amplitude; therefore + base_momentum may not actually be reached depending on + scaling function. Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.9 + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + .. _bckenstler/CLR: https://github.com/bckenstler/CLR + """ + + def __init__( + self, + optimizer: Optimizer, + base_lr: Union[float, List[float]], + max_lr: Union[float, List[float]], + step_size_up=2000, + step_size_down: Optional[int] = None, + mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", + gamma=1.0, + scale_fn: Optional[Callable[[float], float]] = None, + scale_mode: Literal["cycle", "iterations"] = "cycle", + cycle_momentum=True, + base_momentum=0.8, + max_momentum=0.9, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + base_lrs = _format_param("base_lr", optimizer, base_lr) + if last_epoch == -1: + for lr, group in zip(base_lrs, optimizer.param_groups): + if isinstance(group["lr"], Tensor): + lr_val = lr.item() if isinstance(lr, Tensor) else lr + group["lr"].fill_(lr_val) + else: + group["lr"] = lr + + self.max_lrs = _format_param("max_lr", optimizer, max_lr) + + step_size_up = float(step_size_up) + step_size_down = ( + float(step_size_down) if step_size_down is not None else step_size_up + ) + self.total_size = step_size_up + step_size_down + self.step_ratio = step_size_up / self.total_size + + if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None: + raise ValueError("mode is invalid and scale_fn is None") + + self.mode = mode + self.gamma = gamma + + self._scale_fn_ref: Callable[[float], float] + self._scale_fn_custom = scale_fn + self.scale_mode = scale_mode + self._init_scale_fn() + + self.cycle_momentum = cycle_momentum + if cycle_momentum: + if ( + "momentum" not in optimizer.defaults + and "betas" not in optimizer.defaults + ): + raise ValueError( + "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" + ) + + self.use_beta1 = "betas" in self.optimizer.defaults + self.base_momentums = _format_param( + "base_momentum", optimizer, base_momentum + ) + self.max_momentums = _format_param("max_momentum", optimizer, max_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip( + self.max_momentums, self.base_momentums, optimizer.param_groups + ): + if self.use_beta1: + group["betas"] = (m_momentum, *group["betas"][1:]) + else: + group["momentum"] = m_momentum + group["max_momentum"] = m_momentum + group["base_momentum"] = b_momentum + + super().__init__(optimizer, last_epoch, verbose) + self.base_lrs = base_lrs + + def _init_scale_fn(self): + if self._scale_fn_custom is not None: + return + if self.mode == "triangular": + self._scale_fn_ref = self._triangular_scale_fn + self.scale_mode = "cycle" + elif self.mode == "triangular2": + self._scale_fn_ref = self._triangular2_scale_fn + self.scale_mode = "cycle" + elif self.mode == "exp_range": + self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma) + self.scale_mode = "iterations" + + def scale_fn(self, x) -> float: + """Get the scaling policy.""" + if self._scale_fn_custom is not None: + return self._scale_fn_custom(x) + else: + return self._scale_fn_ref(x) # static method + + @staticmethod + def _triangular_scale_fn(x: float) -> float: + return 1.0 + + @staticmethod + def _triangular2_scale_fn(x: float) -> float: + return 1 / (2.0 ** (x - 1)) + + @staticmethod + def _exp_range_scale_fn(gamma: float, x: float) -> float: + return gamma**x + + def get_lr(self): + """Calculate the learning rate at batch index. + + This function treats `self.last_epoch` as the last batch index. + + If `self.cycle_momentum` is ``True``, this function has a side effect of + updating the optimizer's momentum. + """ + _warn_get_lr_called_within_step(self) + + cycle = math.floor(1 + self.last_epoch / self.total_size) + x = 1.0 + self.last_epoch / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + lrs = [] + for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): + base_height = (max_lr - base_lr) * scale_factor + if self.scale_mode == "cycle": + lr = base_lr + base_height * self.scale_fn(cycle) + else: + lr = base_lr + base_height * self.scale_fn(self.last_epoch) + lrs.append(lr) + + if self.cycle_momentum: + momentums = [] + for base_momentum, max_momentum in zip( + self.base_momentums, self.max_momentums + ): + base_height = (max_momentum - base_momentum) * scale_factor + if self.scale_mode == "cycle": + momentum = max_momentum - base_height * self.scale_fn(cycle) + else: + momentum = max_momentum - base_height * self.scale_fn( + self.last_epoch + ) + momentums.append(momentum) + for param_group, momentum in zip(self.optimizer.param_groups, momentums): + if self.use_beta1: + param_group["betas"] = (momentum, *param_group["betas"][1:]) + else: + param_group["momentum"] = momentum + + return lrs + + def state_dict(self): # noqa: D102 + state = super().state_dict() + # We are dropping the `_scale_fn_ref` attribute because it is a + # `weakref.WeakMethod` and can't be pickled. + state.pop("_scale_fn_ref", None) + fn = state.pop("_scale_fn_custom") + state["_scale_fn_custom"] = None + if fn is not None and not isinstance(fn, types.FunctionType): + # The _scale_fn_custom will only be saved if it is a callable object + # and not if it is a function or lambda. + state["_scale_fn_custom"] = fn.__dict__.copy() + + return state + + def load_state_dict(self, state_dict): + """Load the scheduler's state.""" + fn = state_dict.pop("_scale_fn_custom") + super().load_state_dict(state_dict) + if fn is not None: + self._scale_fn_custom.__dict__.update(fn) + self._init_scale_fn() + + +class CosineAnnealingWarmRestarts(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing schedule. + + The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + is the number of epochs since the last restart and :math:`T_{i}` is the number + of epochs between two warm restarts in SGDR: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) + + When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_0 (int): Number of iterations until the first restart. + T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + last_epoch (int, optional): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__( + self, + optimizer: Optimizer, + T_0: int, + T_mult=1, + eta_min=0, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError(f"Expected positive integer T_0, but got {T_0}") + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") + if not isinstance(eta_min, (float, int)): + raise ValueError( + f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}" + ) + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + self.T_cur = last_epoch + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the initial learning rate.""" + _warn_get_lr_called_within_step(self) + + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.T_cur / self.T_i)) + / 2 + for base_lr in self.base_lrs + ] + + def step(self, epoch=None): + """Step could be called after every batch update. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> iters = len(dataloader) + >>> for epoch in range(20): + >>> for i, sample in enumerate(dataloader): + >>> inputs, labels = sample['inputs'], sample['labels'] + >>> optimizer.zero_grad() + >>> outputs = net(inputs) + >>> loss = criterion(outputs, labels) + >>> loss.backward() + >>> optimizer.step() + >>> scheduler.step(epoch + i / iters) + + This function can be called in an interleaved way. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> for epoch in range(20): + >>> scheduler.step() + >>> scheduler.step(26) + >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) + """ + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur - self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError(f"Expected non-negative epoch, but got {epoch}") + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int( + math.log( + (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult + ) + ) + self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / ( + self.T_mult - 1 + ) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + + with _enable_get_lr_call(self): + for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): + param_group, lr = data + param_group["lr"] = lr + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + +class _SchedulePhase(TypedDict): + end_step: float + start_lr: str + end_lr: str + start_momentum: str + end_momentum: str + + +class OneCycleLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy. + + The 1cycle policy anneals the learning rate from an initial learning rate to some maximum + learning rate and then from that maximum learning rate to some minimum learning rate much + lower than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This scheduler is not chainable. + + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + + #. A value for total_steps is explicitly provided. + #. A number of epochs (epochs) and a number of steps per epoch + (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by + total_steps = epochs * steps_per_epoch + + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int): The total number of steps in the cycle. Note that + if a value is not provided here, then it must be inferred by providing + a value for epochs and steps_per_epoch. + Default: None + epochs (int): The number of epochs to train for. This is used along + with steps_per_epoch in order to infer the total number of steps in the cycle + if a value for total_steps is not provided. + Default: None + steps_per_epoch (int): The number of steps per epoch to train for. This is + used along with epochs in order to infer the total number of steps in the + cycle if a value for total_steps is not provided. + Default: None + pct_start (float): The percentage of the cycle (in number of steps) spent + increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: "cos" for cosine annealing, "linear" for + linear annealing. + Default: 'cos' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.95 + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the + learning rate according to 'final_div_factor' instead of modifying the second + phase (the first two phases will be symmetrical about the step indicated by + 'pct_start'). + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> data_loader = torch.utils.data.DataLoader(...) + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> optimizer.step() + >>> scheduler.step() + + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + + def __init__( + self, + optimizer: Optimizer, + max_lr: Union[float, List[float]], + total_steps: Optional[int] = None, + epochs: Optional[int] = None, + steps_per_epoch: Optional[int] = None, + pct_start=0.3, + anneal_strategy: Literal["cos", "linear"] = "cos", + cycle_momentum=True, + base_momentum: Union[float, List[float]] = 0.85, + max_momentum: Union[float, List[float]] = 0.95, + div_factor=25.0, + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + # Validate optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + # Validate total_steps + if total_steps is not None: + if total_steps <= 0 or not isinstance(total_steps, int): + raise ValueError( + f"Expected positive integer total_steps, but got {total_steps}" + ) + self.total_steps = total_steps + elif epochs is not None and steps_per_epoch is not None: + if not isinstance(epochs, int) or epochs <= 0: + raise ValueError(f"Expected positive integer epochs, but got {epochs}") + if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0: + raise ValueError( + f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}" + ) + self.total_steps = epochs * steps_per_epoch + else: + raise ValueError( + "You must define either total_steps OR (epochs AND steps_per_epoch)" + ) + + self._schedule_phases: List[_SchedulePhase] + if three_phase: + self._schedule_phases = [ + { + "end_step": float(pct_start * self.total_steps) - 1, + "start_lr": "initial_lr", + "end_lr": "max_lr", + "start_momentum": "max_momentum", + "end_momentum": "base_momentum", + }, + { + "end_step": float(2 * pct_start * self.total_steps) - 2, + "start_lr": "max_lr", + "end_lr": "initial_lr", + "start_momentum": "base_momentum", + "end_momentum": "max_momentum", + }, + { + "end_step": self.total_steps - 1, + "start_lr": "initial_lr", + "end_lr": "min_lr", + "start_momentum": "max_momentum", + "end_momentum": "max_momentum", + }, + ] + else: + self._schedule_phases = [ + { + "end_step": float(pct_start * self.total_steps) - 1, + "start_lr": "initial_lr", + "end_lr": "max_lr", + "start_momentum": "max_momentum", + "end_momentum": "base_momentum", + }, + { + "end_step": self.total_steps - 1, + "start_lr": "max_lr", + "end_lr": "min_lr", + "start_momentum": "base_momentum", + "end_momentum": "max_momentum", + }, + ] + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError( + f"Expected float between 0 and 1 pct_start, but got {pct_start}" + ) + + # Validate anneal_strategy + if anneal_strategy not in ["cos", "linear"]: + raise ValueError( + f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}" + ) + else: + self._anneal_func_type = anneal_strategy + + # Initialize learning rate variables + max_lrs = _format_param("max_lr", self.optimizer, max_lr) + if last_epoch == -1: + for idx, group in enumerate(self.optimizer.param_groups): + group["initial_lr"] = max_lrs[idx] / div_factor + group["max_lr"] = max_lrs[idx] + group["min_lr"] = group["initial_lr"] / final_div_factor + + # Initialize momentum variables + self.cycle_momentum = cycle_momentum + if self.cycle_momentum: + if ( + "momentum" not in self.optimizer.defaults + and "betas" not in self.optimizer.defaults + ): + raise ValueError( + "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" + ) + self.use_beta1 = "betas" in self.optimizer.defaults + max_momentums = _format_param("max_momentum", optimizer, max_momentum) + base_momentums = _format_param("base_momentum", optimizer, base_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip( + max_momentums, base_momentums, optimizer.param_groups + ): + if self.use_beta1: + group["betas"] = (m_momentum, *group["betas"][1:]) + else: + group["momentum"] = m_momentum + group["max_momentum"] = m_momentum + group["base_momentum"] = b_momentum + + super().__init__(optimizer, last_epoch, verbose) + + def _anneal_func(self, *args, **kwargs): + if hasattr(self, "_anneal_func_type"): + if self._anneal_func_type == "cos": + return self._annealing_cos(*args, **kwargs) + elif self._anneal_func_type == "linear": + return self._annealing_linear(*args, **kwargs) + else: + raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}") + else: + # For BC + return self.anneal_func(*args, **kwargs) # type: ignore[attr-defined] + + @staticmethod + def _annealing_cos(start, end, pct): + """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + @staticmethod + def _annealing_linear(start, end, pct): + """Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" + return (end - start) * pct + start + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032 + ) + + for group in self.optimizer.param_groups: + start_step = 0.0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase["end_step"] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self._anneal_func( + group[phase["start_lr"]], group[phase["end_lr"]], pct + ) + if self.cycle_momentum: + computed_momentum = self._anneal_func( + group[phase["start_momentum"]], + group[phase["end_momentum"]], + pct, + ) + break + start_step = phase["end_step"] + + lrs.append(computed_lr) # type: ignore[possibly-undefined] + if self.cycle_momentum: + if self.use_beta1: + group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined] + else: + group[ + "momentum" + ] = computed_momentum # type: ignore[possibly-undefined] + + return lrs diff --git a/engine/cr_boosters/nadam.py b/engine/cr_boosters/nadam.py new file mode 100644 index 0000000..0e5b6b9 --- /dev/null +++ b/engine/cr_boosters/nadam.py @@ -0,0 +1,639 @@ +# mypy: allow-untyped-defs +r"""Implementation for the NAdam algorithm.""" +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _dispatch_sqrt, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _stack_if_compiling, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + +__all__ = ["NAdam", "nadam"] + + +class NAdam(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: float = 2e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + momentum_decay: float = 4e-3, + decoupled_weight_decay: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= momentum_decay: + raise ValueError(f"Invalid momentum_decay value: {momentum_decay}") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + momentum_decay=momentum_decay, + decoupled_weight_decay=decoupled_weight_decay, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + group.setdefault("decoupled_weight_decay", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0: + if not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + if not torch.is_tensor(p_state["mu_product"]): + mu_prod_val = p_state["mu_product"] + p_state["mu_product"] = ( + torch.tensor( + mu_prod_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("NAdam does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + # note(crcrpar): [special device hosting for step] + # Deliberately host `step` and `mu_product` on CPU if capturable is False. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + state["mu_product"] = ( + torch.ones((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(1.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + mu_products.append(state["mu_product"]) + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + mu_products: List[Tensor] = [] + state_steps: List[Tensor] = [] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + has_complex = self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + ) + + nadam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + momentum_decay=group["momentum_decay"], + eps=group["eps"], + maximize=group["maximize"], + decoupled_weight_decay=group["decoupled_weight_decay"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + has_complex=has_complex, + ) + + return loss + + +NAdam.__doc__ = ( + r"""Implements NAdam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, + \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\ + &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\ + &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ + &\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\ + &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex] + & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + momentum_decay (float, optional): momentum momentum_decay (default: 4e-3) + decoupled_weight_decay (bool, optional): whether to use decoupled weight + decay as in AdamW to obtain NAdamW (default: False) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + + .. _Incorporating Nesterov Momentum into Adam: + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + + """ +) + + +def _single_tensor_nadam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + mu_products: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, + decoupled_weight_decay: bool, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + mu_product = mu_products[i] + step_t = state_steps[i] + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == mu_product.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params, mu_products and state_steps must be " + f"on supported devices: {capturable_supported_devices}." + ) + + # update step + step_t += 1 + + if capturable: + step = step_t + else: + step = _get_value(step_t) + + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + if decoupled_weight_decay: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # calculate the momentum cache \mu^{t} and \mu^{t+1} + mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay))) + mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay))) + + # update mu_product + mu_product *= mu + + # decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = exp_avg_sq.div(bias_correction2).sqrt() + + if differentiable or capturable: + denom = denom.add(eps) + # Make autograd track the operations + # by updating the grad and exp_avg directly and not using the + # scalar "value" argument of addcdiv. + mu_product_next = mu_product * mu_next + grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product)) + exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next)) + param.addcdiv_(grad, denom) + param.addcdiv_(exp_avg, denom) + else: + mu_product_next = _get_value(mu_product) * mu_next + denom.add_(eps) + param.addcdiv_( + grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product))) + ) + param.addcdiv_( + exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next) + ) + + +def _multi_tensor_nadam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + mu_products: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, + decoupled_weight_decay: bool, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == mp.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, mp, step in zip(params, mu_products, state_steps) + ), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps] + ) + for ( + grouped_params, + grouped_grads, + grouped_exp_avgs, + grouped_exp_avg_sqs, + grouped_mu_products, + grouped_state_steps, + ), _ in grouped_tensors.values(): + # handle complex + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + if decoupled_weight_decay: + # Perform stepweight decay + torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) + else: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_( + grouped_grads, grouped_params, alpha=weight_decay + ) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + torch._foreach_mul_(grouped_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 + ) + + exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs) + + bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] + mus: Union[Tuple[Tensor, ...], List[Tensor]] + mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay)) + exponent = torch._foreach_mul(grouped_state_steps, momentum_decay) + mus = torch._foreach_pow(0.96, exponent) + torch._foreach_mul_(mus, -0.5) + torch._foreach_add_(mus, 1.0) + torch._foreach_mul_(mus, beta1) + + # mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay)) + torch._foreach_add_(exponent, momentum_decay) + mu_nexts = torch._foreach_pow(0.96, exponent) + torch._foreach_mul_(mu_nexts, -0.5) + torch._foreach_add_(mu_nexts, 1.0) + torch._foreach_mul_(mu_nexts, beta1) + + # save peak memory as we don't need exponent anymore + del exponent + + bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction_sqrt, 1.0) + torch._foreach_neg_(bias_correction_sqrt) + torch._foreach_sqrt_(bias_correction_sqrt) + else: + bias_correction_sqrt = [ + _dispatch_sqrt(1 - beta2 ** _get_value(step)) + for step in grouped_state_steps + ] + mus = [ + beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) + for step in grouped_state_steps + ] + mu_nexts = [ + beta1 + * (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay))) + for step in grouped_state_steps + ] + + # update mu_products + torch._foreach_mul_(grouped_mu_products, mus) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + + # explicitly delete bias_correction refs to save memory + del bias_correction_sqrt + + if capturable: + # Build up the step_size multiplier for grad, reusing mus' memory + torch._foreach_sub_(mus, 1.0) + torch._foreach_mul_(mus, lr) + # foreach_sub doesn't allow a scalar as the first arg + denom = torch._foreach_sub(grouped_mu_products, 1.0) + torch._foreach_neg_(denom) + torch._foreach_div_(mus, denom) + # - lr * (1 - mu) / (1 - mu_product) + step_size_grads = mus + # explicitly delete denom to save memory + del denom + + # Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory + denom = torch._foreach_mul(grouped_mu_products, mu_nexts) + torch._foreach_mul_(mu_nexts, lr) + # foreach_sub doesn't allow a scalar as the first arg, but it's okay because + # we need a negative here anyway + torch._foreach_sub_(denom, 1.0) + torch._foreach_div_(mu_nexts, denom) + # - lr * mu_next / (1 - mu_product * mu_next) + step_size_expavg = mu_nexts + # explicitly delete denom to save memory + del denom + + # we cannot inplace into step_size_grads cuz it is a list of ScalarTensors + # and mul'ing with grouped_grads will result in a list of bigger Tensors + numerator = torch._foreach_mul(step_size_grads, grouped_grads) + torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs) + + # finally, update params + torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt) + else: + step_size_grads = _stack_if_compiling( + [ + (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1 + for mu_product, mu in zip(grouped_mu_products, mus) + ] + ) + step_size_expavg = _stack_if_compiling( + [ + ( + _get_value(lr) + * mu_next + / (1.0 - _get_value(mu_product) * mu_next) + ) + * -1 + for mu_product, mu_next in zip(grouped_mu_products, mu_nexts) + ] + ) + + torch._foreach_addcdiv_( + grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type] + ) + torch._foreach_addcdiv_( + grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type] + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam) +def nadam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + mu_products: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + decoupled_weight_decay: bool = False, + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + has_complex: bool = False, + maximize: bool = False, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, +): + r"""Functional API that performs NAdam algorithm computation. + + See :class:`~torch.optim.NAdam` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if not all(isinstance(t, torch.Tensor) for t in mu_products): + raise RuntimeError( + "API has changed, `mu_products` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_nadam + else: + func = _single_tensor_nadam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + momentum_decay=momentum_decay, + maximize=maximize, + decoupled_weight_decay=decoupled_weight_decay, + eps=eps, + capturable=capturable, + differentiable=differentiable, + has_complex=has_complex, + ) diff --git a/engine/cr_boosters/optimizer.py b/engine/cr_boosters/optimizer.py new file mode 100644 index 0000000..129e377 --- /dev/null +++ b/engine/cr_boosters/optimizer.py @@ -0,0 +1,1052 @@ +# mypy: allow-untyped-defs +"""Base optimizer.""" +import functools +import math +import warnings +from collections import defaultdict, OrderedDict +from copy import deepcopy +from itertools import chain +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Dict, + Hashable, + Iterable, + List, + Optional, + overload, + Set, + Tuple, + TypeVar, + Union, +) +from typing_extensions import ParamSpec, Self, TypeAlias + +import torch +import torch.utils.hooks as hooks +from torch._utils import is_compiling +from torch.utils._foreach_utils import ( + _get_foreach_kernels_supported_devices, + _get_fused_kernels_supported_devices, + _group_tensors_by_device_and_dtype, + Indices, +) +from torch.utils.hooks import RemovableHandle + +Args: TypeAlias = Tuple[Any, ...] +Kwargs: TypeAlias = Dict[str, Any] +StateDict: TypeAlias = Dict[str, Any] +TensorListList: TypeAlias = List[List[torch.Tensor]] +DeviceDict = Dict[Optional[torch.device], torch.Tensor] + + +GlobalOptimizerPreHook: TypeAlias = Callable[ + ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]] +] +GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] + +__all__ = [ + "Optimizer", + "register_optimizer_step_pre_hook", + "register_optimizer_step_post_hook", +] +_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict() +_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict() +_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] + + +class _RequiredParameter: + """Singleton class representing a required parameter for an Optimizer.""" + + def __repr__(self) -> str: + return "" + + +required = _RequiredParameter() + + +def _use_grad_for_differentiable(func): + def _use_grad(self, *args, **kwargs): + import torch._dynamo + + prev_grad = torch.is_grad_enabled() + try: + # Note on graph break below: + # we need to graph break to ensure that aot respects the no_grad annotation. + # This is important for perf because without this, functionalization will generate an epilogue + # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result, + # inductor will allocate for every parameter in the model, which is horrible. + # With this, aot correctly sees that this is an inference graph, and functionalization will generate + # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that + # step is in place and is able to avoid the extra allocation. + # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter + # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this + # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled. + # see https://github.com/pytorch/pytorch/issues/104053 + torch.set_grad_enabled(self.defaults["differentiable"]) + torch._dynamo.graph_break() + ret = func(self, *args, **kwargs) + finally: + torch._dynamo.graph_break() + torch.set_grad_enabled(prev_grad) + return ret + + functools.update_wrapper(_use_grad, func) + return _use_grad + + +def _get_value(x): + # item is significantly faster than a cpu tensor in eager mode + if not torch.jit.is_scripting() and is_compiling(): + return x + else: + return x.item() if isinstance(x, torch.Tensor) else x + + +def _stack_if_compiling(x): + if not torch.jit.is_scripting() and is_compiling(): + return torch.stack(x) + else: + return x + + +def _dispatch_sqrt( + x: float, +): # float annotation is needed because of torchscript type inference + if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): + return x.sqrt() + else: + return math.sqrt(x) + + +def _disable_dynamo_if_unsupported(single_tensor_fn=None): + # workaround for torchscript BC + # it requires all called functions to be in the + # global environment at the site at which the + # maybe_fallback closure is created + if single_tensor_fn: + globals()[single_tensor_fn.__name__] = single_tensor_fn + + def wrapper(func): + import inspect + + disabled_func = torch._disable_dynamo(func) + ps = inspect.signature(func).parameters + has_state_steps = True + try: + state_steps_ind = list(ps.keys()).index("state_steps") + except ValueError: + has_state_steps = False + + # Today, there are cases where we stack state steps + # and pass them as the value arg of foreach ops. + # Having state steps on cuda as the value arg is not supported in eager, + # but this only occurs in the rare case that the user explicitly deletes + # the capturable flag. If capturable=True, this is not a problem. + @functools.wraps(func) + def maybe_fallback(*args, **kwargs): + if is_compiling() and ( + not kwargs.get("capturable", False) + and has_state_steps + and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) + or ( + "state_steps" in kwargs + and kwargs["state_steps"] + and kwargs["state_steps"][0].is_cuda + ) + ): + return disabled_func(*args, **kwargs) + else: + return func(*args, **kwargs) + + return maybe_fallback + + return wrapper + + +# For any optimizer with a faster implementation, we attempt to default to the +# fastest + stablest whenever possible. For foreach, the requirements are to have +# native params all on CUDA. For fused, there's currently the additional requirement +# that the tensors' dtypes must be floating point. Neither alternative supports +# torch.jit.script nor differentiable, so we fall back to the single tensor +# implementation in those cases. +def _default_to_fused_or_foreach( + params: List[torch.Tensor], differentiable: bool, use_fused: bool = False +) -> Tuple[bool, bool]: + if torch.jit.is_scripting() or differentiable: + return False, False + + fused_supported_devices = _get_fused_kernels_supported_devices() + foreach_supported_devices = _get_foreach_kernels_supported_devices() + fused = use_fused and all( + p is None + or ( + type(p) in _foreach_supported_types + and p.device.type in fused_supported_devices + and torch.is_floating_point(p) + ) + for p in params + ) + foreach = not fused and all( + p is None + or ( + type(p) in _foreach_supported_types + and p.device.type in foreach_supported_devices + ) + for p in params + ) + return fused, foreach + + +def _view_as_real(params, *state_and_grads): + for i, p in enumerate(params): + if torch.is_complex(p): + params[i] = torch.view_as_real(params[i]) + for s in state_and_grads: + s[i] = torch.view_as_real(s[i]) + + +def _get_scalar_dtype(is_fused=None): + if is_fused: + return torch.float32 + return ( + torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 + ) + + +def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: + r"""Return the device type list that supports capturable optimizer.""" + capturable_supported_devices = ["cuda"] + if not torch.jit.is_scripting(): + capturable_supported_devices.append(torch._C._get_privateuse1_backend_name()) + if supports_xla: + capturable_supported_devices.append("xla") + return capturable_supported_devices + + +# Common doc strings among optimizers +_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer + is used. If unspecified by the user (so foreach is None), we will try to use + foreach over the for-loop implementation on CUDA, since it is usually + significantly more performant. Note that the foreach implementation uses + ~ sizeof(params) more peak memory than the for-loop version due to the intermediates + being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer + parameters through the optimizer at a time or switch this flag to False (default: None)""" + +_fused_doc = r"""fused (bool, optional): whether the fused implementation is used. + Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` + are supported. (default: None) + + .. note:: The foreach and fused implementations are typically faster than the for-loop, + single-tensor implementation. Thus, if the user has not specified BOTH flags + (i.e., when foreach = fused = None), we will attempt defaulting to the foreach + implementation when the tensors are all on CUDA. For example, if the user specifies + True for fused but nothing for foreach, we will run the fused implementation. If + the user specifies False for foreach but nothing for fused (or False for fused but + nothing for foreach), we will run the for-loop implementation. If the user specifies + True for both foreach and fused, we will prioritize fused over foreach, as it is + typically faster. We attempt to use the fastest, so the hierarchy goes fused -> + foreach -> for-loop. HOWEVER, since the fused implementation is relatively new, + we want to give it sufficient bake-in time, so we default to foreach and NOT + fused when the user has not specified either flag.""" + +_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to + capture in a CUDA graph. Passing True can impair ungraphed performance, + so if you don't intend to graph capture this instance, leave it False + (default: False)""" + +_differentiable_doc = r"""differentiable (bool, optional): whether autograd should + occur through the optimizer step in training. Otherwise, the step() + function runs in a torch.no_grad() context. Setting to True can impair + performance, so leave it False if you don't intend to run autograd + through this instance (default: False)""" + +_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the + params, instead of minimizing (default: False)""" + + +def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle: + r"""Register a pre hook common to all optimizers. + + The hook should have the following signature:: + + hook(optimizer, args, kwargs) -> None or modified args and kwargs + + Args: + hook (Callable): A user defined hook which is registered on all optimizers. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_optimizer_pre_hooks) + _global_optimizer_pre_hooks[handle.id] = hook + return handle + + +def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle: + r"""Register a post hook common to all optimizers. + + The hook should have the following signature:: + + hook(optimizer, args, kwargs) -> None + + Args: + hook (Callable): A user defined hook which is registered on all optimizers. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_optimizer_post_hooks) + _global_optimizer_post_hooks[handle.id] = hook + return handle + + +ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] + +_P = ParamSpec("_P") +R = TypeVar("R") +T = TypeVar("T") + + +class Optimizer: + r"""Base class for all optimizers. + + .. warning:: + Parameters need to be specified as collections that have a deterministic + ordering that is consistent between runs. Examples of objects that don't + satisfy those properties are sets and iterators over values of dictionaries. + + Args: + params (iterable): an iterable of :class:`torch.Tensor` s or + :class:`dict` s. Specifies what Tensors should be optimized. + defaults: (dict): a dict containing default values of optimization + options (used when a parameter group doesn't specify them). + """ + + OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc] + OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] + + _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook] + _optimizer_step_post_hooks: Dict[int, OptimizerPostHook] + _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' + _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' + + def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: # noqa: D107 + torch._C._log_api_usage_once("python.optimizer") + self.defaults = defaults + self._optimizer_step_pre_hooks = OrderedDict() + self._optimizer_step_post_hooks = OrderedDict() + self._optimizer_state_dict_pre_hooks = OrderedDict() + self._optimizer_state_dict_post_hooks = OrderedDict() + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + self._optimizer_load_state_dict_post_hooks = OrderedDict() + + self._patch_step_function() + + if isinstance(params, torch.Tensor): + raise TypeError( + "params argument given to the optimizer should be " + "an iterable of Tensors or dicts, but got " + torch.typename(params) + ) + + self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict) + self.param_groups: List[Dict[str, Any]] = [] + + param_groups = list(params) + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + if not isinstance(param_groups[0], dict): + param_groups = [{"params": param_groups}] + + for param_group in param_groups: + self.add_param_group(cast(dict, param_group)) + + # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, + # which I don't think exists + # https://github.com/pytorch/pytorch/issues/72948 + self._warned_capturable_if_run_uncaptured = True + + def __getstate__(self) -> Dict[str, Any]: # noqa: D105 + return { + "defaults": self.defaults, + "state": self.state, + "param_groups": self.param_groups, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: # noqa: D105 + self.__dict__.update(state) + if "_optimizer_step_pre_hooks" not in self.__dict__: + self._optimizer_step_pre_hooks = OrderedDict() + if "_optimizer_step_post_hooks" not in self.__dict__: + self._optimizer_step_post_hooks = OrderedDict() + if "_optimizer_state_dict_pre_hooks" not in self.__dict__: + self._optimizer_state_dict_pre_hooks = OrderedDict() + if "_optimizer_state_dict_post_hooks" not in self.__dict__: + self._optimizer_state_dict_post_hooks = OrderedDict() + if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__: + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + if "_optimizer_load_state_dict_post_hooks" not in self.__dict__: + self._optimizer_load_state_dict_post_hooks = OrderedDict() + self._patch_step_function() # To support multiprocessing pickle/unpickle + self.defaults.setdefault("differentiable", False) + + def __repr__(self) -> str: # noqa: D105 + format_string = self.__class__.__name__ + " (" + for i, group in enumerate(self.param_groups): + format_string += "\n" + format_string += f"Parameter Group {i}\n" + for key in sorted(group.keys()): + if key != "params": + format_string += f" {key}: {group[key]}\n" + format_string += ")" + return format_string + + # Currently needed by Adam and AdamW + def _cuda_graph_capture_health_check(self) -> None: + # Note [torch.compile x capturable] + # If we are compiling, we try to take the capturable path automatically by + # setting the flag to True during tracing. Due to this, we skip all the checks + # normally required for determining whether we can use CUDA graphs and + # shunt the responsibility to torch.inductor. This saves time during tracing + # since the checks are slow without sacrificing UX since inductor will warn + # later if CUDA graphs cannot be enabled, e.g., + # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390. + # Thus, when compiling, inductor will determine if cudagraphs + # can be enabled based on whether there is input mutation or CPU tensors. + if ( + not is_compiling() + and torch.backends.cuda.is_built() + and torch.cuda.is_available() + ): + capturing = torch.cuda.is_current_stream_capturing() + + if capturing and not all( + group["capturable"] for group in self.param_groups + ): + raise RuntimeError( + "Attempting CUDA graph capture of step() for an instance of " + + self.__class__.__name__ + + " but param_groups' capturable is False." + ) + + if ( + (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) + and all(group["capturable"] for group in self.param_groups) + and (not capturing) + ): + warnings.warn( + "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, " + "but step() is running without CUDA graph capture. If you never intend to graph-capture this " + "instance, capturable=True can impair performance, and you should set capturable=False." + ) + self._warned_capturable_if_run_uncaptured = True + + def _optimizer_step_code(self) -> None: + """Entry point for `torch.profile.profiler`. + + When python tracing is enabled the profiler will hook into this + function at the CPython level to inspect the optimizer's parameters and + param groups. It is called it after `step()` since many optimizers + lazily initialize state. + + This is a workaround due to lack of a proper step hook on the optimizer, + and will be removed if it exists. + """ + pass + + @staticmethod + def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: # noqa: D102 + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: + self, *_ = args + self = cast(Optimizer, self) + profile_name = f"Optimizer.step#{self.__class__.__name__}.step" + with torch.autograd.profiler.record_function(profile_name): + # call optimizer step pre hooks + for pre_hook in chain( + _global_optimizer_pre_hooks.values(), + self._optimizer_step_pre_hooks.values(), + ): + result = pre_hook(self, args, kwargs) + if result is not None: + if isinstance(result, tuple) and len(result) == 2: + args, kwargs = result # type: ignore[assignment] + else: + raise RuntimeError( + f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." + ) + + out = func(*args, **kwargs) + self._optimizer_step_code() + + # call optimizer step post hooks + for post_hook in chain( + self._optimizer_step_post_hooks.values(), + _global_optimizer_post_hooks.values(), + ): + post_hook(self, args, kwargs) + + return out + + return wrapper + + @staticmethod + def _group_tensors_by_device_and_dtype( + tensorlistlist: TensorListList, + with_indices: bool = False, + ) -> Union[ + Dict[Tuple[None, None], Tuple[TensorListList, Indices]], + Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]], + ]: + """Group a list of lists of tensors by device and dtype. + + Skips this step if we are compiling since this will occur during inductor lowering. + """ + if is_compiling(): + return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} + else: + return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] + + def _patch_step_function(self) -> None: + self._zero_grad_profile_name = ( + f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad" + ) + hooked = getattr(self.__class__.step, "hooked", None) + if not hooked: + self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment] + self.__class__.step.hooked = True # type: ignore[attr-defined] + + def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle: + r"""Register an optimizer step pre hook which will be called before optimizer step. + + It should have the following signature:: + + hook(optimizer, args, kwargs) -> None or modified args and kwargs + + The ``optimizer`` argument is the optimizer instance being used. If + args and kwargs are modified by the pre-hook, then the transformed + values are returned as a tuple containing the new_args and new_kwargs. + + Args: + hook (Callable): The user defined hook to be registered. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks) + self._optimizer_step_pre_hooks[handle.id] = hook + return handle + + def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: + r"""Register an optimizer step post hook which will be called after optimizer step. + + It should have the following signature:: + + hook(optimizer, args, kwargs) -> None + + The ``optimizer`` argument is the optimizer instance being used. + + Args: + hook (Callable): The user defined hook to be registered. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_step_post_hooks) + self._optimizer_step_post_hooks[handle.id] = hook + return handle + + def register_state_dict_pre_hook( + self, hook: Callable[["Optimizer"], None], prepend: bool = False + ) -> RemovableHandle: # noqa: D101 + r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called. + + It should have the following signature:: + + hook(optimizer) -> None + + The ``optimizer`` argument is the optimizer instance being used. + The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``. + The registered hook can be used to perform pre-processing before the ``state_dict`` + call is made. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided pre ``hook`` will be fired before + all the already registered pre-hooks on ``state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + pre-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks) + self._optimizer_state_dict_pre_hooks[handle.id] = hook + if prepend: + self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) + return handle + + def register_state_dict_post_hook( + self, + hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called. + + It should have the following signature:: + + hook(optimizer, state_dict) -> state_dict or None + + The hook will be called with arguments ``self`` and ``state_dict`` after generating + a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally + return a new one. The registered hook can be used to perform post-processing + on the ``state_dict`` before it is returned. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided post ``hook`` will be fired before + all the already registered post-hooks on ``state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + post-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks) + self._optimizer_state_dict_post_hooks[handle.id] = hook + if prepend: + self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) + return handle + + @torch._disable_dynamo + def state_dict(self) -> StateDict: + r"""Return the state of the optimizer as a :class:`dict`. + + It contains two entries: + + * ``state``: a Dict holding current optimization state. Its content + differs between optimizer classes, but some common characteristics + hold. For example, state is saved per parameter, and the parameter + itself is NOT saved. ``state`` is a Dictionary mapping parameter ids + to a Dict with state corresponding to each parameter. + * ``param_groups``: a List containing all parameter groups where each + parameter group is a Dict. Each parameter group contains metadata + specific to the optimizer, such as learning rate and weight decay, + as well as a List of parameter IDs of the parameters in the group. + + NOTE: The parameter IDs may look like indices but they are just IDs + associating state with param_group. When loading from a state_dict, + the optimizer will zip the param_group ``params`` (int IDs) and the + optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to + match state WITHOUT additional verification. + + A returned state dict might look something like: + + .. code-block:: text + + { + 'state': { + 0: {'momentum_buffer': tensor(...), ...}, + 1: {'momentum_buffer': tensor(...), ...}, + 2: {'momentum_buffer': tensor(...), ...}, + 3: {'momentum_buffer': tensor(...), ...} + }, + 'param_groups': [ + { + 'lr': 0.01, + 'weight_decay': 0, + ... + 'params': [0] + }, + { + 'lr': 0.001, + 'weight_decay': 0.5, + ... + 'params': [1, 2, 3] + } + ] + } + + """ + for pre_hook in self._optimizer_state_dict_pre_hooks.values(): + pre_hook(self) + + # Save order indices instead of Tensors + param_mappings: Dict[int, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + param_mappings.update( + { + id(p): i + for i, p in enumerate(group["params"], start_index) + if id(p) not in param_mappings + } + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in self.param_groups] + # Remap state to use order indices as keys + packed_state = { + (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + + state_dict = { + "state": packed_state, + "param_groups": param_groups, + } + + for post_hook in self._optimizer_state_dict_post_hooks.values(): + hook_result = post_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + return state_dict + + @staticmethod + def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + key: Hashable = None, + ) -> torch.Tensor: + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + # UNLESS fused or capturable, see note [special device hosting for step] + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=param.device) + else: + return value + else: + if param.is_floating_point(): + return value.to(dtype=param.dtype, device=param.device) + else: + return value.to(device=param.device) + + def register_load_state_dict_pre_hook( + self, + hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + prepend: bool = False, + ) -> RemovableHandle: # noqa: D205 D400 + r"""Register a load_state_dict pre-hook which will be called before + :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the + following signature:: + + hook(optimizer, state_dict) -> state_dict or None + + The ``optimizer`` argument is the optimizer instance being used and the + ``state_dict`` argument is a shallow copy of the ``state_dict`` the user + passed in to ``load_state_dict``. The hook may modify the state_dict inplace + or optionally return a new one. If a state_dict is returned, it will be used + to be loaded into the optimizer. + + The hook will be called with argument ``self`` and ``state_dict`` before + calling ``load_state_dict`` on ``self``. The registered hook can be used to + perform pre-processing before the ``load_state_dict`` call is made. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided pre ``hook`` will be fired before + all the already registered pre-hooks on ``load_state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + pre-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks) + self._optimizer_load_state_dict_pre_hooks[handle.id] = hook + if prepend: + self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) + return handle + + def register_load_state_dict_post_hook( + self, hook: Callable[["Optimizer"], None], prepend: bool = False + ) -> RemovableHandle: # noqa: D205 D400 + r"""Register a load_state_dict post-hook which will be called after + :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the + following signature:: + + hook(optimizer) -> None + + The ``optimizer`` argument is the optimizer instance being used. + + The hook will be called with argument ``self`` after calling + ``load_state_dict`` on ``self``. The registered hook can be used to + perform post-processing after ``load_state_dict`` has loaded the + ``state_dict``. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided post ``hook`` will be fired before + all the already registered post-hooks on ``load_state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + post-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) + self._optimizer_load_state_dict_post_hooks[handle.id] = hook + if prepend: + self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + @torch._disable_dynamo + def load_state_dict(self, state_dict: StateDict) -> None: + r"""Load the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError( + "loaded state dict has a different number of " "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return Optimizer._process_value_according_to_param_policy( + param, value, param_id, param_groups, key + ) + elif isinstance(value, dict): + return { + k: _cast( + param, v, param_id=param_id, param_groups=param_groups, key=k + ) + for k, v in value.items() + } + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast( + param, v, param_id=k, param_groups=state_dict["param_groups"] + ) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group( + group: Dict[str, Any], new_group: Dict[str, Any] + ) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) + + @torch._disable_dynamo + def zero_grad(self, set_to_none: bool = True) -> None: + r"""Reset the gradients of all optimized :class:`torch.Tensor` s. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + foreach = self.defaults.get("foreach", False) or self.defaults.get( + "fused", False + ) + + if not hasattr(self, "_zero_grad_profile_name"): + self._patch_step_function() + + per_device_and_dtype_grads: Optional[ + DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]] + ] + if foreach: + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + else: + per_device_and_dtype_grads = None + + with torch.autograd.profiler.record_function(self._zero_grad_profile_name): + for group in self.param_groups: + for p in group["params"]: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + if not foreach or p.grad.is_sparse: + p.grad.zero_() + else: + assert per_device_and_dtype_grads is not None + per_device_and_dtype_grads[p.grad.device][ + p.grad.dtype + ].append(p.grad) + if foreach: + assert per_device_and_dtype_grads is not None + for per_dtype_grads in per_device_and_dtype_grads.values(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) + + @overload + def step(self, closure: None = ...) -> None: + ... + + @overload + def step(self, closure: Callable[[], float]) -> float: + ... + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + r"""Perform a single optimization step to update parameter. + + Args: + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + Unless otherwise specified, this function should not modify the + ``.grad`` field of the parameters. + """ + raise NotImplementedError + + @torch._disable_dynamo + def add_param_group(self, param_group: Dict[str, Any]) -> None: + r"""Add a param group to the :class:`Optimizer` s `param_groups`. + + This can be useful when fine tuning a pre-trained network as frozen layers can be made + trainable and added to the :class:`Optimizer` as training progresses. + + Args: + param_group (dict): Specifies what Tensors should be optimized along with group + specific optimization options. + """ + if not isinstance(param_group, dict): + raise TypeError(f"param_group must be a dict, but got {type(param_group)}") + + params = param_group["params"] + if isinstance(params, torch.Tensor): + param_group["params"] = [params] + elif isinstance(params, set): + raise TypeError( + "optimizer parameters need to be organized in ordered collections, but " + "the ordering of tensors in sets will change between runs. Please use a list instead." + ) + else: + param_group["params"] = list(params) + + for param in param_group["params"]: + if not isinstance(param, torch.Tensor): + raise TypeError( + "optimizer can only optimize Tensors, " + "but one of the params is " + torch.typename(param) + ) + if not self.defaults.get("differentiable", None) and not ( + param.is_leaf or param.retains_grad + ): + raise ValueError("can't optimize a non-leaf Tensor") + + for name, default in self.defaults.items(): + if default is required and name not in param_group: + raise ValueError( + f"parameter group didn't specify a value of required optimization parameter {name}" + ) + else: + param_group.setdefault(name, default) + + params = param_group["params"] + if len(params) != len(set(params)): + warnings.warn( + "optimizer contains a parameter group with duplicate parameters; " + "in future, this will cause an error; " + "see github.com/pytorch/pytorch/issues/40967 for more information", + stacklevel=3, + ) + + param_set: Set[torch.Tensor] = set() + for group in self.param_groups: + param_set.update(set(group["params"])) + + if not param_set.isdisjoint(set(param_group["params"])): + raise ValueError("some parameters appear in more than one parameter group") + + self.param_groups.append(param_group) diff --git a/engine/cr_boosters/radam.py b/engine/cr_boosters/radam.py new file mode 100644 index 0000000..fc6955c --- /dev/null +++ b/engine/cr_boosters/radam.py @@ -0,0 +1,598 @@ +# mypy: allow-untyped-defs +r"""Implementation for the RAdam algorithm.""" +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _dispatch_sqrt, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + +__all__ = ["RAdam", "radam"] + + +class RAdam(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + decoupled_weight_decay: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + foreach=foreach, + capturable=capturable, + decoupled_weight_decay=decoupled_weight_decay, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("decoupled_weight_decay", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + has_complex = self._init_group( + group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps + ) + + radam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + decoupled_weight_decay=group["decoupled_weight_decay"], + has_complex=has_complex, + ) + + return loss + + +RAdam.__doc__ = ( + r"""Implements RAdam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2 + \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \: + \lambda \text{ (weightdecay)}, \:\textit{maximize} \\ + &\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0 \leftarrow 0 \text{ ( second moment)}, \\ + &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{6mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{12mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{6mm}\textbf{else} \\ + &\hspace{12mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{6mm} \theta_t \leftarrow \theta_{t-1} \\ + &\hspace{6mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ + &\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t} \\ + &\hspace{12mm}\textbf{else} \\ + &\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t} \\ + &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} - + 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex] + &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\ + &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\ + &\hspace{12mm} r_t \leftarrow + \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\ + &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t \\ + &\hspace{6mm}\textbf{else} \\ + &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_. + + This implementation provides an option to use either the original weight_decay implementation as in Adam + (where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is applied + to the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False + (default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style which + corresponds more closely to the `author's implementation`_ in the RAdam paper. Further information + about decoupled weight decay can be found in `Decoupled Weight Decay Regularization`_. + + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_weight_decay (bool, optional): whether to use decoupled weight + decay as in AdamW to obtain RAdamW (default: False) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + {_capturable_doc} + + .. _On the variance of the adaptive learning rate and beyond: + https://arxiv.org/abs/1908.03265 + .. _author's implementation: + https://github.com/LiyuanLucasLiu/RAdam + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + + """ +) + + +def _single_tensor_radam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + decoupled_weight_decay: bool, + differentiable: bool, + maximize: bool, + capturable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + + # update step + step_t += 1 + step = step_t if capturable else _get_value(step_t) + + if weight_decay != 0: + if decoupled_weight_decay: + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + # correcting bias for the first moving moment + bias_corrected_exp_avg = exp_avg / bias_correction1 + + # maximum length of the approximated SMA + rho_inf = 2 / (1 - beta2) - 1 + # compute the length of the approximated SMA + rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2 + + def _compute_rect(): + return ( + (rho_t - 4) + * (rho_t - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) ** 0.5 + + def _compute_adaptive_lr(): + exp_avg_sq_sqrt = exp_avg_sq.sqrt() + if differentiable: + exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps) + else: + exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps) + + return (bias_correction2**0.5) / exp_avg_sq_sqrt + + # Compute the variance rectification term and update parameters accordingly + if capturable: + update = torch.where( + rho_t > 5.0, _compute_rect() * _compute_adaptive_lr(), 1.0 + ) + param.add_(bias_corrected_exp_avg * lr * update, alpha=-1.0) + else: + if rho_t > 5.0: + param.add_( + bias_corrected_exp_avg + * lr + * _compute_adaptive_lr() + * _compute_rect(), + alpha=-1.0, + ) + else: + param.add_(bias_corrected_exp_avg * lr, alpha=-1.0) + + +def _multi_tensor_radam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + decoupled_weight_decay: bool, + differentiable: bool, + maximize: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, state_steps] + ) + for ( + grouped_params, + grouped_grads, + grouped_exp_avgs, + grouped_exp_avg_sqs, + grouped_state_steps, + ), _ in grouped_tensors.values(): + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # maximum length of the approximated SMA + rho_inf = 2 / (1 - beta2) - 1 + # compute the length of the approximated SMA + bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] + rho_t_list: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_neg_(bias_correction1) + torch._foreach_add_(bias_correction1, 1) + bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_mul_(bias_correction2, grouped_state_steps) + torch._foreach_mul_(bias_correction2, 2) + torch._foreach_div_(bias_correction2, bias_correction1) + torch._foreach_neg_(bias_correction2) + torch._foreach_add_(bias_correction2, rho_inf) + rho_t_list = bias_correction2 + else: + rho_t_list = [ + rho_inf + - 2 + * _get_value(step) + * (beta2 ** _get_value(step)) + / (1 - beta2 ** _get_value(step)) + for step in grouped_state_steps + ] + + if weight_decay != 0: + if decoupled_weight_decay: + torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) + else: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_( + grouped_grads, grouped_params, alpha=weight_decay + ) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + torch._foreach_mul_(grouped_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 + ) + + # Delete the local intermediate since it won't be used anymore to save on peak memory + del grouped_grads + + if capturable: + num = torch._foreach_sub(rho_t_list, 4) + sub2 = torch._foreach_sub(rho_t_list, 2) + torch._foreach_mul_(num, sub2) + del sub2 + torch._foreach_mul_(num, rho_inf) + rho_inf = (rho_inf - 4) * (rho_inf - 2) + denom = torch._foreach_mul(rho_t_list, rho_inf) + torch._foreach_div_(num, denom) + del denom + torch._foreach_sqrt_(num) + + # TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884 + rect = [ + torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list) + ] + del num + del rho_t_list + unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect] + torch._foreach_mul_(unrect_step_size, lr) + + bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps) + torch._foreach_neg_(bias_correction1) + torch._foreach_add_(bias_correction1, 1) + + torch._foreach_div_(unrect_step_size, bias_correction1) + torch._foreach_neg_(unrect_step_size) + + bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_neg_(bias_correction2) + torch._foreach_add_(bias_correction2, 1) + torch._foreach_sqrt_(bias_correction2) + torch._foreach_mul_(bias_correction2, lr) + torch._foreach_mul_(bias_correction2, rect) + del rect + torch._foreach_neg_(bias_correction2) + torch._foreach_div_(bias_correction2, bias_correction1) + del bias_correction1 + else: + rect = [ + _dispatch_sqrt( + (rho_t - 4) # type: ignore[arg-type] + * (rho_t - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) + if rho_t > 5 + else 0 + for rho_t in rho_t_list + ] + unrectified = [0 if rect > 0 else 1.0 for rect in rect] + + bias_correction1 = [ + 1 - beta1 ** _get_value(step) for step in grouped_state_steps + ] + unrect_step_size = [ + (lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1) + ] + bias_correction2 = [ + _dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1 + for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1) + ] + + buffer = torch._foreach_sqrt(grouped_exp_avg_sqs) + torch._foreach_add_(buffer, eps) + torch._foreach_div_(buffer, bias_correction2) + torch._foreach_reciprocal_(buffer) + torch._foreach_add_(buffer, unrect_step_size) + + # Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size + torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam) +def radam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + decoupled_weight_decay: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + maximize: bool = False, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +): + r"""Functional API that performs RAdam algorithm computation. + + See :class:`~torch.optim.RAdam` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_radam + else: + func = _single_tensor_radam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + decoupled_weight_decay=decoupled_weight_decay, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/engine/cr_boosters/rmsprop.py b/engine/cr_boosters/rmsprop.py new file mode 100644 index 0000000..d3534a7 --- /dev/null +++ b/engine/cr_boosters/rmsprop.py @@ -0,0 +1,510 @@ +# mypy: allow-untyped-defs +r"""Implementation for the RMSprop algorithm.""" +from typing import List, Optional + +import torch +from torch import Tensor +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + +__all__ = ["RMSprop", "rmsprop"] + + +class RMSprop(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0, + momentum: float = 0, + centered=False, + capturable=False, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= momentum: + raise ValueError(f"Invalid momentum value: {momentum}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= alpha: + raise ValueError(f"Invalid alpha value: {alpha}") + + defaults = dict( + lr=lr, + momentum=momentum, + alpha=alpha, + eps=eps, + centered=centered, + weight_decay=weight_decay, + capturable=capturable, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("momentum", 0) + group.setdefault("centered", False) + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + square_avgs, + momentum_buffer_list, + grad_avgs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError("RMSprop does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=_get_scalar_dtype()) + ) + state["square_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["momentum"] > 0: + state["momentum_buffer"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["centered"]: + state["grad_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + square_avgs.append(state["square_avg"]) + state_steps.append(state["step"]) + + if group["momentum"] > 0: + momentum_buffer_list.append(state["momentum_buffer"]) + if group["centered"]: + grad_avgs.append(state["grad_avg"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + square_avgs: List[Tensor] = [] + grad_avgs: List[Tensor] = [] + momentum_buffer_list: List[Tensor] = [] + state_steps: List[Tensor] = [] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + square_avgs, + momentum_buffer_list, + grad_avgs, + state_steps, + ) + + rmsprop( + params_with_grad, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=group["lr"], + alpha=group["alpha"], + eps=group["eps"], + weight_decay=group["weight_decay"], + momentum=group["momentum"], + centered=group["centered"], + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + capturable=group["capturable"], + has_complex=has_complex, + ) + + return loss + + +RMSprop.__doc__ = ( + r"""Implements RMSprop algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)}, + \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\ + &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \: + \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t + \hspace{8mm} \\ + &\hspace{5mm} \tilde{v_t} \leftarrow v_t \\ + &\hspace{5mm}if \: centered \\ + &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\ + &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\ + &\hspace{5mm}if \: \mu > 0 \\ + &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} + + g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\ + &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\ + &\hspace{5mm} else \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - + \gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to + `lecture notes `_ by G. Hinton. + and centered version `Generating Sequences + With Recurrent Neural Networks `_. + The implementation here takes the square root of the gradient average before + adding epsilon (note that TensorFlow interchanges these two operations). The effective + learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma` + is the scheduled learning rate and :math:`v` is the weighted moving average + of the squared gradient. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing constant (default: 0.99) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + + """ +) + + +def _single_tensor_rmsprop( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + step = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grad = grads[i] + grad = grad if not maximize else -grad + square_avg = square_avgs[i] + + step += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + is_complex_param = torch.is_complex(param) + if is_complex_param: + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + square_avg = torch.view_as_real(square_avg) + + square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) + + if centered: + grad_avg = grad_avgs[i] + if is_complex_param: + grad_avg = torch.view_as_real(grad_avg) + grad_avg.lerp_(grad, 1 - alpha) + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_() + else: + avg = square_avg.sqrt() + + if differentiable: + avg = avg.add(eps) + else: + avg = avg.add_(eps) + + if momentum > 0: + buf = momentum_buffer_list[i] + if is_complex_param: + buf = torch.view_as_real(buf) + buf.mul_(momentum).addcdiv_(grad, avg) + param.add_(buf, alpha=-lr) + else: + param.addcdiv_(grad, avg, value=-lr) + + +def _multi_tensor_rmsprop( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps] + ) + for ( + ( + grouped_params, + grouped_grads, + grouped_square_avgs, + grouped_grad_avgs, + grouped_momentum_buffer_list, + grouped_state_steps, + ) + ), _ in grouped_tensors.values(): + if has_complex: + state_and_grads = [grouped_grads, grouped_square_avgs] + if momentum > 0: + state_and_grads.append(grouped_momentum_buffer_list) + if centered: + state_and_grads.append(grouped_grad_avgs) + _view_as_real(grouped_params, *state_and_grads) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + torch._foreach_mul_(grouped_square_avgs, alpha) + torch._foreach_addcmul_( + grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha + ) + + if centered: + torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha) + avg = torch._foreach_addcmul( + grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1 + ) + torch._foreach_sqrt_(avg) + torch._foreach_add_(avg, eps) + else: + avg = torch._foreach_sqrt(grouped_square_avgs) + torch._foreach_add_(avg, eps) + + if momentum > 0: + torch._foreach_mul_(grouped_momentum_buffer_list, momentum) + torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg) + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr) + torch._foreach_add_(grouped_params, momentum_lr) + else: + torch._foreach_add_( + grouped_params, grouped_momentum_buffer_list, alpha=-lr + ) + else: + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + torch._foreach_div_(avg, -lr) + torch._foreach_addcdiv_(grouped_params, grouped_grads, avg) + else: + torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop) +def rmsprop( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, +): + r"""Functional API that performs rmsprop algorithm computation. + + See :class:`~torch.optim.RMSProp` for details. + """ + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_rmsprop + else: + func = _single_tensor_rmsprop + + func( + params, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=centered, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + has_complex=has_complex, + ) diff --git a/engine/cr_boosters/sgd.py b/engine/cr_boosters/sgd.py new file mode 100644 index 0000000..418ef19 --- /dev/null +++ b/engine/cr_boosters/sgd.py @@ -0,0 +1,504 @@ +# mypy: allow-untyped-defs +r"""Implementation for Stochastic Gradient Descent optimizer.""" +from typing import List, Optional + +import torch +from torch import Tensor +from torch.utils._foreach_utils import _get_fused_kernels_supported_devices +from .optimizer import ( + _default_to_fused_or_foreach, + _differentiable_doc, + _foreach_doc, + _fused_doc, + _maximize_doc, + _use_grad_for_differentiable, + DeviceDict, + Optimizer, +) + +__all__ = ["SGD", "sgd"] + + +class SGD(Optimizer): # noqa: D101 + def __init__( + self, + params, + lr: float = 1e-3, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov=False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + fused: Optional[bool] = None, + ): # noqa: D107 + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + fused=fused, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + if fused: + self._step_supports_amp_scaling = True + + fused_supported_devices = _get_fused_kernels_supported_devices() + if not all( + p.device.type in fused_supported_devices and torch.is_floating_point(p) + for pg in self.param_groups + for p in pg["params"] + ): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}." + ) + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("differentiable", False) + group.setdefault("fused", False) + + def _init_group(self, group, params, grads, momentum_buffer_list): + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + params.append(p) + grads.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + if group["momentum"] != 0: + state = self.state[p] + momentum_buffer_list.append(state.get("momentum_buffer")) + + return has_sparse_grad + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params: List[Tensor] = [] + grads: List[Tensor] = [] + momentum_buffer_list: List[Optional[Tensor]] = [] + + has_sparse_grad = self._init_group( + group, params, grads, momentum_buffer_list + ) + + sgd( + params, + grads, + momentum_buffer_list, + weight_decay=group["weight_decay"], + momentum=group["momentum"], + lr=group["lr"], + dampening=group["dampening"], + nesterov=group["nesterov"], + maximize=group["maximize"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + if group["momentum"] != 0: + # update momentum_buffers in state + for p, momentum_buffer in zip(params, momentum_buffer_list): + state = self.state[p] + state["momentum_buffer"] = momentum_buffer + + return loss + + +SGD.__doc__ = ( + r"""Implements stochastic gradient descent (optionally with momentum). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, + \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ + &\hspace{10mm}\textbf{if} \: t > 1 \\ + &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ + &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ + &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ + &\hspace{10mm}\textbf{else} \\[-1.ex] + &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] + &\hspace{5mm}\textbf{else} \\[-1.ex] + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + {_maximize_doc} + {_foreach_doc} + {_differentiable_doc} + {_fused_doc} + """ + + r""" + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et al. and implementations in some other frameworks. + + Considering the specific case of Momentum, the update can be written as + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ + p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, + \end{aligned} + + where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the + parameters, gradient, velocity, and momentum respectively. + + This is in contrast to Sutskever et al. and + other frameworks which employ an update of the form + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ + p_{t+1} & = p_{t} - v_{t+1}. + \end{aligned} + + The Nesterov version is analogously modified. + + Moreover, the initial value of the momentum buffer is set to the + gradient value at the first step. This is in contrast to some other + frameworks that initialize it to all zeros. + + """ +) + + +def sgd( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = False, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, +): + r"""Functional API that performs SGD algorithm computation. + + See :class:`~torch.optim.SGD` for details. + """ + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if foreach is None and fused is None: + # why must we be explicit about an if statement for torch.jit.is_scripting here? + # because JIT can't handle Optionals nor fancy conditionals when scripting + if not torch.jit.is_scripting(): + fused, foreach = _default_to_fused_or_foreach( + params, differentiable=False, use_fused=False + ) + else: + foreach = False + fused = False + if foreach is None: + foreach = False + if fused is None: + fused = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_sgd + elif fused and not torch.jit.is_scripting(): + func = _fused_sgd + else: + func = _single_tensor_sgd + + func( + params, + d_p_list, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _single_tensor_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if momentum != 0: + buf = momentum_buffer_list[i] + + if buf is None: + buf = torch.clone(grad).detach() + momentum_buffer_list[i] = buf + else: + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + param.add_(grad, alpha=-lr) + + +def _multi_tensor_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + if len(params) == 0: + return + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item] + ) + for ( + device_params, + device_grads, + device_momentum_buffer_list, + ), indices in grouped_tensors.values(): + device_has_sparse_grad = has_sparse_grad and any( + grad.is_sparse for grad in device_grads + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + if momentum != 0: + bufs = [] + + all_states_with_momentum_buffer = True + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + all_states_with_momentum_buffer = False + break + else: + bufs.append(device_momentum_buffer_list[i]) + + if all_states_with_momentum_buffer: + torch._foreach_mul_(bufs, momentum) + torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) + else: + bufs = [] + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + buf = device_momentum_buffer_list[i] = momentum_buffer_list[ + indices[i] + ] = torch.clone(device_grads[i]).detach() + else: + buf = device_momentum_buffer_list[i] + buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) + + bufs.append(buf) + + if nesterov: + torch._foreach_add_(device_grads, bufs, alpha=momentum) + else: + device_grads = bufs + + if not device_has_sparse_grad: + # handle internal item() call if lr is a tensor + if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): + grads_x_lr = torch._foreach_mul(device_grads, -lr) + torch._foreach_add_(device_params, grads_x_lr) + else: + torch._foreach_add_(device_params, device_grads, alpha=-lr) + else: + # foreach APIs don't support sparse + for i in range(len(device_params)): + device_params[i].add_(device_grads[i], alpha=-lr) + + +def _fused_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +) -> None: + if not params: + return + if has_sparse_grad: + raise RuntimeError("`_fused_sgd` does not support sparse gradients") + grad_scale_dict: DeviceDict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else {} + ) + found_inf_dict: DeviceDict = ( + {found_inf.device: found_inf} if found_inf is not None else {} + ) + + no_momentum_buffer = momentum == 0 + is_first_step = ( + all(t is None for t in momentum_buffer_list) and not no_momentum_buffer + ) + if is_first_step: + for i, g in enumerate(grads): + momentum_buffer_list[i] = torch.empty_like(g) + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item] + ) + for (device, _), ( + (device_params, device_grads, device_momentum_buffer_list), + _, + ) in grouped_tensors.items(): + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault( + device, grad_scale.to(device) + ) + if found_inf_dict is not None and found_inf is not None: + device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device)) + torch._fused_sgd_( + device_params, + device_grads, + [] if no_momentum_buffer else device_momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, + is_first_step=is_first_step, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) diff --git a/engine/cr_boosters/sparse_adam.py b/engine/cr_boosters/sparse_adam.py new file mode 100644 index 0000000..adb7c17 --- /dev/null +++ b/engine/cr_boosters/sparse_adam.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from typing import List, Tuple + +import torch +from torch import Tensor +from . import _functional as F +from .optimizer import _maximize_doc, Optimizer, ParamsT + +__all__ = ["SparseAdam"] + + +class SparseAdam(Optimizer): + def __init__( + self, + params: ParamsT, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + maximize: bool = False, + ): + if not 0.0 < lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 < eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize) + super().__init__(params, defaults) + + sparse_params = [] + complex_params = [] + for index, param_group in enumerate(self.param_groups): + assert isinstance( + param_group, dict + ), f"param_groups must be a list of dicts, but got {type(param_group)}" + # given param group, convert given params to a list first before iterating + for d_index, d_param in enumerate(param_group["params"]): + if d_param.is_sparse: + sparse_params.append([index, d_index]) + if d_param.is_complex(): + complex_params.append([index, d_index]) + if sparse_params: + raise ValueError( + f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors" + ) + if complex_params: + raise ValueError( + f"Complex params at indices {complex_params}: SparseAdam does not support complex parameters" + ) + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + state_steps: List[int] = [] + beta1, beta2 = group["betas"] + maximize = group.get("maximize", False) + + for p in group["params"]: + if p.grad is not None: + params_with_grad.append(p) + if not p.grad.is_sparse: + raise RuntimeError( + "SparseAdam does not support dense gradients, please consider Adam instead" + ) + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + # update the steps for each param group update + state["step"] += 1 + # record the step after step update + state_steps.append(state["step"]) + + F.sparse_adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + eps=group["eps"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + maximize=maximize, + ) + + return loss + + +SparseAdam.__doc__ = rf"""SparseAdam implements a masked version of the Adam algorithm + suitable for sparse gradients. Currently, due to implementation constraints (explained + below), SparseAdam is only intended for a narrow subset of use cases, specifically + parameters of a dense layout with gradients of a sparse layout. This occurs in a + special case where the module backwards produces grads already in a sparse layout. + One example NN module that behaves as such is ``nn.Embedding(sparse=True)``. + + SparseAdam approximates the Adam algorithm by masking out the parameter and moment + updates corresponding to the zero values in the gradients. Whereas the Adam algorithm + will update the first moment, the second moment, and the parameters based on all values + of the gradients, SparseAdam only updates the moments and parameters corresponding + to the non-zero values of the gradients. + + A simplified way of thinking about the `intended` implementation is as such: + + 1. Create a mask of the non-zero values in the sparse gradients. For example, + if your gradient looks like [0, 5, 0, 0, 9], the mask would be [0, 1, 0, 0, 1]. + 2. Apply this mask over the running moments and do computation on only the + non-zero values. + 3. Apply this mask over the parameters and only apply an update on non-zero values. + + In actuality, we use sparse layout Tensors to optimize this approximation, which means the + more gradients that are masked by not being materialized, the more performant the optimization. + Since we rely on using sparse layout tensors, we infer that any materialized value in the + sparse layout is non-zero and we do NOT actually verify that all values are not zero! + It is important to not conflate a semantically sparse tensor (a tensor where many + of its values are zeros) with a sparse layout tensor (a tensor where ``.is_sparse`` + returns ``True``). The SparseAdam approximation is intended for `semantically` sparse + tensors and the sparse layout is only a implementation detail. A clearer implementation + would be to use MaskedTensors, but those are experimental. + + + .. note:: + + If you suspect your gradients are semantically sparse (but do not have sparse + layout), this variant may not be the best for you. Ideally, you want to avoid + materializing anything that is suspected to be sparse in the first place, since + needing to convert all your grads from dense layout to sparse layout may outweigh + the performance gain. Here, using Adam may be the best alternative, unless you + can easily rig up your module to output sparse grads similar to + ``nn.Embedding(sparse=True)``. If you insist on converting your grads, you can do + so by manually overriding your parameters' ``.grad`` fields with their sparse + equivalents before calling ``.step()``. + + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + {_maximize_doc} + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + + """ diff --git a/engine/cr_boosters/swa_utils.py b/engine/cr_boosters/swa_utils.py new file mode 100644 index 0000000..a17f387 --- /dev/null +++ b/engine/cr_boosters/swa_utils.py @@ -0,0 +1,463 @@ +# mypy: allow-untyped-defs +r"""Implementation for Stochastic Weight Averaging implementation.""" +import itertools +import math +import warnings +from copy import deepcopy +from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.optim.lr_scheduler import _format_param, LRScheduler +from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices +from .optimizer import Optimizer + +__all__ = [ + "AveragedModel", + "update_bn", + "SWALR", + "get_ema_multi_avg_fn", + "get_swa_multi_avg_fn", + "get_ema_avg_fn", + "get_swa_avg_fn", +] + +from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype + +PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]] + + +def get_ema_multi_avg_fn(decay=0.999): + """Get the function applying exponential moving average (EMA) across multiple params.""" + + @torch.no_grad() + def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): + # foreach lerp only handles float and complex + if torch.is_floating_point(ema_param_list[0]) or torch.is_complex( + ema_param_list[0] + ): + torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay) + else: + for p_ema, p_model in zip(ema_param_list, current_param_list): + p_ema.copy_(p_ema * decay + p_model * (1 - decay)) + + return ema_update + + +def get_swa_multi_avg_fn(): + """Get the function applying stochastic weight average (SWA) across multiple params.""" + + @torch.no_grad() + def swa_update( + averaged_param_list: PARAM_LIST, + current_param_list: PARAM_LIST, + num_averaged: Union[Tensor, int], + ): + # foreach lerp only handles float and complex + if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( + averaged_param_list[0] + ): + torch._foreach_lerp_( + averaged_param_list, current_param_list, 1 / (num_averaged + 1) + ) + else: + diffs = torch._foreach_sub(current_param_list, averaged_param_list) + if isinstance(num_averaged, Tensor): + torch._foreach_addcdiv_( + averaged_param_list, + diffs, + [num_averaged + 1] * len(averaged_param_list), + ) + else: + torch._foreach_add_( + averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1) + ) + + return swa_update + + +def get_ema_avg_fn(decay=0.999): + """Get the function applying exponential moving average (EMA) across a single param.""" + + @torch.no_grad() + def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged): + return decay * ema_param + (1 - decay) * current_param + + return ema_update + + +def get_swa_avg_fn(): + """Get the function applying stochastic weight average (SWA) across a single param.""" + + @torch.no_grad() + def swa_update( + averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int] + ): + return averaged_param + (current_param - averaged_param) / (num_averaged + 1) + + return swa_update + + +class AveragedModel(Module): + r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). + + Stochastic Weight Averaging was proposed in `Averaging Weights Leads to + Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii + Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson + (UAI 2018). + + Exponential Moving Average is a variation of `Polyak averaging`_, + but using exponential weights instead of equal weights across iterations. + + AveragedModel class creates a copy of the provided module :attr:`model` + on the device :attr:`device` and allows to compute running averages of the + parameters of the :attr:`model`. + + Args: + model (torch.nn.Module): model to use with SWA/EMA + device (torch.device, optional): if provided, the averaged model will be + stored on the :attr:`device` + avg_fn (function, optional): the averaging function used to update + parameters; the function must take in the current value of the + :class:`AveragedModel` parameter, the current value of :attr:`model` + parameter, and the number of models already averaged; if None, + an equally weighted average is used (default: None) + multi_avg_fn (function, optional): the averaging function used to update + parameters inplace; the function must take in the current values of the + :class:`AveragedModel` parameters as a list, the current values of :attr:`model` + parameters as a list, and the number of models already averaged; if None, + an equally weighted average is used (default: None) + use_buffers (bool): if ``True``, it will compute running averages for + both the parameters and the buffers of the model. (default: ``False``) + + Example: + >>> # xdoctest: +SKIP("undefined variables") + >>> loader, optimizer, model, loss_fn = ... + >>> swa_model = torch.optim.swa_utils.AveragedModel(model) + >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, + >>> T_max=300) + >>> swa_start = 160 + >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) + >>> for i in range(300): + >>> for input, target in loader: + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + >>> if i > swa_start: + >>> swa_model.update_parameters(model) + >>> swa_scheduler.step() + >>> else: + >>> scheduler.step() + >>> + >>> # Update bn statistics for the swa_model at the end + >>> torch.optim.swa_utils.update_bn(loader, swa_model) + + You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters. + If no averaging function is provided, the default is to compute + equally-weighted average of the weights (SWA). + + Example: + >>> # xdoctest: +SKIP("undefined variables") + >>> # Compute exponential moving averages of the weights and buffers + >>> ema_model = torch.optim.swa_utils.AveragedModel(model, + >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True) + + .. note:: + When using SWA/EMA with models containing Batch Normalization you may + need to update the activation statistics for Batch Normalization. + This can be done either by using the :meth:`torch.optim.swa_utils.update_bn` + or by setting :attr:`use_buffers` to `True`. The first approach updates the + statistics in a post-training step by passing data through the model. The + second does it during the parameter update phase by averaging all buffers. + Empirical evidence has shown that updating the statistics in normalization + layers increases accuracy, but you may wish to empirically test which + approach yields the best results in your problem. + + .. note:: + :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model. + + .. note:: + When :meth:`update_parameters` is called for the first time (i.e. + :attr:`n_averaged` is `0`) the parameters of `model` are copied + to the parameters of :class:`AveragedModel`. For every subsequent + call of :meth:`update_parameters` the function `avg_fn` is used + to update the parameters. + + .. _Averaging Weights Leads to Wider Optima and Better Generalization: + https://arxiv.org/abs/1803.05407 + .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should + Average: + https://arxiv.org/abs/1806.05594 + .. _SWALP: Stochastic Weight Averaging in Low-Precision Training: + https://arxiv.org/abs/1904.11943 + .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That + Generalizes Well: + https://arxiv.org/abs/2001.02312 + .. _Polyak averaging: + https://paperswithcode.com/method/polyak-averaging + """ + + n_averaged: Tensor + + def __init__( + self, + model: Module, + device: Optional[Union[int, torch.device]] = None, + avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None, + multi_avg_fn: Optional[ + Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] + ] = None, + use_buffers=False, + ): # noqa: D107 + super().__init__() + assert ( + avg_fn is None or multi_avg_fn is None + ), "Only one of avg_fn and multi_avg_fn should be provided" + self.module = deepcopy(model) + if device is not None: + self.module = self.module.to(device) + self.register_buffer( + "n_averaged", torch.tensor(0, dtype=torch.long, device=device) + ) + self.avg_fn = avg_fn + self.multi_avg_fn = multi_avg_fn + self.use_buffers = use_buffers + + def forward(self, *args, **kwargs): + """Forward pass.""" + return self.module(*args, **kwargs) + + def update_parameters(self, model: Module): + """Update model parameters.""" + self_param = ( + itertools.chain(self.module.parameters(), self.module.buffers()) + if self.use_buffers + else self.parameters() + ) + model_param = ( + itertools.chain(model.parameters(), model.buffers()) + if self.use_buffers + else model.parameters() + ) + self_param_detached: List[Optional[Tensor]] = [] + model_param_detached: List[Optional[Tensor]] = [] + for p_averaged, p_model in zip(self_param, model_param): + p_model_ = p_model.detach().to(p_averaged.device) + self_param_detached.append(p_averaged.detach()) + model_param_detached.append(p_model_) + if self.n_averaged == 0: + p_averaged.detach().copy_(p_model_) + + if self.n_averaged > 0: + if self.multi_avg_fn is not None or self.avg_fn is None: + grouped_tensors = _group_tensors_by_device_and_dtype( + [self_param_detached, model_param_detached] + ) + for (device, _), ( + [self_params, model_params], + _, + ) in grouped_tensors.items(): + if self.multi_avg_fn: + self.multi_avg_fn( + self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type] + ) + elif ( + device is not None + and device.type in _get_foreach_kernels_supported_devices() + ): + multi_avg_fn = get_swa_multi_avg_fn() + multi_avg_fn( + self_params, model_params, self.n_averaged.to(device) + ) + else: + avg_fn = get_swa_avg_fn() + n_averaged = self.n_averaged.to(device) + for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] + p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) + else: + for p_averaged, p_model in zip( # type: ignore[assignment] + self_param_detached, model_param_detached + ): + n_averaged = self.n_averaged.to(p_averaged.device) + p_averaged.detach().copy_( + self.avg_fn(p_averaged.detach(), p_model, n_averaged) + ) + + if not self.use_buffers: + # If not apply running averages to the buffers, + # keep the buffers in sync with the source model. + for b_swa, b_model in zip(self.module.buffers(), model.buffers()): + b_swa.detach().copy_(b_model.detach().to(b_swa.device)) + self.n_averaged += 1 + + +@torch.no_grad() +def update_bn( + loader: Iterable[Any], + model: Module, + device: Optional[Union[int, torch.device]] = None, +): + r"""Update BatchNorm running_mean, running_var buffers in the model. + + It performs one pass over data in `loader` to estimate the activation + statistics for BatchNorm layers in the model. + Args: + loader (torch.utils.data.DataLoader): dataset loader to compute the + activation statistics on. Each data batch should be either a + tensor, or a list/tuple whose first element is a tensor + containing data. + model (torch.nn.Module): model for which we seek to update BatchNorm + statistics. + device (torch.device, optional): If set, data will be transferred to + :attr:`device` before being passed into :attr:`model`. + + Example: + >>> # xdoctest: +SKIP("Undefined variables") + >>> loader, model = ... + >>> torch.optim.swa_utils.update_bn(loader, model) + + .. note:: + The `update_bn` utility assumes that each data batch in :attr:`loader` + is either a tensor or a list or tuple of tensors; in the latter case it + is assumed that :meth:`model.forward()` should be called on the first + element of the list or tuple corresponding to the data batch. + """ + momenta = {} + for module in model.modules(): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module.reset_running_stats() + momenta[module] = module.momentum + + if not momenta: + return + + was_training = model.training + model.train() + for module in momenta.keys(): + module.momentum = None + + for input in loader: + if isinstance(input, (list, tuple)): + input = input[0] + if device is not None: + input = input.to(device) + + model(input) + + for bn_module in momenta.keys(): + bn_module.momentum = momenta[bn_module] + model.train(was_training) + + +class SWALR(LRScheduler): + r"""Anneals the learning rate in each parameter group to a fixed value. + + This learning rate scheduler is meant to be used with Stochastic Weight + Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). + + Args: + optimizer (torch.optim.Optimizer): wrapped optimizer + swa_lrs (float or list): the learning rate value for all param groups + together or separately for each group. + annealing_epochs (int): number of epochs in the annealing phase + (default: 10) + annealing_strategy (str): "cos" or "linear"; specifies the annealing + strategy: "cos" for cosine annealing, "linear" for linear annealing + (default: "cos") + last_epoch (int): the index of the last epoch (default: -1) + + The :class:`SWALR` scheduler can be used together with other + schedulers to switch to a constant learning rate late in the training + as in the example below. + + Example: + >>> # xdoctest: +SKIP("Undefined variables") + >>> loader, optimizer, model = ... + >>> lr_lambda = lambda epoch: 0.9 + >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, + >>> lr_lambda=lr_lambda) + >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, + >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) + >>> swa_start = 160 + >>> for i in range(300): + >>> for input, target in loader: + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + >>> if i > swa_start: + >>> swa_scheduler.step() + >>> else: + >>> scheduler.step() + + .. _Averaging Weights Leads to Wider Optima and Better Generalization: + https://arxiv.org/abs/1803.05407 + """ + + def __init__( + self, + optimizer: Optimizer, + swa_lr: float, + anneal_epochs=10, + anneal_strategy: Literal["cos", "linear"] = "cos", + last_epoch=-1, + ): # noqa: D107 + swa_lrs = _format_param("swa_lr", optimizer, swa_lr) + for swa_lr, group in zip(swa_lrs, optimizer.param_groups): + group["swa_lr"] = swa_lr + if anneal_strategy not in ["cos", "linear"]: + raise ValueError( + "anneal_strategy must by one of 'cos' or 'linear', " + f"instead got {anneal_strategy}" + ) + elif anneal_strategy == "cos": + self.anneal_func = self._cosine_anneal + elif anneal_strategy == "linear": + self.anneal_func = self._linear_anneal + if not isinstance(anneal_epochs, int) or anneal_epochs < 0: + raise ValueError( + f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}" + ) + self.anneal_epochs = anneal_epochs + super().__init__(optimizer, last_epoch) + + @staticmethod + def _linear_anneal(t): + return t + + @staticmethod + def _cosine_anneal(t): + return (1 - math.cos(math.pi * t)) / 2 + + @staticmethod + def _get_initial_lr(lr, swa_lr, alpha): + if alpha == 1: + return swa_lr + return (lr - alpha * swa_lr) / (1 - alpha) + + def get_lr(self): + """Get learning rate.""" + # `_get_lr_called_within_step` is only available `_enable_get_lr_call`, + # so we ignore the type error here. See `LRScheduler.step()` for more details. + if not self._get_lr_called_within_step: # type: ignore[attr-defined] + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + ) + # Set in `LRScheduler._initial_step()` + step = self._step_count - 1 # type: ignore[attr-defined] + if self.anneal_epochs == 0: + step = max(1, step) + prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) + prev_alpha = self.anneal_func(prev_t) + prev_lrs = [ + self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) + for group in self.optimizer.param_groups + ] + t = max(0, min(1, step / max(1, self.anneal_epochs))) + alpha = self.anneal_func(t) + return [ + group["swa_lr"] * alpha + lr * (1 - alpha) + for group, lr in zip(self.optimizer.param_groups, prev_lrs) + ] diff --git a/engine/cr_utility/dataloader.py b/engine/cr_utility/dataloader.py new file mode 100644 index 0000000..a9a4d5c --- /dev/null +++ b/engine/cr_utility/dataloader.py @@ -0,0 +1,1604 @@ +# mypy: allow-untyped-defs +r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter. + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" + +import functools +import itertools +import logging +import multiprocessing as python_multiprocessing +import os +import queue +import threading +import warnings +from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar, Union + +import torch +import torch.distributed as dist +import torch.utils.data.graph_settings +from torch._utils import ExceptionWrapper +from torch.utils.data import _utils +from torch.utils.data.datapipes.datapipe import ( + _IterDataPipeSerializationWrapper, + _MapDataPipeSerializationWrapper, + IterDataPipe, + MapDataPipe, +) +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +__all__ = [ + "DataLoader", + "get_worker_info", + "default_collate", + "default_convert", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[List[_T]], Any] + + +# These functions used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate: _collate_fn_t = _utils.collate.default_collate +default_convert = _utils.collate.default_convert + +get_worker_info = _utils.worker.get_worker_info + +logger = logging.getLogger(__name__) + + +class _DatasetKind: + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + else: + return _utils.fetch._IterableDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + + Used as sampler for :class:`~torch.utils.data.IterableDataset`. + """ + + def __iter__(self): + while True: + yield None + + +def _get_distributed_settings(): + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + else: + return 1, 0 + + +def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): + global_worker_id = worker_id + info = torch.utils.data.get_worker_info() + assert info is not None + total_workers = info.num_workers + datapipe = info.dataset + assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) + # To distribute elements across distributed process evenly, we should shard data on distributed + # processes first then shard on worker processes + total_workers *= world_size + global_worker_id = global_worker_id * world_size + rank_id + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + datapipe, total_workers, global_worker_id + ) + if worker_init_fn is not None: + worker_init_fn(worker_id) + + +def _share_dist_seed(generator, pg): + _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator) + if isinstance(pg, dist.ProcessGroup): + dist.broadcast(_shared_seed, src=0, group=pg) + return _shared_seed.item() + + +class DataLoader(Generic[_T_co]): + r""" + Data loader combines a dataset and a sampler, and provides an iterable over the given dataset. + + The :class:`~torch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler or Iterable, optional): defines the strategy to draw + samples from the dataset. Can be any ``Iterable`` with ``__len__`` + implemented. If specified, :attr:`shuffle` must not be specified. + batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but + returns a batch of indices at a time. Mutually exclusive with + :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, + and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (Callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (Callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If + ``None``, the default `multiprocessing context`_ of your operating system will + be used. (default: ``None``) + generator (torch.Generator, optional): If not ``None``, this RNG will be used + by RandomSampler to generate random indexes and multiprocessing to generate + ``base_seed`` for workers. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of batches loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers batches prefetched across all workers. (default value depends + on the set value for num_workers. If value of num_workers=0 default is ``None``. + Otherwise, if value of ``num_workers > 0`` default is ``2``). + persistent_workers (bool, optional): If ``True``, the data loader will not shut down + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is + ``True``. + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, + it instead returns an estimate based on ``len(dataset) / batch_size``, with proper + rounding depending on :attr:`drop_last`, regardless of multi-process loading + configurations. This represents the best guess PyTorch can make because PyTorch + trusts user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. + + However, if sharding results in multiple workers having incomplete last batches, + this estimate can still be inaccurate, because (1) an otherwise complete batch can + be broken into multiple ones and (2) more than one batch worth of samples can be + dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such + cases in general. + + See `Dataset Types`_ for more details on these two types of datasets and how + :class:`~torch.utils.data.IterableDataset` interacts with + `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. + + .. _multiprocessing context: + https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + """ + + dataset: Dataset[_T_co] + batch_size: Optional[int] + num_workers: int + pin_memory: bool + drop_last: bool + timeout: float + sampler: Union[Sampler, Iterable] + pin_memory_device: str + prefetch_factor: Optional[int] + _iterator: Optional["_BaseDataLoaderIter"] + __initialized = False + + def __init__( + self, + dataset: Dataset[_T_co], + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = None, + sampler: Union[Sampler, Iterable, None] = None, + batch_sampler: Union[Sampler[List], Iterable[List], None] = None, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: Optional[_worker_init_fn_t] = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "", + ): + torch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." + ) + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + if num_workers == 0 and prefetch_factor is not None: + raise ValueError( + "prefetch_factor option could only be specified in multiprocessing." + "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None." + ) + elif num_workers > 0 and prefetch_factor is None: + prefetch_factor = 2 + elif prefetch_factor is not None and prefetch_factor < 0: + raise ValueError("prefetch_factor option should be non-negative") + + if persistent_workers and num_workers == 0: + raise ValueError("persistent_workers option needs num_workers > 0") + + self.dataset = dataset + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.pin_memory = pin_memory + self.pin_memory_device = pin_memory_device + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler + if isinstance(self.dataset, IterDataPipe): + self.dataset = _IterDataPipeSerializationWrapper(self.dataset) + elif isinstance(self.dataset, MapDataPipe): + self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and IterableDataset ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if isinstance(dataset, IterDataPipe): + if shuffle is not None: + dataset = torch.utils.data.graph_settings.apply_shuffle_settings( + dataset, shuffle=shuffle + ) + # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. + elif shuffle not in {False, None}: + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}" + ) + + if sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}" + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + f"batch_sampler option, but got batch_sampler={batch_sampler}" + ) + else: + shuffle = bool(shuffle) + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with " "shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] + else: + sampler = SequentialSampler(dataset) # type: ignore[arg-type] + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + self.generator = generator + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.persistent_workers = persistent_workers + + self.__initialized = True + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) + + self._iterator = None + + self.check_worker_number_rationality() + + torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] + + def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIter(self) + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + if multiprocessing_context is not None: + if self.num_workers > 0: + if isinstance(multiprocessing_context, str): + valid_start_methods = torch.multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + "multiprocessing_context option " + f"should specify a valid start method in {valid_start_methods!r}, but got " + f"multiprocessing_context={multiprocessing_context!r}" + ) + multiprocessing_context = torch.multiprocessing.get_context( + multiprocessing_context + ) + + if not isinstance( + multiprocessing_context, python_multiprocessing.context.BaseContext + ): + raise TypeError( + "multiprocessing_context option should be a valid context " + "object or a string specifying the start method, but got " + f"multiprocessing_context={multiprocessing_context}" + ) + else: + raise ValueError( + "multiprocessing_context can only be used with " + "multi-process loading (num_workers > 0), but got " + f"num_workers={self.num_workers}" + ) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val): + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + "persistent_workers", + ): + raise ValueError( + f"{attr} attribute should not be set after {self.__class__.__name__} is initialized" + ) + + super().__setattr__(attr, val) + + # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up + # since '_BaseDataLoaderIter' references 'DataLoader'. + def __iter__(self) -> "_BaseDataLoaderIter": + # When using a single worker the returned iterator should be + # created everytime to avoid resetting its state + # However, in the case of a multiple workers iterator + # the iterator is only created once in the lifetime of the + # DataLoader object so that workers can be reused + if self.persistent_workers and self.num_workers > 0: + if self._iterator is None: + self._iterator = self._get_iterator() + else: + self._iterator._reset(self) + return self._iterator + else: + return self._get_iterator() + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self) -> int: + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + + # Cannot statically verify that dataset is Sized + length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type] + if ( + self.batch_size is not None + ): # IterableDataset doesn't allow custom sampler or batch_sampler + from math import ceil + + if self.drop_last: + length = length // self.batch_size + else: + length = ceil(length / self.batch_size) + return length + else: + return len(self._index_sampler) + + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + suggested_max_worker_msg = ( + ( + ( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create." + ).format( + num_worker_suggest, + ( + "" + if cpuset_checked + else " (`cpuset` is not taken into account)" + ), + ) + ) + if num_worker_suggest is not None + else ( + "DataLoader is not able to compute a suggested max number of worker in current system." + ) + ) + + warn_msg = ( + f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary." + ) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, "sched_getaffinity"): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satisfy mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + + +class _BaseDataLoaderIter: + def __init__(self, loader: DataLoader) -> None: + self._dataset = loader.dataset + self._shared_seed = None + self._pg = None + if isinstance(self._dataset, IterDataPipe): + if dist.is_available() and dist.is_initialized(): + self._pg = dist.new_group(backend="gloo") + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + ws, rank = _get_distributed_settings() + self._world_size = ws + self._rank = rank + # for other backends, pin_memory_device need to set. if not set + # default behaviour is CUDA device. if pin_memory_device is selected + # and pin_memory is not set, the default behaviour false. + if len(loader.pin_memory_device) == 0: + self._pin_memory = loader.pin_memory and torch.cuda.is_available() + self._pin_memory_device = None + else: + if not loader.pin_memory: + warn_msg = ( + "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used" + "please set pin_memory to true, if you need to use the device pin memory" + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory + self._pin_memory_device = loader.pin_memory_device + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = ( + torch.empty((), dtype=torch.int64) + .random_(generator=loader.generator) + .item() + ) + self._persistent_workers = loader.persistent_workers + self._num_yielded = 0 + self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" + + def __iter__(self) -> "_BaseDataLoaderIter": + return self + + def _reset(self, loader, first_iter=False): + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + if isinstance(self._dataset, IterDataPipe): + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self): + raise NotImplementedError + + def __next__(self) -> Any: + with torch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + # TODO(https://github.com/pytorch/pytorch/issues/76750) + self._reset() # type: ignore[call-arg] + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}" + f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. " + ) + if self._num_workers > 0: + warn_msg += ( + "For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." + ) + warnings.warn(warn_msg) + return data + + def __len__(self) -> int: + return len(self._index_sampler) + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super().__init__(loader) + assert self._timeout == 0 + assert self._num_workers == 0 + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Taking care of distributed sharding + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + self._dataset, self._world_size, self._rank + ) + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) + return data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler.""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may already be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. + # + # Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: + # + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. + # + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # *exits*. + # + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separete events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `worker_result_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader): + super().__init__(loader) + + self._prefetch_factor = loader.prefetch_factor + + assert self._num_workers > 0 + assert self._prefetch_factor > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = torch.multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + self._worker_init_fn = functools.partial( + _sharding_worker_init_fn, + self._worker_init_fn, + self._world_size, + self._rank, + ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_utils.worker._worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + if self._pin_memory_device == "xpu": + current_device = torch.xpu.current_device() # type: ignore[attr-defined] + elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr( + torch, torch._C._get_privateuse1_backend_name() + ) + current_device = custom_device_mod.current_device() + else: + current_device = torch.cuda.current_device() # choose cuda for default + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device, + self._pin_memory_thread_done_event, + self._pin_memory_device, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue # type: ignore[assignment] + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def _reset(self, loader, first_iter=False): + super()._reset(loader, first_iter) + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + # Not that this indicates that a worker still has work to do *for this epoch*. + # It does not mean that a worker is dead. In case of `_persistent_workers`, + # the worker will be reset to available in the next epoch. + self._workers_status = [True for i in range(self._num_workers)] + # Reset the worker queue cycle so it resumes next epoch at worker 0 + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + # We resume the prefetching in case it was enabled + if not first_iter: + for idx in range(self._num_workers): + self._index_queues[idx].put( + _utils.worker._ResumeIteration(self._shared_seed) + ) + resume_iteration_cnt = self._num_workers + while resume_iteration_cnt > 0: + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None + resume_iteration_cnt -= 1 + # prime the prefetch loop + for _ in range(self._prefetch_factor * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._mark_worker_as_unavailable(worker_id) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w.pid) for w in failed_workers) + raise RuntimeError( + f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" + ) from e + if isinstance(e, queue.Empty): + return (False, None) + + import errno + import tempfile + + try: + # Raise an exception if we are this close to the FDs limit. + # Apparently, trying to open only one file is not a sufficient + # test. + # See NOTE [ DataLoader on Linux and open files limit ] + fds_limit_margin = 10 + fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] + except OSError as e: + if e.errno == errno.EMFILE: + raise RuntimeError( + "Too many open files. Communication with the" + " workers is no longer possible. Please increase the" + " limit using `ulimit -n` in the shell or change the" + " sharing strategy by calling" + " `torch.multiprocessing.set_sharing_strategy('file_system')`" + " at the beginning of your code" + ) from None + raise + + # NOTE [ DataLoader on Linux and open files limit ] + # + # On Linux when DataLoader is used with multiprocessing we pass the data between + # the root process and the workers through SHM files. We remove those files from + # the filesystem as soon as they are created and keep them alive by + # passing around their file descriptors through AF_UNIX sockets. (See + # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in + # the wiki (https://github.com/pytorch/pytorch/wiki).) + # + # This sometimes leads us to exceeding the open files limit. When that happens, + # and the offending file descriptor is coming over a socket, the `socket` Python + # package silently strips the file descriptor from the message, setting only the + # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that + # it _indicates that some control data were discarded due to lack of space in + # the buffer for ancillary data_). This might reflect the C implementation of + # AF_UNIX sockets. + # + # This behaviour can be reproduced with the script and instructions at the + # bottom of this note. + # + # When that happens, the standard Python `multiprocessing` (and not + # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` + # + # Sometimes, instead of the FD being stripped, you may get an `OSError: + # Too many open files`, both in the script below and in DataLoader. However, + # this is rare and seems to be nondeterministic. + # + # + # #!/usr/bin/env python3 + # import sys + # import socket + # import os + # import array + # import shutil + # import socket + # + # + # if len(sys.argv) != 4: + # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") + # sys.exit(1) + # + # if __name__ == '__main__': + # dirname = sys.argv[1] + # sock_path = dirname + "/sock" + # iterations = int(sys.argv[2]) + # def dummy_path(i): + # return dirname + "/" + str(i) + ".dummy" + # + # + # if sys.argv[3] == 'send': + # while not os.path.exists(sock_path): + # pass + # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # client.connect(sock_path) + # for i in range(iterations): + # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) + # ancdata = array.array('i', [fd]) + # msg = bytes([i % 256]) + # print("Sending fd ", fd, " (iteration #", i, ")") + # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) + # + # + # else: + # assert sys.argv[3] == 'recv' + # + # if os.path.exists(dirname): + # raise Exception("Directory exists") + # + # os.mkdir(dirname) + # + # print("Opening socket...") + # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # server.bind(sock_path) + # + # print("Listening...") + # for i in range(iterations): + # a = array.array('i') + # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) + # assert(len(ancdata) == 1) + # cmsg_level, cmsg_type, cmsg_data = ancdata[0] + # a.frombytes(cmsg_data) + # print("Received fd ", a[0], " (iteration #", i, ")") + # + # shutil.rmtree(dirname) + # + # Steps to reproduce: + # + # 1. Run two shells and set lower file descriptor limit in the receiving one: + # (shell1) ulimit -n 1020 + # (shell2) ulimit -n 1022 + # + # 2. Run the script above with the `recv` option in the first shell + # (shell1) ./test_socket.py sock_tmp 1017 recv + # + # 3. Run the script with the `send` option in the second shell: + # (shell2) ./test_socket.py sock_tmp 1017 send + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError( + f"DataLoader timed out after {self._timeout} seconds" + ) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError("Pin memory thread exited unexpectedly") + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info[self._rcvd_idx] + worker_id = info[0] + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + if not self._persistent_workers: + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + data = self._task_info.pop(self._rcvd_idx)[1] + return self._process_data(data) + + assert not self._shutdown and self._tasks_outstanding > 0 + idx, data = self._get_data() + self._tasks_outstanding -= 1 + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + if self._persistent_workers: + self._workers_status[data.worker_id] = False + else: + self._mark_worker_as_unavailable(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + # store out-of-order samples + self._task_info[idx] += (data,) + else: + del self._task_info[idx] + return self._process_data(data) + + def _try_put_index(self): + assert self._tasks_outstanding < self._prefetch_factor * self._num_workers + + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined] + self._task_info[self._send_idx] = (worker_queue_idx,) + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data): + self._rcvd_idx += 1 + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _mark_worker_as_unavailable(self, worker_id, shutdown=False): + # Mark a worker as having finished its work e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + assert self._workers_status[worker_id] or ( + self._persistent_workers and shutdown + ) + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joinning is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + + self._workers_status[worker_id] = False + + assert self._workers_done_event.is_set() == shutdown + + def _shutdown_workers(self): + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + if ( + _utils is None + or _utils.python_exit_status is True + or _utils.python_exit_status is None + ): + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, "_pin_memory_thread"): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + # If we are using workers_status with persistent_workers + # we have to shut it down because the worker is paused + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: + # We should be able to join here, but in case anything went + # wrong, we set a timeout and if the workers fail to join, + # they are killed in the `finally` block. + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + # FIXME: Unfortunately, for Windows, we are missing a worker + # error detection mechanism here in this function, as it + # doesn't provide a SIGCHLD handler. + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: + if w.is_alive(): + # Existing mechanisms try to make the workers exit + # peacefully, but in case that we unfortunately reach + # here, which we shouldn't, (e.g., pytorch/pytorch#39570), + # we kill the worker. + w.terminate() + + # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` + @staticmethod + def _clean_up_worker(w): + try: + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + finally: + if w.is_alive(): + w.terminate() + + def __del__(self): + self._shutdown_workers() \ No newline at end of file diff --git a/engine/cr_utility/dataset.py b/engine/cr_utility/dataset.py new file mode 100644 index 0000000..9c2f7da --- /dev/null +++ b/engine/cr_utility/dataset.py @@ -0,0 +1,489 @@ +# mypy: allow-untyped-defs +import bisect +import itertools +import math +import warnings +from typing import ( + cast, + Dict, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) +from typing_extensions import deprecated + +# No 'default_generator' in torch/__init__.pyi +from torch import default_generator, Generator, randperm, Tensor + + +__all__ = [ + "Dataset", + "IterableDataset", + "TensorDataset", + "StackDataset", + "ConcatDataset", + "ChainDataset", + "Subset", + "random_split", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_T_dict = Dict[str, _T_co] +_T_tuple = Tuple[_T_co, ...] +_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict) + + +class Dataset(Generic[_T_co]): + r"""An abstract class representing a :class:`Dataset`. + + All datasets that represent a map from keys to data samples should subclass + it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given key. Subclasses could also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~torch.utils.data.Sampler` implementations and the default options + of :class:`~torch.utils.data.DataLoader`. Subclasses could also + optionally implement :meth:`__getitems__`, for speedup batched samples + loading. This method accepts list of indices of samples of batch and returns + list of samples. + + .. note:: + :class:`~torch.utils.data.DataLoader` by default constructs an index + sampler that yields integral indices. To make it work with a map-style + dataset with non-integral indices/keys, a custom sampler must be provided. + """ + + def __getitem__(self, index) -> _T_co: + raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") + + # def __getitems__(self, indices: List) -> List[_T_co]: + # Not implemented to prevent false-positives in fetcher check in + # torch.utils.data._utils.fetch._MapDatasetFetcher + + def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]": + return ConcatDataset([self, other]) + + # No `def __len__(self)` default? + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # in pytorch/torch/utils/data/sampler.py + + +class IterableDataset(Dataset[_T_co], Iterable[_T_co]): + r"""An iterable Dataset. + + All datasets that represent an iterable of data samples should subclass it. + Such form of datasets is particularly useful when data come from a stream. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this dataset. + + When a subclass is used with :class:`~torch.utils.data.DataLoader`, each + item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the dataset object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker + process, returns information about the worker. It can be used in either the + dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's + :attr:`worker_init_fn` option to modify each copy's behavior. + + Example 1: splitting workload across all workers in :meth:`__iter__`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> # xdoctest: +SKIP("Fails on MacOS12") + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... worker_info = torch.utils.data.get_worker_info() + ... if worker_info is None: # single-process data loading, return the full iterator + ... iter_start = self.start + ... iter_end = self.end + ... else: # in a worker process + ... # split workload + ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... iter_start = self.start + worker_id * per_worker + ... iter_end = min(iter_start + per_worker, self.end) + ... return iter(range(iter_start, iter_end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + [tensor([3]), tensor([4]), tensor([5]), tensor([6])] + + >>> # xdoctest: +REQUIRES(POSIX) + >>> # Mult-process loading with two worker processes + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + >>> # With even more workers + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + >>> + >>> # Directly doing multi-process loading yields duplicate data + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) + [3, 3, 4, 4, 5, 5, 6, 6] + + >>> # Define a `worker_init_fn` that configures each dataset copy differently + >>> def worker_init_fn(worker_id): + ... worker_info = torch.utils.data.get_worker_info() + ... dataset = worker_info.dataset # the dataset copy in this worker process + ... overall_start = dataset.start + ... overall_end = dataset.end + ... # configure the dataset to only process the split workload + ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... dataset.start = overall_start + worker_id * per_worker + ... dataset.end = min(dataset.start + per_worker, overall_end) + ... + + >>> # Mult-process loading with the custom `worker_init_fn` + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) + [3, 5, 4, 6] + + >>> # With even more workers + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) + [3, 4, 5, 6] + """ + + def __add__(self, other: Dataset[_T_co]): + return ChainDataset([self, other]) + + # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + + +class TensorDataset(Dataset[Tuple[Tensor, ...]]): + r"""Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Args: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + + tensors: Tuple[Tensor, ...] + + def __init__(self, *tensors: Tensor) -> None: + assert all( + tensors[0].size(0) == tensor.size(0) for tensor in tensors + ), "Size mismatch between tensors" + self.tensors = tensors + + def __getitem__(self, index): + return tuple(tensor[index] for tensor in self.tensors) + + def __len__(self): + return self.tensors[0].size(0) + + +class StackDataset(Dataset[_T_stack]): + r"""Dataset as a stacking of multiple datasets. + + This class is useful to assemble different parts of complex input data, given as datasets. + + Example: + >>> # xdoctest: +SKIP + >>> images = ImageDataset() + >>> texts = TextDataset() + >>> tuple_stack = StackDataset(images, texts) + >>> tuple_stack[0] == (images[0], texts[0]) + >>> dict_stack = StackDataset(image=images, text=texts) + >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} + + Args: + *args (Dataset): Datasets for stacking returned as tuple. + **kwargs (Dataset): Datasets for stacking returned as dict. + """ + + datasets: Union[tuple, dict] + + def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: + if args: + if kwargs: + raise ValueError( + "Supported either ``tuple``- (via ``args``) or" + "``dict``- (via ``kwargs``) like input/output, but both types are given." + ) + self._length = len(args[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = args + elif kwargs: + tmp = list(kwargs.values()) + self._length = len(tmp[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = kwargs + else: + raise ValueError("At least one dataset should be passed") + + def __getitem__(self, index): + if isinstance(self.datasets, dict): + return {k: dataset[index] for k, dataset in self.datasets.items()} + return tuple(dataset[index] for dataset in self.datasets) + + def __getitems__(self, indices: list): + # add batched sampling support when parent datasets supports it. + if isinstance(self.datasets, dict): + dict_batch: List[_T_dict] = [{} for _ in indices] + for k, dataset in self.datasets.items(): + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, d_sample in zip(items, dict_batch): + d_sample[k] = data + else: + for idx, d_sample in zip(indices, dict_batch): + d_sample[k] = dataset[idx] + return dict_batch + + # tuple data + list_batch: List[list] = [[] for _ in indices] + for dataset in self.datasets: + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, t_sample in zip(items, list_batch): + t_sample.append(data) + else: + for idx, t_sample in zip(indices, list_batch): + t_sample.append(dataset[idx]) + tuple_batch: List[_T_tuple] = [tuple(sample) for sample in list_batch] + return tuple_batch + + def __len__(self): + return self._length + + +class ConcatDataset(Dataset[_T_co]): + r"""Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + """ + + datasets: List[Dataset[_T_co]] + cumulative_sizes: List[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = list(datasets) + assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + @deprecated( + "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", + category=FutureWarning, + ) + def cummulative_sizes(self): + return self.cumulative_sizes + + +class ChainDataset(IterableDataset): + r"""Dataset for chaining multiple :class:`IterableDataset` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = datasets + + def __iter__(self): + for d in self.datasets: + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" + yield from d + + def __len__(self): + total = 0 + for d in self.datasets: + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" + total += len(d) # type: ignore[arg-type] + return total + + +class Subset(Dataset[_T_co]): + r""" + Subset of a dataset at specified indices. + + Args: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + + dataset: Dataset[_T_co] + indices: Sequence[int] + + def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None: + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.dataset[[self.indices[i] for i in idx]] + return self.dataset[self.indices[idx]] + + def __getitems__(self, indices: List[int]) -> List[_T_co]: + # add batched sampling support when parent dataset supports it. + # see torch.utils.data._utils.fetch._MapDatasetFetcher + if callable(getattr(self.dataset, "__getitems__", None)): + return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] + else: + return [self.dataset[self.indices[idx]] for idx in indices] + + def __len__(self): + return len(self.indices) + + +def random_split( + dataset: Dataset[_T], + lengths: Sequence[Union[int, float]], + generator: Optional[Generator] = default_generator, +) -> List[Subset[_T]]: + r""" + Randomly split a dataset into non-overlapping new datasets of given lengths. + + If a list of fractions that sum up to 1 is given, + the lengths will be computed automatically as + floor(frac * len(dataset)) for each fraction provided. + + After computing the lengths, if there are any remainders, 1 count will be + distributed in round-robin fashion to the lengths + until there are no remainders left. + + Optionally fix the generator for reproducible results, e.g.: + + Example: + >>> # xdoctest: +SKIP + >>> generator1 = torch.Generator().manual_seed(42) + >>> generator2 = torch.Generator().manual_seed(42) + >>> random_split(range(10), [3, 7], generator=generator1) + >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) + + Args: + dataset (Dataset): Dataset to be split + lengths (sequence): lengths or fractions of splits to be produced + generator (Generator): Generator used for the random permutation. + """ + if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: + subset_lengths: List[int] = [] + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = int( + math.floor(len(dataset) * frac) # type: ignore[arg-type] + ) + subset_lengths.append(n_items_in_split) + remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn( + f"Length of split at index {i} is 0. " + f"This might result in an empty dataset." + ) + + # Cannot verify that dataset is Sized + if sum(lengths) != len(dataset): # type: ignore[arg-type] + raise ValueError( + "Sum of input lengths does not equal the length of the input dataset!" + ) + + indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] + lengths = cast(Sequence[int], lengths) + return [ + Subset(dataset, indices[offset - length : offset]) + for offset, length in zip(itertools.accumulate(lengths), lengths) + ] \ No newline at end of file diff --git a/engine/models.py b/engine/models.py index 7c879d7..20e8ad5 100644 --- a/engine/models.py +++ b/engine/models.py @@ -40,7 +40,8 @@ """ from torch.nn import Module, Linear, TransformerEncoderLayer, TransformerEncoder, ReLU, ModuleList, Embedding from torch_geometric.utils import degree -from cr_pkg import sage_conv, gat_conv, gcn_conv, han_conv +# from cr_pkg import sage_conv, gat_conv, gcn_conv, han_conv +from torch_geometric.nn import SAGEConv, GATConv, GCNConv, HANConv import torch class GraphTransformer(Module): @@ -220,4 +221,3 @@ def forward(self, data): return x - diff --git a/engine/timecapsule.py b/engine/timecapsule.py index 6b30413..04df314 100644 --- a/engine/timecapsule.py +++ b/engine/timecapsule.py @@ -1,4 +1,3 @@ - # This is a time capsule module for the corerec # ############################################################################################################### # --CoreRec: Connecting to the Unseen-- @@ -13,6 +12,7 @@ from common_import import * from async_ddp import * +from torch_geometric.data import Data class GraphTransformer(Module): ''' @@ -60,10 +60,12 @@ def __len__(self): def __getitem__(self, idx): node_features = self.adj_matrix[idx] + edge_index = torch.nonzero(self.adj_matrix).t().contiguous() + data = Data(x=node_features, edge_index=edge_index) if self.weight_matrix is not None: weights = self.weight_matrix[idx] - return node_features, weights - return node_features, node_features # Return node_features as targets if no weights + data.weights = weights + return data # Training Loop def train_model(model, data_loader, criterion, optimizer, num_epochs): @@ -259,4 +261,3 @@ def explainable_predict(model, graph, node_index, top_k=5, threshold=0.5): explanations.append(explanation) return recommended_indices, explanations - diff --git a/src/SANDBOX/tempCodeRunnerFile.py b/src/SANDBOX/tempCodeRunnerFile.py index e51987e..07719c0 100644 --- a/src/SANDBOX/tempCodeRunnerFile.py +++ b/src/SANDBOX/tempCodeRunnerFile.py @@ -1,40 +1 @@ -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -from matplotlib.table import Table - -# Load node labels from the CSV file -df_labels = pd.read_csv("SANDBOX/labelele.csv") -node_labels = df_labels["Names"].tolist() - -# Assuming you have the recommended nodes for each node index stored in a list of lists -recommended_nodes_list = [] -for i in range(11): - recommended_nodes = cs.predict(model, adj_matrix, i, top_k=3, threshold=0.7) - recommended_nodes_list.append(recommended_nodes) - -# Create a scatter plot with labels for node indexes -plt.figure(figsize=(10, 8)) -for i in range(11): - x = [i] * len(recommended_nodes_list[i]) - y = recommended_nodes_list[i] - plt.scatter(x, y, label=f"Node {i}") - for j, txt in enumerate(y): - plt.annotate(node_labels[y[j]], (x[j], y[j]), textcoords="offset points", xytext=(0,10), ha='center') - -plt.xticks(range(11), node_labels, rotation=45) -plt.yticks(range(11), node_labels) -plt.xlabel('Node Index / Names') -plt.ylabel('Recommended Node Index / Names') -plt.title('Recommendations for Node Index') - -# Create a legend table with batch size and threshold details outside the plot area -legend_data = [['Batch Size', 'Threshold'], - [10, 0.7]] -table = plt.table(cellText=legend_data, loc='upper right', cellLoc='center', colWidths=[0.1, 0.1]) -table.auto_set_font_size(False) -table.set_fontsize(12) -table.scale(1.5, 1.5) - -plt.grid(True) -plt.show() \ No newline at end of file +t_labeled=False) \ No newline at end of file diff --git a/src/SANDBOX/test7.py b/src/SANDBOX/test7.py new file mode 100644 index 0000000..1954844 --- /dev/null +++ b/src/SANDBOX/test7.py @@ -0,0 +1,52 @@ +import pandas as pd +import torch +import numpy as np +import core_rec as cr +import vish_graphs as vg +from engine.torch_nn import * +from torch_geometric.datasets import Planetoid +from torch_geometric.data import DataLoader + +def train_model123(model, data_loader, criterion, optimizer, num_epochs): + model.train() + for epoch in range(num_epochs): + for batch in data_loader: + optimizer.zero_grad() + outputs = model(batch) + loss = criterion(outputs, batch.y) + loss.backward() + optimizer.step() + print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}") + + + +# Load the Cora dataset +dataset = Planetoid(root='/tmp/Cora', name='Cora') +data = dataset[0] + +# Create a DataLoader for the Cora dataset +data_loader = DataLoader([data], batch_size=1, shuffle=True) + +# Define model parameters +num_layers = 2 +d_model = 128 +num_heads = 8 +d_feedforward = 512 +input_dim = dataset.num_node_features + +# Initialize model, loss function, and optimizer +model = cr.GraphTransformer(num_layers, d_model, num_heads, d_feedforward, input_dim) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + +# Train the model +num_epochs = 10 +train_model123(model, data_loader, criterion, optimizer, num_epochs) + +# Predict recommendations for a specific node +node_index = 2 #target node +recommended_nodes = cr.predict(model, data.x.numpy(), node_index, top_k=5, threshold=0.5) +print(f"Recommended nodes for node {node_index}: {recommended_nodes}") + +# Visualize the graph (optional) +vg.draw_graph(data.x.numpy(), top_nodes, recommended_nodes, transparent_labeled=False) \ No newline at end of file diff --git a/src/SANDBOX/testcase.py b/src/SANDBOX/testcase.py new file mode 100644 index 0000000..a284c85 --- /dev/null +++ b/src/SANDBOX/testcase.py @@ -0,0 +1,52 @@ +import pandas as pd +import torch +import numpy as np +import core_rec as cr +import vish_graphs as vg +from engine.torch_nn import * +from torch_geometric.datasets import Planetoid +from torch_geometric.data import DataLoader + +def train_model123(model, data_loader, criterion, optimizer, num_epochs): + model.train() + for epoch in range(num_epochs): + for batch in data_loader: + optimizer.zero_grad() + outputs = model(batch) + loss = criterion(outputs, batch.y) + loss.backward() + optimizer.step() + print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}") + + + +# Load the Cora dataset +dataset = Planetoid(root='/tmp/Cora', name='Cora') +data = dataset[0] + +# Create a DataLoader for the Cora dataset +data_loader = DataLoader([data], batch_size=1, shuffle=True) + +# Define model parameters +num_layers = 2 +d_model = 128 +num_heads = 8 +d_feedforward = 512 +input_dim = dataset.num_node_features + +# Initialize model, loss function, and optimizer +model = cr.GraphTransformer(num_layers, d_model, num_heads, d_feedforward, input_dim) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + +# Train the model +num_epochs = 10 +train_model123(model, data_loader, criterion, optimizer, num_epochs) + +# Predict recommendations for a specific node +node_index = 2 #target node +recommended_nodes = cr.predict(model, data.x.numpy(), node_index, top_k=5, threshold=0.5) +print(f"Recommended nodes for node {node_index}: {recommended_nodes}") + +# Visualize the graph (optional) +vg.draw_graph(data.x.n, top_nodes, recommended_nodes, transparent_labeled=False) \ No newline at end of file diff --git a/src/USECASES/transfmodel.py b/src/USECASES/transfmodel.py index f4f5f1f..3b73514 100644 --- a/src/USECASES/transfmodel.py +++ b/src/USECASES/transfmodel.py @@ -1,12 +1,10 @@ +import pandas as pd +import torch import numpy as np + import core_rec as cs import vish_graphs as vg -import pandas as pd -import torch -import torch.nn as nn from engine.torch_nn import * -import torch.optim as optim -from torch.utils.data import Dataset, DataLoader # # Generate random graph and load adjacency matrix @@ -24,11 +22,12 @@ col=[1,2,3,4,5] node_labels = {i: label for i, label in enumerate(col)} - # Convert adjacency matrix to dataset +# Convert adjacency matrix to dataset graph_dataset = cs.GraphDataset(adj_matrix) -data_loader = DataLoader(graph_dataset, batch_size=5, shuffle=True) +data_loader = cs.DataLoader(graph_dataset, batch_size=5, shuffle=True) + - # Define model parameters +# Define model parameters num_layers = 2 d_model = 128 num_heads = 8 @@ -38,7 +37,7 @@ # Initialize model, loss function, and optimizer model = cs.GraphTransformer(num_layers, d_model, num_heads, d_feedforward, input_dim) criterion = MSELoss() -optimizer = optim.Adam(model.parameters(), lr=0.001) +optimizer = cs.optim.Adam(model.parameters(), lr=0.001) top_nodes = vg.find_top_nodes(adj_matrix, num_nodes=5) # Train the model @@ -52,8 +51,8 @@ print(f"Recommended nodes for node {node_index}: {recommended_nodes}") - # Draw the graph -# vg.draw_graph(adj_matrix, top_nodes, recommended_nodes,node_labels,transparent_labeled=False) +# Draw the graph +vg.draw_graph(adj_matrix, top_nodes, recommended_nodes,node_labels,transparent_labeled=False) # Draw the graph in 3D # vg.draw_graph_3d(adj_matrix, top_nodes, recommended_nodes,transparent_labeled=False)