From 837847e406e85a6181dc0e4716b98ddaac1a3670 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 6 Jul 2023 21:52:51 +0100 Subject: [PATCH 01/18] Add details to doc --- docs/src/images/state_space_model.png | Bin 0 -> 18077 bytes docs/src/index.md | 73 +++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 docs/src/images/state_space_model.png diff --git a/docs/src/images/state_space_model.png b/docs/src/images/state_space_model.png new file mode 100644 index 0000000000000000000000000000000000000000..02f1afb7a5ec6b0f8d1c6e62e917d80c5850cf6e GIT binary patch literal 18077 zcmbSzV{~T0)@H}(*y`A}la6iMcHY>@8{4*RyJI`)*mlyfr|-Jo+_^KeX01EluTxd0 z&c?H!y=tA>bxyc~oH#rz4(zvY-{2)BM3laL1HJwB4a5){?C;351qA%x3%I$Utl+nA z^|9aI4I%#aL7bGtg}&8HaHnt%HSMfRh64KJr(%<{8uC8WgW>ZsB%*@OxDk{Xp#FdqmnVFeJMn>V`;hdbD zBO@aS2nf{F)a~u?l-b$YGcz-H zcXz3&spRD3$H&JN3a-fCz7c+t6cJQ$2f6f#^jzmmEGE;8G0|3|<#D9{n2~u(-CTY0 z>*sSB2qq?t`zARjJ4_TH1X9LLi2+MUwz=uTdlmck@n~8)GU3|W+k3q-cyp2Yw&2^z zMnupO`tj<^;@z6<IdtS<3T|;> zYcsY!uJd?d0Umqg7MRuQ?j{GfY4wrD*t2PfwZEQ~HM4Pv-{J`_geS+rC^v$XVvat~ z6})+6`m;=av^heRv@`&FqnzPQyMxPV2`9?Fk>CA#$;^AEIr`eNR&790h82rG$=Oc74hn1jSTg_(qrFWcr|TIhr#0b?WFZvh5U(%EsYAnKK#H6C zI+nmX2)nYIbGq^_zoMwQ~#A);)+{A#sVKMhEX920Bfs`@Utc(XLg_x4mdPGd)KKb&W21hvOy%1f_3bAgY?D!oJV6#ie+!%dvPqe;Et;R5>o*ehS2bi_-Q}Mx z_a0>wB+4l}e4~<=fM}GST^u?H0p0J((R%MQjG8?PH=AJQ5=ZEx@8@?Xg1Yf}2a_a6 zXl+C-(#W?@U8EBq7nUEEN6SD<1APlT#&5*3T-oJJ$by*Tto@cV(#FO^$wo+HieEwxNPIV_tx(PQ1w{cDgK%~k!zf(-Xy#J&XJ z_quWZ+k3H|6TE^wp){oaoK{);09*4f6_SWQn{i;(-Fs^W2a(G_J&~?5{|9T)iKY^& zBB~ANx{?Da!#yXNR~fE`WM$c4i-_*M9;iBbtA?vjf+4Bo5g!QUjx+d25e?bt<*byEpGMKax_Ab<$7o0UoaU2Oq%Op6h zE01JWi?Qqtk(v38CGrC}OTi5tIhN8R7^W0I1miQDIhM3q5iOI)h-e`UI$oQssv$no zF`PM?hx%4>ajWF@1Qu_kkv{vm4`h~Oz}KM5Stbl)YT7mIPXwne!=TKf?ZhUDm2fS@ zNXZy+7>>t(rBGixiu=r-HTTzeXR8s9^~2?rJL+%E*(eIGD>PBUb3w7jl~pP-w*^Bv zo4C2P!bAWj_6fCKT2gb)>(pU7Z_AbA`m|QszD|APT$LLJ-us75Q1Wr zJys?H^mNIWPF`9d3%{u&WA0(H36Cnb@t99~O$h0ph`5&4qYyoLOgU5%wed@V3{TA< zsJBTHm0ai-uP5^y${$3jbJK7gF^M+1_0jY-KOWb_AvJqo;YHnxJgcqFk^Ujhw7-Hf z^9b?119s~PFU(15*C%}imgl-Oi5Ug2*hsHhp1$W7bxlhx+^r07*UIl#=@dSC))qb$+vA5R34ZGk_+`TT(oD>23t+ z2YlB+JcBd{1-q4)!z~ihk5%h|8>tUndkS5q^!qJ&MDVLcb8|D!nS{1KhP2B(TXilW zb;1xbn`9PlCouF2yo2u3pl=F30~Fx6f%b+}-lFV0ThJI&X+LO9HGV5U$iEgJ!1IdC zL6@%C7aPaD_H?n}GdGo*Oy{XaMh?yaYdbp=u5#sozEMim)~qk$;>)r zsr`dgg*kL`a1`cwjvwaAF;7JLl6GYQwo_Vag~B)xyHWL#Dh#`)m{@O3J479r-Vw;Z zwttuX^JKwE%#moN*zne9j3Umvr!oyYt?7FL6Hu>0iRx`8?Xq89V94mt+KS4c^oQc` zpHMp%ERY=ZH0}Oc{ht@h3FnY#YVe!Xj2w$SJ%$?clTjYEy5Uhfl{rZlo;3sW^ryB8 zn<0i;SsCtqVe0oK`f8I>1bK)#Y5e-#9H`to5M?n;kZeI(;g2{6+(t*h`4e94X>O%I z>+hO@Zsf<+NX%|6R=W@*9`f6w=}m4Ogy$c@5& zy9USGRH{W?1BrpYZf2DCL8mru(leSY8CEK3m^tGWE{CDdU%^Y^+T^}UM*W(mmNTTh^4P#(eP$+!CABq4jl{CNck1$7 z{iOBdkT#ilWUpSi&g-6?Jvu$KyML-D>z@3=>~^SdZDMcv;9$mek)zaSCtmntdI2rb z9PcZYXDx@5{>nnq?f%g?GA#`B0-rmTrr_6}pEiccGc6IeziwgLH=Fw>cBjr)LtsCc@gmcp~d)MRo$<5ZW3_EdL=OxCH z0Yy$ICwtb&fSJqqhp3_R+Nrw-FP}>jsrQm)W}2IPY`gsWXacunyVq7N#|wk-jYE^v zOB3H_uAbiHiPp#zDv|kP{pR);fj&!i|72lzYw`R;88}$7PZw`Lu#;E7dceAQJyCu` zTKL@BUf-ql`FImm+va{XlP}U|{GC0$2BUZVD)uWG{q3z!BCMTYc8>tTzh^yrmHg03 zx)CBr{_yg}x&HZhWZ~-$saV+oqoc`My7*jC5;Jn-muur+e|{6674k!^>IGCc4qcZU zZjr&JbmGDQ99nh7gr(NFcM>0e#RVZkO$wxsy!QES716o0$JeR{3D!?$&)~JYG(4+y-BZlz zFAD2qG7s)G=5FuD6E{>Tq#=4G}{hph%z)Xbg)bQT=h+N-h zsN>P}?o85tdBo6Ypmurzd<>y=odq^Oz}qhWI@$Q;TK`wJno$?7*WAroUigrHCj;J5 zSRoddp4T&$S!L~)-Ca)z2>YIv3`Gt;KF|=D_lY2_o{$ifrjllR8^%~$6;Q8&Qb9I!h4R!aOBWNyRbWx~^iu^6-5{XE$0GqJ!HNwK_|ul&Rf56S*a zUJli~W1FLTw7EPpN6!q&?!_?L`+yoLS(s&eyG$1zeRSFvo70pV3gV7xP;+4ER1+JT z#g5n@U|zb=Jo;EQToOnC_-kzXb%frJ?&p}vcpHtJHw%Ni z?eVs#)PKmNaLe2*A5@$tr$w4e^Tf2j+@tJ+ZQdC1>Cp7+Yw#U-8-aQ?E&Zqr;u!xnRNtCqNP=48#82k&gQ07~+2iGk)x<4xp}cY~qfS?{ zvxS+; zGayaksu3T17OJUH$7X}`Sjt?+V=-WbZDg-mkcOFwYevKMg+$YeFq+^SBEh^bX184v3%@Di1da+bCA^O@V; znpz4%K83C$vqP7-)5>r7Q?Hy==Y){u@%K9FN$oC(lbY~U)QRy^iVY-JAi`n)f@<*L#*Y2Iv8av5i&XD-&HLn4F zgV;{I@=IW$8`Pp#WUZ)2-)YO-95%~^jp%Qs>ISAi9x9Ls7fL@ZTsil;m!Lu?xq~gt zABAlCx>b>^ollUeR>a4>$Eu4r6FQ5F+gV{olXyk{>TK6c!62$6#ptx zEqz}Xsv&_>S`{vvlw%V6NDXqZ48uH~U_F`6*-93WyyFav#hOxu&y|S{XY)iP#)d*L z2t+&NqOZ+!K9RLgqhQW#Bd`I#1m1Zf@>LoyOoZE9-e%s%)G{Gfo)Js86nAJ^iWRRW zU~8O`#w^enGt5{de3h}p%wR-sMOp9Eu#PK#U_LPo`)C#-W<+h^oKn+%U=5!X=9+9; z8(z(LM3#9C2u3&ZM|{}Nx)W>mr9@n(2U&o$N1Rv9OAx9^qv%G}msy^+^Pf(B{k|?R z#ixp@XECWVfE?IPbeTA&A$lG^(#m%r7;4RyQmrrU zyPl9DdfrJ??7EHX_h!G(Uka$?)jt|Y>uKltJD_XZGi|h>kbixA|W{b3ctcjWui<&bLuTBt>nb) zLe$w{lt?hIIZR>{Zr-lb4dC3b{mwncC37H{V_7%$43}IFU9?1q$YgOOm+m4v5ar)7 zYd)BoTyCA#GqKh>uMonuQGHOKM(8re7&+MPP6Lm|tdNU7K*p+0_k5%Ad)X&#$9U9X ztdL3YOz1|DcnTd#^84f&E;|LQmJ^1Vy#lC^_9n;;IhXYiK+a@M4#!6F!ZV#z@cX3l zWlpWueKJo$Ebl}1JW`<^+l^A%nLMsr;Krqom`fNpTJ7N%qgV`$ERc&QqBR!%msK@`J~vhO3x3!-?2H|4fC4E*Wd#obP<5md$$ z+jwVQW%0y`el^WF7yOJB8r zR@gQVJMjr3jMMQhk;mniOkkv!k}m!ExoMQl9aCO_YljQuMwQ4`TRULU;y9^RyVEL7 zHmht+GsjfyGF@1jTXe`9knczD!moy~u1Q{#DVw-oQ_HnrgCt#gte*2A$iJbhy5i5& z9d99QXfI`xFhg6y1>}Hy&q=F(L^EC1gx665f|l;2Rxs4Jq*Rv$%ygTZ3IM?Ot01U;t@Ax~H=p!WwX!6GoFC9S zhh8J_{h*&I_j%M?!N)Hkh@24t>MHpOPhJGCva=@N-$99MsXu{~M0yXukpdo9rS%)L ze1`kDb9bA+c5b@5y-pr>YRz@4*}DQ$0}gJ=JAcMmkvVFM^K_)*R#k3ylm?4kx3?5V zwGR3ggjaVCTxw3SroQBBgRl8n)#-KW4Z*m20u>n9oRg>To*aIA>OaDVdT)OyLDQF< z;pk6H^y#!x1X{uBZc2s4fFCftO`bkmZqFnJ>pn_MM1Q0`4e+{UveDTLhj$p6n!pjT zEdR(voml00kA1sJ%v{{NzT(Cum&@_vIlQBrvFN z?D_q4G*}1spw@?8ZW=5M(DR}?kvzO$@6$W&vA*@}oN{*2)t){$mc<0;Sk4|)^tSCK z(^bZSAuHd$=dAt>#q_jknBn12+`4}Cl7vs14cytL(1y0_$~w)w(+foeutV){r}Q3s zAo$(`6E)InH~I8kCeoiR@dQ4umY<^I*EVXOx=l=tvK9_0tcAWpxN9$=uHg;3g;<@UEA08hGMK?RKEPP z^3I)L!)_QSSZ;^s-T29w)$1E#!Cl3Q)(bb1hg~gx3ukn1f1fbQrlpW?T0| zd&>K*8s4z!9t)U}^3mHyw^Sxavih|AY`&30rEz2V#*ODkkFg?G|C)&RdU>@(i;;`d z`*&#cU=}lSM)0)79V3zfI~LdL2&b&!`T1KEWxIwq_(Hpyw?Wx#K3vmTUrc#X8gZ8 z@;GEShxVi$^!bsM{;4V#@+sd-ckdak{~vz+Kji#d001X{^+Oe}?`hKZnDC}-j{VEb ze~B;c^EuH=<>dd*t^8BoiKlze*<c*Hc(MK8^y&eBI!w3z->A0#{~#rr z3--9Z3ujkyw7%PF-Xu*^uUB%zu}Tw;KLV!aki3MfxOhQFzC4T~{XF?S_AZMNLeaI~ zUrn8?`!`RHby{PAE2j|rwp^g#Eu8~DJ2&?SoCG?%4~}*Z|91TNuI=o}JLWdWrecu1 zytX!*C^Gx!b6f!%@^c=(xpg$$dVj2=hCB)}{7S<%TPSd?;|hH7SCd+hFCUpJp)j1n z_ZmaZhWsxv2B(pPy`%Pz0(Nol48tX7h8ZN01`~$9_JEjJ?*mikGjq%PQ;{CRp!j2n zn_C^nkVpmo_K?4Lg(pjK{jK?l<7hhDtE;vB?ZFJ}yf>v7cZa*$UC#OJ1UfEh~z;a7uiogZ`=82uU8lbhA#T#UDlZdh%h)Kfn%?5dN2u=-l2mV)S%t1!l$8XSFs058>n2wFvF5mfr@{a`Z1;*9ShGd#m)XbL7sl1~)!l zb{#s%+;bi!xt3dZZ&!^sCht1;b6T3TVp2!ezr&1V3gkQV>r5c!LXBCTkno8b7=Nu9YKii`L4GO)`h^aN7@#eMf5+<0t`1gU!7XidHZ?! ze0aX^#n$?D_w;-q1?fGaUbj>O4zpg5in=#(-F;g%2^L$neY-zDUw1j-Yw4*Wz|w2PwY2|#WEV5xtyOB)laq#=wHuS zyiV#vxtPtotwTeN)V+CyhxdNGo*KN|9Y(%6op`^0l6~$m%1Zb#>_1&Ps7`al-#OxN znNA}8mU&m2mjtF0)cUm(IR{1r@x+{w;(wB{_ZoVIVmd$vS@@qQ2)%wBGvq)~)RDC_ z;0_#{@TVq5`}Q*pCe}-4A7x;gZtMH_DRU_g!w%`#0R8qK@;HJ^m+p_^)D^=-X0q*m#;o! zWhR)fjigHroT21u+>iIg&AG;A?)Pw@{Yx<1H11?e(IIPq2-X48)Y1AkvSPc#Ko zLr+C^_x1?81poyM4ux3@>cn8T#2=tM^?NFusSRSX&RR>b2R>AIHg~&hDR-?Te`x13 z&Lw*vLd?33AI`+7_3i}|u6uz4TYW==QYM1S_Ym<%!c;iO0QeqorCY4n0IZp{o>C8< z97J2f7b|RBm5IBDS^`!E-|4?)cz-VMxGh@&Xar7ap;hmr?|-Z9yRtonS;UUY)0a1Z z9L+6}AhwVPZAN=21#sbN6`Cd*PTx-fumuQ{f&7+`KEH+FJR(7}OS+AG|DbovdQ)E` z%V);>5)SyC*p`nF1zrqZB&9`0F2BO?)civ{yP}R}86IT&YtXP9Y$~vUWXgc=Bwd^) zAKvu96yWxDc1cuVfRQ@MpXN(v(nJDR$RZF@F5z@q+hr#%MHGB}$?5WfGdoD_(W&^e z0OwdDXn>#+BWS)L3Fhs4yUt()-096#RY89Je5OK>hh(pxI1kvge7Y%FOlB55 zktyp>iD~4g6xhoaw+D)~KsO#OP>5+)EOu8mPiL~v-ceDR#;xn>x1oqGjG}5pu>m<( z(+@5Z2|}vBWwWGPICm^+3Eu)X8u+~HIZ~zw-Qcm~q#5Yzrd#%L?lMbq4DipMGhGOf6{5UiR^jsoE>|1Z3nLOe;lB^MWBXW{YMb z*p&zLbd_Z(MU--%X22ME3?aqTO~?4-b2$MczP!`^z{Z+jlM0X5jAJt~kLGnP;4ho6 z?hywbH|H&EkipSGB?=3pfD}sqf#VEHEoGQ|L5Mx{j3~Zq7ib|22UMI|eZ60h%S$)v zU)TQzh6au|fJS8L8v3Kpoy1QZGkUyY)cS3y%Ga_gWz}1M6M+jhW_;X_739a2=P0q` z4${?HtI5xzjYk3+70ml-nQu>uvz~SOtu5|_Iuj8iM7hI-Pa1Hgu?mX3$eP5$H>gnsX2-QH5Il~|%2fXf>t0MH{RNzhQ{Bf9*4 zgiKTM>(Vx#PlUsLwZi zHjDk$ETWXbrJNMv5Efnh&V&VI$!V8$6D9)3U`tY_ zbKFhRNtV#U#YwO`N|M^jTN8t}rFfyJmLrrb|YfehnQxDJ7rY9wT5mICli>d)~E-Pv^)~ST2?F&*GUa3Ow~45ELl~Q1aF#oQktTg6i%#;2Gcm0 zeIz#pbs4+-co9PN6EFC|eikoP+C2%pt(WFqU@9baIhwwx(8A%@ZaF>4176`=vJKmr zjF{#lCnu_r>VL~qlyq!2|0zrRX7*^bG@*;H>1J0Y-cEiXMRd;x$ z!UH~+)1szr>`{QUce46RvGkvDIugq45E`7#c&Ty9+Ufg`806%DY*||o!F;WBT@u%k z6OR_w)iP)JhIyUfF$)whdsnIYvzL+gtl^M= zksCOek#i5qLo0w$$hx+>;?Sc_;H{fgFm=Td0#52I@VQ_1+128c#=};W10Tp=bScI` z2{q(K^D_Jr6|7av{VboTxZrt|78*Iite-spz5EoRfC@ z3l=M_>Rsi3t{!f<1xcPTlfP>z znem1MDLLi`G9(OR6Le5|38*xXOu(eQ!_ur9y5MkG7x*;|M{Oq6W*SseIFSCvZeiYYz%z#7M>miOgg&`-pR> z-;*5l9lB=&%A<>N5%oBVba}@HLXW%}6+I~Z2^P0J5OIA8LC*KAB|jh`KL>#`UtE+) z3zAtmP_0NRpQ5uu(iT!wtRCq{mr@oc@xY<9kW7@IWfD9V+e;|F)7;N0!J8p@780Ki5jam7H+op{s1 zU@`*#=y{V8a#;HW5qAR1X3gUOeujL%kvgRsePDj;;N`Tg@tDXdO2&0atcoT`F&Yi- zdhnSEg1slW_{p1r#=64CAy3`IB@@3e^={|;_Sn|hsWI8kshow>$aiI zjwr@7bfL|$>JJynn;_8bnfZd3iXaqaDJ>;cnec~&Mc^7SDR}|Sg{9dsEz+d(#=E00U zDCkCC77#+~TNY~(Ee?03ft#qfhpC&9?31JXF#r*jcS%{EMU2ZvlO3ew3lUx}OBI}4 z0ULVTu2zK8^V`tuZ;d4c4Bo^$d~ zo`s0G|1=cuF=GGT@^bBU$=h;n?~N0x{(!!x9V@a4mt?8mYshyL?|c|*Hi-h=gXtpD zy){YfKO~#HVJ5U~LYoOA1F~bx3wKw;W zLBm;K>j2F9U-vkDdw2B$LDnZOga+C3t@KP*)_!vDM5B22sz1ejjU(wN@Xl|&)KE}0 z+!G_kMrzWb>^|ge?Y5+hmKW!?#f4ojV5@ z0cm{FBZor3r$#uiQZkSl9mSNVo8n?6G?t#;ZU`=ArHt2d*rXdyK$I$zCL@#!BTS;P zQoi!-CobH~W0=gaeKtachi=kxM$;Cb z4B(6yIDS_j8fij(P_^!W*1K;l4?EY0jqH2lKmhg)qAPZwEpF1d zK6H4C04WWr(hQm08#I8xc^;qhx~QmIAwJ=h6ow#WpGGdF70_d3Z!x{Cowycfp}4%e zgthbyL3nhD9cUDK5!*f)ZYEg*)CF@9rJ=72196S~5y!4a(!rcMi&4+FYQB!RK5+$R zV=>=+hCEPFWJ(=lng9KzJ)!VFm&a$}BA=FztgPR8F~HR5&Qer+I^x&Y*X)wuI{R_PK=LMq zG943$?!&IANB3f@+lq{3B&vzg9uwNSZ~h)6Vh}u85H@UA9+LD7uX14{cROiFlV+BQX&)C^sZA5Uj z_O(!6l`_vjZo!0c&(~?f*9e>_X7WOQNw^+!e4Fd{$dTs~nAN)bu9Solhoi#MEKIt) zzA1a#@<`QA3~qwjZ$snh!BHsz<&qsYr!^xfu z>kIkq_LBnB8iI?^r8OnYN>@5($F8zPPM)dwoj`H6uxA0@E*AQq&M#jjD_Zb~845?H zw?(W6W^)hJR*X2i;dDgq?Ze64XTr9`8ay=$DDwdM?bCt`jK1ktCnd;9zpWD&SbAg- zRqpr-BI*Xd>&coFF0|geTR9ftF+&vcu2>F-twZgJpBKf(x0s7sxJG_Q`}?=x5nIIb z0|EZLrI!cItOW98VneDv-i1U)e{8Z&+h?MGJq4b3KDOB={))J$;FoS79%AqafF>h= zkFQENBe}TIY?;Ixrw+Aw3h<-qvo^7)uPGo(NZN^tu2B5+)|?16@dIT$;`G`pi#{>s zm(Me}^^2DxZxh^!za`6p?E<;8P<_13a(&No;4&(ak(UN$?It94KH6%lw~gmEx*zr( z>hkwMytFrj?MJd5;#^T3$Ej@;Up7+CCZ$-H(J4-gtG=*C+3yW{WSvlBUOYO-+jHwZ z$rXgrbC;nMG@&aXjC0xNd35X$Mp8VyIKbqhYQ&A0`@eP70*UGgIu)?@Yxd|#r~-*e z5qD#w^57{n;tER4?-si@U2oKep~{KI?^fkvL|{H@@%t%Nlo1g2>5SDa`v8FKsZ5jU zJ4?K3jB>>f`^$;s>7zzoQI88$#Ej%Q z*y%jwIf5>1w)1PuBd+)dCE1sPQvrQh96dF#sP_@WF4d(qz#Lp*u-0?lxeh-pgEU^E z{x^vuJ^gz4@O}RvBQgvlN8TlyTHD2(>umc2zR$bIY^-?(CJ`mA-oeoOT6^DAcul$o zs^twEZ?(?YP?-InocV4G}O(T?uUz!_Jea?b`nkF{ESRGd$S7b$VH-k{sz3x zWXzwTUTpofrYaJ=p7)Os&p`6?NZbSV)1pi5oZ8=+I((WoO>hqfWNbvD-Vjo1ysSEx>FL|`; zr%noCIN^BPoqmrS*2ztHI8FhLN1;`DJg&XIZu8#xNl?RxCAU`yR(qTJDWh!?P-y-f z*(!r;o>rxuxuxTuSSy%R*lk)nq_p7P^rPo~BM$2ahO>La%x!o>QS|a4iGVZB9H0bR zQLU!m*rc6Wwqgg)`Q7>AGo}WtkQR(&(4cw{WVprh=^}5U`CM#au8#=iyTnpjtP-%lo~%;)JFu?G=lujW=yh*X5lN6+dw69R$%_xd+1^ z``HhZl_O5@ko^tCZ}E5@*AEWKZJ8`GO%L#UZf^95ziEpc6c&*Vb0J|fM=JI-0B zVLYsVBcb0%t6>pRkdeIj^^o1M`5cqB02cv^GX8D7fzL-hdI=YUXUt#rYEX&)? zq5zQ`$~0w!EiZyr?8cOzj2Z)6p9Ep!ux2nZy@nk&1v8pXyf^`YYKnVM=x~z6Qr5^G zAjhLkH|>vZxI{d}OFft;6BSt}%1WZig^ocK=%b=X)2655Uq%uiDczLz+S+C80zC_}pO=Vw|$pPvIX_=kO{e1_K1i<}#(wp|fOsvEqLy++Mbe4i{U&?1$LwVf^T# z<{yes{!LazOD=&ccE?qMMkI_H1Veb4K)X>B?=v0LSZU1CKXhKzKkq<0(B#Qn+72}J zXs3>!WGrsu5xa6Kl<+t8c>^ ztINzeM-USIN-Q}vuN?Y7f78&?2iLlhdytK!apF;$RFRMgk~KJiF54<-co7K&EXwB% zfMqI~>nbn1VCTGoEM^$#-2zx{y^IBR01gR8GU;arl$&Y{JDE(i+!$06cT0@r0TSu}np+x8Y8cLFh=sIKe~oX9=Y@x{S& zFqFCEti9QkWogI_F!8I;id#2pUPN*8U~hUDAf`YlP%2FC{7LD2>*lt1&@O7L>m@La z+<(udxdNI)SRP^2e0zhf^Y_a-E)OLQ=OUH2-SOxlWN-nJ&NYx-#BMx|25#xCx*Lb4Uy; za-wyc+$dF|;$jo(_}^V|xk&#Sol_w;^Z_L)66z09`kdU(Kk}&&F-1MN+O;h-P2e?x z^F&lj#WEUc-|Nrb|KZC!0|r+~=_&cMLX7l#_@80gF@FQH!8{WwHOK6;w5LePtR7ViM@3iz40mXxIk0K3v@IuNT1c%C#CrnQV;MbgPF$tx`(RqW%^ zWx#k&;2#kUA)gNbAA!ae6Pu($> zd6Ro0*$nVDEoS<-tely)szpo%n#K+HHUp|#7Py6WM|9|obSZRn`MdHmW?qF}c}KlC zk0Rnpzei0CjnSz8K?TBd6OcThV9D;OYmcc$&c)4+-xB$#cXbwdcl+tf7%}hgMu-l$30P zZCUbPD%!J&S@)!KKYqjLXZuq^23Pjo8j@s(z^Hi^Bgq=(Uh6#%7|NsdAg+6Gc8J7iE`xAH zWH1FS2W;woWcsOEMcL1>7yR*n&jW^z2yfCzbgTY)T-{oY=IN8Wu~q+?oA70W$b^8v zVORJBb5^9Hm9zrysdOIg%#JAcoFSLheaSLb3dNJsXMzm;zLPQeO zl$i8>H|%0d>2DPnfsAn+VeCR#;}8Y6#^1g{B>v;S0HBFR#MpOW_bPlV z)}!UR(yej`^2wm6GAFy~q`qJL{#k)=O0i;_D0E`|(JFzoD;8AA5Py@3eahSc=>Xp3 zEd?!FHWj+B4{lUEs%ImDYNfIlr-*3LzP3K5%sSN-hCx9A2c>z9XY6IQ#qzzINV=(H zE5W?V0c8>o5M7mQjZ5$s*INi?#PDJ{MzSuuUuSD$HCA|_qG5Uxf?AEv`Eh1(;@m4v zPJj-4RRIh-B+wTLYAVm(nt zl!?^kgnDdT32)zOl96%hK4O!aQ|53WXl-+%z1$%WZJa0%2Mwash2CN%<9s?sHR8bN zDOnDysB{yf0lWYJVSr?)IHS3#9iDih=C|}mx)*T}r9s{{x`FJJ7%A$;qhLdKRGD8q z5j+Ik*h_4p@*$Jv1x^??MX=y1|MEBTix8c`M?SFtCuFo}$EOQebW|jetCn0!@FLUQ z)|BGzK2*XW(!+Zr7{eX=xoy5>&qYuEm?zTXZ9ADPB-hjq6pH-1H1`@j#2o%9G#`Y- zLWN(iGjj$Fm3et-ta*y~@`sB&kHq||H7U6uB)Y18K#7wcR-i1l1zj5t+$l_sMM;T^ zD`EOr2`iN)>f+h^OAy#e6t5iKnCQ(@E~fw-1inBq&J+_&f5gRK0ZSSiOinZzAb#B} zaLe-eWpPg*esjaW?6IpY_)2OYq_-~{f&ii{-&SHZ8hvvR`E*+nD7qx5u1S%JBa$@% z;0A?^J4X@^PThLrd2_E1p)76mcI~?M%n}O^9O1Vwq_0WruAzmr%6e&h&<)W%L=nU8&0*uL+rV-reYyRVGxQ~+Snb*7e=8t@kwM2G z2~r+B$@6^p+5H2DHsTLN>n0W+6-#lTw|f~kGL|tkzr{kDq`1l-1T-2YM9$~&?#oL? zIC2?pOwckd1cwHY0Wy@L%$KjUz~4O%x7R!}2C44UAH2%0w_y=MBSbmB|4j><-OK1t zFHKHty?G(Sq}`Bmeg8R;oGF~grf*pPAHU_hY#5XPHv^GLbo%%`_PjkdAVGiP4|d-U zKi|b#fPh0Hg68ArFEV!odQ$|(8(xw@PX1nSmJsrqu$G|w$8VMM6sY;HOWyM`ZTI6< zo>~)N3>u#-!y|rt{fEc(pe z66ySvBtHJ;*?Fq_@|JF_dwPC{L!AVn!;l5YefjhC^#?OeFmF$Z)#~_2iya)_Uw^J) zPLUOTeEscU*^e8)H$Yjf&*UUoRoahrAm_ivUo{|#vYV4QdN_&Kzx+nDVD{y>em&6U z436j1zixexFWlbqjtB$r&)cgVZ|d3!mi;Zja(wi1z6EmrV9Dw9uN=O}(u2%ueNblx zXaDxghV?){<*LVF{ppvh0U%NGX|MC3VG&%nr`v($DY!WU5jlCdAvq&(C(O7Rya^U! zzsL0yi?Rv9Xm-81KkW}T6*52%+rxf+e`~OioZE(+k2ff16W^)wPxt%%Hu7gL+x`C7 bG4}iaW7%${$3{`O00000NkvXXu0mjfij}=7 literal 0 HcmV?d00001 diff --git a/docs/src/index.md b/docs/src/index.md index de6bc24..4d6260b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,6 +1,77 @@ # SSMProblems -### API +### Installation +In the `julia` REPL: +```julia +]add SSMProblems +``` + +### Documentation + +The package defines a generic interface to work with State Space Problems (SSM). The main objective is to provide a consistent +interface for + +![state space model](docs/images/state_space_model.png) +Source[^Murray] +[^Murray]: + > Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing +units. + +The model is fully specified by the following densities: +- *Initialisation*: ``f_0(x)`` +- *Transition*: ``f(x)`` +- *Emission*: ``g(x)`` + +And the dynamics of the model reduce to: +```math +x_t | x_{t-1} \sim f(x_t | x_{t-1}) +y_t | x_t \sim g(y_t | x_{t}) +``` +assuming ``x_0 \sim f_0(x)``. The joint law is then fully describes: + +```math +p(x_{0:T}, y_{0:T}) = f_0{x_0} \prod_t g(y_t | x_t) f(x_t | x_{t-1}) +``` + +Model users can define their `SSM` using the following interface: +```julia + +struct Model <: AbstractParticle end + +function transition!!(rng, step, model::Model) + if step == 1 + ... # Sample from the initial density + end + ... # Sample from the transition density +end + +function emission_logdensity(step, model::Model) + ... # Return log density of the model at +end + +isdone(step, model::Model) = ... # Define the stopping criterion + +# Optionally, if the transition density is known, the model can also specify it +function transition_logdensity(step, particle, x) + ... # Scores the forward transition at `x` +end +``` + +Package users can then consume the model `logdensity` through calls to `emission_logdensity`. + +For example, a bootstrap filter targeting the filtering distribution ``p(x_t | y_{0:t})`` using `N` particles would read: +```julia +while !isdone(t, model) + ancestors = resample(rng, logweigths) + particles = particles[ancestors] + for i in 1:N + particles[i] = transition!!(rng, t, particles[i]) + logweights[i] += emission_logdensity(t, particles[i]) + end +end +``` + +### Interface ```@autodocs Modules = [SSMProblems] Order = [:type, :function] From 3648f4cf87d52028bb227f3b186bdb9bf6efd05b Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 6 Jul 2023 21:55:54 +0100 Subject: [PATCH 02/18] Fix source --- docs/src/index.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 4d6260b..9016d09 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -9,10 +9,11 @@ In the `julia` REPL: ### Documentation The package defines a generic interface to work with State Space Problems (SSM). The main objective is to provide a consistent -interface for +interface to work with SSMs and their logdensities. +Consider a markovian model from[^Murray]: ![state space model](docs/images/state_space_model.png) -Source[^Murray] + [^Murray]: > Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing units. From f7033db6f75b6e2f9aee7b4397129a145e5b8ddc Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 6 Jul 2023 22:14:18 +0100 Subject: [PATCH 03/18] Typo --- docs/src/index.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 9016d09..551c1d7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -19,16 +19,18 @@ Consider a markovian model from[^Murray]: units. The model is fully specified by the following densities: -- *Initialisation*: ``f_0(x)`` -- *Transition*: ``f(x)`` -- *Emission*: ``g(x)`` +- __Initialisation__: ``f_0(x)`` +- __Transition__: ``f(x)`` +- __Emission__: ``g(x)`` -And the dynamics of the model reduce to: +And the dynamics of the model reduces to: ```math x_t | x_{t-1} \sim f(x_t | x_{t-1}) +``` +```math y_t | x_t \sim g(y_t | x_{t}) ``` -assuming ``x_0 \sim f_0(x)``. The joint law is then fully describes: +assuming ``x_0 \sim f_0(x)``. The joint law is then fully described: ```math p(x_{0:T}, y_{0:T}) = f_0{x_0} \prod_t g(y_t | x_t) f(x_t | x_{t-1}) From 673e670057925b19a2505cd5ff18124242f2ef58 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 12 Jul 2023 22:27:19 +0100 Subject: [PATCH 04/18] Update SSM Interface --- examples/smc.jl | 53 +++++++++++++++++++++++++++------------------- src/SSMProblems.jl | 19 +++++++++++++++-- 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index 9945e16..5455207 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -7,12 +7,12 @@ using StatsFuns # Particle Filter implementation struct Particle{T} # Here we just need a tree parent::Union{Particle,Nothing} - state::T + model::T end -Particle(state::T) where {T} = Particle(nothing, state) +Particle(model::T) where {T} = Particle(nothing, model) Particle() = Particle(nothing, nothing) -Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))") +Base.show(io::IO, p::Particle) = print(io, "Particle($(p.model))") """ linearize(particle) @@ -23,7 +23,7 @@ function linearize(particle::Particle{T}) where {T} trace = T[] parent = particle.parent while !isnothing(parent) - push!(trace, parent.state) + push!(trace, parent.model) parent = parent.parent end return trace @@ -44,7 +44,7 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre t = 1 N = length(particles) logweights = zeros(length(particles)) - while !isdone(t, particles[1].state) + while !isdone(t, particles[1].model) # Resample step weights = get_weights(logweights) @@ -57,9 +57,9 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre # Mutation step for i in eachindex(particles) parent = particles[i] - mutated = transition!!(rng, t, parent.state) + mutated = transition!!(rng, t, parent.model) particles[i] = Particle(parent, mutated) - logweights[i] += emission_logdensity(t, particles[i].state) + logweights[i] += emission_logdensity(t, particles[i].model) end t += 1 @@ -81,6 +81,14 @@ Base.@kwdef struct Parameters u::Float64 = 0.7 # Observation noise stdev end +struct LinearSSM{T} <: AbstractStateSpaceModel + state::T +end + +LinearSSM() = LinearSSM(Nothing) + +dimension(::LinearSSM) = 1 + # Simulation T = 250 seed = 1 @@ -88,31 +96,32 @@ N = 1000 rng = MersenneTwister(seed) params = Parameters(; v=0.2, u=0.7) -function transition!!(rng::AbstractRNG, t::Int, state=nothing) - if isnothing(state) - return rand(rng, Normal(0, 1)) - end - return rand(rng, Normal(state, params.v)) +function transition!!(rng::AbstractRNG, t::Int, model::LinearSSM) + return LinearSSM(rand(rng, Normal(model.state, params.v))) +end + +function transition!!(rng::AbstractRNG, t::Int) + # Initial transition + return LinearSSM(rand(rng, Normal(0, 1))) end -function emission_logdensity(t, state) - return logpdf(Normal(state, params.u), observations[t]) +function emission_logdensity(t, model::LinearSSM) + return logpdf(Normal(model.state, params.u), observations[t]) end -isdone(t, state) = t > T -isdone(t, ::Nothing) = false +isdone(t, state::LinearSSM) = t > T -x, observations = zeros(T), zeros(T) -x[1] = rand(rng, Normal(0, 1)) +x, observations = Vector{LinearSSM}(undef, T), zeros(T) +x[1] = transition!!(rng, 1) for t in 1:T - observations[t] = rand(rng, Normal(x[t], params.u)) + observations[t] = rand(rng, Normal(x[t].state, params.u)) if t < T - x[t + 1] = rand(rng, Normal(x[t], params.v)) + x[t + 1] = transition!!(rng, t, x[t]) end end samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling) -traces = reverse(hcat(map(linearize, samples)...)) +traces = map(model-> model.state, reverse(hcat(map(linearize, samples)...))) scatter(traces; color=:black, opacity=0.3, label=false) -plot!(x; label="True state") +plot!(map(model -> model.state, x); label="True state") diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index fd91a4b..1d29374 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -5,7 +5,8 @@ module SSMProblems """ """ -abstract type AbstractParticle end +abstract type AbstractStateSpaceModel end +abstract type AbstractParticle end # Could be a message as well ? abstract type AbstractParticleCache end """ @@ -36,6 +37,20 @@ Determine whether we have reached the last time step of the Markov process. Retu """ function isdone end -export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle +""" + dimension(::Type{AbstractStateSpaceModel}) + +Returns the dimension of the state space for a given model type +""" +dimension(::Type{<:AbstractStateSpaceModel}) = Nothing +dimension(model::AbstractStateSpaceModel) = dimension(typeof(model)) + +export transition!!, + transition_logdensity, + emission_logdensity, + isdone, + AbstractParticle, + AbstractStateSpaceModel, + dimension end From 8732544c837d161f3c2586d2621e9e3b22add918 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 19 Jul 2023 22:41:19 +0100 Subject: [PATCH 05/18] Fix linearize bug --- examples/smc.jl | 72 ++++++++++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index 5455207..681be60 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -5,7 +5,7 @@ using Plots using StatsFuns # Particle Filter implementation -struct Particle{T} # Here we just need a tree +struct Particle{T} <: AbstractParticle parent::Union{Particle,Nothing} model::T end @@ -21,16 +21,22 @@ Return the trace of a particle, i.e. the sequence of states from the root to the """ function linearize(particle::Particle{T}) where {T} trace = T[] - parent = particle.parent + parent = particle while !isnothing(parent) push!(trace, parent.model) parent = parent.parent end - return trace + return trace[1:(end - 1)] end ParticleContainer = AbstractVector{<:Particle} +# Specialize `isdone` to the concrete `Particle` type +function isdone(t, particles::AbstractVector{<:Particle}) + return all(map(particle -> isdone(t, particle), particles)) +end +isdone(t, particle::AbstractParticle) = isdone(t, particle.model) + ess(weights) = inv(sum(abs2, weights)) get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights) @@ -44,7 +50,7 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre t = 1 N = length(particles) logweights = zeros(length(particles)) - while !isdone(t, particles[1].model) + while !isdone(t, particles) # Resample step weights = get_weights(logweights) @@ -70,9 +76,16 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre return particles[idx] end -function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5) - particles = [Particle(0.0) for _ in 1:n] - return sweep!(rng, particles, resampling, threshold) +function sample( + rng::AbstractRNG, + n::Int, + model::AbstractStateSpaceModel; + resampling=systematic_resampling, + threshold=0.5, +) + particles = fill(Particle(model), N) + samples = sweep!(rng, particles, resampling, threshold) + return samples end # Inference code @@ -85,43 +98,46 @@ struct LinearSSM{T} <: AbstractStateSpaceModel state::T end -LinearSSM() = LinearSSM(Nothing) - +LinearSSM() = LinearSSM(zero(Float64)) dimension(::LinearSSM) = 1 # Simulation T = 250 seed = 1 -N = 1000 +N = 1_000 rng = MersenneTwister(seed) params = Parameters(; v=0.2, u=0.7) -function transition!!(rng::AbstractRNG, t::Int, model::LinearSSM) - return LinearSSM(rand(rng, Normal(model.state, params.v))) +f0(t) = Normal(0, 1) +f(t, x) = Normal(x, params.v) +g(t, x) = Normal(x, params.u) + +# Generate synthtetic data +x, observations = zeros(T), zeros(T) +x[1] = rand(rng, f0(1)) +for t in 1:T + observations[t] = rand(rng, g(t, x[t])) + if t < T + x[t + 1] = rand(rng, f(t, x[t])) + end end -function transition!!(rng::AbstractRNG, t::Int) - # Initial transition - return LinearSSM(rand(rng, Normal(0, 1))) +function transition!!(rng::AbstractRNG, t::Int, model::LinearSSM) + if t == 1 + return LinearSSM(rand(rng, f0(t))) + else + return LinearSSM(rand(rng, f(t, model.state))) + end end function emission_logdensity(t, model::LinearSSM) - return logpdf(Normal(model.state, params.u), observations[t]) + return logpdf(g(t, model.state), observations[t]) end isdone(t, state::LinearSSM) = t > T -x, observations = Vector{LinearSSM}(undef, T), zeros(T) -x[1] = transition!!(rng, 1) -for t in 1:T - observations[t] = rand(rng, Normal(x[t].state, params.u)) - if t < T - x[t + 1] = transition!!(rng, t, x[t]) - end -end - -samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling) -traces = map(model-> model.state, reverse(hcat(map(linearize, samples)...))) +samples = sample(rng, N, LinearSSM()) +traces = map(model -> model.state, reverse(hcat(map(linearize, samples)...))) scatter(traces; color=:black, opacity=0.3, label=false) -plot!(map(model -> model.state, x); label="True state") +plot!(x; label="True state") From 639470495a3478fcfe5c68a22080d6452a72a472 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 20 Jul 2023 20:41:49 +0100 Subject: [PATCH 06/18] Interface --- src/SSMProblems.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index 1d29374..7ff69ed 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -6,32 +6,32 @@ module SSMProblems """ """ abstract type AbstractStateSpaceModel end -abstract type AbstractParticle end # Could be a message as well ? +abstract type AbstractParticle end abstract type AbstractParticleCache end """ - transition!!(rng, step, particle[, cache]) + transition!!(rng, step, model, particle[, cache]) Simulate the particle for the next time step from the forward dynamics. """ function transition!! end """ - transition_logdensity(step, particle, x[, cache]) + transition_logdensity(step, model, particle, x[, cache]) (Optional) Computes the log-density of the forward transition if the density is available. """ function transition_logdensity end """ - emission_logdensity(step, particle[, cache]) + emission_logdensity(step, model, particle[, cache]) Compute the log potential of current particle. This effectively "reweight" each particle. """ function emission_logdensity end """ - isdone(step, particle[, cache]) + isdone(step, model, particle[, cache]) Determine whether we have reached the last time step of the Markov process. Return `true` if yes, otherwise return `false`. """ @@ -42,7 +42,7 @@ function isdone end Returns the dimension of the state space for a given model type """ -dimension(::Type{<:AbstractStateSpaceModel}) = Nothing +dimension(::Type{<:AbstractStateSpaceModel}) = nothing dimension(model::AbstractStateSpaceModel) = dimension(typeof(model)) export transition!!, From f7bdf6ae40d8d3f288ed40c7127d5b77875fbafd Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 20 Jul 2023 21:05:43 +0100 Subject: [PATCH 07/18] Format --- src/SSMProblems.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index 7ff69ed..fb3189b 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -6,7 +6,7 @@ module SSMProblems """ """ abstract type AbstractStateSpaceModel end -abstract type AbstractParticle end +abstract type AbstractParticle end abstract type AbstractParticleCache end """ From 573cab3681481fd0e36666aa184f46e2edeb27e1 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Mon, 24 Jul 2023 21:53:14 +0100 Subject: [PATCH 08/18] Trying things --- examples/smc.jl | 25 +++++++++++++++---------- src/SSMProblems.jl | 4 ++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index 681be60..9a10675 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -5,7 +5,7 @@ using Plots using StatsFuns # Particle Filter implementation -struct Particle{T} <: AbstractParticle +mutable struct Particle{T<:AbstractStateSpaceModel} <: AbstractParticle{T} parent::Union{Particle,Nothing} model::T end @@ -14,6 +14,11 @@ Particle(model::T) where {T} = Particle(nothing, model) Particle() = Particle(nothing, nothing) Base.show(io::IO, p::Particle) = print(io, "Particle($(p.model))") +function set_parent!(particle::Particle, parent::Particle) + setproperty!(particle, :parent, parent) + return particle +end + """ linearize(particle) @@ -29,13 +34,12 @@ function linearize(particle::Particle{T}) where {T} return trace[1:(end - 1)] end -ParticleContainer = AbstractVector{<:Particle} +const ParticleContainer = AbstractVector{<:Particle} # Specialize `isdone` to the concrete `Particle` type -function isdone(t, particles::AbstractVector{<:Particle}) +function isdone(t, particles::ParticleContainer) return all(map(particle -> isdone(t, particle), particles)) end -isdone(t, particle::AbstractParticle) = isdone(t, particle.model) ess(weights) = inv(sum(abs2, weights)) get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights) @@ -63,8 +67,9 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre # Mutation step for i in eachindex(particles) parent = particles[i] - mutated = transition!!(rng, t, parent.model) - particles[i] = Particle(parent, mutated) + mutated = transition!!(rng, t, parent) + mutated = set_parent!(mutated, parent) + particles[i] = mutated logweights[i] += emission_logdensity(t, particles[i].model) end @@ -122,11 +127,11 @@ for t in 1:T end end -function transition!!(rng::AbstractRNG, t::Int, model::LinearSSM) +function transition!!(rng::AbstractRNG, t::Int, particle::Particle{<:LinearSSM}) if t == 1 - return LinearSSM(rand(rng, f0(t))) + return Particle(LinearSSM(rand(rng, f0(t)))) else - return LinearSSM(rand(rng, f(t, model.state))) + return Particle(LinearSSM(rand(rng, f(t, particle.model.state)))) end end @@ -134,7 +139,7 @@ function emission_logdensity(t, model::LinearSSM) return logpdf(g(t, model.state), observations[t]) end -isdone(t, state::LinearSSM) = t > T +isdone(t, ::Particle{LinearSSM{F}}) where {F} = t > T samples = sample(rng, N, LinearSSM()) traces = map(model -> model.state, reverse(hcat(map(linearize, samples)...))) diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index fb3189b..ee1a0b4 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -6,11 +6,11 @@ module SSMProblems """ """ abstract type AbstractStateSpaceModel end -abstract type AbstractParticle end +abstract type AbstractParticle{T<:AbstractStateSpaceModel} end abstract type AbstractParticleCache end """ - transition!!(rng, step, model, particle[, cache]) + transition!!(rng, step, particle[, cache]) Simulate the particle for the next time step from the forward dynamics. """ From 71c74cd219fb41775184427d417367a1a0be45fb Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sat, 29 Jul 2023 18:50:57 +0100 Subject: [PATCH 09/18] Fix transition!! --- examples/smc.jl | 95 +++++++++++++++++++++++++--------------------- src/SSMProblems.jl | 10 ++--- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index 9a10675..47e4351 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -5,40 +5,42 @@ using Plots using StatsFuns # Particle Filter implementation -mutable struct Particle{T<:AbstractStateSpaceModel} <: AbstractParticle{T} +struct Particle{T<:AbstractStateSpaceModel,V} <: AbstractParticle{T} parent::Union{Particle,Nothing} - model::T + state::V end -Particle(model::T) where {T} = Particle(nothing, model) -Particle() = Particle(nothing, nothing) -Base.show(io::IO, p::Particle) = print(io, "Particle($(p.model))") - -function set_parent!(particle::Particle, parent::Particle) - setproperty!(particle, :parent, parent) - return particle +function Particle(parent, state, model::T) where {T<:AbstractStateSpaceModel} + return Particle{T, particleof(model)}(parent, state) +end +function Particle(model::T) where {T<:AbstractStateSpaceModel} + N = dimension(model) + V = particleof(model) + state = N == 1 ? zero(V) : zeros(V, N) + return Particle{T, V}(nothing, state) end +Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))") """ linearize(particle) Return the trace of a particle, i.e. the sequence of states from the root to the particle. """ -function linearize(particle::Particle{T}) where {T} - trace = T[] - parent = particle - while !isnothing(parent) - push!(trace, parent.model) - parent = parent.parent +function linearize(particle::Particle{T, V}) where {T, V} + trace = V[] + current = particle + while !isnothing(current) + push!(trace, current.state) + current = current.parent end return trace[1:(end - 1)] end -const ParticleContainer = AbstractVector{<:Particle} +const ParticleContainer{T} = AbstractVector{<:Particle{T}} # Specialize `isdone` to the concrete `Particle` type -function isdone(t, particles::ParticleContainer) - return all(map(particle -> isdone(t, particle), particles)) +function isdone(t, model::AbstractStateSpaceModel, particles::ParticleContainer) + return all(map(particle -> isdone(t, model, particle), particles)) end ess(weights) = inv(sum(abs2, weights)) @@ -50,11 +52,11 @@ function systematic_resampling( return rand(rng, Distributions.Categorical(weights), n) end -function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, threshold=0.5) +function sweep!(rng::AbstractRNG, model::AbstractStateSpaceModel, particles::ParticleContainer, resampling, threshold=0.5) t = 1 N = length(particles) logweights = zeros(length(particles)) - while !isdone(t, particles) + while !isdone(t, model, particles) # Resample step weights = get_weights(logweights) @@ -67,10 +69,9 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre # Mutation step for i in eachindex(particles) parent = particles[i] - mutated = transition!!(rng, t, parent) - mutated = set_parent!(mutated, parent) + mutated = transition!!(rng, t, model, parent) particles[i] = mutated - logweights[i] += emission_logdensity(t, particles[i].model) + logweights[i] += emission_logdensity(t, model, particles[i]) end t += 1 @@ -89,33 +90,27 @@ function sample( threshold=0.5, ) particles = fill(Particle(model), N) - samples = sweep!(rng, particles, resampling, threshold) + samples = sweep!(rng, model, particles, resampling, threshold) return samples end # Inference code -Base.@kwdef struct Parameters +Base.@kwdef struct LinearSSM <: AbstractStateSpaceModel v::Float64 = 0.2 # Transition noise stdev u::Float64 = 0.7 # Observation noise stdev end -struct LinearSSM{T} <: AbstractStateSpaceModel - state::T -end - -LinearSSM() = LinearSSM(zero(Float64)) -dimension(::LinearSSM) = 1 - # Simulation T = 250 seed = 1 N = 1_000 rng = MersenneTwister(seed) -params = Parameters(; v=0.2, u=0.7) + +model = LinearSSM(0.2, 0.7) f0(t) = Normal(0, 1) -f(t, x) = Normal(x, params.v) -g(t, x) = Normal(x, params.u) +f(t, x) = Normal(x, model.v) +g(t, x) = Normal(x, model.u) # Generate synthtetic data x, observations = zeros(T), zeros(T) @@ -127,22 +122,34 @@ for t in 1:T end end -function transition!!(rng::AbstractRNG, t::Int, particle::Particle{<:LinearSSM}) +function transition!!( + rng::AbstractRNG, t::Int, model::LinearSSM, particle::AbstractParticle +) if t == 1 - return Particle(LinearSSM(rand(rng, f0(t)))) + return Particle(particle, rand(rng, f0(t)), model) else - return Particle(LinearSSM(rand(rng, f(t, particle.model.state)))) + return Particle(particle, rand(rng, f(t, particle.state)), model) end end -function emission_logdensity(t, model::LinearSSM) - return logpdf(g(t, model.state), observations[t]) +function emission_logdensity(t, model::LinearSSM, particle::AbstractParticle) + return logpdf(g(t, particle.state), observations[t]) end -isdone(t, ::Particle{LinearSSM{F}}) where {F} = t > T +# isdone +isdone(t, ::LinearSSM, ::AbstractParticle) = t > T + +# Type of latent space +# particleof(::S) :: S -> T +# f(t, x) : Int -> T -> T +particleof(::LinearSSM) = Float64 +dimension(::LinearSSM) = 1 samples = sample(rng, N, LinearSSM()) -traces = map(model -> model.state, reverse(hcat(map(linearize, samples)...))) +traces = reverse(hcat(map(linearize, samples)...)) + +#scatter(traces; color=:black, opacity=0.3, label=false) +plot(x; label="True state") +plot!(mean(traces, dims=2); label="Posterior mean") -scatter(traces; color=:black, opacity=0.3, label=false) -plot!(x; label="True state") +gui() diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index ee1a0b4..6f6b07e 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -10,7 +10,7 @@ abstract type AbstractParticle{T<:AbstractStateSpaceModel} end abstract type AbstractParticleCache end """ - transition!!(rng, step, particle[, cache]) + transition!!(rng, step, model, particle[, cache]) Simulate the particle for the next time step from the forward dynamics. """ @@ -38,12 +38,10 @@ Determine whether we have reached the last time step of the Markov process. Retu function isdone end """ - dimension(::Type{AbstractStateSpaceModel}) - -Returns the dimension of the state space for a given model type + particleof(::Type{AbstractStateSpaceModel}) """ -dimension(::Type{<:AbstractStateSpaceModel}) = nothing -dimension(model::AbstractStateSpaceModel) = dimension(typeof(model)) +particleof(::Type{AbstractStateSpaceModel}) = Nothing +particleof(model::AbstractStateSpaceModel) = particleof(typeof(model)) export transition!!, transition_logdensity, From d603f5674f017b33564df7006824c0330e082b72 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sat, 29 Jul 2023 18:51:18 +0100 Subject: [PATCH 10/18] Format --- examples/smc.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index 47e4351..b5ae5d0 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -11,13 +11,13 @@ struct Particle{T<:AbstractStateSpaceModel,V} <: AbstractParticle{T} end function Particle(parent, state, model::T) where {T<:AbstractStateSpaceModel} - return Particle{T, particleof(model)}(parent, state) + return Particle{T,particleof(model)}(parent, state) end function Particle(model::T) where {T<:AbstractStateSpaceModel} N = dimension(model) V = particleof(model) state = N == 1 ? zero(V) : zeros(V, N) - return Particle{T, V}(nothing, state) + return Particle{T,V}(nothing, state) end Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))") @@ -26,7 +26,7 @@ Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))") Return the trace of a particle, i.e. the sequence of states from the root to the particle. """ -function linearize(particle::Particle{T, V}) where {T, V} +function linearize(particle::Particle{T,V}) where {T,V} trace = V[] current = particle while !isnothing(current) @@ -52,7 +52,13 @@ function systematic_resampling( return rand(rng, Distributions.Categorical(weights), n) end -function sweep!(rng::AbstractRNG, model::AbstractStateSpaceModel, particles::ParticleContainer, resampling, threshold=0.5) +function sweep!( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + particles::ParticleContainer, + resampling, + threshold=0.5, +) t = 1 N = length(particles) logweights = zeros(length(particles)) @@ -150,6 +156,6 @@ traces = reverse(hcat(map(linearize, samples)...)) #scatter(traces; color=:black, opacity=0.3, label=false) plot(x; label="True state") -plot!(mean(traces, dims=2); label="Posterior mean") +plot!(mean(traces; dims=2); label="Posterior mean") gui() From a3af900f5ca64877651479e64486195faeadef9f Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 30 Jul 2023 21:22:29 +0100 Subject: [PATCH 11/18] Fix doc --- docs/src/index.md | 31 ++++++++++++++++++------------- examples/smc.jl | 37 +++++++++++++++++-------------------- src/SSMProblems.jl | 15 +++++++++++++-- 3 files changed, 48 insertions(+), 35 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 551c1d7..ab95e56 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,15 +8,14 @@ In the `julia` REPL: ### Documentation -The package defines a generic interface to work with State Space Problems (SSM). The main objective is to provide a consistent +`SSMProblems` defines a generic interface for State Space Problems (SSM). The main objective is to provide a consistent interface to work with SSMs and their logdensities. Consider a markovian model from[^Murray]: ![state space model](docs/images/state_space_model.png) [^Murray]: - > Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing -units. + > Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing units. The model is fully specified by the following densities: - __Initialisation__: ``f_0(x)`` @@ -30,7 +29,7 @@ x_t | x_{t-1} \sim f(x_t | x_{t-1}) ```math y_t | x_t \sim g(y_t | x_{t}) ``` -assuming ``x_0 \sim f_0(x)``. The joint law is then fully described: +assuming ``x_0 \sim f_0(x)``. The joint law follows: ```math p(x_{0:T}, y_{0:T}) = f_0{x_0} \prod_t g(y_t | x_t) f(x_t | x_{t-1}) @@ -39,20 +38,24 @@ p(x_{0:T}, y_{0:T}) = f_0{x_0} \prod_t g(y_t | x_t) f(x_t | x_{t-1}) Model users can define their `SSM` using the following interface: ```julia -struct Model <: AbstractParticle end +struct Model <: AbstractStateSpaceModel end -function transition!!(rng, step, model::Model) +# Define the structure of the latent space +particleof(::Model) = Float64 +dimension(::Model) = 2 + +function transition!!(rng::Random.AbstractRNG, step, model::Model, particle::AbstractParticl{<:AbstractStateSpaceModel}) if step == 1 ... # Sample from the initial density end ... # Sample from the transition density end -function emission_logdensity(step, model::Model) - ... # Return log density of the model at +function emission_logdensity(step, model::Model, particle::AbstractParticle) + ... # Return log density of the model at *time* `step` end -isdone(step, model::Model) = ... # Define the stopping criterion +isdone(step, model::Model, particle::AbstractParticle) = ... # Stops the state machine # Optionally, if the transition density is known, the model can also specify it function transition_logdensity(step, particle, x) @@ -62,14 +65,16 @@ end Package users can then consume the model `logdensity` through calls to `emission_logdensity`. -For example, a bootstrap filter targeting the filtering distribution ``p(x_t | y_{0:t})`` using `N` particles would read: +For example, a bootstrap filter targeting the filtering distribution ``p(x_t | y_{0:t})`` using `N` particles would roughly follow: ```julia -while !isdone(t, model) +struct Particle{T<:AbstractStateSpaceModel} <: AbstractParticle{T} end + +while !all(map(particle -> isdone(t, model, particles), particles)): ancestors = resample(rng, logweigths) particles = particles[ancestors] for i in 1:N - particles[i] = transition!!(rng, t, particles[i]) - logweights[i] += emission_logdensity(t, particles[i]) + particles[i] = transition!!(rng, t, model, particles[i]) + logweights[i] += emission_logdensity(t, model, particles[i]) end end ``` diff --git a/examples/smc.jl b/examples/smc.jl index b5ae5d0..940bd11 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -4,7 +4,7 @@ using Distributions using Plots using StatsFuns -# Particle Filter implementation +# Particle Filter struct Particle{T<:AbstractStateSpaceModel,V} <: AbstractParticle{T} parent::Union{Particle,Nothing} state::V @@ -13,13 +13,17 @@ end function Particle(parent, state, model::T) where {T<:AbstractStateSpaceModel} return Particle{T,particleof(model)}(parent, state) end + function Particle(model::T) where {T<:AbstractStateSpaceModel} N = dimension(model) V = particleof(model) state = N == 1 ? zero(V) : zeros(V, N) return Particle{T,V}(nothing, state) end -Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))") + +Base.show(io::IO, p::Particle{T,V}) where {T,V} = print(io, "Particle{$T, $V}($(p.state))") + +const ParticleContainer{T} = AbstractVector{<:Particle{T}} """ linearize(particle) @@ -33,12 +37,9 @@ function linearize(particle::Particle{T,V}) where {T,V} push!(trace, current.state) current = current.parent end - return trace[1:(end - 1)] + return trace[1:(end - 1)] # Discard the root node end -const ParticleContainer{T} = AbstractVector{<:Particle{T}} - -# Specialize `isdone` to the concrete `Particle` type function isdone(t, model::AbstractStateSpaceModel, particles::ParticleContainer) return all(map(particle -> isdone(t, model, particle), particles)) end @@ -74,9 +75,7 @@ function sweep!( # Mutation step for i in eachindex(particles) - parent = particles[i] - mutated = transition!!(rng, t, model, parent) - particles[i] = mutated + particles[i] = transition!!(rng, t, model, particles[i]) logweights[i] += emission_logdensity(t, model, particles[i]) end @@ -88,6 +87,7 @@ function sweep!( return particles[idx] end +# Turing style sample method function sample( rng::AbstractRNG, n::Int, @@ -106,6 +106,10 @@ Base.@kwdef struct LinearSSM <: AbstractStateSpaceModel u::Float64 = 0.7 # Observation noise stdev end +# Structure of the latents space +particleof(::LinearSSM) = Float64 +dimension(::LinearSSM) = 1 + # Simulation T = 250 seed = 1 @@ -128,6 +132,7 @@ for t in 1:T end end +# Model dynamics function transition!!( rng::AbstractRNG, t::Int, model::LinearSSM, particle::AbstractParticle ) @@ -142,20 +147,12 @@ function emission_logdensity(t, model::LinearSSM, particle::AbstractParticle) return logpdf(g(t, particle.state), observations[t]) end -# isdone isdone(t, ::LinearSSM, ::AbstractParticle) = t > T -# Type of latent space -# particleof(::S) :: S -> T -# f(t, x) : Int -> T -> T -particleof(::LinearSSM) = Float64 -dimension(::LinearSSM) = 1 - +# Sample latent state trajectories samples = sample(rng, N, LinearSSM()) traces = reverse(hcat(map(linearize, samples)...)) -#scatter(traces; color=:black, opacity=0.3, label=false) -plot(x; label="True state") +scatter(traces; color=:black, opacity=0.3, label=false) +plot!(x; label="True state") plot!(mean(traces; dims=2); label="Posterior mean") - -gui() diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index 6f6b07e..bcc76c7 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -17,7 +17,7 @@ Simulate the particle for the next time step from the forward dynamics. function transition!! end """ - transition_logdensity(step, model, particle, x[, cache]) + transition_logdensity(step, model, prev_particle, next_particle[, cache]) (Optional) Computes the log-density of the forward transition if the density is available. """ @@ -39,16 +39,27 @@ function isdone end """ particleof(::Type{AbstractStateSpaceModel}) + +Returns the type of the latent state. """ particleof(::Type{AbstractStateSpaceModel}) = Nothing particleof(model::AbstractStateSpaceModel) = particleof(typeof(model)) +""" + dimension(::Type{AbstractStateSpaceModel}) + +Returns the dimension of the latent state. +""" +dimension(::Type{AbstractStateSpaceModel}) = Nothing +dimension(model::AbstractStateSpaceModel) = dimension(typeof(model)) + export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle, AbstractStateSpaceModel, - dimension + dimension, + particleof end From d40c277b550ffa1882b23f9574c31a33e2af3c0f Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 30 Jul 2023 22:09:45 +0100 Subject: [PATCH 12/18] Helper --- docs/src/index.md | 17 ++++++++++++----- src/SSMProblems.jl | 16 +++++++++++++++- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index ab95e56..09450f8 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -12,7 +12,7 @@ In the `julia` REPL: interface to work with SSMs and their logdensities. Consider a markovian model from[^Murray]: -![state space model](docs/images/state_space_model.png) +![state space model](./docs/images/state_space_model.png) [^Murray]: > Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing units. @@ -29,10 +29,12 @@ x_t | x_{t-1} \sim f(x_t | x_{t-1}) ```math y_t | x_t \sim g(y_t | x_{t}) ``` -assuming ``x_0 \sim f_0(x)``. The joint law follows: +assuming ``x_0 \sim f_0(x)``. + +The joint law follows: ```math -p(x_{0:T}, y_{0:T}) = f_0{x_0} \prod_t g(y_t | x_t) f(x_t | x_{t-1}) +p(x_{0:T}, y_{0:T}) = f_0(x_0) \prod_t g(y_t | x_t) f(x_t | x_{t-1}) ``` Model users can define their `SSM` using the following interface: @@ -44,7 +46,12 @@ struct Model <: AbstractStateSpaceModel end particleof(::Model) = Float64 dimension(::Model) = 2 -function transition!!(rng::Random.AbstractRNG, step, model::Model, particle::AbstractParticl{<:AbstractStateSpaceModel}) +function transition!!( + rng::Random.AbstractRNG, + step, + model::Model, + particle::AbstractParticl{<:AbstractStateSpaceModel} +) if step == 1 ... # Sample from the initial density end @@ -58,7 +65,7 @@ end isdone(step, model::Model, particle::AbstractParticle) = ... # Stops the state machine # Optionally, if the transition density is known, the model can also specify it -function transition_logdensity(step, particle, x) +function transition_logdensity(step, prev_particle::AbstractParticle, next_particle::AbstractParticle) ... # Scores the forward transition at `x` end ``` diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index bcc76c7..84d16ec 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -4,8 +4,13 @@ A unified interface to define State Space Models interfaces in the context of Pa module SSMProblems """ + AbstractStateSpaceModel """ abstract type AbstractStateSpaceModel end + +""" + AbstractParticle{T<:AbstractStateSpaceModel} +""" abstract type AbstractParticle{T<:AbstractStateSpaceModel} end abstract type AbstractParticleCache end @@ -53,6 +58,14 @@ Returns the dimension of the latent state. dimension(::Type{AbstractStateSpaceModel}) = Nothing dimension(model::AbstractStateSpaceModel) = dimension(typeof(model)) +""" + latent_space_dimension(::Type{AbstractStateSpaceModel}) + +Returns the type of the latent space and its dimension. +""" +latent_space_dimension(T::Type{AbstractStateSpaceModel}) = particleof(T), dimension(T) +latent_space_dimension(model::AbstractStateSpaceModel) = latent_space_dimension(typeof(model)) + export transition!!, transition_logdensity, emission_logdensity, @@ -60,6 +73,7 @@ export transition!!, AbstractParticle, AbstractStateSpaceModel, dimension, - particleof + particleof, + latent_space_dimension end From 8f64bf4be70936e5352d2f83368ea6b5188fe60a Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 30 Jul 2023 22:10:06 +0100 Subject: [PATCH 13/18] Format --- src/SSMProblems.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index 84d16ec..4991560 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -64,7 +64,9 @@ dimension(model::AbstractStateSpaceModel) = dimension(typeof(model)) Returns the type of the latent space and its dimension. """ latent_space_dimension(T::Type{AbstractStateSpaceModel}) = particleof(T), dimension(T) -latent_space_dimension(model::AbstractStateSpaceModel) = latent_space_dimension(typeof(model)) +function latent_space_dimension(model::AbstractStateSpaceModel) + return latent_space_dimension(typeof(model)) +end export transition!!, transition_logdensity, From 593d30582d1b18aa97a357fdca0d7f0def00b3cd Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 7 Sep 2023 22:10:18 +0100 Subject: [PATCH 14/18] Forget about particles --- examples/smc.jl | 98 +++++++++++++++++++++------------------------- src/SSMProblems.jl | 56 +++----------------------- 2 files changed, 50 insertions(+), 104 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index 940bd11..ad0df4a 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -5,23 +5,20 @@ using Plots using StatsFuns # Particle Filter -struct Particle{T<:AbstractStateSpaceModel,V} <: AbstractParticle{T} - parent::Union{Particle,Nothing} - state::V -end +abstract type Node{T} end -function Particle(parent, state, model::T) where {T<:AbstractStateSpaceModel} - return Particle{T,particleof(model)}(parent, state) -end +struct Root{T} <: Node{T} end +Root(T) = Root{T}() +Root() = Root(Any) -function Particle(model::T) where {T<:AbstractStateSpaceModel} - N = dimension(model) - V = particleof(model) - state = N == 1 ? zero(V) : zeros(V, N) - return Particle{T,V}(nothing, state) +struct Particle{T} <: Node{T} + parent::Node{T} + state::T end -Base.show(io::IO, p::Particle{T,V}) where {T,V} = print(io, "Particle{$T, $V}($(p.state))") +Particle(state::T) where {T} = Particle(Root(T), state) + +Base.show(io::IO, p::Particle{T}) where {T} = print(io, "Particle{$T}($(p.state))") const ParticleContainer{T} = AbstractVector{<:Particle{T}} @@ -30,18 +27,14 @@ const ParticleContainer{T} = AbstractVector{<:Particle{T}} Return the trace of a particle, i.e. the sequence of states from the root to the particle. """ -function linearize(particle::Particle{T,V}) where {T,V} - trace = V[] +function linearize(particle::Particle{T}) where {T} + trace = T[] current = particle - while !isnothing(current) + while !isa(current, Root) push!(trace, current.state) current = current.parent end - return trace[1:(end - 1)] # Discard the root node -end - -function isdone(t, model::AbstractStateSpaceModel, particles::ParticleContainer) - return all(map(particle -> isdone(t, model, particle), particles)) + return trace end ess(weights) = inv(sum(abs2, weights)) @@ -57,29 +50,30 @@ function sweep!( rng::AbstractRNG, model::AbstractStateSpaceModel, particles::ParticleContainer, - resampling, + observations::AbstractArray, + resampling=systematic_resampling, threshold=0.5, ) - t = 1 N = length(particles) - logweights = zeros(length(particles)) - while !isdone(t, model, particles) + logweights = zeros(N) + for (timestep, observation) in enumerate(observations) # Resample step weights = get_weights(logweights) if ess(weights) <= threshold * N idx = resampling(rng, weights) particles = particles[idx] - logweights = zeros(length(particles)) + fill!(logweights, 0) end # Mutation step for i in eachindex(particles) - particles[i] = transition!!(rng, t, model, particles[i]) - logweights[i] += emission_logdensity(t, model, particles[i]) + latent_state = transition!!(rng, model, timestep, particles[i].state) + particles[i] = Particle(particles[i], latent_state) + logweights[i] += emission_logdensity( + model, timestep, particles[i].state, observation + ) end - - t += 1 end # Return unweighted set @@ -90,13 +84,17 @@ end # Turing style sample method function sample( rng::AbstractRNG, + model::AbstractStateSpaceModel, n::Int, - model::AbstractStateSpaceModel; + observations::AbstractVector; resampling=systematic_resampling, threshold=0.5, ) - particles = fill(Particle(model), N) - samples = sweep!(rng, model, particles, resampling, threshold) + particles = map(1:N) do i + state = transition!!(rng, model) + Particle(state) + end + samples = sweep!(rng, model, particles, observations, resampling, threshold) return samples end @@ -106,10 +104,6 @@ Base.@kwdef struct LinearSSM <: AbstractStateSpaceModel u::Float64 = 0.7 # Observation noise stdev end -# Structure of the latents space -particleof(::LinearSSM) = Float64 -dimension(::LinearSSM) = 1 - # Simulation T = 250 seed = 1 @@ -118,13 +112,13 @@ rng = MersenneTwister(seed) model = LinearSSM(0.2, 0.7) -f0(t) = Normal(0, 1) -f(t, x) = Normal(x, model.v) -g(t, x) = Normal(x, model.u) +f0() = Normal(0, 1) +f(t::Int, x::Float64) = Normal(x, model.v) +g(t::Int, y::Float64) = Normal(y, model.u) # Generate synthtetic data x, observations = zeros(T), zeros(T) -x[1] = rand(rng, f0(1)) +x[1] = rand(rng, f0()) for t in 1:T observations[t] = rand(rng, g(t, x[t])) if t < T @@ -133,24 +127,22 @@ for t in 1:T end # Model dynamics -function transition!!( - rng::AbstractRNG, t::Int, model::LinearSSM, particle::AbstractParticle -) - if t == 1 - return Particle(particle, rand(rng, f0(t)), model) - else - return Particle(particle, rand(rng, f(t, particle.state)), model) - end +function transition!!(rng::AbstractRNG, model::LinearSSM) + return rand(rng, f0()) end -function emission_logdensity(t, model::LinearSSM, particle::AbstractParticle) - return logpdf(g(t, particle.state), observations[t]) +function transition!!(rng::AbstractRNG, model::LinearSSM, timestep::Int, state::Float64) + return rand(rng, f(timestep, state)) end -isdone(t, ::LinearSSM, ::AbstractParticle) = t > T +function emission_logdensity( + model::LinearSSM, timestep::Int, state::Float64, observation::Float64 +) + return logpdf(g(timestep, state), observation) +end # Sample latent state trajectories -samples = sample(rng, N, LinearSSM()) +samples = sample(rng, LinearSSM(), N, observations) traces = reverse(hcat(map(linearize, samples)...)) scatter(traces; color=:black, opacity=0.3, label=false) diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index 4991560..c04035c 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -7,75 +7,29 @@ module SSMProblems AbstractStateSpaceModel """ abstract type AbstractStateSpaceModel end - -""" - AbstractParticle{T<:AbstractStateSpaceModel} -""" -abstract type AbstractParticle{T<:AbstractStateSpaceModel} end abstract type AbstractParticleCache end """ - transition!!(rng, step, model, particle[, cache]) + transition!!(rng, model[, timestep, state, cache]) Simulate the particle for the next time step from the forward dynamics. """ function transition!! end """ - transition_logdensity(step, model, prev_particle, next_particle[, cache]) + transition_logdensity(model, timestep, prev_state, next_state[, cache]) (Optional) Computes the log-density of the forward transition if the density is available. """ function transition_logdensity end """ - emission_logdensity(step, model, particle[, cache]) + emission_logdensity(model, timestep, state, observation[, cache]) -Compute the log potential of current particle. This effectively "reweight" each particle. +Compute the log potential of the current particle. This effectively "reweight" each particle. """ function emission_logdensity end -""" - isdone(step, model, particle[, cache]) - -Determine whether we have reached the last time step of the Markov process. Return `true` if yes, otherwise return `false`. -""" -function isdone end - -""" - particleof(::Type{AbstractStateSpaceModel}) - -Returns the type of the latent state. -""" -particleof(::Type{AbstractStateSpaceModel}) = Nothing -particleof(model::AbstractStateSpaceModel) = particleof(typeof(model)) - -""" - dimension(::Type{AbstractStateSpaceModel}) - -Returns the dimension of the latent state. -""" -dimension(::Type{AbstractStateSpaceModel}) = Nothing -dimension(model::AbstractStateSpaceModel) = dimension(typeof(model)) - -""" - latent_space_dimension(::Type{AbstractStateSpaceModel}) - -Returns the type of the latent space and its dimension. -""" -latent_space_dimension(T::Type{AbstractStateSpaceModel}) = particleof(T), dimension(T) -function latent_space_dimension(model::AbstractStateSpaceModel) - return latent_space_dimension(typeof(model)) -end - -export transition!!, - transition_logdensity, - emission_logdensity, - isdone, - AbstractParticle, - AbstractStateSpaceModel, - dimension, - particleof, - latent_space_dimension +export AbstractStateSpaceModel end From 0a21829cdb14a60ddba5c6e8c1354c4b74fa46fb Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 13 Sep 2023 21:51:02 +0100 Subject: [PATCH 15/18] Optional timestep --- examples/smc.jl | 68 ++++++++++++---------------------------------- src/SSMProblems.jl | 7 +++-- 2 files changed, 22 insertions(+), 53 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index ad0df4a..e6a2e24 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -5,38 +5,6 @@ using Plots using StatsFuns # Particle Filter -abstract type Node{T} end - -struct Root{T} <: Node{T} end -Root(T) = Root{T}() -Root() = Root(Any) - -struct Particle{T} <: Node{T} - parent::Node{T} - state::T -end - -Particle(state::T) where {T} = Particle(Root(T), state) - -Base.show(io::IO, p::Particle{T}) where {T} = print(io, "Particle{$T}($(p.state))") - -const ParticleContainer{T} = AbstractVector{<:Particle{T}} - -""" - linearize(particle) - -Return the trace of a particle, i.e. the sequence of states from the root to the particle. -""" -function linearize(particle::Particle{T}) where {T} - trace = T[] - current = particle - while !isa(current, Root) - push!(trace, current.state) - current = current.parent - end - return trace -end - ess(weights) = inv(sum(abs2, weights)) get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights) @@ -49,7 +17,7 @@ end function sweep!( rng::AbstractRNG, model::AbstractStateSpaceModel, - particles::ParticleContainer, + particles::SSMProblems.ParticleContainer, observations::AbstractArray, resampling=systematic_resampling, threshold=0.5, @@ -68,10 +36,10 @@ function sweep!( # Mutation step for i in eachindex(particles) - latent_state = transition!!(rng, model, timestep, particles[i].state) - particles[i] = Particle(particles[i], latent_state) + latent_state = transition!!(rng, model, particles[i].state, timestep) + particles[i] = SSMProblems.Particle(particles[i], latent_state) logweights[i] += emission_logdensity( - model, timestep, particles[i].state, observation + model, particles[i].state, observation, timestep ) end end @@ -92,7 +60,7 @@ function sample( ) particles = map(1:N) do i state = transition!!(rng, model) - Particle(state) + SSMProblems.Particle(state) end samples = sweep!(rng, model, particles, observations, resampling, threshold) return samples @@ -112,38 +80,36 @@ rng = MersenneTwister(seed) model = LinearSSM(0.2, 0.7) -f0() = Normal(0, 1) -f(t::Int, x::Float64) = Normal(x, model.v) -g(t::Int, y::Float64) = Normal(y, model.u) +f0(::LinearSSM) = Normal(0, 1) +f(x::Float64, model::LinearSSM) = Normal(x, model.v) +g(y::Float64, model::LinearSSM) = Normal(y, model.u) # Generate synthtetic data x, observations = zeros(T), zeros(T) -x[1] = rand(rng, f0()) +x[1] = rand(rng, f0(model)) for t in 1:T - observations[t] = rand(rng, g(t, x[t])) + observations[t] = rand(rng, g(x[t], model)) if t < T - x[t + 1] = rand(rng, f(t, x[t])) + x[t + 1] = rand(rng, f(x[t], model)) end end # Model dynamics function transition!!(rng::AbstractRNG, model::LinearSSM) - return rand(rng, f0()) + return rand(rng, f0(model)) end -function transition!!(rng::AbstractRNG, model::LinearSSM, timestep::Int, state::Float64) - return rand(rng, f(timestep, state)) +function transition!!(rng::AbstractRNG, model::LinearSSM, state::Float64, ::Int) + return rand(rng, f(state, model)) end -function emission_logdensity( - model::LinearSSM, timestep::Int, state::Float64, observation::Float64 -) - return logpdf(g(timestep, state), observation) +function emission_logdensity(model::LinearSSM, state::Float64, observation::Float64, ::Int) + return logpdf(g(state, model), observation) end # Sample latent state trajectories samples = sample(rng, LinearSSM(), N, observations) -traces = reverse(hcat(map(linearize, samples)...)) +traces = reverse(hcat(map(SSMProblems.linearize, samples)...)) scatter(traces; color=:black, opacity=0.3, label=false) plot!(x; label="True state") diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index c04035c..5854d03 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -17,19 +17,22 @@ Simulate the particle for the next time step from the forward dynamics. function transition!! end """ - transition_logdensity(model, timestep, prev_state, next_state[, cache]) + transition_logdensity(model, prev_state, current_state[, timestep, cache]) (Optional) Computes the log-density of the forward transition if the density is available. """ function transition_logdensity end """ - emission_logdensity(model, timestep, state, observation[, cache]) + emission_logdensity(model, state, observation[, timestep, cache]) Compute the log potential of the current particle. This effectively "reweight" each particle. """ function emission_logdensity end +# Include utils and adjacent code +include("utils/particles.jl") + export AbstractStateSpaceModel end From 18ab51e08d83b9771451ff45c5f43b277711aa17 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 13 Sep 2023 21:56:47 +0100 Subject: [PATCH 16/18] Add utils --- src/utils/particles.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 src/utils/particles.jl diff --git a/src/utils/particles.jl b/src/utils/particles.jl new file mode 100644 index 0000000..ccb4697 --- /dev/null +++ b/src/utils/particles.jl @@ -0,0 +1,32 @@ +# Concrete Particles as LinkedLists +abstract type Node{T} end + +struct Root{T} <: Node{T} end +Root(T) = Root{T}() +Root() = Root(Any) + +struct Particle{T} <: Node{T} + parent::Node{T} + state::T +end + +Particle(state::T) where {T} = Particle(Root(T), state) + +Base.show(io::IO, p::Particle{T}) where {T} = print(io, "Particle{$T}($(p.state))") + +const ParticleContainer{T} = AbstractVector{<:Particle{T}} + +""" + linearize(particle) + +Return the trace of a particle, i.e. the sequence of states from the root to the particle. +""" +function linearize(particle::Particle{T}) where {T} + trace = T[] + current = particle + while !isa(current, Root) + push!(trace, current.state) + current = current.parent + end + return trace +end From 5bcb4f1cab802becc278846fae4202d32cca0e71 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 13 Sep 2023 21:58:16 +0100 Subject: [PATCH 17/18] Apply suggestions from code review Co-authored-by: David Widmann --- docs/src/index.md | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 09450f8..541e875 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -24,10 +24,10 @@ The model is fully specified by the following densities: And the dynamics of the model reduces to: ```math -x_t | x_{t-1} \sim f(x_t | x_{t-1}) -``` -```math -y_t | x_t \sim g(y_t | x_{t}) +\begin{aligned} +x_t | x_{t-1} &\sim f(x_t | x_{t-1}) \\ +y_t | x_t &\sim g(y_t | x_{t}) +\end{aligned} ``` assuming ``x_0 \sim f_0(x)``. @@ -37,9 +37,8 @@ The joint law follows: p(x_{0:T}, y_{0:T}) = f_0(x_0) \prod_t g(y_t | x_t) f(x_t | x_{t-1}) ``` -Model users can define their `SSM` using the following interface: +Users can define their SSM with `SSMProblems` in the following way: ```julia - struct Model <: AbstractStateSpaceModel end # Define the structure of the latent space From 09f168703df90a968faafda93cec34dd41f996d1 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 13 Sep 2023 22:40:55 +0100 Subject: [PATCH 18/18] Utils module --- examples/smc.jl | 8 ++++---- src/SSMProblems.jl | 2 +- src/utils/particles.jl | 13 ++++++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index e6a2e24..e27d7d6 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -17,7 +17,7 @@ end function sweep!( rng::AbstractRNG, model::AbstractStateSpaceModel, - particles::SSMProblems.ParticleContainer, + particles::SSMProblems.Utils.ParticleContainer, observations::AbstractArray, resampling=systematic_resampling, threshold=0.5, @@ -37,7 +37,7 @@ function sweep!( # Mutation step for i in eachindex(particles) latent_state = transition!!(rng, model, particles[i].state, timestep) - particles[i] = SSMProblems.Particle(particles[i], latent_state) + particles[i] = SSMProblems.Utils.Particle(particles[i], latent_state) logweights[i] += emission_logdensity( model, particles[i].state, observation, timestep ) @@ -60,7 +60,7 @@ function sample( ) particles = map(1:N) do i state = transition!!(rng, model) - SSMProblems.Particle(state) + SSMProblems.Utils.Particle(state) end samples = sweep!(rng, model, particles, observations, resampling, threshold) return samples @@ -109,7 +109,7 @@ end # Sample latent state trajectories samples = sample(rng, LinearSSM(), N, observations) -traces = reverse(hcat(map(SSMProblems.linearize, samples)...)) +traces = reverse(hcat(map(SSMProblems.Utils.linearize, samples)...)) scatter(traces; color=:black, opacity=0.3, label=false) plot!(x; label="True state") diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index 5854d03..b3f1cac 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -10,7 +10,7 @@ abstract type AbstractStateSpaceModel end abstract type AbstractParticleCache end """ - transition!!(rng, model[, timestep, state, cache]) + transition!!(rng, model[, state, timestep, cache]) Simulate the particle for the next time step from the forward dynamics. """ diff --git a/src/utils/particles.jl b/src/utils/particles.jl index ccb4697..15ecb5d 100644 --- a/src/utils/particles.jl +++ b/src/utils/particles.jl @@ -1,10 +1,19 @@ -# Concrete Particles as LinkedLists +""" + Common concrete implementations of Particle types for Particle Filter kernels. +""" +module Utils + abstract type Node{T} end struct Root{T} <: Node{T} end Root(T) = Root{T}() Root() = Root(Any) +""" + Particle{T} + +Particle as immutable LinkedList. +""" struct Particle{T} <: Node{T} parent::Node{T} state::T @@ -30,3 +39,5 @@ function linearize(particle::Particle{T}) where {T} end return trace end + +end