From 51cc93765e6d2c01b42206f73e65e04e0bb26499 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 27 Dec 2022 19:30:51 +0000 Subject: [PATCH] Specialized ivalue dicts using string keys and tensor values. (#598) * Specialized ivalue dicts using string keys and tensor values. * Clippy fixes. --- tests/create_jit_models.py | 10 ++++++++++ tests/foo.pt | Bin 1902 -> 2030 bytes tests/foo1.pt | Bin 1467 -> 1659 bytes tests/foo2.pt | Bin 1595 -> 1723 bytes tests/foo3.pt | Bin 1595 -> 1723 bytes tests/foo4.pt | Bin 1467 -> 1595 bytes tests/foo5.pt | Bin 1467 -> 1531 bytes tests/foo6.pt | Bin 1531 -> 1659 bytes tests/foo7.pt | Bin 1595 -> 1787 bytes tests/foo8.pt | Bin 0 -> 1595 bytes tests/jit_tests.rs | 19 ++++++++++++++++++- torch-sys/libtch/torch_api.cpp | 22 +++++++++++++++++++--- 12 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 tests/foo8.pt diff --git a/tests/create_jit_models.py b/tests/create_jit_models.py index 06fc4f48..376a76ae 100644 --- a/tests/create_jit_models.py +++ b/tests/create_jit_models.py @@ -1,4 +1,5 @@ import torch +from typing import Dict, Tuple class Foo(torch.jit.ScriptModule): def __init__(self, v): @@ -107,3 +108,12 @@ def make_input_object(self, foo, bar): foo_7 = TorchScriptExample() foo_7.save("foo7.pt") + +# https://github.com/LaurentMazare/tch-rs/issues/597 +class DictExample(torch.jit.ScriptModule): + @torch.jit.script_method + def generate(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + return batch["foo"], batch["bar"] + +foo_8 = DictExample() +foo_8.save("foo8.pt") diff --git a/tests/foo.pt b/tests/foo.pt index 4759d17f2a2bd318076ad45fdc10ed862c613de5..ae7b294023efae893057fb496fcafb0907c15743 100644 GIT binary patch delta 608 zcmV-m0-ycv4(<=IKm&h;l3Q=nFcim=UbZ&7v>Ug9gyg-jly!-;2apDv7!sFq2uLLo zenC+#ZWBn8+HuDE#+$z--+?axI87DC_7Ka~v3&mi`G2hQqk;DnMS1;kY@OTA=aK!^ z`DBkrAKp0j`OEjOsG?(;R;Y>}%7El}!Bggk#@C1&na`p*cV>TthiRbEMx_bJ1Wi)* z-DQEEk$5shmqWA|oM^O(AFVXSV1(x^LC5GZxT zGm>y~8u2jqOzwwAEHgRDcxL*HxSTlBk~2#K5@*JILEBiBmo>VEbr}msvpm*abb_LU z50NJ~Hc7IS4bXql*=j2duH)X);L4gOt+^WQR1h}^>LK0hki!C^;e!<(GHNXs`jBum#)DgFft(mv*I; z?!hrU70?jyJ1hY21T+Qw39EBUQNxOCwbiilVI7niM?gzh&#q)Ry)TTvP)i30V_eTy z?g0P*k^_^T1RfU?0A_D*FJo_Rb97;DbaO6nYiveB4*;|A1ZDvVV_eTy?g0P*l9ONt uUIIP^lfVW&0hE*e20jrI1^@s60000`O9lr30001tlT8O31`-AU0000#^ANoN delta 453 zcmaFI|Bi2i19Sb<$$R|{8;H2g-O{?sm_I=+D9WhBQ7gCQ#se+xq!VJF9$$0J46^pi zG&kF?`u+enL)mPXo*!C)I`?Jdz5o45v~FNp(xuqN;F;F=@5OWDiqOvySNCdEI^Esw zvm^A-=CvH=E*FesU#5A5RfK-}-m%=z_0v_6_sf46%#)wJ%RdZf#bZ84rlzq=H3e)+~%e4Tq^oD zJNXo@Vbfo-!prWMd7SZ;f5&9Bi%X6yGBG`<#v3Le{Ho|?Ld(pA(_*JvIiGEva#3z? z*fz6Cx>2V*wq>M-%Ko2kotm^xAzNLpTzz7i^<_~D&rQ1elQph$zf{@aXb>_nI;1K6 z=&j^*t(Z45X$9H`r+x4(eKJqwi^`u1dH3AK-fnuquXF3J$p2SYI90zat^dAU`NoC6 ztO4Ha9Q)?{lL=sCU`S?WV3;5!A;@6Jkd~jXpPZjpT#}eqQmj{yo#W=DJF!u3a|3HQ tBPU$_+< diff --git a/tests/foo1.pt b/tests/foo1.pt index 8831620b2331d676ce0be0a3cf40e8c6f22d627d..0a6cc41e738bb2548be475e95b311c92f84c3250 100644 GIT binary patch delta 595 zcmV-Z0<8VJ3;PV90)K_l$!^m?7zc1C>5{rj(o&#BLYkg%Xc|jBfK<_nBC&*zMGBUf z3$naSG6RmCSq#mMo1c#}F8~~P24<4TfpYN3(v19l^Yedff34#K%d(!m9lXBu{P%tD zh5ybQ^xr)9y~`)Bo^s2CHA@!;YR!h;g1c0Xjo0Rb3S7g-24E zi4HosSS)p>b${HNWm=>Ob(+BD%@PUXxEk5GjvO1r#`hLTN~(=b-2T@%IycBxi4b0m zY+pytXUI-@9GQpJN(wCDV z`yNRpSUS4kHx%D@z>1dtoIhdFr-ayH&XU;I&rFz9Hg3P^d~k>wfCW~;Z_OrieWRuT zHK;=aEYq+8`vt!ZeHau_E8zRQVHMV39h$HKHf+KcY{L$;pbfk8O?$ePAHXR*ETCS% z4{Csy1vCoyNtgG3(IxzwmMIiJEP8(cP)i30wd(#8+W`OoX#)TNkQx~h03`rsZ*MU# zV{dMAbYX6Eb1raeY(_#MlfeTbvw#C30SUG0{uA2)003!|ssvvF>yruvKLS()lS>6Y h64C?!0000008mQ?2LJ#702v08n*|yM(gXki005|!3RnOD delta 432 zcmV;h0Z;z>47&@U0)Lg0y-ve06ouWCpQ-reDZ)+3~EJyx6Ck6{8?p+=CH!xJm2QV!rM zY>w8lF-weDyu)HjU9e$;70iXYm^By^F4(Naa^-@BE#-8_ur*scnN3^D$!yx1Rh#7h z++WNJw)B(1{!S%m_Ts7_S(>e7nZ~&lvPlrfQO`tfS}BKoo}UByCf-Y0`2lNJvvQI5df+EP*v{Ki z`?+&&4^Lk`bL^AHFP~7&#Hy`P9p6?3sqsjp%nt`2A~E1Ti>80VnGg}CfkA5p6Oa)a zrR=lI0+W+?)I%3NH0vH3w2tpAnS3xJGL|6o0Nq0uXdlfKi28~&XanC@pcSSQV#RL8 zNFuCpB*M(Igdg@9w*=vWTRtPMAdZUTOwxeFd@!BS6|Aex23^Ibs(G487Mm_QL{TDo z$Ws%WBs^tZ)IWb-PF1)y+@6P9#`07u*PzWp;s!yrWc{*aKbIKz?oyJFYR3ld{Oi~| z$|akHgwSe9>$2o%F4-!M!!xhi(#E&`wk&>rLdK@oU7V~__=4zDwQn2m{qJMMkESg3 z-8k3IzxobtD()c{EK0`5^opQYP0;-E@;TGU{}>a)%c6fcaMU~Fhzo;WGvAv;3_$}; z((Ch*KVQSV0Xo#64q6U007-fSUck8oUBb5oumY>F22EH812$k2TCfFe=)m@3(~d&v zE*!uE2{j4d!(0Fo>Jom0<@(RC6#oj#6u;H_8+a$7AObl#|K@KN1oJ00000002-+1_uBD003DAlN$ya1`-7T0000@wGdAL delta 451 zcmV;!0X+V@4Z94mKmvc2lR-e05Je&Z*MX$V{dMAbYX6Eb1raeY?03~vw#F40SS8;do)i0003N*ss&#H tp97N%20j54lS&3Z5|RV}0000008mQ?2LJ#70NDkTn+6&Nk^}$%007)*!@~do diff --git a/tests/foo3.pt b/tests/foo3.pt index e5e9a1fc7add380ea87cd4b339082cdfc6e9e6e2..b82ffbf335ef71b613ff166e5217230f191196c5 100644 GIT binary patch delta 561 zcmV-10?z%r47&}mfC7JAkX>ukKoo|P>{pEyR8T}v21LwCO?u@;L9mM2_G}c^6$842 zG@0#8%_ch^!Bqr%Hy7S|A$X@h!VCX~($ZhknX+_;9tLLK^PKaZnK|&18;;{Ve>IFZ zN9o(`(N_9;G~9mqA{}i$89ZZ7kZ9~MH@T%Zq~o!yc}~K8Dnoyfb2^jh6iZSUKJ${B zi<$yY(4^)cGhPG&my5S$dO6>18v@B9u$&v~^ z#Sg{-mP2sBQS9RBX4W5K0Jz{m3);|uE;y#F2WyJ0!WIk^xC;J}RoDSg;3@b|mia>~ ze`Mv4_4y6_P|#Ab=W0>FzJj)jFT2Z*_=-)gx>nCMBfd@;(NXbjBfi3cg0711+T@-Q z-;C&~_^EBMW0Qdq-;G#N@q7CQW1D1lPj1ioXEhQuubER#X z%N45#c6)=XgW#rrfTM%|hNAe_cvnjE9tZcn-_N}t!_7S(oUUUccTxu&%w4dTJs;h7 z<1-VdI#W>^45SW2ljibpHfk)P29ZAv;;JXx`U{EBYSqa>s? z5uNiwTfb8BM#-m}uD*oMYr#KIO9u$T?n+HM0RR9p0ssJz8XFP-1^{MnZ!<4rZ*FsR zVQzGDE^upXMnVDrTC;xyA^{1)?n+HM0RR9pld1(@0?GrE3kE&`6q8B@KN6Az00000 a002-+1_uBD007wqlbZ$_29g8-0000GPqlLZ diff --git a/tests/foo4.pt b/tests/foo4.pt index ea3607e32c66fe8c67fdfc768d6143f9192b44c5..420a3d2d374e6f5ab8123626bb551227491cfe5e 100644 GIT binary patch delta 476 zcmdnZy_;u)BU63ennB4(|4_OP9TAvwwkl**2xA@t%41ED7zH?FPyr=hOM7h7)c{OXE z+^(?X_Z9Pn4~Hcm;{LXXEB30M+o9yGjtyIn2s`$?o1oNP(<}aVZGGqC%k%DaA1QVJ z;pTKWE>!+{sdlgPt#_?~?5lsNGznUrF72KEX6ebRYgmqk^n59-zI!n|G{@psSC_B* z#hf$8eSD=KB}92ogaGt{I5wZ)8~9%Rl&f_naY`O`ETB5`=UR=A0>9Q zNOHzo{+s)`zUciwKiMa<*8F7*@Mh=OG5y%3Bt`}XHYNs!0B>d%1_lle1|XQo!^bSd zAU8RYS(H&~@hjK0ICq4EYGUNB*HS;o7GNm63_?`2=HcO P2hlet&tjEkgNOnEFT%#s delta 342 zcmdnZvzvQ^BU62!q5ok6f!4bwu~wg%ZNye6CFpSfUduCal}5{vgDWptmBj>_ZTp@5 zhn>yg3{QhZ#_CniR|Wk^c{$f^|K;7!6IgW8RL-g}x9ws-cc;2D`q7+{5TPif68uNNs79rwK=(NOMO=2_wX}GS2rm=g*T!BG#j`B84od3k8W&3! zKF_ne9KX?Y#hp7(o;Y0vmKVAg zbc=nNs7hlhsBsklIW0ebU^pfMM{}&zny6n2b)9td|fee=n@0ka9vvV{}RpB#bWMG&N#1o{Xgc!IP((?07^^^1S zic1pnN{aOgvUA*=*cmogGHWn$!j(^6#S+ihI+=^rp7F|LFIIcOAQlEN2=HcO2Qg+& Lp2aH71`!1SS9qy4 delta 320 zcmey(y_2@i@9jC8 z{!9BDhu5sOWxH0b{`Ktp#G_}Xr8FrTUvf5&xu2Z(Jt5m+>nYa_t2Hj{T;(@K)}nLQ z3(HLpRHtm+l@ogLdQE(t^J3>2;j@H%wf=I=*J3#Oa!ua^6VD%?r^oCP&RAd-ReEc3 z-Ji_OMMt+EF`hkH_EppS(t=&5`1*R^JYp(tp4VAtd-14h9?vheU*C`XzMpX;&5dtO zLGFg+MV_-~T74**y_LtKj)gzKo1NolW#|5L3=9kcj0_ADq$GtHj2P1L^G)@W^Ye;J z67x!m^$N0c+?@0#GqUJ!KESNV$O+dm`4vk%qugXKR(r-Nle<{$1%+6E1_42UHyb;M Nj+}gpRhkVV3IImthUNeO diff --git a/tests/foo6.pt b/tests/foo6.pt index 4e7084aeb804d539a6594cd2ebfc1933da2b5790..a87986635012200634c08edf013b3e441042cdfc 100644 GIT binary patch delta 504 zcmV~k}OH0E*6on^geZ?ob5fvPFsc0fD1S^UUe3WZc+Mu(Mp{7$4 zO`A?;3hmC_{3a3qo0C3l1up}a%lSC>%>6LQDnjV$zU-a(^r__E(28F!-Ca}v?ELnU zBb#U$af2-Db}ExJ2%Wb`IwDA;mbAD@HndPdRT;IX}M4BWk13s$^b2|7KcPu_dw#J;gJ&OX~ zJ^5py9j3|Jn8SbIV*{gYyl?Rt$t@dDFgh7+MZtT_te18a*K|Y~M?zn4)_%5`ECT|h z_*}1L?<>p%Fu;TypbW^vqvG>$2qguE@-~g|m-ar5K0X1HL-;}Q1>KGmn9BPehv&^c zJO$H3xT5$X>_JgMPI+r)xTx3YA+NmkY#1sV=Ez3!CFK(Q08mQ@2sw719Do4;06+o& z0FW9P5&%E|W^ZpcFJo_Rb97;DbaO6nYiveBJdwdYv%mu(0SY;Go*aMy002M&lj{Ut u0nn351wR5*1e1#eJ`&Od00000002-+1_uBD000>VliLLv2GRro0002Wz}0mC delta 378 zcmey(^P78vBU639zxQDWfuoasuSk7(=E?efMhDwRg&C_j=A}fitetQ>=n`+|ktM&~ zJ@W5;Gp_uw+cH`^b76mLYQ>)KvwtcXR_tt_6!)|-eJTID!^fHI(=%*l_=$%8iklrG zzG~j=3^lK)GTHrozkB}e6u$V%Pc_md&})YO)k~elanmOrPg_-br(S)c%1^~IF)yFq z2RBaI@0yuWVikXa;ocgahivn7eoUMD>*|Kn+vnd;I=9u@>Un8*q+IsJDbGuOOZlDi zudZ_c?eQ*z(I$E)w_=a3!0Ebcoa`#{>jlM5ziNrQlOMu=cC~VNgK?GUNn4gZ5tdEf zCv;e*$xh5_`Tjs4s_ac${qq(Po($#hm*WGx**Pv8l3mBj$iUFSIQbm2oVYbZT7JHn zesX?ZaYbScUz$liIk2dC(@!>qo{bFVS)pRaa1yY8I#Q`Qytj^;+)9a((*U<@JkoaohCt-<)-R{i&?F_|gk@P^i8UH`Ov_WMH^5 zS%XQQ(H0o2iojsC1O}^BeeT55S$7-+T%%`w^sZb#$JarH<;jg{Wsw#tM;dL~x{M42 zcbG3+?B1&yE#Bi(UXq-QPw5yid(ZT zmjC*^tL%IDd6lp}na|tXcs}hqw4f)tm7B|t`_yff1t#_(z7976gqLt7J0)J~T6Jzi z&Y!39I}d(mcU=8L#`PkraOwP|F%O@z8zgxBJ{EEN-_)fCRZTRxd_{b)Y^WG(swrLm}PdT+n(cSJ# zU8RR{!htO{SHz1C$7;V9*>Ipua6_L2bJkA@zlZDDyxa5F{1N+K-<*HonQ2_cPZmFg ziwUm_baVJWa96kHA2fRt^LYQ(nygbHhdtV-GtYTq{5GBu6nPsv^m=CjBd?Q*fnkD_ zq!5D|Lt1{mxqfnfUU5lcUP-ZDL3WOtlM*mCl_#+*;)G|R$*Wn6WZ+Q`)D6ml2(Xf2 zvMsBGc7Qh{P)dY>gM$Mshish!LN$=(I@yNRmnn;7ayP3T8)b` delta 876 zcmey(yPIc1nNVp`A(LK!H#>*Hu9kz=3=9m`6PL*|I!^w_D$eZYWILIW)vmrb@StCd zfxzC+qQxyNofRqvXS4IaI=ENTSTaFi%a$cmUOczfU9)Mo%(qKR*1tbnogbw=OU*(e zx_tfCx9;CRI3HeNQD<~^jfJKX!;y(Pyl-y!bp#x`UG1=XXRg0O`pWBj|1~gaB~MY{ z-R|Jqyza5$yojx>FXogxpV5mAp4ibLQ=ePeUp42h_t%qKy&pW2<6EPi+|_qx7uX30~zdA-k;ZkzwK-u=<(pZt~oFTLJ&bFa=XR#4a-k(UjXW@KQP&&0qG z;LXe;!ob170Srr2puivw3|2*8unGf%RiwUmg1z5i1Ce8sxBNMMhS6r(q9+Vf9oC$1 zN_9}0{qzW*U*@A5h3pB5O63U;cw`!^4Ga|A+yvSmSmx>oIK7I_`8#j^{`b|}9b5Ju zX%3qAVZ(JT`P+~A>{ZPj1g}h-XeIn{jX~k$my67nJaO|7yQTKna#4?I*~l&zQ{pc&6vVKgFoJqGPMwUU`~lc^7ZK5Na;;JZRD0r>0A-LL4`7^BJ98 z`H5k5^TOJtY%{iIb?AREwWt-_`#k*OX7l6sWPcaDxx8`xyty{NJa^_jm=sqO zbfdwgj5GYC;z5bV?-%Dc*l!cu8NTqA0JBci*3R>VTyk;tHR;doJp97T)+jMr-eC9{ zf5L2+sN9Wr`8@iJ<=%IhTraK4wK)C^6H`F!?X5gm!>8 zBhXkx+#y@0fTBTr@*h@TroGIQz1i#-MJ9K#*$Yi#fv5=ZW@87jZn93k%_hYL5e5Ls C7+WU* diff --git a/tests/foo8.pt b/tests/foo8.pt new file mode 100644 index 0000000000000000000000000000000000000000..e6bee10da87fbd9e15e8c1abae6798d1b7a1be35 GIT binary patch literal 1595 zcmWIWW@cev;NW1u0K5#M3~BlK7WyfPC5d_k**R`bf(%jUpn)klKE5QsC^;iOp35aO zxx}?1F}ENmm8+0Jv$0kq0%UMWQDSCZW?p(BV@HG-kRP8}9G_O2lM|nmn4DdnSd|5Emxrr=;q`ouOAy>E@({ z=~Q&5gZZI-j#)#^;MsGjPUb6Iz&e#9d zFWP&y9*XAr^{;okY4HDp#ivbX2sadKbKUdsTvhymsmv)&WSv9YGs-F%gb`s+vZIzM|VEoo%r||Q^l1H zGmeDKJ>#2qwLh)TGk%cyQdB5*c7wVZk;O>^Mi+{4lwOk0Ea_f%G z7lp7tTFIhW2OKyT`fk6QF)O89ojw1an7^EeP34977bm-w@^VyB(M zO(i*>x}NKsBwKAIP4cabIGtbReSeS?Gxd-3W-09Cb&NI^*I|Tcl5D>mqDM( zteUgma=U6397N*(>$>jS+{D1F6~G$Ak)=s%DB4`Nj{5Z9P?IP;|Im>zZw03k-)gU1&v?q8C8fu9TvZN z#U+V(CB=~3s!U9bHZW;|b6l=IXE?ARU;trIj>B#h4}%lLs`R!Hzzw{EI~^% z$XsJCkVh}hOAloN+6=-0-i#m$UM3-z1-u{$6o67fL9`+Y3*`I(DtHiJ5s(Siicxf+ zn}Hk$3J3#$tdqc4!EOd3U!i*jIhch|Ok2qeHw`Hk(e)z-6c37i5nz;oT?37p0B<%n w9jFR9W?i^iRwxTby90dz0+T>74+KDc>>&CkQ~^jjz?+o~B*qGaAoUQn0LF3To&W#< literal 0 HcmV?d00001 diff --git a/tests/jit_tests.rs b/tests/jit_tests.rs index f3381d5c..5118595b 100644 --- a/tests/jit_tests.rs +++ b/tests/jit_tests.rs @@ -1,4 +1,4 @@ -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use tch::{IValue, Kind, Tensor}; #[test] @@ -160,3 +160,20 @@ fn jit_double_free() { }; assert_eq!(Vec::::from(&result), [5.0, 7.0, 9.0]) } + +// https://github.com/LaurentMazare/tch-rs/issues/597 +#[test] +fn specialized_dict() { + let mod_ = tch::CModule::load("tests/foo8.pt").unwrap(); + let input = IValue::GenericDict(vec![ + (IValue::String("bar".to_owned()), IValue::Tensor(Tensor::of_slice(&[1_f32, 7_f32]))), + ( + IValue::String("foo".to_owned()), + IValue::Tensor(Tensor::of_slice(&[1_f32, 2_f32, 3_f32])), + ), + ]); + let result = mod_.method_is("generate", &[input]).unwrap(); + let result: (Tensor, Tensor) = result.try_into().unwrap(); + assert_eq!(Vec::::from(&result.0), [1.0, 2.0, 3.0]); + assert_eq!(Vec::::from(&result.1), [1.0, 7.0]) +} diff --git a/torch-sys/libtch/torch_api.cpp b/torch-sys/libtch/torch_api.cpp index 091834f4..bc2f3ef3 100644 --- a/torch-sys/libtch/torch_api.cpp +++ b/torch-sys/libtch/torch_api.cpp @@ -1212,11 +1212,27 @@ ivalue ati_generic_list(ivalue *is, int nvalues) { return nullptr; } +using generic_dict = c10::Dict; + ivalue ati_generic_dict(ivalue *is, int nvalues) { - c10::Dict dict(c10::AnyType::get(), c10::AnyType::get()); PROTECT( - for (int i = 0; i < nvalues; ++i) dict.insert(*(is[2*i]), *(is[2*i+1])); - return new torch::jit::IValue(dict); + bool all_keys_are_str = true; + for (int i = 0; i < nvalues; ++i) { + if (!is[2*i]->isString()) all_keys_are_str = false; + } + bool all_values_are_tensor = true; + for (int i = 0; i < nvalues; ++i) { + if (!is[2*i+1]->isTensor()) all_values_are_tensor = false; + } + if (all_keys_are_str && all_values_are_tensor) { + generic_dict dict(c10::StringType::get(), c10::TensorType::get()); + for (int i = 0; i < nvalues; ++i) dict.insert(is[2*i]->toString(), is[2*i+1]->toTensor()); + return new torch::jit::IValue(dict); + } else { + generic_dict dict(c10::AnyType::get(), c10::AnyType::get()); + for (int i = 0; i < nvalues; ++i) dict.insert(*(is[2*i]), *(is[2*i+1])); + return new torch::jit::IValue(dict); + } ) return nullptr; }