From e522e8f97e670471a00779ac8506253c5c8b3ac2 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Wed, 24 Jul 2024 21:15:55 +0200 Subject: [PATCH 01/16] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2e8a8dee..0818062c 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ python_requires=">=3.0", install_requires=[ "numpy < 2.0", - "scikit-learn", + "scikit-learn <= 1.4.1.post1", "tqdm", "torch >= 1.6.0", ], From 2ab5840b5f677c20c461fa8c6c0b58f27f2cf36e Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Wed, 24 Jul 2024 21:16:57 +0200 Subject: [PATCH 02/16] Update __init__.py --- src/pytorch_metric_learning/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index e5e59e38..ff4a5f40 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.6.0" +__version__ = "2.6.1.dev0" From 64d8d2eaaed54a5482f169817737d1be4c344f50 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Thu, 25 Jul 2024 03:26:32 +0200 Subject: [PATCH 03/16] Update __init__.py --- src/pytorch_metric_learning/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index ff4a5f40..fab833f3 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.6.1.dev0" +__version__ = "2.6.1" From f21a8315d8bcfdff034e360c9633aa7535040a9a Mon Sep 17 00:00:00 2001 From: ir2718 Date: Wed, 16 Oct 2024 23:29:36 +0200 Subject: [PATCH 04/16] implement tcm loss --- docs/imgs/tcm_loss_equation.png | Bin 0 -> 39338 bytes docs/losses.md | 29 +++++++++ .../losses/__init__.py | 1 + .../losses/tcm_loss.py | 60 ++++++++++++++++++ tests/losses/test_tcm_loss.py | 46 ++++++++++++++ 5 files changed, 136 insertions(+) create mode 100644 docs/imgs/tcm_loss_equation.png create mode 100644 src/pytorch_metric_learning/losses/tcm_loss.py create mode 100644 tests/losses/test_tcm_loss.py diff --git a/docs/imgs/tcm_loss_equation.png b/docs/imgs/tcm_loss_equation.png new file mode 100644 index 0000000000000000000000000000000000000000..b8e62954021bd84dda90631bd981db333ffeb674 GIT binary patch literal 39338 zcmeFZ^;?x+^ewtED3J!~1`!15ZZ;_`C|y#5bhk>kh#(-SfGCa9DFV_hB@NPD(sk$l ze9w9AIltXM;O@sq1$MmeT5HZZ#+YLgp`oU58=C?ffk50=QbcMY5E$g}{~ws>@b_xh z_X_yuhO3N{HYPm$FfGF2YjQU^UAIS$R&JhVE|v&u2S8HM<+82GddnF9zISw9v%^X zZV_$)I$60#+B+e1xCjItLJ29Y?UlAM98W9{f&9To>tA{VO6P0)WaJz5lYW!aeE5y;n;h38r+ZE}d2g~?Y9u&YawST? z?Ayk3v#@v{|C0B5@$!6QvvAw~N!ro3ZsG0yF$)~mVBJ=x@peE1A_nc(A@tHo0Sv8djEcK zRMagCo|LJ8Dc|jStw72SwN*C}v6F>__TJtk5n=|V&)L~#BL#Y_?Cc5fB}1mj<03;W z1qFqU{g@@ci~WyRZ%-!_JvKTM*~4RFTQlXNKI4Wc@^tHkkFyBy5E7IVD8$i(7#F`2 ziMG1O%KCG<-ZvXJXhS6+K!V0=7L7{WhuGNISSw7H3Cq#Z@%dB@iETLryd+VC3ht4p z;e78F8ynka+$=qh=wy_X5TXa6Sz21cjWv&zT5D=+ z2QwD{-^(> z&CMA^JT`0&TpvGqLg|{=_3nXe@n}{-0nS&obe^Sq`h_Ezq0PlVPO*vZs+IW^40)uB zdg4B9iMia0f4-TX8QCxc+&eW0?lCx>X$Vi|`xi{h1VD{&mnLEWnYYb&r45gq9 zOM(*N6L>|1{Y0q!M8#YfZNSgf!LO8J6;Fzdg}nEesHmv)xn)WeLPA2kb{a7>b#y{t zISGO=@RCHBC|l%P4=uK*>vs+g1G=$M&rKKiOwSQ>bFe&Tg7 z{De6iv&!dy?Qml(2tLc~>~Ql>iRG=%t}aS32`(=DXV0D$lF^TqK77Y1oi$(z@26XC zOY%kiBPT46z)RszO9yoF0hq_sK%8-5Z2jAX%#O1owEpqnM+JQDPR8FH^9TlM@%}9-kO{hun@vmUCIt+Zz07S40lAb|-UZ zP17_L*B?@NY>qQaBfr4*mdyUdMA`>r|yKi^^CGjJ}#iMlrA%viZXd zVq$6QsrXYCzdtJ_OZbcI?eFhSy7nni$Eu1t%_5Q>n8eD4;VMyM6KK}%O#5yRr%)8b z(S_A7k2qLNPfy=14!pR9iV_}PUT)K$h=i*K*a?Ug=E&HMj#oEt?QB-g{QYsZC99$m zJv4_uMzwT*>1nS_MeJZUM5lD%<;h(!aq;#0pT*(T3yX^xzDHYxa*rDbw%&O`i0$SQHuhtKa(`__WJMaTcn(Z;nC4J?Tapm zpG8Ks8oIiy&K4H=kj8epdHYZc5Fa1^+?2`ohHKI(w>7AJD1S;4O=*{5P>8>MjRFZ2sr=i?+f-Gz_>_NV_=$@pyWVVicF zZaq+>j%DiX?0gOJS~=}=3o=@;OvcZy2MygAWJn}oRCxH$tT+jM>p%;O&#Gz|f)0~H zs3mVzOUr*;_(q7KL}!1wuWhb5Y@pokN!|Hg7i4CO{p3SwUQu((|GHz4JHdp5|D za8WrEAbPkpbt_SsS<~V+JEV6=pUzv8!smw*j$}ekpN8n;V-(fZ+kQn5p&`i0$#;&9 z-U%`IHo{RUEw>vr-&^QFyhuxL@!mAy>v(v0Z{NO6-k=q6V})?L3FRn)!R`<2=*!tz(|?;2ck~NMc~Vr? zUbly)iFp%V$G}B;dkbe~X2K@kg{0#FNynA6#F454f;Y6of~UZ+nlSLB%S7eSq{a=a zv9S@B`88`v=xqwVb|E8lU(7-G;o)Jgi-S=##L-NkBqdoy_rG66>Yq#+0uV3W8T%;x zwb37e#1WsK9Wr{oR1^n3g6T^1xufiqDRi>Tz%$)9N<8w99IJ}#846|u0e zAkyIr1r+XwDKgKw)T*ZmYE6(%{^OU|x3F@^TaIDF5ieAW*vIv8#^QIhMW7t$?JT~! zN6C&zYx>Qd^fp0yX$7b9r-e;+z`R?Rh}+7GuO9;_5a0aIJT}J5AN~#3kIl+r^f_3& znXQGDs2l=C2NU%AzopvM@iy*M;3sB629i(Qr|~B3=9m9 za3&&=bli`Adsr`0(G#HTKcpfkchqOXsA(JcNljtBdVG_rt*b9zlWXim8KDo(jjASj{YXC2CxyjqhzV&7rp`XIEA*)73EU z-o2YE`zDxPkYep86Y<(@(!^SBTRq(B$v1CeHBRr+iHZ|Dkfz5!@7SGh8)yg&gxWth zX8rNMcy(8NLoVW$^SxQdKFudQTu~%3F_Dblme^x!@>$7DOUuo(qiuzmAO6gd1!hoc zP}`R=xZ0on@D6IDzB$dw%wraGcdHzqoG77gPUJJX9ezqt$Pc-`c}(j*o6pPgS=B^o zJl554+{P&XgF{45iVR=I#o--pPFR;TKrQ~$5GV;T%m@K=cD~=oZPAYJ<>jSFjjgn@ z(2=k*T1=alm&fJ3XW_Toaz`u7!n%ru5$(s1AAc(BDQi47Bj3CsblV))S$oaY7@Qv7 zEHeOdY*Haf#E$bR23#>=HamhI9hcVhZBUzp(^8g zA9^OH;M>tR5bYfu(q?A#2s$pV*HBX{!-a3f`tf`_~{nw-&@4|%xi(}er8${3%3Af6+I#}ZpcR(5^kqhz22AfJ&_G=h|8 zVCH_Od=~%aKw+!Vzf@cdl(zv*|FolQbm|@`2{Bmr#Ku3`A zSYTf6y=zdUB8d0j+u7lYt!VG*X*tn*Dr%#&X8S4ydt{_-$gM5;2LUQMqrS0k7HU z`oW)C&tkcT?Zw5#-A2xyW$(ZOfq| z-i?fmj2^sU&~}i`*+)6HZNb_c_=gr4FMNi=D#ritqpfoC8;YZ zDV5IL87p>O?lqfh3L)b*zaiqe6%FqpTln}XY{AN8HLsYM*f#)U`FujSlmgwyse`?6 zehLAKG>LB4J+cn;@ex7gG%Nrp$>sZrU!Vw+XJlkRdBgBI+KM^bs>MTHwxOY6f32ru zaA@dX_z?$&7S=85huN`i-4@lX;dSXtHCn4sLU2#K5u9l&frL zqFfj#)&|~xMIo1fQQj1wIHXLd<;TbKam=U#I%I>yfPn!bTnEhie5|Bp;drXn^XX8IW-|9v^vS8I-6_u*>HOip=O_EemuK7DHvM!c z7>g%)&UJY!s@WbE%2h(@$oMqz~&&DbZb30P~`;pmrS)H$% z#FiTtKC$<6zF#Y>0!V8MfH~00&b$gxce<~em0U`uJ z`J4lI(hMBJd8MB|N1Kg_g{4CGzWo^R2%H%CAb_cKl9IFySC^uwTRvI-`1E*Z4h}a< z+t<&ZVcE@g|F&S_;2_o1aNR_xFOTQ(3iQjdA(j4uI>7C=q5)*Jy!T|<|H!DSrpA0Y zPsi)>WcmBkHe6J2Da2AWqCOiIz5Bf&9?+`L$NBcz+po%!OYz6(rJ>Y4c6ToaR+)w{ z-5Z9Qg1XGy<}8n*oSgQBqdj{>-BNElJCLz;j*sc-_AXmc zyam@U_zSc@A9vjRgm0#&N~2O}(fI1-wHGT>wUx&Jlqtk3a&i<|6P3BGy2*eVx_2#~ z1CNMqa7`C+x83n9<)GxTP)xU@LCmeJu)5nIk!}ftU)67AYw0cX`u2a4{`))op;njz zQ!suL)QjWVlODmDvUcIfsHher&nad^xL&ChR#0RA?05JvgNTSekY-K4tBWHJLUjWj zBO?wVZ=8O9en81WoUh=`VkDnRN=nXOuB_OgNdR?o)8cd z4;6h`pqtnSl0sp^PK@ogKC+p#UgNdPki=;e3AmD=sHmv*qTJ-8h&%RsLC4lqemi_( z;$XNM3rouk0D%W1xf-mq%gY#_tm?z%qDavYC>j8?!F+SP+zTohER?RJw~2|#(HgbtO%JH*65QM591!W~eA2993z1VmTasc?~TedN!dQdGh&HzXt^P#FPE z2BnxWOM)3BBUJK1osImq!6Zv->!-l;QCFxs5>lAX8vvVtss>hlA7Va#X7YmzR8&@$ z0kXATxDVwe3su#V`HFpTsQV!Btfy4V@OKV7KRsRM}j}5p^6GShHOk`KYX_X zOd1VA0j0vOOQ!9w>4p4H-AcD8 zC$b>zHFb66t6}^c7tX@JMo&+#TjO@`_wV0aKzQZXH~xIjQes?8;xO15w<|V-MFpA7 zp~S2PPS8sKT1X071e`>-O%6El}qRF)%e0glYCL;f=Zc7-}$J+`JdOy6+Ph5Y7`hv)}+F7lmHG1$VZcDT3=mV z%v@zuRq4J854W96d&nnfU9D$n$xKK{h(Of)95ew4nOWJDy!e-}>-Fj6C}y(*qU*cI zrar)ss)vmVMvp7GW9hTP0 zW2*eC!Zq73iawj=qxT~>YsMe#X-w$U&dbXVvJHP3U5_dD{kQTGB_io9KVeI39o-;N z*-R-!6(Ijf<<~S^OH<#xc?Ft5>V{emCMtILlnIqB^~{y(u^A+ z^@QSQP6s=$j})>@+`7a03$n5U-0x^io$(J)g{%?_ArovjTsce%e7(C~=zN39?&_gd zUj3K9pbWkjyM*V5zI%74HI_OGG>U}WAO|^6XQGQ?|0fD+x^;AQ!-0l zjwg-u>Y?(NxbNZY&zIM8ii#X|_X*1R|BHw$->vGe(=BtAAZr2?e~3>|ZvF0oiNFts zUd2|EzzdS2Y2OE1Kh7KyE}*agM|$r(0YvWVxJ|OV=xAfC6eMpxhkR|eXj;kXjXhD% zExuLnnKv=fsXK}{LXZqYr6Fq4R{7~=0f04*1FyEP07BlnO|!kc90po4N=D7`%oBPz zOe3y$=2*ZC3pGEWK-&U5iVN^74wVqTu&^+Pk6Mi!I)@8ysiw_q5qyTsP}Ou162On{ zhi$E`4D9Ur2iG$_`t(nfWgQ$glDW)evW~QHadD3~O8a*9`y}TjuTIo!$q%eh0>{?h1HU4(rb^-x}rm%Qx3tw!Lw)E4Q;3ZJZu1dMkeZU zUp|Hc1f8L2UG`^z5JEeh?8QT4fv!zC12!rniqAViuAmF%axf5ySp-mMT5>TPgr~1P zO{jb&^!c<>_n^n+^xrcmjOF7wHK%T?zi&cV+q(4 zcfo4}V@0szc&L=q@thF093`R#T%K4LNh=iQWmA$nZZYW@8b-%6DwyxgHlmOsv<0pM z?5?V=*1MRzE$p^}kJ5A{ug*({d%bz@+_@umJR9T$*i1hi$g)woUiynvequXb&(y(n zTTk1Y&qb}vmbM}Qao-FFgDFGaZRNT=ZfjqY{n8flmyB`4sQP7IU z?Ma~ot=dLZH!P2lo137zx?0HRUyid0D4Ww>^S1$l%Qdy@^GDL0@0v$P?*P;lyM%y~1^H+5ynJ;3*=n#q)Kd3DfU!8R+mzaDI#D0qr+(^F zurEUJ+@U!xHMwnKsv4?PLH=_LoP*4|?*>dpfOLBY|gLap5XViU=zUdLj6cI#@t6PNm| zp!Ke&InJr=H-Z-wp0->T>Yh%L-H+LoS9HTMZ@L)!^9T38Cgj@6biL@~V&fQ)Kjz_J zBM>Pnp#a5{smROkLYTFVmeVrmMZJxQ&d$#MQ}0{p>ykoq=kr zm5>2*#8Fr8BVBzNP-^|rpA=df&45NJ#lB=_UYiNL5=V7p8ZI^+QK3uWw_}z&zRoam zNoj|5AinUX!$O0Vn8U}|Hkk8H$97$WV1hvAA38d^uzVAtrJg969DOYI3TsdYlg*k# z=>-HxQL>=33FPASbmhbx=5n_mzjGhm^#c+KVfj3(G=P|d1Vw65pyF|%enz=a85I@PahfM};f=%JJHSiawf~CifhhTQ-8st}&#hmP)p6#F?`beG zKZbd0zT9~}c3H+9<4xq;ozm*qPCIANg)Ja>-BV#r-=^Oi~9h?{H@D*Na4 zZak{~U1FkM`PM06E1F>5A{+buG_0d0HOa%)sL~<5!^5|e+<@2Gdp<@GpPc;JWKxF( z6-i)}MM>|9ar(JP9i!*YFHHZdXw(Fz(lN`AawF$N;?@=|Dg7mVP zG%i~{iYUTXRFU``uaPyN>vjos`nSay$j@atVq=x-5$kaGyYQvy+E8wc1liONe^Ds= zg)Up;F{$OW(dEGyG+u~WkIkk6y;8ukNWat7uPD-B@Gs4xQzJRH%2b{8EUqz)MASL2 z35!PhyCQpk2A6m&XnGdViCE{DCw>D<$w?D%Y!<p33HxJRcU3N1i68;W4(E)WBe$T^ES!>F&Xq|$ z^pq^ZtyEJai{;D^evX{WCqjq#>f1v$zM#h1`q^)OHtG!4EFPT)sKDibN5+L{4#TTl zIUf=LnU(~lp+jpR79_EAAXCQ+3FUCIgy30@{IlIx3ci#FO;FvV&GVR^OIHK z{2gO2>A&IaZ*fmnvBGgiej2UfDCk+cd9323bQAPC*xm?jqp$CGW5Muu;1f%~h@afK zfhUm5yq|a~w)68ks{I8)^%4-~EPQ?w0H-;)WS2=eLN6iAzqHIhTQ5lGJv}{5HSz!q z!gAXl6vpSf*|^V@?3VAt|9=5RMzQRu0JR}L>ps(8*>_}~ybk?ChbN?y)-;mcXQmP; zm?KN{x}&vK;AITB7%Wht0a$E)0sw%_jaOjO@I%L8tzD(&Ig=q}BH;%OOa)aYLM>j9 zuEOfl?prXm(erfmObyNd2QkGQS-u>NbDB!oJGhLWC$`9tL8Z zC4TUzG=B1MiCxv%$>|032cXw04&eG{S6A1<)&Aa|KL3Y`rKLrFl==uxjnR!gfW58I zR*9KlOYF>C23_ZRQ5ubXB405 z^+|b+RRozOF1Fd*f#nw#TmO;o3ah@^(~+lTd}foVsyZ<>HI*4M zCwaEPo3IJ}f??sM$!gc*Wy!0EIEhnC6uGK(C=r*zLwNKo125~ zpU@@vOUTxt`11heD*c-~#aMGVLpH{Hy=<%_JR(+lt@lxmo{MJ4iyg9GNQ{Z8>k|gn zQVwMlFO~F@2WiJO9;2JyzLw5wrZ$+LWgzr>fBut$)mx<7I>^}s`1mgX^bvVZI%5Do zPxUuqV&^7i5az?5v%_s36TJ2;#rYsM^@-gz0t)i5rXi;rp=!bMchLLJv;XL#3*EFe zm7Jot8~X7s50oJQJMX8YD|2!%L48K4Cs4h**Sd2wSigDiTm4%f#k=#s_y%IW!OaeI z9Y>d$=Sp9-K3RFkfQU{}$w7Y8d-XRnG|+wpIiQga3d8m-XLbWPl7rWpVD#A&b zGUxEPU{LITnA6Q$q-1yQm{d8L)u!&xjI8oBXJ8pdHunZ7cB?5=Srf$}6Dg3^Uu2h8 zgId^&HeGdO=SP0W>#J{CWFDSI^dx@#XmZ!A%mTX&EFM!XO}c=p3k@nMD#gXcp=tG6 zq^;1no&vC3Gql1jJf(kCGv~iqU6p^#Pr!JU`4pdsJp53QF^}=B?QK5&LOna;^}W%j*59vvNc9of z7>Xe*^A!^g%@E%GU)5>BSb*SKU~DuA5;nN45n@Rv+R+fv z_fVk<<1XXir=aET>>NFJHQe%s@d7p4iiIWLTsQ$%{5+RHLILV>3*Z9*MY)jp^JTL@ z@1g-q3Q9UhIE{jO=|Xh_xJ}A)1CjlnrJnby6&xMpqlE@xAi0oNepN|%rIx-X(PUC) zL!g!>WTt6&_)7TU#Bz={Bi}vB4CfQYm?&wL-#%2+Rb;$^EcD}oWnaE)&Rcno-YgZa z1%Hf`V~@PNg0;&zw$^`L1_B#R6M;T z?#q`CovUO*!rgc&$YQ=~GqP}Eb`Xtz!g-OQCnS2ZbA#7AKH(N_LWm38$Rzv zb=G>YvY6MozkuG2X*L-ErE^W(;b3hT#Q_z#$E!i#Na5qUCS_)pF}7%TRWvl}JEjG4 zG=GXqB9^b+fb9G9nC?kKVq5qAW^!dg!K&K)x$WZ6-+%3bXijpS!y}ZMPMi7hOc$^n zZz#0Uvls@Eek2(FBxmJ9$g_Xogs3_ zVw<+DiR(Ux8xi0enFG#r191**kDk==-RyV9JmfEJxX3gN4WAzt+$-Ov9y6=*hRUa} z6#xsxGI6jN4Niq);Mw>@L{DL}UkFDbaZvZeC@4r9A|gJ!f@AIJ>E(0UY%(88dKT<) zh;%X(d6ji5{sVDPauxrd2KRd(C`*aCr|H9EBc!XYIQRtCO5)SfVhuTFmsaC`YFzn< zno+o&wD)$)dj!zCp)0UXP0r?GAwSi;bfs`>?@>|kaG);}xjwsru3nCQ`O>y)k4r&4 zm?fKG==YIj7W9h=#<0X)=3fy|y}RA}+uk26U&#XY6v6maSC?nPKAkwHP0v5GAQ^P$ zu{^VRsVsTifL6voAWyqUup6QN9!{z>$NfhSja7ya|7*p@NavU&8HF&=#F0NwwxPp> zLHYG(UO~Yx;MfIP#lWgoS65p{sYTp=LTIj0xh}L5EG#YkHMu%{C+8f$Iacjzo$jCX z&wRSnI&RlpS68=m9$5J6-2-^Wlk>#~fzLeBhsdsD4>D#rKgzsH)?l)*HFMIIS;bk| z=Z(iNYNS?H5B_S`Q=lF;xuAGldvA_^k-9+A5WPB|mLX|bD!)t4#e$uWYmJv+Yc+QE zv&>3{tWUQ}^lEW13uooXMGEei^Gl~EHH=&=)FS#P=)>TyG~_?O*dB^u~XfAJJn|T3KS>E+9;y6o`*{3Kuf;+2$@F)~J5>OcyV(qkjO-1=oS`vC%7G#M#lhG(3w}lhmW8p!10Qje zg<$dr$RE(m?`HGe{VSK?HwYzn!E>Vq%r)Tom`kz#z%9z@+nAPJzWSPc8P9Y1atA9` zx@H8SkWJXzi@s8lK(IMOt9)e2bVuJ+zt2agmXS+@jdc^96cmV+hc;p#qpiTe55``RV}lazWzPPhcU3saBqmH({8YGZ4M_pi>6v>Ao=eg zZSPiw;_yh3PS{4UX%+j_ENyzEpZIKnTd;&6TKprwev{RR9XIN64r!#tof2d?(TxRZ z2GuonwT^ECA03PK=JY`cD7^lR1qqLN~sD6-k}nqJo7bff6R#m%j3 zFQrq?xUSggTCkl|W_#{YbA$H6uT$*1)nVbCb_GR*@d;!sek3kW>(ufP1c`6%Di>0@ z&nwI6zm^d?X^ncdmntr|OEYS&75+NfgBU_aovaoyEI50VN8FvGe9x@Bc1Fk9LlhB!jW)^^Kb%TCw8+fB&aHA>Gu{4jp(ucXA^X!VRgRbf zvZ^ygLA$qnT{%~G%d-&oDAD@&x$$B9o^LJA1U}uC=y@%`~y?K;Fb|?eE_HDFL%=o7vUixPl z8~MdQZ0j=oz*z>$g)32Oal_^K=W`%j9Beb7e?0}Wl7B_-6|Inyl-o+bG>UhV4nM58 z2fkMXJ}4LF&@&`W1jE!7J0ctkSOj?fbeTT@$(%kudzYMS1LsiHj(KBI&#Z|AUqddU zE|il%?jFOhEEf8MaC;$24%OXS=gs)q$WM!^Od}lGbdSB-jpUGpQA*y*FSrH1zF;L- z++-pS*uDK@c~5^^jwF$>BvUNdaPBT4C7Z?@*1F;b&uuSFiG zper5;P2q>sZ+h5zB1i%vDbs0<{AH2a=8^4etOOiMd2jFE`LSn*XCCCRJIKd~6q4pV zPkQKLSsJPOUVwc2u6!KLR6)&5`nwoYs+nY?FIa_WkmyYsWI@R*rd}0)_jIb2lvv&H6!K16CDPTx*4dZd76xKq zb3W0dC&$O1phcDgPz7bn)zmaI+-uYt7#IMT$FHra+PQ7V>BvTu#r$cUcqm91 zbgZo4D8Q<8bEs)d70yy&DYfyAMR}2+qXfz^Vr?lJ_5oJ32?RKFPfyRNza@9N_4Myq z4}4|hbdY5GBxPC^c-@Y3iILuq+vCGEgu>+BUVUZzr!ILH6`a3sD*qBAq;Pel_+%*k zs&v8^f(==(&(AtW>)@qtog(u|{?r^>R)t#7NvciFiX?!li|9n8$=H5FN4I3n) z&+Z&JX%&=zqKkPQP;4$k{_C2=gz$DTsty}_yO9F!!%%EIymum06FRBPP1^jA*oC}< zBg!&I(r7ca?;q6<*!6TgSS0;I&xYR>bdJ=1EbCv^DW~=}OhZT*>n`t)qNCn7@4f4q z81wWPSsxfy(SfP`24VmJoF-^;(c{6%BJO!g>OY~Mud8jGH9*!SA|wa0!tPD^;lG4`YLI*azw_bU|X!1Am|gNrP%D9ZkWn zjKGJf5QYbRKJ|3)Ms{gyPja;@##X(G1Fs4$EH#VgeiP4m+FOJ%GJumH1X2cJHr*j9Ty<6p4E5Cg*^aj7<}I~=#mEGa3`TmC`;Iyd*F7`X852mD>Z z)bhH^agz2nt;FQt{RhT%e{=(&@!xs@{yGxCFXk{eV1H=cla_tplC&3%Kj&ueX42e6 z_}1{fnhvL7ki(+)P~p7?o%p--)8+CHM6lvM>Bc|q|D{vJYV2ib(%Q)SF!-LW|3__t z%-t@XYQnh8URzZT1`1**fX-ho3tm9M*tZnWuY(3YgER6#)O865 z3`hfo27bVwcyX-b@36vQ-q?&_JPFT1)&ao{B0wJB*#~C|t(Mqgni& zekkvxxY5f({w(q?fzT#NDn(Vy(Au%>M(T{*l$i5||6%SMs^J(Q&dHP)~**NpN2qDEu5pSGHUKh*Hgj~Hin_m&UhzB?!p9O?;50!7EA3YQK- zM~y_Va&WXj&)RusR@x98;;7y#${J+-!5bZz&NqE3(63r&uD*Wo!RATnwsoI)15MIQ zc;V}=MJFdGDQMQBnhQ|E(ctID+daYf(Af^n(NntyuH)9?$VhDMCr{qf220~zl*dzl z9l0Sp4>@#4%4W!crQULi(KX}N@b%tas{=QkJlXiyRsJWz&c#`7ud2Qm~r^HN7`Vp`Fdy&PP zHZ?W3x{4PVV-Up1zNVc(8S%3{M(F0vYg#4oG#gf{6E-xS{F93hH4V`aVc&(g@Eulj z!^`N*T!wzpKYH5gtjy0xL0z`Qbbsl?#2EF59wlwt6zJW$xVoYZO}g^mHSq>#54XpL z?tT3-Oqkyx18pZXLpCfEtTNCjYC>(s#vg7JQ-@-Y!J?WP3~D(E%Ky~+o<1Rsn}ww{ z^K5DSgaHA)n~A^sSL-HMO;Al?w1==emH!sx$3}RQ?~XGP@B~Vbx{ef%;R_I`%WF{9d}^*>F93m7NAyt2ZdcFU1Ydh3!V5&L-g_zlbPLj#x3u5JPLI|AyJA?qofR;>yOW$J1P9x z<6}av;n#c5>3t2or#sk~Ah;M?i)%pc73Ji)O_gtbP;)CaPkWlq&cTtwi;uK?6#QaS zsS|tmOt+?_mQ?PDrEbOD&c4Uxc#W1k_5B@`S#M{5|2m>OUGxsn)h&)u9WAX_V7tnM zc^J%Fw@Rzf&LGX9S{(iLzA!O^`~(HOlDG}kX+m}HKsrD*g;Bl5 zsm^;2)2%Wfj*R{WT|(or5o>hN@~3y-X%Pznxe#4!y&47h>h@9y6MSO-j|GtkrvTzR7I%&p}jE3L_2iiQ$ zmX&vet->m_fXO2c8Y+|C%js50zv+1$SeELl3Bu+g!>*mrt}fvytYolvw~fTqesaAf zV?OQ@EHIqa|2j(8SVEEYZS%aOaJ%=|t@QU7MfbB6Y(G(cDR{uU72iRXv2&>KsTD6E zfFfsji_~rBx#2^@`9k-;?9ufB&_%(m@)Qs$IL)ECK`b+SygLu=Z$ZdZ=-@Tq+j9ZO zC;S1ktk}PP*_EMOB3Ddy=s+|Bo4|#JJ=nN<{_agw*rSA2EycK9~S@ zuEV~4(#quuIQN=_U9mr1hQf5xeRoLu$mr&YHqLDk?e|(dki4Wh_&OcmG z=Htgm&uKp*rFA^n2RGY3Cp{S~Aj7P-er zxm}g>@YQMTvh9xtE+$O2r3BM9<5sEj!4@m8H)e3(lN)$1FcA!J{B_G6aRP>mva!zd zVf!9##KVvVm~nJK;?g?@PmNU)>mBdiAAlc<&?0WD&^}>vG4~Hin3P5G^Yg2DCta7T~7V6xksZq*p$8@!KQ^u*5VfMS;gnVwE+Wu^A)i68R7*eHjRn3SW-d5dyuYjl} zw&ufo4<6i6Q{#^s_4W0ANUc_@kxIa#t*T=Ed8eY6cydurkAb`U3G0I37b2~j$cVRf zN%bCyfXTQ+DjdxR2k#~_VLQ&3$p)BvC(I424CwEFAZB3dsu>l8xeNaIz{cYFxC@1NdiKG2U;cA`ko@Aj}M0)sgutf0Uc$7IWvzi;fHqldLjzbHobqD z#AK49SgdZYuHR#mJjf;fdbR@*!&&GVQ=>IF20 zYz~tZ3bT6Q;^(4&-Ahm2II@t~xc}wTyD$t}p zet0otF11A+Q}y$4I^zLL?~Zz>oPIJ!8ii2i-yyT-Mh(OdBn1g1F48lgM||(e%+GKFtBzE{aG0GX#auY=B(`bQBUz2yjY#H@9=@$<^`@OYlmOX?ZQx^OXyIgM>qJ=9XLhM+ z6cex0BhJ?;KV0_nxj*YsmpeT$(7b#573IfK81q`XXaBL$^e0>fYT^hG7BoX?dV95| zgP5EqpZ#wyfZ%iR)3S1MR@hLCo>xf)7farhQ)K&{#eWdHHMq#aD11e%krwgqyynUz zUx11Ex!g@+OttV}>!)EAd`C_eR&CAd1O0EmTe40!5U`W>HNNEzn*Fy zx2~LwM2!YBTv_&mT2yy%|LMo&sW*NW%f__!+t}q_BFd=|3)V{%u2fHAYC0Av%2spv zd211`?XJD=bslHSH4t7e1#*h=l_M}=kPJ15I8bV>j>=!C0WRl0@gq#&4S>y9;k1Wn zmVxOAD42}OkO*E;0 z8)J_QtDL0QoW|3CXNtuqzF{@bYC^s!&UdO=-PgC2kJIl`OTMd>dCk^{eL-f9nt`x1 zGWbA*ew4;?g^j@_l{22?3)c_p(eN*AQwzvRvX*V)w`MJW?B0G3!D{$c_C>+;jq0~e zv5*3otNL|(-x}OBT495qfcr<$i2H;?hK~d?mhvQrPxXfIaC9fG-OfIbI+9%SiB3#k z&6_(qy!zApGUJ-j<%xG9qJLo$%of_1Yj>CTxo_*n50(&pQ$*S4uDz&QFRDdaT4tjr ziAk33Yq8y7_SR#t_pS@azD4+h?dyAqN)m?q`6tV)icuueH05iLy`o+p;|!P+xsbHX zBrCPtQ~b>D{XvW%yM{ZGmnkdvnjFLQnSAeJW!;4pslZr!70lOjL;XNa1EGd220&|< zcT@xYunMg3Az#0K-8noY1rr(Jd}|C5jM>MZ7UGX(Lnayc{#0RpMi=N}3&6zBv5_&+ zJus`tzOa`n4+hoNvFA-Fc@j0KhKa}o4v|?m*)1CLACXKIlGBh68r&iN<$LR{RJ)ED zL)Jh=^L>KM7YIMOOx0<|w^46i-H(4m#@uwxFGQK%@l(ct0TZd-QMZ=bgF6HSW*aZu zK&QrHpevC?ISF5PSX_IdTV@kKG*{-I&74^l@$iYaQd0Oq?Z7O0!`q}3iK03wi(lwi zcyfhdH8j!;0)?MWZhHtaZ>TQKklx)8R8905++Ae8nc41fiqK>)=@AP_WtWhcKIZjb zFRc8G`&Am+a44|LC*blvUyN};dgVe&3H zKVz_F$rmi(=#m`AeNmhP=fP~}W*DcgSo4`Hs-CMqtz_$xbk%&>sk7F^O zLyjuB?GGae(aI?}7qLwUl9GpAVxmU->L#~a_izKc-uG536NVJ;dZ(HUKKe-dYb2Ge zd}}vb?&jtf*(_B7o{?q1I>kp{m^e9mUmethw;MtLr*+JX12DSpjC^iUr*l{R12KyA z9_EMzUe*oK4}&=0`6v8_6in4OpEs3{dcR88!Hl@`OjzzI^|oz;2sW=&MkyFSPu!2QWk+ABHc%XH(VYR|J_@J zvF5?h#PDd$xKxZ>7d(QAq5e0&`3Rgg_w{Msj+Ef(bwwc`nFk@899vjM3)mCeqq64?ZYO(eu}sJew2kBUz4p`gTK0pr(CDv-B%4)fU|h5q4Re? zcr{s#PLHe~m_rk%b*2_DYl=~YSmDU2noA;PxGI6(&4Xp8Mb!gs0zvuEiQm!eF>DVA zyvZZJtUI7xb1~TL9}|NCj%zP=_V(VuuirpzQR3U>aNX_HR*7o1%#MS=uF|I9_42;w zCQ-%0%na7v!8@+?@j>lI7r(Box3`)n3f3=PHi!zmY233RC3})QKfBr?W)Q%>#@l!5C&s$ku{M38Dae}|F@m3LV_5tvPjTX!aGmJgD-#XE4QV1XBYl{y+mAC+AP+M55{+@@Af^Z{DepuKEX8S zR&p@LtBgN3Fksi@<3o+HDzFmE7#q`|rf1UJNA$G9_+*|wWyC?hZtOJ|l3@=v-B#vz zhCDAHM>}~dXHb5xr}WkL%naB^FA%mSANs;`oD~sqlQls)DS~o=NOapGx5vqsy6r=h zE-_L4g5hEPwm%+t)aY@kMS_8Nz(94Uo($%sOBct=`g${X$BwIoyZ**w9^y=E5;r+K zjt!Lh6Zew_!A6-T;GUF4`8l*>OP-EG_My?uNH^?%TVpR(*B{H|3uJW@$=omWXu`Zv zixKd+(c<@4F$QAH49{8Xn6c>HILzC2CK&lsLL1+Ot~t<}pWq6qJiGa@6o3Ei*c{O5 zZ~BGDD_L<);DkW+bpFfWX3WAa3z#%MEA)nL19CQwjn~e5V9wqBtdVF zotd79m=djgx^HCn-6e%WzV^->mcK)v#s?n!WWq5uiiQ4(dZ#=0(6^2lC?c#*(gwF- zG^qtn!8n-vpI;Mk8~Wau>N*8B2=kBa2q4odB#n$_f0aV@*q5^K08_kM7*2=S%+7q4Mo>H#Gz14blIHtFMlV z>Wlszx{;J_fgz;3rMtUZS{gy5J4HaGVd#+Vk}g3S0YL=>BozcD^_}bY_gm||H~+bo z%-nnKx%=$BKe566_uC7lbPF>+LdE!KWMU(H@8&%>VbsS3?EDR^LxJa{aM7xoynQj? zVxf56wuOm~II6C%*SgLXKOH9K=UZ}A^9$FW;M=(IKFU=uC%R(Wdzkw?1T-`6Nn%Lf zZr&@yXPp2Il@UEHNoKB&FA+MSk3m~J-KD*xkiK-O0m7*uT;9}Ma*OrkncyKU(>0RyYR8E|&O`4IAQ|_PTqo`gDSRm4h;G%JKbImB=x8E6klK0@!E?rk>f(BEiCLw z@HUlar};d3)SrTWi1rI~_=()Xn>ZZ+)H(&Fqt(5vtesQ+kzWeN|l!?+BV0u*)qgBOm{>BJ|tcw%x=m$#4P z6>pxf5z3A4MJ+jvj>gfJ;XQiFxv*CFTg3a@L(R)UofhGxqMlL`hx_i6C0he9^v?L-^I4Bv7ads|sF~6yet_cqQL%PWR~IF4 zimF~@Alw1IJ_}k_yNXp@;oAK;q4Y^QZ&l)BHv>Le^SR3?n31v2798(bWxK z+lAXcx;F{f>(aXq2&je969#M*tJRdFkjaCLmdswb*iwFTZ@|}NxUYpfg+`#%Fo|4) z0sk3D4>s9GX`OfZHbCYn=DJ<~)fox`tU!tV;O^K0aOa5>1_db=NRt&uUNWnp!cIrH zLv)S+v9o4R-1}?XxBY!WKvCHO&%6R+v<<({7gw;b$Y1I3D!KgC{`R3dOS&#0t^ZWO zqQl37EuZ;*_$Tq~oHjeQ8dbS-@yThYITw-)lqe8`&EE)0mz^k3=*NS+)Cxubl@MZi zqUCA=A2Q;QHU8ejs@f4;p`js)ie$bX3dka1!hI19Es6p=gQ1A@MIAk!Sq@`{g!%pD zb<2lm4k+JM6!7+w*AAQs1r@{?P@-hhxp66@U$FTiI?5I1Y3G^;njGGt!gQ>xxc%UI zufzkfHHN>3{1#8k%3_7ZfpUuLaPJ26NQXCvVik;xvQ_{q%E{fG6!6e`egyRn?s&%o zNIRkVy_oB;SfSkpkWjPw{wYXF$@u%-PvAow!shVZo>6;hsLZ?mZ&+B47S3EBFvqbg z1oTvuH2(A}ry=X`M}EGy0NF~#0}f~$sB{wnX-Mn9GtL{`=RTHSYuwtIY_Xn3L@;Z+_aRV+>u78ZzNCu{pMPU$xn54sN$V!ZPra(=bEpI8O9~X#5GzH@0*--N^B_sjy{Y~6G|7Hd@#vCjFr*)RT?d|E2AWT+dj*C(l z^`#c`RVGg!t<^lvI5rCm49qe5$UylH@6Kv%5|?o&slx&fq0Gx8Y2YO{$;}3SH42%a z>GmK)@4kbsmR+N^dZWn7f#eDyv6a-Z`CEGbe@)PREoge?v~ z5g${dXNgWOZUla+Z^l?}dDZI(d#`=@cHIlWGfG zt_Gb15y>t1g^PH^A9Mg4z7J(b^V!Hf7rYu7ogScM6%Rl~C zIr`iR-(21swH9IFgqSX|xfO}i;KyyR2^B^M3h;)Wn5U*W1p)Q9edXBCFUJ5P$)YE^ zPq6?=aVLnsSACK#)(#^37|cni#ZB?#{tji_pSgJFp4fANcwLRqGs|_J=if!AcXIx? z^64i`EMa_j_X7k4KqMd$CTnSSA z{{C6owimN*Oj+Q$rpRU#;o&x5Cu<;_oFtBqkAGYRgEO8PctKrVT{$09rOG!v{*i(p z_@z%=9(#v*gjCctko`XD%$wi5P8L2Jf6a$+yU-+LhdS`3-=MW?^GwV4&47q3K2=_K z5Cgw{|DKSRjm20t12^|QNGT4nf;-H)dM>vTbz|ipU=_~9PUi}5gSO!qFwl4ix~B3b zV-gOISJO{IAU7_x1D|fRSi#S}w0SQpAijge5ViOyX%BKc`T*+QQg@Kmk1?f6FzU$5`^@uAR=v7cXt`S-MpsMl|83f8)t969KMs0 zhb-bp$l|g{!}V0O-FK>cn{@_``q4}9=WS>OOBJOQI^r%5_r8qMY`mb#cXFwYV#_t? zeof|hEOta~o^R@`=wOGE=bSGP(;u^Be28D7WE9a=Zdz;dDr>~rX9l1oc&(<%TF(#k zbYo=&j-{@@F@a!#`8dxjX0ez+2j2{c-vIY97Ch-qprX$d%VT?&ED=3F_u)fqN(%n@ z`MEO9Wa&ebCRA02pb44M$vDyg1s!e6#GFY%V_wbbmvJ49K|AQE4hm13Mp9dy%jG1e z@$-B(S&m{Tm%AJVR$BIHL{y6T1Pwh-sJrT0mx`;F(ceg1G{!ezxPDm24Lx6gUpf<0 z3g<<76}&WY1!`OopduBZOa_;y5+RB+cO+nS0Mo@nk7R1gp5Ap%jdbF=nKbG{2K*NQ zXWauA?oO5do4Un!U-w^P!&4qRffdk5=Zbt$`8fpteT!L}- zk4slZo;RP7{+jZ0^3nz4>%DU1i>D_);O03{rTW4fWD`&u-8X^}yAaQp>%3@d$4D7f zKZ#Zmz$#|*ZDy=?tyW%y81!vX)n0L?iQzI%0K>VQeD9D4cEDZ;F zxVYVN$bJVZKw?k0&DwB*+Fc4y9%*!nxXvIP}eXQH1Z# zZa2JJu?%yug2g8uidxu90DnnM(H#+_ZR1)bWh>6Pb`izp4th1W6`c&nf=ZkY&WS$|n8w=FK-WSTvG? z5LK-q&J>Ge%Vbpnb9m7Ldf^=2%&8s?vn?hXGA4Q?rUD{vWLsGVtJM_m%_rQ%Y3z5E zCDie8PUVR=oZ9sp;hz?}!EuT}OJh%gH<;0~a=+D7Rli&WRqP*e^67fS#rArNn*TUj_gN?FJf?y|3o&+hCSkp+- zH1%(RQF)jw8(p&=!rVP-h$<}dPCYKpiHK;9J~*7m)+GPA0)E*VX5z!2{E2gmU(a)r zz)=rxH-C2Asxol{YL4Wb97;}e5wa||icVGMmLZy!AyPcXFTaPr-(#dsX1ISS9}vLW zHFngo6fZAOu*4k)mi_+ALz~7!n=~TgkqtVw@suBrK5=fpNzzjGCwR0KwS+>amk20gEU+sQVvjA(~9Q) zIk~w}L%Es|%Xj%mV`a#qI>&!XFHYBn0dl8qVtSmI1mH$pU2gd>y|VI`abrkVB8TD3 z$9XVmn_o9WbbrDCrPSzvCkCqJAYc+4fIfBOnCDsw0uq7uP>~^@8$f`1xx@c1^(G5^ zCveFq^bG`}TyXL0nav7{bz>q0sJ5H*$l*4w;eP%l0bg{oWl?%3z(!ALn!7KX9Av|5 zFhCwxbkL#DOPod!F?UJ6mnr8iFi`p4I*pa~5f2HCPtc>zWz=TN?9QLbTcA5R48rwe zUq3&G3yJ-g?{;PnFGR&%+8Y`OK;}%I=X?c-7M1n&V-MTc!G%%j?6NB|>Y)3)ngOf} z@ZnGVzb?>~FGrmPm^^Ygn=OB^!(Aydw~+oNyOM2RDQ~udbr*%7z;sj-v9LdyB8kGR z5RHH3NUYAnbnYUU5ywQENKb;xNZ_VWu^#r4{oc9*%X!uTqi1f?N`qSb{4Tjf`B(m^ z{GhyMuiq;Sm8x?j7yWr@Vs@woicBXi$230o44jE=X?<8U6)9#H79yc8!zG|D7q&bH z@6}U4_+d_9ZQ7hu~_1_ zf>eF0E8?aIkHz-HdO(CVj3AUBWIzo#ky;MnOdx;ZpTJfmGB0(eXU(N=+iF{ z?PM|)WCkH*Q$l_@X{iqhD&QLYQCQv(sdiF(?<~&9WqW}>pgB0~ zoTOt{X#x=+dE4>XS%L8ubvNRu)8`Y9v87@v3RLfMwDM5{XlGe@4 z4Sud1VU|agXACYrGPwu0Y<2Ix|B`dAk67&vsT{9t@78wAUb)nq>;uTk0h%ASnrY zF&(==K8PCm=oV1Fa&*OkQV|xmSMUt~^wRj9eJQX3^ISH0JYY}Bid29;rKfuXwSEEE zzYPSjavve&7P5tfs>=WUzQ{aIKF4R0iXhboMp7BCUa`2RD%~U|b)AU`h^Ui?BIHEB zo}HegdPgls9E_R{8=6O7zbCi#vd&{?2@{E+GPJoU_xzg^&4)5cm8;!*K%!pyw%GiM zl3Y5jt6>Pv$FVm&Dc-f_`1DA2*(ki9#g4wA%^t*ve+_%aer3Oi6EKb$tCZUC`|_s_ z7i3%lvnZVWUoJph&zY)#px`DVk!ujVAZnb(4X_WE&mID8OYAR@aDL3w{A2}M_k{!Dpj>pry=HN^dtO{HZks&Lj|P6q{55;@(5o4*!=9;k z(j&5-#Kw1@gv7q|_|qp#$Tr99^YyY<>H~tr5hN2N>i3y!*hv7bn4}ar7F&Fh#6&qH zqA-fe=z&vv5{}H|3&%s<6sn+pD{>t;V~$h!g!Q~e$Zp1-Qv2hV2GK@4SJp7eA6<8z zSG842dMONF{h4gNhRQRUBqoI5HxR}EY_bnmfM2H6&FakH_uuzOBc60@#*}`&15Nl* z0MJC>aCOL3dos5IHx|N=zzDKQYMS5k=pFP*F}~Ckl1{hqTQax=xZd2-l3 zwgHrSfD+7=eoULN7XcttBe1(+8{s#Z(wi4HE#Uj~OPxrT#|R2ilf2n~f?xS4jDp_N z=1k}xpZPZRy0J^JH7Nm2l*S?LM#T2ivfnA1`Q1Ah<(vCXLP5L&;)ie|2G&|Cf4n!F zl{G&!yZS|xJ?bJVhEtc!#94g=IPt}O&4g`*!#v(n{drCLYeK>ZOL|?Zf45y`g2?W@ z7%*eudVA}A3qWtk2%3LJA`+Nz<4Sr6PRHuoAKHE#E?Se|(XIPBY? z?>}=L2*s7S(oqQq?*s71fY?n9u(V;|nV13!R~k^HDuVwGj36K1QA45j!J9x`49_07 zy0MkDwS?VWM-Y~i&YY+l$%~d;`MvAB*p{?8f(TPbl%9+1`VI2eVplebT2C)p!|f<0 z&sd)|+!u;PO9+4VhE_NtKE91G`~ywQgw5FSyL15>VAu*v9Ekw#(^#+{U}tW>RO6!X zv&x;03H+IIC)@74rjn^Gz{p6 zZXo6*m%VKOcjA+W;)8-C|Xtm6rEG{RwhXaMkM>dJ3ap#V-MpYYi0XVA{aaH$9xm zhd-j2Ubt4+VQWcF-nutTv@+2pTyMJ5_*>b;jJ)1#&ehO%!`1Y-Ht3|lk44J;?t(U( z9QHA`vXyh1&2byE<9XOam>T|c#x>Xp+pQlgL4o`rO&QgroeGMEZ9XTT-_qSrM4O(T z{trO8oX`q+dDijd`JtK0MEB?KV?Zn0=ro6Li~j#FI1^kjm}@`*?6OrrSWwSokNO89 z0KMl{43!7D9XE5!wfeM5PL6I3ouA<7)7jMVA9z7b=M-Gy6AukAg%peP>`!%{aLa|e zO8Wmo?iJto5%o?xQ`1a`ZvBxY^`?5tR9+X5>)QhLRNNS=a2Njq@KEkMFXPs54R7r3 zrUIJal~z|WSih41ly9o%*gA0nKDA)q@80j5AReH@sHK9*@$m3E{#+k6Q}RQRfNT)hM z1QiCE<^UW_OQWjyNR|G6BD1^BWM+EqvOrBngV@n5?)u?*_okL0CJ(?#sbZIGX#7xF zRhVc~x98QkX;jy>#P%QD0-R#RxguYv*gwT;?;QBq-Nk+2`?nxEVy`QJ6E#>FTv0uH z2@riPpI1`A$YR#?Q}h6NHrks9jX-P9^LC`MR(y6_|64 z0k%eZ%Q0wo7V7?+K_f5F&Z7mLbc{JB-iYx_V9mS@ItkQ`ci>FXJiG)q-UD*cr=XPc z1PQJkNck^8@XctM0aDLcVoebSrc=js(;AJg!niS|6ptxFmt2&#f$;KY!JZhSaRNs% zthO;j<@t59PSgO4?FZug1d!plJAG__d*==XbUl~fzy#BO{tncRG!RLoVv&VV<2_)P zSWFd4(}4t-|JN^Y;F=y(K-FM5;ZG7oXLu*5-R6$?Lc6pe?0<0YC&UB*K-Htik6ERF z7o5)TUO-dVHZ#M?Ru|ghWk3a+vN7PhB_KUNaZ+7uoXQy4xL{!-SxG(#zv-QWBVgwg zQd*2@VZqgUjgt>OSNvAGyoAFvFE^xyIsGfsXQDt*iJHdXAhVmo0uJI@x?6vi>O)8Y z;aJb(gFW9S#0>yWv{Eg{kEL%Y+T2$8cG}{dY)#F~yuc%`8hZV5jP{w^@Hh&z)F9m_ zn8uqj0QcjV<;=||BrIG~RfPqV%PO7U0TE8Ic0q(9%fjDet%R|RMcB;x<&vG!_x>1V zg3S)%YxVoz5FWLrjTF1PLi(Cse{@0Kd3AEJvV?Pf;yFasI3(J`NcnT%*r*!LpbjZ`uSTWk?y|3G_4o@dW@O9n{AA6u3tNP}RTvzyNV3BR>Ki4H@ZZ zqnoVoYU(i`Ei}Jqhjy1u-Ur2p6vX*ehqc5cfZy4q4xGZ@G?-DZG(NZHw0na2M;lp4 z4PDK54s(qd87lo|=!g?wOJBprc+60w!LooaXkQhjx zKYtDp93lFLW37AhYEpSQ%i_|K4Ui_oUKqwAGfvh5xh^FjIjr;q`po|V1QsOFE)eGG zyR=o=0uEr(mLXRTpfMTR4{*)7^>v9B|FhE+YSDNk;W=~(uail8eaDFum`gE-;6v43&L*_5a{xEhW0Kzc8l(V*_wisTOr1J zi(Gb={IA}q6?I9X#`-5cjdYl;TT9Y}X_J&Ni17%zFh=q9l6nQ8_VMni;uWj+66K2D zeBUu@XA)QjV{qW^pcoPX#1E+H$Ko!1+_Escww5!;eZq2>AMj}a*ryW<;cbSv!EApcg{# zbqzsT8+boK&4{R~QaR9$iyIwbwA0ib+C42&7W8ayC*_JIo5O9%!|c>)Rl}+o!uCU7 zJFJM^(1B6Fipaz_Jmd)?XK2FOYB}vj)WKK2Zk~~%x+i%)8wXrWpNw&boZ?_HQf`3h zI-Y=y2?_az{~;`6|J?5oh*fsVWGMCQIBI(5IfN%{4LB75B9a$yM_0mlXxU@Dr5xu} zXICdHAjW_S-fsvPfD$vSy91+BaGqEkG=OU!X0)r*{rKe~Q|^ORr`+!pKI43b;fi@j zM%0Su#9L@w`eou>JZYaiTxw932L`gbdn>(-255u^Z1!wHJl5^H+RX)1tpscV;pB75 zTVSj~Brnhg(Gq4Pe&PbWuCEvZBIK;B3Zr!9+#n*(<|{Lg0kXH33aH58#m98d&(orG z2njXn#ZNj=pdmNWWR`G0xV$`g;4%Bh$0NjA65&4YMKOdoe2KSC?LH&(KW0Ffk1}qP zl$wO*3Y4uT7rpf@nq5g;PQGuykN&g~bWu%0DYIZpme@MqSu2Rij9)v^dj};}lIUU|S0Z8-xt&9(j0+Sd({E zgiqT%n#|TE<{gfebU)2lg0@v96m!5M80ra}cyiWB|=m!a4ZsnptybP?Qbu zqlYwNK;n`L5dZDcbj?AiG6()}qh{>@5P@MA>ye*Fg$e(jQ9x%UK)P9HAM7!J@9 zPdM!Fd&zG!eI>H7vNdRmy_7DUM0bZqt~PAVrIOyB>Zd$CTo~pB`6~-JGUc*;(@09$ zc!ufz8Mci^b;4$I9lCI+jp+j~5G= z+6eoIJRsCz0cWIS2cR)1xz&Tdr|{b=4e&-G(P8d+ z6S1-(sg0CTRm6W#Y*162!Ag*-K#0|Dt0hdoKsvbTJ>0;kWClX{90 zHY^h^m?nell7Uv@nXA}&_JWutk&Y!4llCY4vHGQrLER%*g55Ezq;TYn1$8`XQ5y!q z#cS5f2jHnY0iH|(nCfx6uNA7gR;LXoz>h6naw)@dqlT1|s;a8cc!YHjia;k!Wpi4Uj%{_SX zNdgjNLK-$GJb$|gGWQ%n`{$HAu6n z|4^+j3m`Ptrhb>k{!*jzf~E|B<|Tm>lngzR!I>}+X58t`{O3okG&IeENI|Fzq>SXr ziov`5g#(~KeO5J8OmP6_R}BJck>c|t1z?T^sy*=hcALPy&IJF?N%KD3$YznjZ=H+a zD^URc<#A5vqT5OQkcX$a&$OV*$K2dpNJH`kA~49OC3^q@##2yG%+1duGoXUTPbvO? zLn@oX{w0P}{M;8@P*ss@4pRA!{qkHzw|byI1#^nvkL0(-0mm2W8S*=WSry=&-T|i% zNdGxGfo7U;`VpuNt!}r-(%UXT&eL^RH`Mq1=0QRNwD00}#dg`j7xQBzLg};5EU-k? zF*3sFlnC7U1p1ah=Bx><{h`mlvjSX#PXqKw1eBhGNmZkA_as4VhdMgIEdva(`|tBF zG0KjBmIfhu1fqx{Q0m7WL*qzHO-;dwoBuxlZcQCvdSLfZ{|0F<)Lm6u`vtQ3A;;!T zAk75RTEJC->!SJhmFuJb{oeqH8Vh39Yhmsyy8*E3P((5Kf3Gv+qw&AtpuoTI3x@RN zK-{rqJ(-r&RVlEq|M%gA8e)?U|JH~Q_64wm1{BCF1ZncIhFNwjX09lT3 zV4 zW#b*I^cBQl5(SrqaEp&c!HlxijRg}tHcF%Ycf{O2 zbz<~l(*suIPOE(qLeX96f)PIZXq~5;njRWMN zFWW4@nGKZ(%n5PcvJJQ{WfTRKi{Z1ETbMU_4F|2pjk{cn?f!D*lS(QY#A|jomwATY zL?fVj%@GbQ_oKS@!F?R2jM6-AT=7;rh9p0!ODnMaz5czFJ)%ZL>7y$Yi(${=z;yF5I zjxn!B6u`sDL4%eNzjHaMY0c5MSCuC`U$M~zEOCp!=kkSrWC__1ObHUv zzgV^ty*oR5!RnU>7FM33-`t_xrjC#G>kh+{Rn=M_41(m=qB;bo)ur#>wI7FuF2zDw z1H?%Y^{9x|$p0SX0+i-K6$$Li=~EDs$-v|)=6EH7A0jrwRK$ovFnUypv7a}8VHduK?L(RD@&_5(Z%GMc?r zMu&mYXUL;njf#LKJ&y59i)_7R_!2%H!i7JNRyrrgY<7MVJNmvc!q)FFb%xwgM+SZrEFl*) zRyKcvL+r<~bm(f!TUS0+J@T?3kjnMHz1uN0-gtDmrC9r^$IrTi@IEo?Odfu`G5LTD z!KNqb+*4%j>f$9T8$L7ZWv$x)0Q~{umQVTLj?iFez6y6E3LnRxJmKS_{WWli$B6te z$q~nGd9K1HSqT{vH3Cz?7H~SsK4B&s(0}-d zo;ubX{E-XQU2`yFf;N4aHBXWg^^2O;*<8*B(+vRRM=D!$+Oo0kR(&%)`mL-?PoHH~eLt?w*oQ`9kWBd9L! zARW{;XQNyzKa0qTqn$Iem;HVm!2$ZyNM>MwSX58Pw$ya11 zX_`wPEX}Pp7El+#*(U^Iv_IQ^=_)5FTW9P9G;Z2+JX?uj1-stGR2}FV>I>WzSod5*U-!oq(og`xL`hFkTlg z#`61gc8~dgc?jn&J2AhJPiD5ubwZLPqO3~>4l8D>CY+`zIbLAS;lPQKbtE8NFugb? zJWgIh*f)H`z=Y!>D!uJf(rVQNNNS=GMM0H%cF$UEitsYw&PRlZ8Wa6pjwo3pMi)hM z-`L8B)Ns#9YX+bVG?&2LcjN=`%oa#79*eOu$M-X09A7m5ZAE*=c(ySPA9!ywt z%gZ&&!m(XP8W-B*us5vT&~;2rI0g9f;qKHN{r7cvpFV72)cK&xk)Y34%;8RAM>D6- zcR3P)_^TGKZUY^{PKvg5d`LbLM%h*6HRFIK&MpuZRwvt%bkXfmPt1%`f{m9+Or4!W zD2XU(l>N$)B8jox*M=^pfIp-4>>vR>A&ejafvuBPPf$I0z+2y0QRnFtYN62VRLz5# zU9G*C*-jAxoOz>1ReH>SDx zpE7v~Jr_jARm@(~csEY7G|X#)qm#cTH6Jvq@yg8cW$q1m-6*>}O5DVC3T} z@{jb{cVoxy83zkdcO%Gc8wU!_8n5cWk`Dba9_?@(FV;t1K`9cBB&(EnGiRcVP$ec*qQ)Ws*}*a6v`G|Q(IeFzT~*pQT+5O znTvyVy7LI(DMkYZEQl#i9x2JcS5@+ylEVL32hlKz_4Dv-zbB`_CE1$2U+RgrO^*1p(NZ1g1YN-EHDN)Cdzoyf`>VSnf-{K*fe z4gb6yQB(akof%enrZ)f0o+4uAErw~XgYk8yj0m|@TOuc&L5Q&W_HD}cP{PY^=<2Of zB4g~p{+TnTI7xx4uuro{UNjC$IS`_?6%`nUHKc&J5DjPrfQ@|$S`VfN4c^+gA-NQ; zhreM7$oiMTm6Se8Nn<;v#8bP*K?7!hyIHs&eU?7VzK}C5KHAE=CjU`hRWB4R$!Y$mSFl3j?%glyK zd96kcN#uOtqn3{AqFng$AKJluGn?kh&7a-Y*)l?IzL%SHdtRNN@Ka|B49&a>(JE?_ z`%QOZ#g~MU7QTX?9R^(s;F{ZDJ%9{pCBXo5DbPdV0c|2MOFsGOUEH6GiC7N+tcU;+ zErf)CJ^*TOZ!hryTfHz9!^bX^URcoYqyAAp@djXnRTeqt6tU{4ls%36Mq(Oy+td5u zlmjj5J9Esy8Ytbf`|a;$-c952*e0}b&V%x?@0YKy*XOo_5?Beaqhf=uT#oRK^~}{I zNcy#BFHIaHM3pSwVA{|{6pC;ASuM=K~OIFHoMwVsD{|ANT~gPDc>Id7<7@FHGv z8$J^oj^+UueO3*+9K3jAOsOnluJ7;v>_2CSEN{T&$|Ck(Ky9Kd!wc_t7-;q&Zq}Mr z$%KL0VU#eI8Hb6-Hw^iHtfaymPq;O`FDENswy}L_r$*NfJ;=nARi&>cv|{e-fLWzC zW%!TPg>p4+0ZZ#<82`ZK{?Z6~F)T_wlRH6mQUTZnR>HXn70|?r!$5*jqu^si>E3FP z7)*g$L|mp}$9Y}gE~$Z#_j8+j21Dk{ZyG&fiqmC5z-FANW(Cz*o_U zSP(7Y>*s}2;cP_`1K~tp%b@K;qKAH@2<3pIV7`B89nrDLb6$bO`Qi@<)eNCBDMXdL zAV~>JGMR>vd+`fADXS|%#Pdnsdka7)hh$Uv-l+yGe}>JJ!Rs1re_z2@*c(r;u<7-$ z$`A}{DhgW9N-DV~$koZWAmKQMK9aKIfpJcU16GSlWAPF*terq*@yXeKt zb5075c_WCb{``vbM*W`#8gMxo20bF+;qoNpXRdB6e|cM*5Y(|e{R}1-?$a&K(H{kM zj_Z-jUoCu(8+!g6@x9~Yrf?z22xX_XDU~kCvo85agHaVU8qM;eR9Yi*733o77%? zRosv9ufvz|^K}i6QL)%xlZHe8(QZrf^n+3I`E%o@}Fe9xZZxP z*jf$1PXq(Or}CB0<3?>qqR!$Ax)3(`)HQkm-fnDUIX0p?lO(ZDuk8LMM=4i8u*C(4 z{~3Sot)oW*R@RF}hU9%9q#=b3fBy5UT_Dwoj$EajAllFF!sv)-Iy3rCf!-t`s;Ia8 z?c)Z7@}dLJKu#X0Kg>7-2n%=y?@<6fmmwbr5wDkW(ABNfbd1( z*|TnF1&rQaXF^XI_$~Xdjhzc*{`w8iNai1ODH@-(iee+w*lq!opU;IN8DF)xX6GKb za3m=7v5d6jg~99xwLfK?e$7`H-M)ycOQ|ENPve?1G-AYUvcY=7g3kZv@2?ZOw&d%- z+m|&~2Cdn)=ou)eyqsPYpPqzu;Rs^hd4GKgH1CQ0J1_57 zw_U&YJ{87xgzUK8jE~PskBqKpKl?hq`YSW#d4yQ!YM;pRXOjUnGb+iBzsbM27hLnM zLVlXwEPivkSm2|7;xiEfZ@Hg~2OGT_> zPjYeEBIdWhCB@;>{^xIWzkmF@d$IqEJ8$vn;H?@AjHa6PX!is2n(qKN4a{@yhGu3l zr%z1NF3VZ;1{nYjhj#bIGBDKbmU>?!o|9Snigm`yu|i(8s`z#3qvVW)NzNOmH1i5J zw(_KWot{IT^XxR8I{oy+*sj349IgWCvJ$(JHA#u06%_Gy0lziFy8CW}GKIKlMrjTb zWfg_bO()ze-u*2jR&}VUPX6B2?S8&r$~P~XNnEDtdCA#EdAhV)JM-H)5{QezM=O>@AD=0#GA1e;jG@re;W$Yop%kaj(;$- zc~olyQ~D75PjfJsjD2SJIJ}yn|j*h#(joqvvSF>-Z);5LSvTCTkH4b7N6V^|>1#sP^;o+zU`+%_f0ic?V6HGxJ z`}@wl7t3Y|2=UA)Ful_6xf?l`AzitEhY7nO)mg+1I}w(ef%IZBRe0X5ssfRfMm;lI zneUi=ljv3q8;6;8m1)*tbb^Qa;n`{lC5e?>X2CtNGjVDg za-6ClcnOz_uj@;mX7(jXv;{Ih+s|BhZM)|iWU&|5YB5u8Dk&fATHgokaG@UX*+4W&y|>zCQ(sao4u%3Pq$3=dAdrl!_2ON%SG&mKDUDd z%krgI7xmb@N(vjt8%@FKalfxz9;JM$^&;<;HOuNoqWJUPeNNY4Y_DBQVt98#o)Pvi zni2f2R@lUZ3XOD#`OA!VJLdq_@;Hm51!FOMf_5^E&OtJu{1Q;R40>i^5XlQP6 z@HdK~=JNdSN+W|+4Y=93e~&f{2BYP+CvCzuDGQaIS=f?zgqH-*+V}JpJoI*C#&a~7 z^1uE{Gg(9CP=R}V`&;SS(M7XypjXUyw<+Y*!T{%$E|n?D+xnlXU5W2G-oJd$k)32H zHE3~mIn?B%wqOOXFxPA7E*exsll(&XD2)&n_V5HO*&?XyD*U_O01U<;Q`Mn4_B;#_ z%7A1?-OkPqj1hQy@<$?f`q#OLVqNXor*MprPdT}BOKMios$S?avUML-R~%6%vbE7# zQ7wz$_!gu-4%1G`vDG^%H1dO{ zhPxkG1x5MmS7zf<@_HGioho0>z&k1EK0Y@}9!WTmF*O{s1*XPEuPAxG@2eAwq3s;T zob3e1GBbv~iG{YFnK5|`B@N9|Vdx*WAszu<-eTjUfh8WHcfg9G8jxL#p(&AoUlcI& z9nf&fAus$fC3Uv%0=~ayE%CU1^O-g0+`<<|#IQyOaknhQ=8{Mh4o^whoeWQUMK9dHeuwS)I-KEcuvx!~= z$5J#scP!215I>7_6$`ahx4Sam>n2*t*simY98GDCU>=%u`92s=iiVT5JXp@&R%UBm zV><0c2Hg&($qoYui%<3ulQ%H}d&~MY!e1q8BWTdPT@_Xx6T@HGVV_qXJgOMkK_ z(agvH{6KP}r|NTTl;AIyM8sc43R%7(STS_z0lA>yG!-9ua&4zB1Jx4!b-zBRoF$ z-Ob)vK06ka1XBWfr5~;9*Uf9!QIETJd0QB1v{xyX6jybQyi`&%QroWHJ8a6D)>J9e z<8>ZQKU7%zBE@}nkg`kEmLw?Q&Y1Q>fSZk(zf%)sA;0y{k&f1{a>GelJg??TMgFqW zt&X#2nl69)i(=AMDB?RvqOpUu_@$qS z_)G@mky7t_j0m==U73?fM!pO*a4l%MTIe!uvlk-5&gsi|Uvu)xLB|FaX8n%7=7qs^ zrh%ulod26Dj?9PzW;1c=m*}y&sw~6_Y6-|zpQdVeF=7N?K0G{oYMJdJ9l+|IS?w>4 z@i~2@RksYuB(TP3fIop`!eR^ z>-jqi_qaLNDJh?Rqu0FBaI3+dZ%H+04G*ZgH+<&o=0t^=$sFc4f^??Q;=dyq1;8&7QQy zTYcGVt#wn<{&rnjd3cwU{ME|+>)yDrwr|inZLz}pr0}e7VUjxz&wBHu&iCB$E#G_A zyUn|O02m<~exA@et-KmozMHQ(32YgNUCuH4dXQZ{=bOQ`C%}bjZ?v{}H)j28($igY z)ecsBGnQ#A?mK--b+c9!^QAvax{e-mzJB|4JFu?Yrt-5->fkCl$9F$=s6AhLV_nlJ zODAifJ!bO2o8MDJ84f&d1D9DoJ7cZ^Pd_lPmdKuTwrsoCoF}VaZhx1@d)spDg+PxW zjpgsWx-!c*E;zUPh-%98V_U7&&elx~U%1im;@+1#W=&smtz2vRoq75aEPin}FB@Mv z5H~gJv61}$eZV`<7;Z3~(9%`}`N@8ZH!DLzcjnfpS+i$f1|~4GYK_(34sq*e0FUmD zeIvPImglPlvoA#~T~%~0rcjo3f138*`k;&VjEkgyu8aO{;+?s)KNi^JTJL8nBl>Gr zi|y1ibFIv_?yy}k(YWtbtge#$$<bKl>dCwJ$7J&9z&X9KWOM@AvA$-?vpaq2?VnS&eNCTogTZmL-tvdqZ(eUck~(AB8=p`XRa%{sI8lB1wF%5nd?=Cy|I*8wr9Xsh4D@X&y*bAi=-AhNYn*@=Cmouln*3jNtJV8mX@}j{pVVG^ zr$FMsvjxEREW@@wjS1ItPJMQ64OBTQY}B(PXxYjqyUtYgTv_uZMSS1I9Pgl>#~+)WMKH!3361HlE2^CiQ+D{>(Acowk|hHyS_Jr zFQGOdg#{SKfg&a$x?-mPu7;ky`+fI~0}VnSpI$1C=?n=AUB1MGQ5xvA16H8CX9irJ z!NAaPmCM%jRq54Nr8|A&>kO9vm85pD+KpE)3t}YSX z+uCaDQq}pGL2~gEfpV|AL2^*gFz-+HW_B14*uVPd1Z-NtfP+6MnqlAtC&&dbV7m|$ zs8Fy#T?FJ`DEQKXVguNtP^k;TpyCS#s_1IjC;9Ljr`y|4M$3Zi_jL7hS?83{1ON+r Bny3H( literal 0 HcmV?d00001 diff --git a/docs/losses.md b/docs/losses.md index c4126f0d..a85c00f6 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -1249,6 +1249,35 @@ losses.SupConLoss(temperature=0.1, **kwargs) * **loss**: The loss per element in the batch. If an element has only negative pairs or no pairs, it's ignored thanks to `AvgNonZeroReducer`. Reduction type is ```"element"```. +## ThresholdConsistentMarginLoss +[Threshold-Consistent Margin Loss for Open-World Deep Metric Learning](https://arxiv.org/pdf/2307.04047){target=_blank} + +```python +losses.ThresholdConsistentMarginLoss( + base_loss, + lambda_plus=1.0, + lambda_minus=1.0, + margin_plus=0.9, + margin_minus=0.5, + **kwargs +) +``` +**Equation**: +![threshold_consistent_margin_loss](imgs/tcm_loss_equation.png) + +**Parameters**: + +* **base_loss**: The final loss is calculated as `base_loss + tcm_loss`. +* **lambda_plus**: The scaling coefficient for the anchor-positive part of the loss. This is $\lambda^+$ in the above equation. +* **lambda_minus**: The scaling coefficient for the anchor-negative part of the loss. This is $\lambda^-$ in the above equation. +* **margin_plus**: The minimum anchor-positive similarity to be included in the loss. This is $m^+$ in the above equation. +* **margin_minus**: The maximum anchor-negative similarity to be included in the loss. This is $m^-$ in the above equation. + + +**Default distance**: + + - [```CosineSimilarity()```](distances.md#cosinesimilarity) + - This is the only compatible distance. ## TripletMarginLoss diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index ba653cda..543dea04 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -37,3 +37,4 @@ from .triplet_margin_loss import TripletMarginLoss from .tuplet_margin_loss import TupletMarginLoss from .vicreg_loss import VICRegLoss +from .tcm_loss import ThresholdConsistentMarginLoss diff --git a/src/pytorch_metric_learning/losses/tcm_loss.py b/src/pytorch_metric_learning/losses/tcm_loss.py new file mode 100644 index 00000000..045be5fe --- /dev/null +++ b/src/pytorch_metric_learning/losses/tcm_loss.py @@ -0,0 +1,60 @@ +import torch.nn.functional as F + +from ..utils.loss_and_miner_utils import convert_to_pairs +from .base_metric_loss_function import BaseMetricLossFunction +from ..distances import CosineSimilarity +from ..utils import common_functions as c_f + +class ThresholdConsistentMarginLoss(BaseMetricLossFunction): + """ + Implements the TCM loss from: https://arxiv.org/abs/2307.04047 + """ + + def __init__( + self, base_loss, lambda_plus=1.0, lambda_minus=1.0, margin_plus=0.9, margin_minus=0.5, **kwargs + ): + super().__init__(**kwargs) + c_f.assert_distance_type(self, CosineSimilarity) + self.base_loss = base_loss + self.lambda_plus = lambda_plus + self.lambda_minus = lambda_minus + self.margin_plus = margin_plus + self.margin_minus = margin_minus + + + def get_default_distance(self): + return CosineSimilarity() + + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): + ap, p, an, n = convert_to_pairs(indices_tuple, labels, ref_labels) + + # calculate the similarities for positive and negative pairs + ap, p = embeddings[ap], embeddings[p] + an, n = embeddings[an], embeddings[n] + + pos_sims = F.cosine_similarity(ap, p) + neg_sims = F.cosine_similarity(an, n) + + # calculate the positive part + s_lte_m = (pos_sims <= self.margin_plus) + tcm_pos_num = ((self.margin_plus - pos_sims) * s_lte_m).sum() + tcm_pos_denom = s_lte_m.sum() + pos_tcm = 0 if s_lte_m.sum() == 0 else tcm_pos_num / tcm_pos_denom + + # calculate the negative part + s_gte_m = (neg_sims >= self.margin_minus) + tcm_neg_num = ((neg_sims - self.margin_minus) * s_gte_m).sum() + tcm_neg_denom = s_gte_m.sum() + neg_tcm = 0 if s_gte_m.sum() == 0 else tcm_neg_num / tcm_neg_denom + + # add the components for final loss + tcm_loss = self.lambda_plus * pos_tcm + self.lambda_minus * neg_tcm + base_loss = self.base_loss(embeddings, labels, indices_tuple, ref_emb, ref_labels) + total_loss = base_loss + tcm_loss + return { + "loss": { + "losses": total_loss, + "indices": None, + "reduction_type": "already_reduced", + } + } \ No newline at end of file diff --git a/tests/losses/test_tcm_loss.py b/tests/losses/test_tcm_loss.py new file mode 100644 index 00000000..cc1acfe0 --- /dev/null +++ b/tests/losses/test_tcm_loss.py @@ -0,0 +1,46 @@ +import unittest + +import torch +import torch.nn.functional as F + +from pytorch_metric_learning.losses import ThresholdConsistentMarginLoss, ContrastiveLoss +from pytorch_metric_learning.distances import CosineSimilarity + +from .. import TEST_DEVICE, TEST_DTYPES + +class TestThresholdConsistentMarginLoss(unittest.TestCase): + def test_tcm_loss(self): + torch.manual_seed(3459) + for dtype in TEST_DTYPES: + loss_func = ThresholdConsistentMarginLoss( + base_loss=ContrastiveLoss( + distance=CosineSimilarity(), + pos_margin=0.9, + neg_margin=0.4, + ) + ) + embs = torch.tensor([ + [0.00, 1.00], + [0.43, 0.90], + [1.00, 0.00], + [0.50, 0.50], + ], device=TEST_DEVICE, dtype=dtype, requires_grad=True) + labels = torch.tensor([0, 0, 1, 1]) + + # Contrastive loss = 0.4866 + # + # TCM loss part: + # Only pair (2, 3) is taken into account for positive part + # Positive part = 1 * ( 0.9 - 0.7071 ) / ( 1 ) = 0.1929 + # + # Only pairs (1, 2) and (1, 3) are taken into account for negative part + # Negative part = 1 * ( 0.7071 - 0.5 + 0.9429 - 0.5 ) / ( 2 ) = 0.325 + # + # Sum of these losses -> 0.4866 + 0.518 = 1.0046 + correct_loss = torch.tensor(1.0045).to(dtype) + + with torch.no_grad(): + res = loss_func.compute_loss(embs, labels, None, embs, labels) + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + atol = 1e-4 + self.assertTrue(torch.isclose(res["loss"]["losses"], correct_loss, rtol=rtol, atol=atol)) From 19ecba5fe65eea9c708467141f86b0bd708dd4c2 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Wed, 16 Oct 2024 23:40:06 +0200 Subject: [PATCH 05/16] run formatting --- .../losses/__init__.py | 2 +- .../losses/tcm_loss.py | 30 ++++++++++------- tests/losses/test_tcm_loss.py | 33 +++++++++++++------ 3 files changed, 43 insertions(+), 22 deletions(-) diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index 543dea04..6bd679a7 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -34,7 +34,7 @@ from .sphereface_loss import SphereFaceLoss from .subcenter_arcface_loss import SubCenterArcFaceLoss from .supcon_loss import SupConLoss +from .tcm_loss import ThresholdConsistentMarginLoss from .triplet_margin_loss import TripletMarginLoss from .tuplet_margin_loss import TupletMarginLoss from .vicreg_loss import VICRegLoss -from .tcm_loss import ThresholdConsistentMarginLoss diff --git a/src/pytorch_metric_learning/losses/tcm_loss.py b/src/pytorch_metric_learning/losses/tcm_loss.py index 045be5fe..bf7a6aef 100644 --- a/src/pytorch_metric_learning/losses/tcm_loss.py +++ b/src/pytorch_metric_learning/losses/tcm_loss.py @@ -1,17 +1,24 @@ import torch.nn.functional as F -from ..utils.loss_and_miner_utils import convert_to_pairs -from .base_metric_loss_function import BaseMetricLossFunction from ..distances import CosineSimilarity from ..utils import common_functions as c_f +from ..utils.loss_and_miner_utils import convert_to_pairs +from .base_metric_loss_function import BaseMetricLossFunction + class ThresholdConsistentMarginLoss(BaseMetricLossFunction): """ - Implements the TCM loss from: https://arxiv.org/abs/2307.04047 + Implements the TCM loss from: https://arxiv.org/abs/2307.04047 """ - + def __init__( - self, base_loss, lambda_plus=1.0, lambda_minus=1.0, margin_plus=0.9, margin_minus=0.5, **kwargs + self, + base_loss, + lambda_plus=1.0, + lambda_minus=1.0, + margin_plus=0.9, + margin_minus=0.5, + **kwargs ): super().__init__(**kwargs) c_f.assert_distance_type(self, CosineSimilarity) @@ -21,7 +28,6 @@ def __init__( self.margin_plus = margin_plus self.margin_minus = margin_minus - def get_default_distance(self): return CosineSimilarity() @@ -36,20 +42,22 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): neg_sims = F.cosine_similarity(an, n) # calculate the positive part - s_lte_m = (pos_sims <= self.margin_plus) + s_lte_m = pos_sims <= self.margin_plus tcm_pos_num = ((self.margin_plus - pos_sims) * s_lte_m).sum() tcm_pos_denom = s_lte_m.sum() pos_tcm = 0 if s_lte_m.sum() == 0 else tcm_pos_num / tcm_pos_denom # calculate the negative part - s_gte_m = (neg_sims >= self.margin_minus) + s_gte_m = neg_sims >= self.margin_minus tcm_neg_num = ((neg_sims - self.margin_minus) * s_gte_m).sum() - tcm_neg_denom = s_gte_m.sum() + tcm_neg_denom = s_gte_m.sum() neg_tcm = 0 if s_gte_m.sum() == 0 else tcm_neg_num / tcm_neg_denom # add the components for final loss tcm_loss = self.lambda_plus * pos_tcm + self.lambda_minus * neg_tcm - base_loss = self.base_loss(embeddings, labels, indices_tuple, ref_emb, ref_labels) + base_loss = self.base_loss( + embeddings, labels, indices_tuple, ref_emb, ref_labels + ) total_loss = base_loss + tcm_loss return { "loss": { @@ -57,4 +65,4 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): "indices": None, "reduction_type": "already_reduced", } - } \ No newline at end of file + } diff --git a/tests/losses/test_tcm_loss.py b/tests/losses/test_tcm_loss.py index cc1acfe0..f0349a02 100644 --- a/tests/losses/test_tcm_loss.py +++ b/tests/losses/test_tcm_loss.py @@ -3,11 +3,15 @@ import torch import torch.nn.functional as F -from pytorch_metric_learning.losses import ThresholdConsistentMarginLoss, ContrastiveLoss from pytorch_metric_learning.distances import CosineSimilarity +from pytorch_metric_learning.losses import ( + ContrastiveLoss, + ThresholdConsistentMarginLoss, +) from .. import TEST_DEVICE, TEST_DTYPES + class TestThresholdConsistentMarginLoss(unittest.TestCase): def test_tcm_loss(self): torch.manual_seed(3459) @@ -18,13 +22,18 @@ def test_tcm_loss(self): pos_margin=0.9, neg_margin=0.4, ) - ) - embs = torch.tensor([ - [0.00, 1.00], - [0.43, 0.90], - [1.00, 0.00], - [0.50, 0.50], - ], device=TEST_DEVICE, dtype=dtype, requires_grad=True) + ) + embs = torch.tensor( + [ + [0.00, 1.00], + [0.43, 0.90], + [1.00, 0.00], + [0.50, 0.50], + ], + device=TEST_DEVICE, + dtype=dtype, + requires_grad=True, + ) labels = torch.tensor([0, 0, 1, 1]) # Contrastive loss = 0.4866 @@ -38,9 +47,13 @@ def test_tcm_loss(self): # # Sum of these losses -> 0.4866 + 0.518 = 1.0046 correct_loss = torch.tensor(1.0045).to(dtype) - + with torch.no_grad(): res = loss_func.compute_loss(embs, labels, None, embs, labels) rtol = 1e-2 if dtype == torch.float16 else 1e-5 atol = 1e-4 - self.assertTrue(torch.isclose(res["loss"]["losses"], correct_loss, rtol=rtol, atol=atol)) + self.assertTrue( + torch.isclose( + res["loss"]["losses"], correct_loss, rtol=rtol, atol=atol + ) + ) From 285cea0c89a7e12f5021bd67059d45901eefe4af Mon Sep 17 00:00:00 2001 From: ir2718 Date: Mon, 28 Oct 2024 16:35:31 +0100 Subject: [PATCH 06/16] remove base loss --- src/pytorch_metric_learning/losses/tcm_loss.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/pytorch_metric_learning/losses/tcm_loss.py b/src/pytorch_metric_learning/losses/tcm_loss.py index bf7a6aef..9038c670 100644 --- a/src/pytorch_metric_learning/losses/tcm_loss.py +++ b/src/pytorch_metric_learning/losses/tcm_loss.py @@ -13,7 +13,6 @@ class ThresholdConsistentMarginLoss(BaseMetricLossFunction): def __init__( self, - base_loss, lambda_plus=1.0, lambda_minus=1.0, margin_plus=0.9, @@ -22,7 +21,6 @@ def __init__( ): super().__init__(**kwargs) c_f.assert_distance_type(self, CosineSimilarity) - self.base_loss = base_loss self.lambda_plus = lambda_plus self.lambda_minus = lambda_minus self.margin_plus = margin_plus @@ -55,13 +53,9 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): # add the components for final loss tcm_loss = self.lambda_plus * pos_tcm + self.lambda_minus * neg_tcm - base_loss = self.base_loss( - embeddings, labels, indices_tuple, ref_emb, ref_labels - ) - total_loss = base_loss + tcm_loss return { "loss": { - "losses": total_loss, + "losses": tcm_loss, "indices": None, "reduction_type": "already_reduced", } From 31eeaaad1a4313380cbb96b83c1b1560e5f666fe Mon Sep 17 00:00:00 2001 From: ir2718 Date: Mon, 28 Oct 2024 16:41:40 +0100 Subject: [PATCH 07/16] remove base loss from docs --- docs/losses.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/losses.md b/docs/losses.md index a85c00f6..b8ced202 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -1252,9 +1252,10 @@ losses.SupConLoss(temperature=0.1, **kwargs) ## ThresholdConsistentMarginLoss [Threshold-Consistent Margin Loss for Open-World Deep Metric Learning](https://arxiv.org/pdf/2307.04047){target=_blank} +This loss acts as a form of regularization and is usually combined with another metric loss function. + ```python losses.ThresholdConsistentMarginLoss( - base_loss, lambda_plus=1.0, lambda_minus=1.0, margin_plus=0.9, @@ -1267,7 +1268,6 @@ losses.ThresholdConsistentMarginLoss( **Parameters**: -* **base_loss**: The final loss is calculated as `base_loss + tcm_loss`. * **lambda_plus**: The scaling coefficient for the anchor-positive part of the loss. This is $\lambda^+$ in the above equation. * **lambda_minus**: The scaling coefficient for the anchor-negative part of the loss. This is $\lambda^-$ in the above equation. * **margin_plus**: The minimum anchor-positive similarity to be included in the loss. This is $m^+$ in the above equation. From 60c0583e4c0c94d654347983ecbdabff2383d576 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Mon, 28 Oct 2024 17:47:18 +0100 Subject: [PATCH 08/16] fix test --- tests/losses/test_tcm_loss.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/losses/test_tcm_loss.py b/tests/losses/test_tcm_loss.py index f0349a02..586f9107 100644 --- a/tests/losses/test_tcm_loss.py +++ b/tests/losses/test_tcm_loss.py @@ -5,6 +5,7 @@ from pytorch_metric_learning.distances import CosineSimilarity from pytorch_metric_learning.losses import ( + MultipleLosses, ContrastiveLoss, ThresholdConsistentMarginLoss, ) @@ -16,12 +17,15 @@ class TestThresholdConsistentMarginLoss(unittest.TestCase): def test_tcm_loss(self): torch.manual_seed(3459) for dtype in TEST_DTYPES: - loss_func = ThresholdConsistentMarginLoss( - base_loss=ContrastiveLoss( - distance=CosineSimilarity(), - pos_margin=0.9, - neg_margin=0.4, - ) + loss_func = MultipleLosses( + losses=[ + ContrastiveLoss( + distance=CosineSimilarity(), + pos_margin=0.9, + neg_margin=0.4, + ), + ThresholdConsistentMarginLoss() + ] ) embs = torch.tensor( [ @@ -49,11 +53,11 @@ def test_tcm_loss(self): correct_loss = torch.tensor(1.0045).to(dtype) with torch.no_grad(): - res = loss_func.compute_loss(embs, labels, None, embs, labels) + res = loss_func.forward(embs, labels) rtol = 1e-2 if dtype == torch.float16 else 1e-5 atol = 1e-4 self.assertTrue( torch.isclose( - res["loss"]["losses"], correct_loss, rtol=rtol, atol=atol + res, correct_loss, rtol=rtol, atol=atol ) ) From c6d5b525baa956e4f1f03720a9df68767501a7b5 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Sat, 2 Nov 2024 20:01:41 +0000 Subject: [PATCH 09/16] bump version --- src/pytorch_metric_learning/__init__.py | 2 +- tests/losses/test_tcm_loss.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index fab833f3..2614ce9d 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.6.1" +__version__ = "2.7.0" diff --git a/tests/losses/test_tcm_loss.py b/tests/losses/test_tcm_loss.py index 586f9107..e961bab1 100644 --- a/tests/losses/test_tcm_loss.py +++ b/tests/losses/test_tcm_loss.py @@ -5,8 +5,8 @@ from pytorch_metric_learning.distances import CosineSimilarity from pytorch_metric_learning.losses import ( - MultipleLosses, ContrastiveLoss, + MultipleLosses, ThresholdConsistentMarginLoss, ) @@ -24,7 +24,7 @@ def test_tcm_loss(self): pos_margin=0.9, neg_margin=0.4, ), - ThresholdConsistentMarginLoss() + ThresholdConsistentMarginLoss(), ] ) embs = torch.tensor( @@ -56,8 +56,4 @@ def test_tcm_loss(self): res = loss_func.forward(embs, labels) rtol = 1e-2 if dtype == torch.float16 else 1e-5 atol = 1e-4 - self.assertTrue( - torch.isclose( - res, correct_loss, rtol=rtol, atol=atol - ) - ) + self.assertTrue(torch.isclose(res, correct_loss, rtol=rtol, atol=atol)) From 3b371662276f58ea05b4c31b7561aa20251dad8f Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Sat, 2 Nov 2024 20:04:43 +0000 Subject: [PATCH 10/16] remove version restriction for numpy --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0818062c..b8ca0431 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ ], python_requires=">=3.0", install_requires=[ - "numpy < 2.0", + "numpy", "scikit-learn <= 1.4.1.post1", "tqdm", "torch >= 1.6.0", From c1aba4c5a72d7b0c1cc214f0f2c0b2d47c027d7f Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Sat, 2 Nov 2024 20:07:04 +0000 Subject: [PATCH 11/16] update versions in base_test_workflow --- .github/workflows/base_test_workflow.yml | 6 +++--- setup.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/base_test_workflow.yml b/.github/workflows/base_test_workflow.yml index 02367536..c697a94c 100644 --- a/.github/workflows/base_test_workflow.yml +++ b/.github/workflows/base_test_workflow.yml @@ -18,8 +18,8 @@ jobs: pytorch-version: 1.6 torchvision-version: 0.7 - python-version: 3.9 - pytorch-version: 2.3 - torchvision-version: 0.18 + pytorch-version: 2.5 + torchvision-version: 0.20 steps: - uses: actions/checkout@v2 @@ -30,7 +30,7 @@ jobs: - name: Install dependencies run: | pip install .[with-hooks-cpu] - pip install "numpy<2.0" torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall + pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall pip install --upgrade protobuf==3.20.1 pip install six pip install packaging diff --git a/setup.py b/setup.py index b8ca0431..66326d46 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ python_requires=">=3.0", install_requires=[ "numpy", - "scikit-learn <= 1.4.1.post1", + "scikit-learn", "tqdm", "torch >= 1.6.0", ], From 892f8f724decf3d3576c97864c47d56ef36d3935 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Sat, 2 Nov 2024 20:29:29 +0000 Subject: [PATCH 12/16] minor change to tests --- tests/samplers/test_tuples_to_weights_sampler.py | 2 +- tests/utils/test_inference.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/samplers/test_tuples_to_weights_sampler.py b/tests/samplers/test_tuples_to_weights_sampler.py index abe9a56f..96eecf7a 100644 --- a/tests/samplers/test_tuples_to_weights_sampler.py +++ b/tests/samplers/test_tuples_to_weights_sampler.py @@ -22,7 +22,7 @@ def test_tuplestoweights_sampler(self): eval_transform = transforms.Compose( [ - transforms.Resize(128), + transforms.Resize(256), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] diff --git a/tests/utils/test_inference.py b/tests/utils/test_inference.py index 7a6f4ee1..4424d0c3 100644 --- a/tests/utils/test_inference.py +++ b/tests/utils/test_inference.py @@ -39,7 +39,7 @@ def setUpClass(cls): transform = transforms.Compose( [ - transforms.Resize(64), + transforms.Resize(256), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] From e5117dbeb4d3eeb586778b8dfb22ef1fd51634b0 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Sat, 2 Nov 2024 17:27:38 -0400 Subject: [PATCH 13/16] Update test_metric_loss_only.py --- tests/trainers/test_metric_loss_only.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/trainers/test_metric_loss_only.py b/tests/trainers/test_metric_loss_only.py index 10f85242..b9f25cd3 100644 --- a/tests/trainers/test_metric_loss_only.py +++ b/tests/trainers/test_metric_loss_only.py @@ -66,6 +66,8 @@ def test_metric_loss_only(self): dataset_folder, train=True, download=True, transform=train_transform ) + train_targets = np.array(train_dataset.targets)[subset_idx] + train_dataset_for_eval = datasets.CIFAR100( dataset_folder, train=True, download=True, transform=eval_transform ) @@ -79,6 +81,7 @@ def test_metric_loss_only(self): train_dataset_for_eval, subset_idx ) val_dataset = torch.utils.data.Subset(val_dataset, subset_idx) + for dtype in TEST_DTYPES: for splits_to_eval in [ @@ -114,7 +117,7 @@ def test_metric_loss_only(self): optimizer_dict = {"trunk_optimizer": optimizer} loss_fn_dict = {"metric_loss": loss_fn} sampler = MPerClassSampler( - np.array(train_dataset.dataset.targets)[subset_idx], + train_targets, m=4, batch_size=32, length_before_new_iter=len(train_dataset), From 5472652e271769c500e5998939383b301161c312 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Sat, 2 Nov 2024 17:30:41 -0400 Subject: [PATCH 14/16] Update base_test_workflow.yml --- .github/workflows/base_test_workflow.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/base_test_workflow.yml b/.github/workflows/base_test_workflow.yml index c697a94c..d23f6859 100644 --- a/.github/workflows/base_test_workflow.yml +++ b/.github/workflows/base_test_workflow.yml @@ -14,12 +14,12 @@ jobs: strategy: matrix: include: - - python-version: 3.8 - pytorch-version: 1.6 - torchvision-version: 0.7 - - python-version: 3.9 - pytorch-version: 2.5 - torchvision-version: 0.20 + - python-version: "3.8" + pytorch-version: "1.6" + torchvision-version: "0.7" + - python-version: "3.9" + pytorch-version: "2.5" + torchvision-version: "0.20" steps: - uses: actions/checkout@v2 From 206e8396c72113b0137e101500444c2ac8f72ceb Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Sat, 2 Nov 2024 21:41:12 +0000 Subject: [PATCH 15/16] Revert "minor change to tests" This reverts commit 892f8f724decf3d3576c97864c47d56ef36d3935. --- tests/samplers/test_tuples_to_weights_sampler.py | 2 +- tests/utils/test_inference.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/samplers/test_tuples_to_weights_sampler.py b/tests/samplers/test_tuples_to_weights_sampler.py index 96eecf7a..abe9a56f 100644 --- a/tests/samplers/test_tuples_to_weights_sampler.py +++ b/tests/samplers/test_tuples_to_weights_sampler.py @@ -22,7 +22,7 @@ def test_tuplestoweights_sampler(self): eval_transform = transforms.Compose( [ - transforms.Resize(256), + transforms.Resize(128), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] diff --git a/tests/utils/test_inference.py b/tests/utils/test_inference.py index 4424d0c3..7a6f4ee1 100644 --- a/tests/utils/test_inference.py +++ b/tests/utils/test_inference.py @@ -39,7 +39,7 @@ def setUpClass(cls): transform = transforms.Compose( [ - transforms.Resize(256), + transforms.Resize(64), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] From 63d27dd23b0cf2eb0c8b3933a7ee6782376570fa Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Sat, 2 Nov 2024 21:41:21 +0000 Subject: [PATCH 16/16] Revert "Update test_metric_loss_only.py" This reverts commit e5117dbeb4d3eeb586778b8dfb22ef1fd51634b0. --- tests/trainers/test_metric_loss_only.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/trainers/test_metric_loss_only.py b/tests/trainers/test_metric_loss_only.py index b9f25cd3..10f85242 100644 --- a/tests/trainers/test_metric_loss_only.py +++ b/tests/trainers/test_metric_loss_only.py @@ -66,8 +66,6 @@ def test_metric_loss_only(self): dataset_folder, train=True, download=True, transform=train_transform ) - train_targets = np.array(train_dataset.targets)[subset_idx] - train_dataset_for_eval = datasets.CIFAR100( dataset_folder, train=True, download=True, transform=eval_transform ) @@ -81,7 +79,6 @@ def test_metric_loss_only(self): train_dataset_for_eval, subset_idx ) val_dataset = torch.utils.data.Subset(val_dataset, subset_idx) - for dtype in TEST_DTYPES: for splits_to_eval in [ @@ -117,7 +114,7 @@ def test_metric_loss_only(self): optimizer_dict = {"trunk_optimizer": optimizer} loss_fn_dict = {"metric_loss": loss_fn} sampler = MPerClassSampler( - train_targets, + np.array(train_dataset.dataset.targets)[subset_idx], m=4, batch_size=32, length_before_new_iter=len(train_dataset),