From 8017160f7f03315887ef681144a96c1a20e50058 Mon Sep 17 00:00:00 2001 From: Daniel Reiff <31421471+reiffd7@users.noreply.github.com> Date: Thu, 31 Oct 2024 12:40:08 -0400 Subject: [PATCH 1/2] Create OCR detections stitch workflow block refactored block code and created unit tests fixed some bugs in the unit tests with tolerance and vertical top to bottom Bump version Make linters happpy Adding fixes for the block discovered a bug with reading vertically. fixed it by switching initial grouping to x dimension. adjusted unit tests appropriately Make linters happpy --- docs/workflows/blocks.md | 92 +++--- docs/workflows/kinds.md | 49 +-- inference/core/version.py | 2 +- inference/core/workflows/core_steps/loader.py | 4 + .../stitch_ocr_detections/__init__.py | 0 .../stitch_ocr_detections/v1.py | 294 ++++++++++++++++++ .../models_predictions_tests/test_owlv2.py | 29 +- .../execution/assets/image_credits.txt | 1 + .../execution/assets/multi_line_text.jpg | Bin 0 -> 27768 bytes .../integration_tests/execution/conftest.py | 5 + ...st_workflow_with_keypoint_visualization.py | 3 +- ..._workflow_with_ocr_detections_stitching.py | 96 ++++++ .../test_stitch_ocr_detections.py | 195 ++++++++++++ 13 files changed, 693 insertions(+), 77 deletions(-) create mode 100644 inference/core/workflows/core_steps/transformations/stitch_ocr_detections/__init__.py create mode 100644 inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py create mode 100644 tests/workflows/integration_tests/execution/assets/multi_line_text.jpg create mode 100644 tests/workflows/integration_tests/execution/test_workflow_with_ocr_detections_stitching.py create mode 100644 tests/workflows/unit_tests/core_steps/transformations/test_stitch_ocr_detections.py diff --git a/docs/workflows/blocks.md b/docs/workflows/blocks.md index 06cea3d4c..56e2748da 100644 --- a/docs/workflows/blocks.md +++ b/docs/workflows/blocks.md @@ -13,24 +13,6 @@ hide:
eF@yxM?xQf82-e!~ z0q@C9?bU6C%kOrFQbv=$KVTQ3yQ~r&Pcms;d!ijbbc=(%e6++&+iFETeu;BSISDhq zahPbcU{Ym$6=hNllXo{d8NYA${RV?=e-IM@AfqrXHI@SqP{{!w02p%1+&5=C1vhxO zG;tQKag?|gf;2KTFzkz$*V{T5?2icMJN>?dlgs;y4tdv7T5VmMkT!AOM^lnTq4`DK zk)UBpj>~81ia!A7U6b?d+b5Ji0D(!#fRO$pN#E^_%!?ldRI3-B#w~%oiBjd8? k0Gta`bK%R#{auQ>xQfUeHc9=Qb&mSY%Af1;TP6MsqX=Wd&iHL(NrrJ^-T% zaNiBe`gbk@ix87TqB9b}m^qz2Ts#I7A>1?VE{)}qx5v2rTX_P`eG=p2`38HVizWd9 zEMHbPdd4YNmk`48O)g;) z4Z@06FTK$UQV!@A3z80ixN<@VbD_^~5#@Zr81(1(e2F;qJff)%ZyJz8BhbQSO>l)* z)Au*EzkT?E{(_AHH;Xa55WT>~lHkkBYqc1$)0<*dS)_^i&rZ7frdn9a@oU)h+imaG zA(?*jhfi8Q7tjasXUCI8sU_XYCZ8Pk^+xfeROohve{uHVF=NGN-S!!6j`NAi1nOW3 z<8L`U48 WBis{Ql@d1En=y6Tz+UksD9PWyyUq_^#~?-DLHQe~J%WHMb1J7j68 ze%K%RT~sy=9jnMVQKi;#K*=QP7hRoWpBtK!i-lFP;&{$bvHTY8Z Z*$zG0$I>L)^?}t qn|H-<$z A$-z7;+i9sxZc zArT23BR??%6Ell|@avaCq_1QJ<+R~jTIg_A1a!bn>>fq2Foc2EQr5RV4637Ta(+QA zaYj{hv6SzbetaI*z!a$llLs!pDCbxo)-%E{`yT%FdqBQ2aD}T@$KtClOgt;*T$agz z!Bov^8LJ!m6q9I~F(<5(W S*P7zM4?KAZ_3B9?B6N=gX%<>AhC<479e)^o zw5hUGgw~f`kbkIh{6F3A4h}>8;)7Wm=tW>Q-z_*TLFa6@>`akvOC7DJqiyZYuf4e5 z-71Kl(E!=iV4BY0TGzO47|u|CJ<-j%(^7?gfa)X+pJ=G|)UK{(uPa!s+g(sOei^2b zXcBSHG!o=LB;ZOw&bk(TOX@Cqs^Q5bgB$_<*kT@+4C8ykw!q8dFjZ4=s?%y?#+ir- zulEptXR7ZEb@t{DrnHuh3vx>12ru1@I%W3~kKp6k^-E5aCtynLS2X6YJeOZ9QShHm zT6Q#&+E(j^c0?!eWhH W6jejA{N z=167G^cOuAB+tZo>uThvB&ehTeLY*9&fU*P#c>)9XO1F-Su~Qry9)S+kGc*i2mJ(I zu(m`dp6XilMaIEsK-w%b%&aih%?74+Qcf6?3W1!Qyb_FGK?{kF7nyAMo_Y7^90+D= zPm+^QppRF&5>!hL<$F?=<&+y{U2Dv9)$^q(l=4iVtWE&709k2Dbeh*wQ@={!6g2~F zv{_!?617tsCB&z &lEr8FY6pjGX}yBfFjJ1M+}d7rdeze5*HaSvhV!wm9ayOxRz zeNTxT@5vGH$ocLG=wRv@^ht_+xF>KOM+YL2g^V0I!DPys6I38!TIVUC`2A^$B4Yzc znzm*Ykm{CdeBkvhUVQM7asowInHWSj!Bp--?pqoZ|5U4+uYSb|KlK;qJxhCZ0*`h) zTnMI-TvWLf0gF;==d~z?@nL?;P^}b!hdmzk_`#pj%EmQYFg;5C;ASh#D2wEz*nU-s zq|d6kW9ub4tyU6gb(oKlzER(nkrgPR%8@ep=dt-JDSpH;?eVe~vML1yryAHT?7VGy zksL}e2Ko4@ByuL@%v0s?NK=h=;`kwaykTt1fmf!eYDr9dPq>Ms98>G0cn+>o8LD-% zwVw!p@}Dy~rkEAZIk jMgb5fN!yQ%89jw`JS#A|911)hP)$=~(GprSE zFoX1#`5lv(hjA?eR0w~ RwUWaBH6@Xz~Siih0!0GD~5sGo{z*cv|qiyBSA8Rg)2r-j#I(=(c&Xl z6RH}sO6+sS*;Rz84!G?jWt*)5CNhVjt0q4!3twE;ZpcV+K!3mG(ymAHv7=7Kh{v*K z^DYN{qaaTo?qNAItbF747effXytcSzQqiIKi=fX}6T$9``P68yr`#E5#a{&b*p#@_ zWE|!G{Y=IFs}B88sRT~;v4QhYv#N?~PlX|CvIuphN8l!r4US*P?irS|HxDwF1&5D? z$PzcRui@3U({m3t_hCr)>#xmm`{(1FmCxFiIw>Tdo-M2PEGO4BJC;V)(T5Kv^nyb| z2C*<=NN~xxNVqCL^k4VI=JnUvt!ba{Gu^ro__l 7_Z|3M+Q$YpkKQpwP#fCfP{o zNt|i5Ff_teQ!^j7^Lx$B7r|OmCx5liNBolPdb@ Kg;A6jL}?HkXtywG^A3Kii#;q<&VN-1(Qam6#cQH> zR9uRwqqGq2S98uF1RX$Jb*s13G>=ECLRE8`%&Y}`Ccef3zgA0er&p%tn-=XjUzTvX zMUuBaIXbeha66Zq-{zomB ao85L4AFt6*J`vcgsS9+>cp7F^fA9S1Lgp8FsUwoFZ(Ce#?Fd}GXcGO%d zchqg+SWU4NG@ZSNa&^ozNvnftMaI!f1*i3lBC;Y9EJQ?z+LJ4Zj?YJQ36v|PdoS!2 z?jEf9Df?OqJNv(OeG2LN6jP_V6au4!3h&zDf97B+<7sqdiQ7Ka;Q!5a7TvD|%gC!u z`&!{Lk+$iqJad1IE5AcLlAGW1S@b+p(8 5qII*xvj3mH5k!%1S6Z^sJj;cmkSqYYJiy#n)q5{S( zn;MK~$4EfCYm7ZT+!@($`lr#GpKtOv3tyGK3Ly{-Q3my^g9bG`3L{*ftP-49Rkvdl z+<<~-74wYbVvsvKt&7Kcy!-1HK@p^o&f4Jypv9hb9rw@Jt;HGPqbnl%jq-j&r{m$& z`?R>NIosaXcGpRhZ-U!ieA=Mp+P)H=sTY-^zGl{*mEs3I;+5Vqh=WVS!9`*6DFh&l zDY$jcMk7%TQ|F2dy(^T#WCH>abql OqKBt^i5MFzsdU4RVT9#C_oLY~j9>0-n+G1$ zF_Bp3RHjk4dP*+@T2G~|i^osjmlrWJrD-0zo(?IU%ezu3A$)V8!q8z=u2im-Z!NI4 zbCR`cnFqxW@Fm`n-ShZnWreH5%dvPQkR8NrnXP(6@sYQY&hGB1(O~v8$}j*EkNs7! za-TK F_TuQ)4L;MNvAr>_FSm=h zXo> UahNRdg6L(nC*w-yKMX$_>l7cw*i$2GcJK#g>GJPc^vqr83UNUcEaE~ za_kTr#hKLLdHZVNnbXjAe$Mm0r6pVVF!1}5RpU8dwU+LV&)yTV@pgg0v@B|y;N6&k ze_3L@5B_h2|5Vnh7}9T;6TP$XjGAM{!Cd4=QpYA=5}r@bV+Z{MnC@dMw>)(Ab}TKZ zwlp>=w_r1(__R7bd*|aEHoZz|9Jk0A*%J9((Xu6)DtC1GfVh22Xr*Pr( zXAC{l apYODkIozp*)XYBBO(ZNkDYU<{^~@>^deTda3rWlQc&XDnnnaz`Z+=U>RxZ4v zGJocnz&D=CFXL3LT@|el*=zFu&AOHFpw9xIiI~)@qC$2jYI{NK{vL{a(Uzgk!*Pk$ zJMe4U!g_yyf7?oo+9V2c!LDQ(efvm8;e2CZNV0)1Um`EN8#4N6-P%&C!>-Y*HSLV# ztofG9$+?+3U&>$8t<<(KzkGyjxSW@3yE$Rf`gmu@yeEe;xm%49_^_Q=DO)U}6znq- zRkEcqS|IfXZnu8jhrBLK&s%E=S2w&uSGzzO;ywQ?0pN@R?dKh|U*^35lT)R2-DCrT z^;t65aWgP!$G_t1Kabu10SL*aErQ^Ys~q56SK>Gx6FLBo-M JkR9aen|!`7 zU|MSOWKi;Xoh;S;yTwM }Mp#b>};)AQKsxg2#3yjf0440d_TnaZcet@R_tPPde#a@lPJ z`Xv{ HWjn|D^ NCOM9G~H>bcw zlfPads_{w=Cyy~~GNJDvBx(Y0a6i57KH;r33OJ28_3e#hIhJvuQRFO(FORxshAat8 z=4{21oRafOv*&jdrGINxEs#3rXVs5sPW6wc_w=ilVNOZZHE(Vl%Q+n@q|&$~jO#en zVVsiPoK&N%;B?lwv7|xk$n>#k9nvh0?FI#7P4Ht^-yRrsPnKKLw=NpYq>_~t?>_|f z%lp PEu!^a+UUZJfI-JLdwe9XGg9B< z%vHLYLwr<(mZ)jtv#H>D -Rut&Kw{1hl$_SgD~1AY3glo#MT z(m?_)$F;m4g`}~i{_n~rISU<6e|gm8m(+pk#SBRr`hK4Gj68&l(M4kY&=sBPQv0@9 z 1^xtfVzHlNJ~c?IzLV5=W55h>BOsF&7Rs@EUb%T`t2!Wei?_(PH!CS z=iyTxm{v3@>+Ve_TRjs@?RL6^bU32CD|S5-C7`H*y#T$S^I9iHrvE6L@L9un{*9W( z9(iHfvAPzAnc80eb!S#YcmMs_qIjUHD9l;LZeT%F&0Ce{d3pEe(rGk%&-o;^mEDW1 zz@H~?J#08AcqHB(GKP;H7*L5Dy RbZVdCO)0wXITlYLO ~-oo%)l^JF;JETjzd`vD$0d~un0_adVKw#K(2d-&LLzpw1#mLe)>k5&uAz7L(L zI(}3dmy6N58Yc|RG@{HM(`)`ZXG}NKENb*vA d~C;ej17g3sz2+Pt%i zXt>{xIW3aI`PXkNT&th6TSV!fKK;rZz3Y58^#?HW<*EF)Wb4-_C>ol~hKXC)Of2Wp zNUrM4k;|yIp-z2HNC2x%)ejCM!#7*pMQs q* Q_4`Z3*Sbk=NE={loMp~FQMM24bJa5a4Fm~N z|G-Z%oTj99YZEfOuD~D30m=62DA=9ip7Z>f`a)pfo9c*}jP{3w7%o(jP)=&qg)I0m zb%TbD_w45i&nIkq52Hq!z?`kJFN4A4{#@B@<}Air7uP4C1P29ROR EzP@Zn(m?`#(i2K%Lb>% eM$r(%7E5$F(qG(R$9 zfGlaB-9|#84~IyxIB7VDr(gCE1c_5Tys8PGZc{#NW7IqZ)u1TuUe%yb9B8 ;`yj`sMUB&u)q_~LW6L7liLuujV zVQJCj(5;yiiIT>_(9zfhW2wNQVW4=(#|ivi>z1XrjzeGWgHEHr)sefMX1CUfIZf6> z3;L}HNd(yZGm>d|Is7wHNCY_ivr=dTxFoHO%(v^e`>ejZojP`N@mR&7Wk~mq6Ljx4 zuf3I=mHTErSlSTn<0h`Aey}Vhl4JCPn>&}e__FA##yiD5)IY@?@|ojs>^^xWqXM+Y z-oxw7I^+I=3!W3IrK!v?57%9+!nf84MVKJKKt2j E$uvBi!ZdQ}KdvHj>Y>UmrWKH$!$%Rp6v3pGsO* dH7?!GuKC7Fj;OGei&W3- z!v)e*2ep$*b^UU=jjHq2WRF ?FZ~;?m?|w$W5)x>G*=G&|Sh!#{vr9u3<5%;k=C zN^-Y_jMWp)b|1Z^-$of;3%|-oq+hTeb2T$B`BJsURmzn1nHajX=vrTRtF`IJ^)g=7 z7_5Usd-Y2de_)jH-TQhkHd<*F$X)QukK|%LhG7SXdcideC5$fT80In3J^b%C4ipLbbI9w%$GsE_p zB=3&4D(C8^GccBqC9U3L9oyi^%oM6xI6=iQv}?-*6;hg|tNWW{VJTc^^~O!5zoP3R z2Isnl3|HVs$ATJ$DHuaC{8$2Ft^<5lwm=iQvX#@!%!&m$Y#Ky)Yl=nFcO3`$XMW1l zZ=;EGa*U)Y?ToEGj~{aObEnCSnfp~2pbrMP#n7D>-yIF8m;w{lu`3^;WW$6U_3c88 zPt3q=4N$_|M2Uw$c~*n#1$;I5J@pTO!g~Ly<<1s9I^aDY*WPAxnx(dtutFl-JAOYW zj%JGnR=V!1CLb>u&xX>5wbjJ@j#2&eu zJS{WgN{hIcU30SM%+f?%&@wz<(s(vm9#WsiVJ5pb@Z ?hny>{@nzuF@-|G&1H+ z)8TA}wB44d6t9b2VPTadD?8#{d$DJNa<|9CJCWVI-;uOH!BhB3RqRHg>X9LM1WL~U zysr~F`OTUt3eOMpz%g!p*s+AHnkyk;q;ABuPGzG|f)zK>7~bJ`w%Nrwv6>zb55=bb zxdFYFrfHJP+T0dbxbMYn!-6z++1~D-@>B+nXG->3ok7!2d{@nhk2&1!HFn_dn_o-Q z&4MoZ5 QMU+Aoe=r^BFIfdyVkI4piX6-XTQ^#a}>G+*!hEv9GO> z!`J{++RNF0n${RwklbmkF(D)~{A&>T;6C>4TGC@xG4oO1v1mVcefo_{lt*xz(^zZX zNCtai21kftv3+sBGfP0t%{*4zS|9=ZRN%>*hlG&axoHQZZYQ&D+0A0(QsW%wEF@#y z?`l@sy+>yizgG9`pN%}t;b}8XzYYvDPdvZR`&k7kG(|?N%2hf{qd%o*E9Ck3X8AfH z&6{GwthfO+BCx_0O?xgjUv0QgqDu>HO00z4(PidGDXWL_G27%`Vbu<(i*+!YD0Nvf ze?hZ!-HIcR2dT`JQ{4)BOG3IMkgd2e{obk|5q!nN*|Os^IYHtDBgSuXh@4 z>ofE4Q^2r`FSFaSL+-fojs0MQo>XSt>3p1Zm9G?dxsqn`?f)R6s|u)Wmp0-%8Ff3m zZsxma!`O|JTJOIIc%DBJdOOx-U_}SWhKvdL>4YrsS7SM`noEpbhmck5iGP>Q?~Ck8 z)DgAQxy&5Iw|!V*)wbmd%P(VNIZmkDP;9PzFygUOixd=kAzW?d?>=o+{ V>q>#oBFv~t^!zjF@2nEoqOkt6eO_9nnZ1>S$|~3 z@tudU#v#WozSq=fre#{Z;*!tG)o@!MWqfcn0k{K`NCvfWtffFOZ&59puPpVA# hJRUm>x(m4@4C~w z^ar5b_tvSkPV~k}IF- z+U7u~_((&v!O vCI zl2Y$Fh}Rg}aHPq{Mo{q`$5!Oi>(iSJSHFb$r7C8hU1{>4P?{T(lD9$I`%PAR35y$U zV3dYui|5*{_pa1#;qRlKnxz*~YH -jubF;NH zSxX6LBmP#XfE-*0oDxf!d<0AeiisOG?3hX_YI$<>w{-Gub_7KyorY%1!ewpTy?(kP zoRu-oJ9){nJ8LmJMtb`S>P### 4o{p?vt*LF~+|51Z?yND`hxOp3YU4r#^JF1m?Z}-bV67koMPV~|mml)yScil8k z? pWmmJZ1~vEfwk X z=Ja*FYZV-k{um|onplrtHjya4D1=`5SbTaRwvi+J8n1#Lq*M{7OH&~Qf Nh@N4ItubMNW>=I$f#I@>Am2kgW>JugPVz>i#&GIi&L)<{FLRD_UYH zdh+iIiHUBdZ?9(k2c9UU5BvxA>INt?=f7qK>xxBY_iI{|^66w8ja}{gxJilS5E6?; zrpQf~`3RvX<{!>WQG_)tNM(LG)Kw}Km?`5RHu%@nB2KK}47%|t%?JJ0 K4#uvm#rdV!Uqpu~%9EPQD<8pM+!eMK^b_zg;xjQZuPs3f z1=)d7Mi0p%jH1pfLhl~iA7;}^9qfX7S1o>BRoh0N4mdUwMy@?pRSJ=vJ2<>uljUAF zxKcTrV#7p_Vd7A}N?>CWbks2!aZgezB&Ze)S0GR>C5BNrfM*ombI1M0i{0aJaMhM8 zq>h}MvD9X2X73i}a9z4e;IOSpWUJXVCkaz!EhYdZ0I7#wUBw?U&V(Q4Ps mn9^_rca~dl-7 =1$Z{4tP@|#r)g8JO;apaV^c0Lm!83XItWzPB@LSkex#{N zV%6^AubMQ8?J%>@r3*8}At)!<_f j%kUg`#aV@HZ3z?ur4Ngbfq^LbL@rH82R%_ayc3k<*`6w zgZzcm1LFu#*>p!3xuY}FFhoPmd_=jH-O)@;jK9q{g^($IYWmA9`R@x^>)l{sUY66f zq}_0Vvg5U^10*-sX7%tEi3`h)gx&E@616Q^yqLc6J|<6EA!OA|YC#{{H}d7xVaV>> zVaVdHED563ktYqj`kuz#EGjrS@R#TzBeG;2mua!e@a;L+U>1wKaTYl*llPCG^~QPM zg6W~MBm<*q@Z&t{cmM?WF&<1zR5Ub1RHVPBc>st=_{e|H@@NAS66>Z3=@?#v%w6;A zyHQ@g()pOg$S?b*pl6FnK+XdCX~w02Nzg5*_a{6$|Lc$sh791EnaE$IE${(fX`WV3 zW`?Hg>i;R@tK;I>xxXpy#oZkiD6$lHDeh3%;=Z^Qmtw`COVP!p=;H2BT#GxD;-xsH zSfTIId++o6 bng9fUD4Yv0ztyZ!jhrShHPGOHxrwR8dQ!8-?;d+2uqv*xj=IqY-px(VIoG zRpD`G=$fTfU?lNGHRv7m=OWz>bv~M%p5E&f{L-ua#rxu8i3a)~GE|zO=zd2A;_d6r zU{&V`N>8?oe3=PN*!f|EO{^|}QzXQYehza4H&0*aEn;nNDRB&r6ihSFs0XHLk75y~ z2_T{fF)*9M>_g4dCzB|>N2>fUhAtpeBV7)|v#-DwqLJFULNSljGccdmC-VeggNS3e zFb7B%ym?1qT6P=*w0^Zi8!*Sy4OrZacLk$sgGct^@H(qs9S+>Z+q&qR!{>{?++=r3 z!=2RIge!~%4n~-+twAxh1=JH(xS~P29$ylJCykYllDSt9zoj5Nc-O)vwWldJ*jbhD z o77+z3?W98gpZF!9QMM)?(A8G8ZJli@bxGLjv4Yh z;R!`e^ QM*C^}svsepgUG1k{-l_Qs}ze)l{pXHg{b`e9rists# zAze}h4|u>_>V{0F*qr_QYPe%@7da6$gqbc8c7hV_E?(R@_5iolie?@-yXQ_^-t~G2 zuPT4%byAP7#?Qd;rPC_HHFGb6K~Mf55ZxH}xq5eyvmP!qmXbT$#HY#kMD)7S)p4FW zUt5B=ki`|`PfQyFqsfj@D+nz;8ZgUszey6G4ux|r+^!8!E|VQOpk=6);2czlB|O## z2GkC4PX9(4p&s1t5Jf!raP0n;r4hs$-7rWQI4Z)!vcMLUYLx6ue=g{ey#@Fr+vq7Q zrnY61X!vyT?UWjo!<~H{DnzxU0kW|HkcZv)JeuZr^QYmKx-C_ZTbbdnqVKDqv%D_X zVjKA*PLOk&g^dphBp~P+A6gAg!G0%Prt7i08&{mjnN^p7h%)`m3bU)*1b*h!SIPLb zeWxdpihFX7lV)7zzFpjCg#N_kpd^Gl4Bp&`wXpe!ztwP_{o&4?5TL>Yo3XELllqdX zbRmnKT1K7UTxdZ%B!SnKb_TnU=!|C;Lw%e@xzDmgcuvezp-a#OFzmwiaqINJRc{FP zx3+xu2ch^#0+4fdgq7ddxQL5^0Od7t8}WY2y)Ozj=AbU0yxqa!=ClS#hql;t;)4;y z2<+B6E11W_7OVi0tDQ?LfzI+1Xc?w`AE=oUfx+IuM8>-^4IrBo)34OU3IEF)?jrE# z+M2}s1e{B$j~_5g*QS#j!h-{;YQav|yRA8| z|6$nOkiFxFWld8qlHB2lEeq-H&V#nUJh!#zYvo9L2qetPwIYiG5WrXAl)ESB2Nh;d zxayrd{xqIbPjC3RezMwdcaXD5kC `|XpjdWInn7cHo@5zM^;@Q+ z7Z}zj`iy9gl)-X4qK@>dm|*hZGvzB3N^+((N&|}f6BT=}l;|T5&lmuZX%GbRF^@0K zyQY_w)>^1flZ$moQ_Sqoyy|@~tL`@H0`+E?rDeTaj_h!3O+Gf};QI8kGoQB&4=IhL z&w~{<0dJUkl%B|XYke}7q@OPzj&OlgOS;(ad6clt^2&xnq(%yvD2HNV;^h5ACa7Qn z4l6xa8d~sENzu#=9$>^z5lS7=w9^}ItV4-T?G8Cs*={O4S7e$eik$Ut8{$XTUwG|0 zrzx&g&Rhc=J4csUy>=Zwwt{v}^hZRS0an?oDOKvnql-`M;Qw>>aQZ(74hm@nndUj` zryWqHd|E->1$)66#!n2pe;D>pjBGfg1kNaf=lF+_-=0zgFY?zN#*fY&>!0YC*f+a3 zX2XfX^7YaA8_HHtH $P?uS=`^HW z?=&=BgH@mCzxDLr-*r6#ja94`wk z_nFDz&TabReCD>L?pjZFi iL62pdeVxTrZA=A6lc?TjO^;5>Z-S6^hAEKvZFs=o^0eNk8<^LU}t z7_#;dr{Ix)5f@XK4BNAQ*m0<$H4c|UYzUYbwFr6Y5hcX|pt4W(tnmk6It~{qHG0-R zMP++ S|ArwGg=VtVD=miklFKSTKlebTNM)^8+C}0v)^Iuz>CAIxF?NH zI?zy&btF?l0QnMXzmcY0>AD9Tn#(p_-a$!sQ0T@1F3$TGJL*T185zBhBI*H8W(N&5 z-6nqG5V`v-z|&Df1G6|`;hmIxrYo$$Jk+?P40KEdqxIv-!>_w{v|b%4@j_)}RnGTi zrf|uXg1u<8oW5^dBJBw2p$t_f5a_IO%L7KB%=l!R+rr1+PkMQyw)O{suFc9H^MKC> zTKhO87+#e@8nfrN>QOYb*T#O1%R1kYl1m_c@hWv|`IW(^? Sgt6{da$i} z+)Rn*?Y!l{O@5)O|K2q2UFPEJ8iH(n!$K>r l4R@lzvO&uSRwVz3@%fyf%|{J9 zL#DIL)< 8;Z82mEl5N)*nw>?3)##? It!$>IhvlJLGZ)`t&l{{@^(o4m54I8Ay3#s5fnR5g^9!Js&zt zM!Pz8A$w?$^)(&56Oa8e>M%I}b33;+=^F(Jb@H%V>s_aEk8I`wo4TU8i+GG1Y&I=O zjX+^pK^na4NV-}l_LDf~5PC}vb7# m}|_MGrNazIVbM$KH>dZ1YJqLe|u;V1BSc zM?=#s0Pnrb_19$M5vDe_=|*f8#qD|0uNh`s4dZBfMs^7pbe0Vf$;Jq#Xg!a0!cNl> z2yd*yJ6=nutEk7UPXNRodQ&Cw4C{RUX?066ZleXpEiBq29UtZIPtW=P>8%OfRLaHV z3 MxzhNuoBvf*KUIbW@5kTc0!Y%@O@ei=Bp9}( z_S!Ls#Wu?*SrIRp62PUn3_o9X6JKedksUi3dELR)rrjdmJpU|wTVPL(L1+cEvMumg zndqApH+6c*Nt?Zp`}7 _Yv9;ycaWT+%l^Ec%EAFJYKnUZo=2 Xk`_}|Aq{oWNs)po$=@x==5Fq6ejgI!ApnOh2j zg}4?wjb4h4nARc6EVo=e8~-KHija0XV(rL!CoV3XYg3~~TaASjYKRmX52 tXy1BL=pt^FhQK|EL z$VI>P(XzS1xNuL@ukp(zbtecu*}YNe@xK(^ue_`C @F6Am&Am zY3^G6=(nQ{Dpk4{4#Z;60S%T~oe|hmg#Rig<8DCp!J;^*R!7pHEcY+hF3kX5b^v6( z$T7>^3P0Uw!$g%{xs!XPjzfxlrcs5_%I-hXKR!txd8(Jy9|XAgx@maD-Z;_1$)>*~ z*ArRpF0$ixX&BfU&sf>Rw!Mfc^k}9vaEX}kBAt+75d-27KHT`(%9Q7C($(h#an}r* z7Oi>%7XE)=$P4ZLi;=hmRhv3Xfsf>{^a->7AB8qGs!wQvL#y71ZMllbaS`*SLj 9Xrr(0UuPhfYi91b#CO((|K{%n9!O>ks$|*F@HybQ7d$|9nk+u%&8fFqSQa z%O6J3NXj$X^UZIK4r;Xp!Zu>X2p&Og_{1C9$#AUTZ1VA=R_<1%b4FZ!BoJ>}Hc @Zgp#saufdLx_ l zV)QSv!Ryx(mSE%n5>_C!ppR)mS7%4y>B5jPRzz2-`;ckZJpxz=NKfy`JGf^eR*w~p z+Hj $SsuLC5qchyw-kZ0V-*0xFWUq? zx9jz&+E >QV1HF^>>dCr|4fYbWn{8IT}NxwM1i{a|~&@4(VV+1;AD;DDpSk 5e-&SUr!#pfllfWi z8U_2<|5RQLReCbp;TBhZkbF7d^EW^60{51K(iBJS2*1O*I~gsKe>D0xt3;Q;fRndt z-y %{wp=2jbOPlYgsGYMm z`qlD$o0Ly*!t=c&2kY=Ahb)B=I1MrSye*~EhYKinqj_gLyf=Ijq)sPN=D~5m@1q u5CCgbtg^)}>%Yz(ZQN}BtZN^If@Iu~eEKu1}%Im? GkgaAsCxb< m*{CsK*P;!}hDz#io6=KL~@n(_7~vdop!ckjCGH%npZVE2bon#9H3U z bfy#IQ*;eGi$2dW~3nC4UIo-Cp=$TakA zK_e^peVotDju{c1m_wB}vI?JpmtZqPoG?z04S5P&BxOrsyq@Ju# NSAh=A+<`jca@K>K{x z7tV0_UuuxDrQ@cIMicn1b0g582vfXa9eEr>yf<}1dQ=2x^@`OgGi|nyg4~IuQ3Z9x zNjsXIb#g{3KQyI_G4V1esi_v<@ilxp5OOeT<|EQl_--g3N@Ay3FB1(V;+)F2i68{B z8Nw&vF#j?m{$(No*~s&O0^RjBQ`Ti$lo7Ob0T2yG4^gb8BOKo`t*Fv1yXpl#OmE#1 z*N?+!7|iTp0)^=eb%GNDRX!P2bRT#S+>3c(p3HfXk>*S^&I8J0gi!cX yWu`och|PUV}LSNLo`Ffl(ZvwTlGA2<5L`%n{9iX^vU)Dfn|Ib !ikK+6z+Q`Q%y3Mqe) RZm=V`P*R|nD5p+8Eq}N}t#$hb!g!+6=AgD>+Il|}_ z78WM5%ZdlvCW6&ZG(<2z9B}BA>kjFbO_r^d2tB)Jb W69gQ?e6f^VH2w<1-ZZ0 zTS4lD&FpI{A4AXQm;NtzJG@x_3cLjFoY$^sBXz6#@hfMYZmreiQ#gbFin>+x;T3ys zEbr*+Lt$k&X J%mQM zGK4fiL|USan7oxtEpbwB^LrX^K#OtA9Yd&k{32(>0N!owW+m;qZek`s 5sb;W9N1wm@Csts3KC1v>HE>$ zWYGPR;ISCCbBzkNP*Kz!jb1{KWuC?i>dxuI)VsiP55Eon-pR{3yq+~~x54Wy<{?-P z7k3u=R&_sp*a=-ImDZ7L41Cq_j039@ESI}3)!1s^HNok-n?O&|iMIKX`iVLPQJ9Pm zTCF0GH6Bxa=g~20{Yl_IC}Cw~h+zwtfk)ll)Y%HMo{kuIcX#&TfAd7EvkwS;R~OM) zhbveo8}4b!Y@(yc%b`)xh0+Fw(@8j2zQ=C6FkY1lRr3{IXhb*6bkoIG7A_w}e){AMm0f`)%9) znUq`HZ;A$o)&p-{t!SUA_QM3+=(&-M k=UZkME7Gvq;cCldu*=|$=JZn7Q$s!R4ZecF1A z&X(XCh?@yR1&%x@r#Mj5U#>zYh{;}B!p2j8+Q1%er5vJGb(?dtXXla^I#gECN Y{@yiqJWPySrjN?a@HiqE82Lx=e=F0!1D zrgYe^kp3{?3zydNZn-G1@s&+Z>Cx;wHPo-LUL#*f=yIb|*J_%a*q6xiAfJBNgKCK1 zv64;8nS1hJsVQ1Lv;Sn4&FL+319NOxHxj;=up2M$Nc!kJML0A-o6*$HA>!-QCv9~s zQhu^xOQHd;Rpo}%t3wF68;3hvwKM^Q9alDeFNprdShv9MOcW#DiW&AyBKRvUE$eTx z*p)HUT0t 5Bs_+gA>W~pJWhu |IgJW?!CQ)XXJ<7Q!D zk^f*bKPtLdan5}7;TI`{Pxk9g Czl^Bkwtw>@1es)%C zSmMC1b;&$ld9u4B3+|&dlq^&CDadJ`ih{hX(z-h}sVAPMIRBiIzU$Uo&jKoB02RU& zUTiI{npV!2Q#$Hb$bJ}CX8#wGgm5}-Y9zH3j$t@*o(Rn(!der1B`(3r )q3IPo=mGo-0fSZ(j ztL)orgMqItG3T%(aEH)a##|Pt^GL-x({Q)+gR(}L7-nUuQZl9Itn=?{CLpxi`<&uq z1ku)tPZ}e?QwiCe(>JfZla^$A_RoAK9QP?LQKdV1D_)F-JIz5XW9*}XA3MJy+5NH* zzosphcbV-@UJhKnjrRHQI`V-LTSBzrzRZ_Wu+7wXdJ866Mg4W)4}u15b};se>ElaD zPKm+ZJNqEh!0EHK`T5v}nkzOKuZm1}EiG?sr}G^kDEoKuXm|1N5g$lU7<; $iM&85%n5hwwg7Qug;IRGHH?DF~DC0;qonxNo$d zc2e+u&ZS|q?<-@n`I4V4@!qqIyzaneEI*U`z2jd}B|n3aPr)i~nY@nQ>M4O |+BhXIr7d->Q$E(Y>EOQQmVn7l0y! *OtLrIS__y; zR7qErOuu`GV|n%pzJ)P{$->EKdw0!va=eT64+8G!Uw^#o#nRE() =m|) zf3o!fWBQhmE6=1DrveT)%dA?Pwvt<|a(FwiV78mkqJ$ZPou%a_T?1&Yu#6$X#k9V% zqqF?g+O3(C#1XhlrUYE1B{HLC_ewr%H_e^uF(YCQ3VoNymKFc$PGz{&eRnMFMSH?` zbzUQoOr^$_f};eJjqH{@SiUc;n?Rr_4GYUYC+gT^n!4PNv0+TkE3$xpBWlu*QM7>H zk5O=V;^Lp+Anue=_#p23DfO6On!4DJu~9?Ls~pPMFn0j|QS)i5jkzBgbABj@N}YK$ zh3eCks%oox4xlQ$V77ieBHd({GU4Sq+1eP9YMoZCr-C5gm>b`C{j=2ZCUfidn9que zwU&f)+Np(3g~*;VCA1M_BvCyLa@r;+cI{azb*FEaR+|o*gvC{A1xr`gNzI$wT8gfC z#Z~GF=lZX>@7Gj1Rj!17v>j|3lXBA&zvk1TIOY1D=zcZk{^n`?7RP`31IN?wjf8}R z@^|=tGU00>6A%*7qu}wH)A8_0Yg@SGGXULddl{KR_+=o>0+y(uN&k-FaQnX`0;)s1 zn-rb2_S&2ufo!n16hl@dPG5`*@(KExJg&I38e2R*rqT-XRQp^TtLxYGY4%}uXLC7* z7%IPLb|oa9qN1$zz9Mg2^b;yU6<72UXDa-ybbL96NbOzhgJ>NNdu3y!DR`m?v#5tc z)Vb=~j^?Y<)F^WLa&bf{%zOLlrgFF^VB(~=H*e~s_iNt7Ni%BM?e(f&VB^1@?v|jy z1X7F}p_)HDd)XxVhY+1Zs$JkBpyw2m{zj8fhFtsy^1K-ohNKs2U!$gvQQo^(6R%Gy z5HVp+0poT+pEsw72)8G}C^b;1K$*wFfMYF{u$2+}LN-y;#~IU3D9^U_AqA&iia~G3 zsG1G&zDGaR$dxMDfWD?$#h7?kPQ4sUc1^m#ioD-7^<8N=8lDhaHGKkO3(H`Gkg95_ z`rBUrn2xXy?i~$stQevdP$+oOQ+-4;@-sf4Tv^3%tK+ucOz-eR>P)1dbeeyx{tSYl z(W@wAqZX2w#JU)(3z7-a!WLgidJvROI_i92lJAP+u0ix??txy8D8&dVf~}q}8er6p zh$=)&RJM{zM>)EuH1I=~7F)71vtxCpOa48Z>Hy9Ki29hXGO^R#zk=vt7I|Y@8{cXT zhxMeAK}%d_slcl_5bcCGKdwPK5vSO?*z_958>Dg2f2SORARf{WDThK)(CEB@r5&$y z#4pxn%cCSvp4SmVUImTPfkJYBrI | z9$V$C?-TneIvupYjau^nc|U2HMGoN^Ly tyZO3J_4)W^^%cYbzOg>vGnYlNiopao$Cw z2Fi>{ cj{!k02j=WFu zui=D1U$yM%gW{f92|jf;+8N2wuqDyLBzmzh1`0jMz*GoRw)$~eJ _`VWEnuf-=a `C6e3{Up{{#B2q| z6s>x~=iB<9v@QA?*O7s4k^u1UBh!jA7~bj1W>YocPZW6u#$%q+9?1(dH$nob4_IQO zgw73^$d5?XVuYfM(+XSSgddggqA@A6_C#y~yg^+itTXywo}|JcmHLWt34<-Tgmmk! z1*@uIj>}WWFvl!r{~344PjTvqOWOlUS6J{V?%h_9{_O2ixAqdx=^JPOm)ozjUD^89 z)Rt)%L|Shp*Y<*hF-XpeSmK$tF*)SNjkIK#-$lj}O$@7df$QtcaJj33@V)nVYTh7e z9tH%~vlMe=DMDm>pYQFq{q=8}?CScX$bL}+N)TBGW9 E>e0M16fr#T0AvUy9_0RJRPX5K;-h4n}{u7Jm z4kZJgqbmeURb7RbFbUtpuDYe9NzOxiL!^wcm%j$yejh &E7KXmD&=vq{XDOkQ(nPW_4`i$KpH@?vx9VUqutoz2p29=;cj1v{w zLU>e;BlrLIjcY*^o6sUs$cdXXYnX5zuM*oozXdq5KIm@Xc8-L#?tjQvUSE6k9w~18 zOTgoOfq_}G?Mt4{fXmsq57Nj)rpKU4ZzYv-c>kgoVj#l5t9i`7)*p6>rVi {Q1b|7V*B6jhrSQOi;`oLy&^v(vTcz! zM{6@V{CbWWEE!=%(C}HpEWRUFGlHH$sKVbE8l{vqJ_~Ybo?{*0Vvr(|W`b4#MHiwh z#|giofecChAXpbE>%CFRZTE#G*0se3dx{5RZnfm%SuKtH(AY(ICtZp{+8`3g9W;hD zaa7Z>@n3-UH#Ty%!zzHY8YQI|T0G03Yiq)o6mixME{G<)S~*{5cHr{wi@ +MwWj@Mw1y-B^zV@?)q=}x z@$zX*ZYpOCagIP@Vh~Ll@MexXM2Xvt@#WMRv7k}GyD*4M>IeK2{QQqD`J3-Vg?KHM zzm@Z~yFr4^VmZ{KSW~pmRw6t(IpuR0-upOVc6wX~40BHA$O#p__f^Dil@zZ$0LF{* zh;3|DF@S>To@Y~(d%ANIx)a_e0 np.ndarray: return cv2.imread(os.path.join(ASSETS_DIR, "multi-fruit.jpg")) +@pytest.fixture(scope="function") +def multi_line_text_image() -> np.ndarray: + return cv2.imread(os.path.join(ASSETS_DIR, "multi_line_text.jpg")) + + @pytest.fixture(scope="function") def stitch_left_image() -> np.ndarray: return cv2.imread(os.path.join(ASSETS_DIR, "stitch", "v_left.jpeg")) diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_keypoint_visualization.py b/tests/workflows/integration_tests/execution/test_workflow_with_keypoint_visualization.py index c3ce3ff94..268bdc124 100644 --- a/tests/workflows/integration_tests/execution/test_workflow_with_keypoint_visualization.py +++ b/tests/workflows/integration_tests/execution/test_workflow_with_keypoint_visualization.py @@ -1,12 +1,11 @@ -import numpy as np import cv2 +import numpy as np from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS from inference.core.managers.base import ModelManager from inference.core.workflows.core_steps.common.entities import StepExecutionMode from inference.core.workflows.execution_engine.core import ExecutionEngine - WORKFLOW_KEYPOINT_VISUALIZATION = { "version": "1.1", "inputs": [ diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_ocr_detections_stitching.py b/tests/workflows/integration_tests/execution/test_workflow_with_ocr_detections_stitching.py new file mode 100644 index 000000000..b370602b3 --- /dev/null +++ b/tests/workflows/integration_tests/execution/test_workflow_with_ocr_detections_stitching.py @@ -0,0 +1,96 @@ +import numpy as np + +from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS +from inference.core.managers.base import ModelManager +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.execution_engine.core import ExecutionEngine +from tests.workflows.integration_tests.execution.workflows_gallery_collector.decorators import ( + add_to_workflows_gallery, +) + +WORKFLOW_STITCHING_OCR_DETECTIONS = { + "version": "1.0", + "inputs": [ + {"type": "WorkflowImage", "name": "image"}, + { + "type": "WorkflowParameter", + "name": "model_id", + "default_value": "ocr-oy9a7/1", + }, + {"type": "WorkflowParameter", "name": "tolerance", "default_value": 10}, + {"type": "WorkflowParameter", "name": "confidence", "default_value": 0.4}, + ], + "steps": [ + { + "type": "roboflow_core/roboflow_object_detection_model@v1", + "name": "ocr_detection", + "image": "$inputs.image", + "model_id": "$inputs.model_id", + "confidence": "$inputs.confidence", + }, + { + "type": "roboflow_core/stitch_ocr_detections@v1", + "name": "detections_stitch", + "predictions": "$steps.ocr_detection.predictions", + "reading_direction": "left_to_right", + "tolerance": "$inputs.tolerance", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "ocr_text", + "selector": "$steps.detections_stitch.ocr_text", + }, + ], +} + + +@add_to_workflows_gallery( + category="Workflows for OCR", + use_case_title="Workflow with model detecting individual characters and text stitching", + use_case_description=""" +This workflow extracts and organizes text from an image using OCR. It begins by analyzing the image with detection +model to detect individual characters or words and their positions. + +Then, it groups nearby text into lines based on a specified `tolerance` for spacing and arranges them in +reading order (`left-to-right`). + +The final output is a JSON field containing the structured text in readable, logical order, accurately reflecting +the layout of the original image. + """, + workflow_definition=WORKFLOW_STITCHING_OCR_DETECTIONS, + workflow_name_in_app="ocr-detections-stitch", +) +def test_detection_plus_classification_workflow_when_minimal_valid_input_provided( + model_manager: ModelManager, + multi_line_text_image: np.ndarray, + roboflow_api_key: str, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": roboflow_api_key, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_STITCHING_OCR_DETECTIONS, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image": multi_line_text_image, + "tolerance": 20, + "confidence": 0.6, + } + ) + + assert isinstance(result, list), "Expected list to be delivered" + assert len(result) == 1, "Expected 1 element in the output for one input image" + assert set(result[0].keys()) == { + "ocr_text", + }, "Expected all declared outputs to be delivered" + assert result[0]["ocr_text"] == "MAKE\nTHISDAY\nGREAT" diff --git a/tests/workflows/unit_tests/core_steps/transformations/test_stitch_ocr_detections.py b/tests/workflows/unit_tests/core_steps/transformations/test_stitch_ocr_detections.py new file mode 100644 index 000000000..94ff58ab8 --- /dev/null +++ b/tests/workflows/unit_tests/core_steps/transformations/test_stitch_ocr_detections.py @@ -0,0 +1,195 @@ +import numpy as np +import pytest +import supervision as sv +from pydantic import ValidationError + +from inference.core.workflows.core_steps.transformations.stitch_ocr_detections.v1 import ( + BlockManifest, + stitch_ocr_detections, +) + + +def test_stitch_ocr_detections_when_valid_manifest_is_given() -> None: + # given + data = { + "type": "roboflow_core/stitch_ocr_detections@v1", + "name": "some", + "predictions": "$steps.detection.predictions", + "reading_direction": "left_to_right", + "tolerance": "$inputs.tolerance", + } + + # when + result = BlockManifest.model_validate(data) + + # then + assert result == BlockManifest( + type="roboflow_core/stitch_ocr_detections@v1", + name="some", + predictions="$steps.detection.predictions", + reading_direction="left_to_right", + tolerance="$inputs.tolerance", + ) + + +def test_stitch_ocr_detections_when_invalid_tolerance_is_given() -> None: + # given + data = { + "type": "roboflow_core/stitch_ocr_detections@v1", + "name": "some", + "predictions": "$steps.detection.predictions", + "reading_direction": "left_to_right", + "tolerance": 0, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(data) + + +def create_test_detections(xyxy: np.ndarray, class_names: list) -> sv.Detections: + """Helper function to create test detection objects.""" + return sv.Detections( + xyxy=np.array(xyxy), data={"class_name": np.array(class_names)} + ) + + +def test_empty_detections(): + """Test handling of empty detections.""" + detections = create_test_detections(xyxy=np.array([]).reshape(0, 4), class_names=[]) + result = stitch_ocr_detections(detections) + assert result == {"ocr_text": ""} + + +def test_left_to_right_single_line(): + """Test basic left-to-right reading of a single line.""" + detections = create_test_detections( + xyxy=np.array( + [ + [10, 0, 20, 10], # "H" + [30, 0, 40, 10], # "E" + [50, 0, 60, 10], # "L" + [70, 0, 80, 10], # "L" + [90, 0, 100, 10], # "O" + ] + ), + class_names=["H", "E", "L", "L", "O"], + ) + result = stitch_ocr_detections(detections, reading_direction="left_to_right") + assert result == {"ocr_text": "HELLO"} + + +def test_left_to_right_multiple_lines(): + """Test left-to-right reading with multiple lines.""" + detections = create_test_detections( + xyxy=np.array( + [ + [10, 0, 20, 10], # "H" + [30, 0, 40, 10], # "I" + [10, 20, 20, 30], # "B" + [30, 20, 40, 30], # "Y" + [50, 20, 60, 30], # "E" + ] + ), + class_names=["H", "I", "B", "Y", "E"], + ) + result = stitch_ocr_detections(detections, reading_direction="left_to_right") + assert result == {"ocr_text": "HI\nBYE"} + + +def test_right_to_left_single_line(): + """Test right-to-left reading of a single line.""" + detections = create_test_detections( + xyxy=np.array( + [ + [90, 0, 100, 10], # "م" + [70, 0, 80, 10], # "ر" + [50, 0, 60, 10], # "ح" + [30, 0, 40, 10], # "ب" + [10, 0, 20, 10], # "ا" + ] + ), + class_names=["م", "ر", "ح", "ب", "ا"], + ) + result = stitch_ocr_detections(detections, reading_direction="right_to_left") + assert result == {"ocr_text": "مرحبا"} + + +def test_vertical_top_to_bottom(): + """Test vertical reading from top to bottom.""" + detections = create_test_detections( + xyxy=np.array( + [ + # First column (rightmost) + [20, 10, 30, 20], # "上" + [20, 30, 30, 40], # "下" + # Second column (leftmost) + [0, 10, 10, 20], # "左" + [0, 30, 10, 40], # "右" + ] + ), + class_names=["上", "下", "左", "右"], + ) + # With current logic, we'll group by original x-coord and sort by y + result = stitch_ocr_detections( + detections, reading_direction="vertical_top_to_bottom" + ) + assert result == {"ocr_text": "左右 上下"} + + +def test_tolerance_grouping(): + """Test that tolerance parameter correctly groups lines.""" + detections = create_test_detections( + xyxy=np.array( + [ + [10, 0, 20, 10], # "A" + [30, 2, 40, 12], # "B" (slightly offset) + [10, 20, 20, 30], # "C" (closer to D) + [30, 22, 40, 32], # "D" (slightly offset from C) + ] + ), + class_names=["A", "B", "C", "D"], + ) + + # With small tolerance, should treat as 4 separate lines + result_small = stitch_ocr_detections(detections, tolerance=1) + assert result_small == {"ocr_text": "A\nB\nC\nD"} + + # With larger tolerance, should group into 2 lines + result_large = stitch_ocr_detections(detections, tolerance=5) + assert result_large == {"ocr_text": "AB\nCD"} + + +def test_unordered_input(): + """Test that detections are correctly ordered regardless of input order.""" + detections = create_test_detections( + xyxy=np.array( + [ + [50, 0, 60, 10], # "O" + [10, 0, 20, 10], # "H" + [70, 0, 80, 10], # "W" + [30, 0, 40, 10], # "L" + ] + ), + class_names=["O", "H", "W", "L"], + ) + result = stitch_ocr_detections(detections, reading_direction="left_to_right") + assert result == {"ocr_text": "HLOW"} + + +@pytest.mark.parametrize( + "reading_direction", + [ + "left_to_right", + "right_to_left", + "vertical_top_to_bottom", + "vertical_bottom_to_top", + ], +) +def test_reading_directions(reading_direction): + """Test that all reading directions are supported.""" + detections = create_test_detections( + xyxy=np.array([[0, 0, 10, 10]]), class_names=["A"] # Single detection + ) + result = stitch_ocr_detections(detections, reading_direction=reading_direction) + assert result == {"ocr_text": "A"} # Should work with any direction From 1e5774e6f0091b0f9706133bbd3a6a7c98b7db1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 1 Nov 2024 18:18:49 +0100 Subject: [PATCH 2/2] Make linters happpy --- .../models_predictions_tests/test_owlv2.py | 48 ++++++++++++++++--- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index 3ad5913ab..6bbcbcfd5 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -15,7 +15,14 @@ def test_owlv2(): { "image": image, "boxes": [ - {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}, + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + }, ], } ], @@ -42,7 +49,6 @@ def test_owlv2(): assert abs(532 - posts[3].x) < 1.5 assert abs(572 - posts[4].x) < 1.5 - # test we can handle multiple (positive and negative) prompts for the same image request = OwlV2InferenceRequest( image=image, @@ -50,9 +56,30 @@ def test_owlv2(): { "image": image, "boxes": [ - {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}, - {"x": 247, "y": 294, "w": 25, "h": 165, "cls": "post", "negative": True}, - {"x": 264, "y": 327, "w": 21, "h": 74, "cls": "post", "negative": False}, + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + }, + { + "x": 247, + "y": 294, + "w": 25, + "h": 165, + "cls": "post", + "negative": True, + }, + { + "x": 264, + "y": 327, + "w": 21, + "h": 74, + "cls": "post", + "negative": False, + }, ], } ], @@ -76,7 +103,14 @@ def test_owlv2(): { "image": image, "boxes": [ - {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False} + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + } ], }, { @@ -89,4 +123,4 @@ def test_owlv2(): ) response = OwlV2().infer_from_request(request) - assert len(response.predictions) == 5 \ No newline at end of file + assert len(response.predictions) == 5