From 95b3c324cea4509d71d160840c4cc232b7028bbe Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 11 Sep 2020 12:00:36 +1200 Subject: [PATCH 1/5] add a minor check to the arguments of range --- src/hyperparam/one_dimensional_ranges.jl | 29 ++++++++++++++--------- test/hyperparam/one_dimensional_ranges.jl | 1 - 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/hyperparam/one_dimensional_ranges.jl b/src/hyperparam/one_dimensional_ranges.jl index d3297ea7..bc1801a0 100644 --- a/src/hyperparam/one_dimensional_ranges.jl +++ b/src/hyperparam/one_dimensional_ranges.jl @@ -44,14 +44,7 @@ end r = range(model, :hyper; values=nothing) Define a one-dimensional `NominalRange` object for a field `hyper` of -`model`. Note that `r` is not directly iterable but `iterator(r)` -is. - -By default, the behaviour of range methods depends on the type of the value of the -hyperparameter `:hyper` at `model` during range construction. - -To override this behaviour (for instance if `model` is not available) specify a type -in place of `model` so the behaviour depends on the value of the specified type. +`model`. Note that `r` is not directly iterable but `iterator(r)` is. A nested hyperparameter is specified using dot notation. For example, `:(atom.max_depth)` specifies the `max_depth` hyperparameter of @@ -60,13 +53,22 @@ the submodel `model.atom`. r = range(model, :hyper; upper=nothing, lower=nothing, scale=nothing, values=nothing) -Assuming `values` is not specified, defines a one-dimensional +Assuming `values` is not specified, define a one-dimensional `NumericRange` object for a `Real` field `hyper` of `model`. Note that `r` is not directly iteratable but `iterator(r, n)`is an iterator of length `n`. To generate random elements from `r`, instead apply `rand` methods to `sampler(r)`. The supported scales are `:linear`,` :log`, `:logminus`, `:log10`, `:log2`, or a callable object. +Note that `r` is not directly iterable, but `iterator(r, n)` is, for +given resolution (length) `n`. + +By default, the behaviour of the constructed object depends on the +type of the value of the hyperparameter `:hyper` at `model` *at the +time of construction.* To override this behaviour (for instance if +`model` is not available) specify a type in place of `model` so the +behaviour is determined by the value of the specified type. + A nested hyperparameter is specified using dot notation (see above). If `scale` is unspecified, it is set to `:linear`, `:log`, @@ -84,6 +86,11 @@ See also: [`iterator`](@ref), [`sampler`](@ref) function Base.range(model::Union{Model, Type}, field::Union{Symbol,Expr}; values=nothing, lower=nothing, upper=nothing, origin=nothing, unit=nothing, scale::D=nothing) where D + all(==(nothing), [values, lower, upper, origin, unit]) && + throw(ArgumentError("You must specify at least one of these: "* + "values=..., lower=..., upper=..., origin=..., "* + "unit=...")) + if model isa Model value = recursive_getproperty(model, field) T = typeof(value) @@ -172,13 +179,13 @@ function nominal_range(::Type{T}, field, values::AbstractVector{T}) where T end #specific def for T<:AbstractFloat(Allows conversion btw AbstractFloats and Signed types) -function nominal_range(::Type{T}, field, +function nominal_range(::Type{T}, field, values::AbstractVector{<:Union{AbstractFloat,Signed}}) where T<: AbstractFloat return NominalRange{T,length(values)}(field, Tuple(values)) end #specific def for T<:Signed (Allows conversion btw Signed types) -function nominal_range(::Type{T}, field, +function nominal_range(::Type{T}, field, values::AbstractVector{<:Signed}) where T<: Signed return NominalRange{T,length(values)}(field, Tuple(values)) end diff --git a/test/hyperparam/one_dimensional_ranges.jl b/test/hyperparam/one_dimensional_ranges.jl index 8ac231e5..dc559f23 100644 --- a/test/hyperparam/one_dimensional_ranges.jl +++ b/test/hyperparam/one_dimensional_ranges.jl @@ -46,7 +46,6 @@ super_model = SuperModel(0.5, dummy1, dummy2) @test_throws DomainError range(dummy_model, :K, origin=2) @test_throws DomainError range(dummy_model, :K, unit=1) - @test_throws DomainError range(dummy_model, :K) @test_throws ArgumentError range(dummy_model, :kernel) From 839cda8b2b07dfca2095098cdba41a0a66497ab8 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 1 Oct 2020 11:00:17 +1300 Subject: [PATCH 2/5] Implement new save/restore API --- src/interface/model_api.jl | 13 +++-------- src/machines.jl | 44 +++++++++++++++++++++++++++++------- test/interface/model_api.jl | 37 +----------------------------- test/machine.jlso | Bin 11198 -> 11096 bytes test/machines.jl | 2 +- 5 files changed, 41 insertions(+), 55 deletions(-) diff --git a/src/interface/model_api.jl b/src/interface/model_api.jl index d88adfa7..ed9540a1 100644 --- a/src/interface/model_api.jl +++ b/src/interface/model_api.jl @@ -36,16 +36,9 @@ MLJModelInterface.implemented_methods(::FI, M::Type{<:MLJType}) = getfield.(methodswith(M), :name) |> unique # serialization fallbacks: -# Here `file` can be `String` or `IO` (eg, `file=IOBuffer()`). -MLJModelInterface.save(file, model, fitresult, report; kwargs...) = - JLSO.save(file, - :model => model, - :fitresult => fitresult, - :report => report; kwargs...) -function MLJModelInterface.restore(file; kwargs...) - dict = JLSO.load(file) - return dict[:model], dict[:fitresult], dict[:report] -end +MLJModelInterface.save(filename, model, fitresult; kwargs...) = fitresult +MLJModelInterface.restore(filename, model, serializable_fitresult) = + serializable_fitresult # to suppress inclusion of abstract types in the model registry. for T in (:Supervised, :Unsupervised, diff --git a/src/machines.jl b/src/machines.jl index 5abd7915..76d4257c 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -596,6 +596,10 @@ report(mach::Machine) = mach.report ## SERIALIZATION +# helper: +_filename(file::IO) = string(rand(UInt)) +_filename(file::String) = file + # saving: """ MLJ.save(filename, mach::Machine; kwargs...) @@ -608,10 +612,12 @@ Serialize the machine `mach` to a file with path `filename`, or to an input/output stream `io` (at least `IOBuffer` instances are supported). -The format is JLSO (a wrapper for julia native or BSON serialization) -unless a custom format has been implemented for the model type of -`mach.model`. The keyword arguments `kwargs` are passed to -the format-specific serializer, which in the JSLO case include these: +The format is JLSO (a wrapper for julia native or BSON serialization). +For some model types, a custom serialization will be additionally performed. + +### Keyword arguments + +These keyword arguments are passed to the JLSO serializer: keyword | values | default ---------------|-------------------------------|------------------------- @@ -622,6 +628,9 @@ See (see [https://github.com/invenia/JLSO.jl](https://github.com/invenia/JLSO.jl) for details. +Any additional keyword arguments are passed to model-specific +serializers. + Machines are de-serialized using the `machine` constructor as shown in the example below. Data (or nodes) may be optionally passed to the constructor for retraining on new data using the saved model. @@ -660,15 +669,34 @@ constructor for retraining on new data using the saved model. horse](https://en.wikipedia.org/wiki/Trojan_horse_(computing)). """ -function MMI.save(file, mach::Machine; verbosity=1, kwargs...) +function MMI.save(file::Union{String,IO}, + mach::Machine; + verbosity=1, + format=:julia_serialize, + compression=:none, + kwargs...) isdefined(mach, :fitresult) || error("Cannot save an untrained machine. ") - MMI.save(file, mach.model, mach.fitresult, mach.report; kwargs...) + + # fallback `save` method returns `mach.fitresult` and saves nothing: + serializable_fitresult = + save(_filename(file), mach.model, mach.fitresult; kwargs...) + + JLSO.save(file, + :model => mach.model, + :fitresult => serializable_fitresult, + :report => mach.report; + format=format, + compression=compression) end -# restoring: +# deserializing: function machine(file::Union{String,IO}, args...; kwargs...) - model, fitresult, report = MMI.restore(file; kwargs...) + dict = JLSO.load(file) + model = dict[:model] + serializable_fitresult = dict[:fitresult] + report = dict[:report] + fitresult = restore(_filename(file), model, serializable_fitresult) if isempty(args) mach = Machine(model) else diff --git a/test/interface/model_api.jl b/test/interface/model_api.jl index 018deb1b..5dd8767d 100644 --- a/test/interface/model_api.jl +++ b/test/interface/model_api.jl @@ -6,6 +6,7 @@ import MLJModelInterface using ..Models using Distributions using StableRNGs +using JLSO rng = StableRNG(661) @@ -27,41 +28,5 @@ rng = StableRNG(661) @test_throws ArgumentError predict_mode(rgs, fitresult, X) end -@testset "serialization" begin - - # train a model on some data: - model = @load KNNRegressor - X = (a = Float64[98, 53, 93, 67, 90, 68], - b = Float64[64, 43, 66, 47, 16, 66],) - Xnew = (a = Float64[82, 49, 16], - b = Float64[36, 13, 36],) - y = [59.1, 28.6, 96.6, 83.3, 59.1, 48.0] - fitresult, cache, report = MLJBase.fit(model, 0, X, y) - pred = predict(model, fitresult, Xnew) - filename = joinpath(@__DIR__, "test.jlso") - - # save to file: - # To avoid complications to travis tests (ie, writing to file) the - # next line was run once and then commented out: - # save(filename, model, fitresult, report) - - # save to buffer: - io = IOBuffer() - MLJBase.save(io, model, fitresult, report, compression=:none) - seekstart(io) - - # test restoring data: - for input in [filename, io] - eval(quote - m, f, r = MLJBase.restore($input) - p = predict(m, f, $Xnew) - @test m == $model - @test r == $report - @test p ≈ $pred - end) - end - -end - end true diff --git a/test/machine.jlso b/test/machine.jlso index a96dc86c13feff3e81234c1cf53028836ca17e6c..fa4ac1e7bd3886b5b16a7e6c91b9dcffde86f965 100644 GIT binary patch delta 7704 zcmaiZRa9KjvTf72LlWFA1b24=!QGw49fErUo#5VRa1E}(AvnQ92*KUmgUjvY+;gst z^WOegYmHhpt7gr;KlT{&tIn)8oNfXf5pV=9L4p(|01g0vX9vhZaS`9Vpa1{>Kz0Bk zYx)P}x?gKLN0z_Fmly#AX>3xt9Y2;ur+05TLWSmyr&4Xp!qp__qBO;6zg z6$umS&;WR1I5TJ(yfB5mopXs242{SM@PKUqiYSyn8vyW<>-v1n40-`?jO{m}+vfg{ zC_D$i7;1_@1PuAhn)|=8ru)CF0mv|vhP$h+lQrxB!>$7Ul9v55QAIKsFa<}{o0a;jbnuGY*S-AKGI9PaKNh|{VR%R>~=Ik&8 z7pEmVF9@26Aqlm`;DIh<2tpMxGr(vD=FW~TAa^4)B}I89XA4UQIVX2ZS1XXY{+Fit>NrIbjJJ?jUzJNs!yW`2MBE1*4_;N6>!)VNniNc51Y@FuAOq zU2V-l4ic`eARjkM;eX7@|A&OCtFyJMrJI|QB}|0BO*mL#m0;RvnA=)9x!YRVnrr&F z{HrP#Oi)trAC+KOkSoaX5AL5WIbg%8x>~xs`}_|k{C{K|S8oDiI+Uq_-AfZLn;Z*uu9}Xvz-b{@}65eM(SJ(2nvD z9H2Y3L{rL|zCy?K9b2sZ_VzJOEO6@fX^rjqE|o&`Dbw@e^sMt?YeOgFu4!ysneBe; zsekGII%DDb{G7)TeC+k?##FuHz}M79RsHpTeXRTuyw|uA5U=C>{G|SL@px?#@N{|> zM1U zZX;g;d@iyNQm#w}7ciF`p5)rbmI6J(_&OJz_L}PPq_WD$_4oKQ(JDvK%G6#dMOY-V z7+x`Xw{(2Ix_B7D?!2vj{@NlSmQ`Q$aP#ASO!=sBbbpPW^E}xz+RUC4kK#lF9b9m| zIyP9am1KB0hgIFOW*--&*b)6MMUz#ne6>x(vamv(ucHBEz_Pay7r1r=&&%XI!S;aO zVR9u@#`oZ)kg?S3e`IpxtY2{8ZDpX;%`K#?&A0=hV>pDQb({ zaQXveBMH;8_cL8vlF<q@gO+-fHSd;8L9&>~uAsqsR<3U84B!%d% z|GfA?Y&ox53ws$_tGAXzcX@nip9I;0ai&io2&?w|sPQ7@-sHr|saXy%44yNo80sr_ zBg@m7$LcrJ3<2cg!+7d~tmz$0t(SE= z#X>EYf|?wHWgx{Umd$;-EU*sSV|tLGkdSNdU7C8m)v8~!yTXvSp{=us#m;02Z6dc`yv5IFjPNc=o`*&by?XV+Z5#7Ao|m3Y0S;~7j&6TE{yy#%Il{taAnCGv z?pZnE<-32=fdOeSpC12Ag&eIIG03UhDpfVBV^dqGI(KV_G1mA(#0y-Q%UHw{@iEu6 z$;i?`|CeS^sSoa_*PY?_9i-Y1LVLdjRvbhDE^V}PYkV1DOP(4OXKnUvMwXFu%-?U< zea2vp&_7%hT+vNES$*WWO~1RoSORSGJZbm(@YGr`w!rlQHKj_xu~k(#v1^516hz zs|Lb8jqhHXvCYr=Bx$=-#PTDNSWlz{>m`m4d!aMi!7~jPF?ddaU%DZ9g-SD~+j@be z!#g?`oQRi5tE{L5cegUCExJua<*SCut5*W=$3*pcv9WQ;uy^uY1Ixo`1FjcDo|m)q z?RaBUKWv&$%hkVcoR5!&SAbI!A1Q3hS}Of5(huB~eFsR>#cPCDO?9M;0Y9<>%CggeUw$c0+@JCqKdIDQkKqCASsC_)UG0kg8>YKX!c)X>zID=cW ziuN7-D7A@#w{t(%X5hbIC-b$Vd51CTAtQ1G|FS8|O|d@%1DkDbI}D=Qp2d&P0`X997PL38xt8R7i&CBAQi`cZ?EuU7#Or02N8 zToOArxOk;@e(~#`08WkpD{$(4QT9;P!g#0hq>p4%!-CafNqO^2KH89A|jNv_}add49?K{o+WG^b(+KGB(=wWv!lS$Fh!xL(T$$i9AhL)6Lqce zVE2$H{Jb2YNt3s@0L*$3hkLxPrz`zM%B)B%luye6Kkyiv_*O__3HR+OR`k6-TMMV= zMqHHW&lOVng&{@zcIAX`@qRk{I)Fas7ccY@chr;{EV++4twj|_#5qNjEKyx{9r7K1 z;=O%X6YKf=K<<_0i^>Nwa68gxGGL1&HCCSv9E~45%L+7BHe%ZKO*V$TKn;D6%FRN! z>L&q79-ka7s`E0BjRZxk1pQR^DtepS^_lylwzjqmt%Rg^=J8`erE(njhhV9j7*#X) zK0fRZVDZc%JyMQbjDgQ@>|Zc^$IVPWLp!PKMcA)MeQ#0BFUWZx`GwyIs9EX7AX$4Z zu5QaO#ymjj(4bTmPE-JKoSQs??;>dds#@Qj*1X7*AvYPjUInOke;zOXsNWiwNL`A_ zp}TI7_`|2z$mrw62bYk@kVo|Sb-ab}LVB01CNLdlD=HRn3IHl2~b z0k155BTslKu{h}vE#MA%cxjT2)Vj`6Dt0l~q$QTdsux!l*QPst9#qe=E>v=smWIc( z(D5dQ5V+4+zKAN)-ae$E@-{+#7)_HBcl2S9i0{FVx<_t_JIvOmIQ(J5&{ z0JzDHxuZs|K5|2FNx~5|ot>GraHC8c2E?;n-I7Ga@in|}LJ}c{!IC!qxnF!Hjlqnu z6xN8?jTr6PGz6ra`7bO!9$BTz^gefuhN%dn&GBrtxhgIeNH{=44@PkE=(>XJLr{Y^ zr$Y|$o^r=F1u~ozknLvXGX~uZ`FJo1oP-v(bIDS!CRAn~x0dkh+{3miZ^E^d!=%NZ z%;_?5g2#O_T&mgg$f?nt@D(=M4jIwmUf z>#raRLR4q3gKwJ`{G?iMV=0cwn;czh!LzH`60HxCk?J#`Z(ri1B-=&I_T=UKzEvDqpRfK5E-xB)Wbn!j8uNH2cx3Z3TRl+_ksU zc5M*o?alIRCc!gfYwr8Zu#Ory>kZ7kW0!KrPoq+Zl%#ekM^w5j(nZO~KVI|S+)?QS zcTtEHEi9>cHFXBu-CiAkU4FdVn~CEik=^hQ9$p4sORv7YfM$Z1dXr4tgUnDV+f*KFX>q5TqQ5Ntm+4-d^N>*fj z`IPA6kt{(yVY~6bZ+y>=Lc<0c8-c2bQg~%`%AZA%`|TU})*dz`!)Vp$NpPF+yCTPx zgrD^xJ1V_$-H-W)UK6jaZCx8bx>Yn|#eI!Q5VAy8Hf50-fwS4kb@iy?7f3vZWVei) z<>>)-&u@m!dV0kJR4-0n{Gg?h48;Qq?;;cwuzZg?^och91gW2*RN^N?lM1KS*SfB4bS=IabfWa3eX}Lge>s|0~f2}u!;t-k&^0goJ zNxj07PEv#dNAJU|Zo-wuj?qmynfa2-4iZU~iLdaojV z5VkkX$tVDs$AQ5>CEVapxkWoJr5NDHgIs4w+nen^f;%X`NE^N5(qec9AGuWDEyxAFxi@KgF_8$&(Z!#%t$-SVE-Vzkg#c=4 z*_YH=B4EmjDvN{wJjseP>$g`~Z+B{%y)v*L;@TFNw(U&Vm*1X&un+ogGcG@7_fB!7;y9_lEJ)+vqeM#A(DXj`n%qm(`f!zUG+H1VX1xnU6pG+tAm&rv28$b zj)ZL`o&?xe=0oolWRH?fqJrkm_{??;0qWmSW0^HjF~V!t@cJz~O=4~Me14+6T1rb_ zY&*L-gWA!QkpTbQ1%e@OajMnR)LdXyt{Tusch8C2;`eUDTW2@O;y$r1x;tiUSt(bM z(U(v`uD5dtog}Vy{0cu7eQ|_d#4wbgDHH{HU{NINA9*~#OE|UeuaFxS+nW&4?G|ef&I}lin|Xk9eRN8Z z9|4n9To!u(5g@h|*nwTBJE1>^T6T1i&Aw2M*=I4S+n|LgY2m!cMQmo19x_)b*bpaa zo?^EBN#=*^cyVFH$2Mh*{Y%TB<{h<*yD2tpk_W=P%VchlgZsKqf-GPfRGI!On);(cHY8c>lJ#ZlAPjZFfi4{Je#}1OOBA7aP_V*D)!C`foPj{&+V>C; zelnuxSKl;x1fHIjkdKAsm1qRB9P1l!~HNL4hV1s|)R(>4_h<9-s;eet0> zU9+6&?{<9{&PeLI{l;f}B{b|R)#re@v5oh2AYpr)=PMVt>!%8m2dUKWDdb^JT9OO; z$3$o};tf-Vsg*|TikPOjDh`g>i1jBMa`&G7HW zV>e|6?Kgyjy3KhAO88fi+N8+b%yN8+$n2-gWhA`~-;9&KbV%8Uu=&34AvkU6=Q{sQ zU53qf@jKX0M|iompY5{oap_$f+{gjO?}%{vks5c_SYAOQ@6SlR`k}vtql(kOju6`) zWo`)He99uDo65HAaw+9|Pf|4>-Rr)s4ju=Z zI#eb%A8`JTXF!NVNXmO@+;xWmZh6s+M8m?iSRwNwcwo_C3bU0tP8$IGEdzy8i=#BD z!HwKd%BSN`A<wi=D{8OhgJtB@}9%}*Ed2!}=(jp9mv?`9C{GCwB{5!uqK z?zmeEGSGfMY?P1lRdD@K2JSJLTp(`CMqI#*_=Uwg_|rIq<14f~!t^jhaM0x?4!sMA zorS?N;LX(WXXK}fxdR6I5t8*mgN^uu#YB5STI#Thf^1>UZ)yFLL24XY`bD#-@wxS7 z7^v3z7iO2@gajQ~1Xs|v>}2|S-+x}5U?~s_t&Zl4i&JpC?vp^%R{fc*W)O?S{$u;C zYkZ*SFG`59*-&1U3rVifd$2D2es$wAKJKP zynO+Q0vX3Wh7BaCNEOzu0zp|F+H;hs{JKXk{3C`uOXrU?J94`0J*L^ELno&jd)a{E zi#GUy&0oS*D1H$h&|vaaOC@keCxr8R1yI`Yk5%d8{L}{(o!7naC>O|n#{D_5!;N(zQ>@0rEZktD1g|(mO(qe}v&R>c#v=vRnS8bFsA6q#I!Kw! z{Y86M6{T;S33DTA(T#ZID>aXb-s889hdtxy(Dkxr0u&rk#SKEP$Y^``=j>pYF%w?JQem_O}54!9>CY#g)P}plxDy*m`@AO!~T(Dc#3Ll z|Jy3%CeAdr2QBAME49{J1Ocz2cPq8axKCn+;-p z@m+5~pauOqT)#X7FB3m!AgS>iBcT{wzHW|wITSBTe3hGQ+|tV3Tk=6fT-%pY z!!$S~e^nb7d?heUTl9@0wVRxNHcMcHY0hWNeB4es*`-dKsqKVvn*T#tmk#+RoS8K* z^2j455c}pD(|P$>^928u?^GkfhN%6PBr=ZKhs-H?b>C=#}3Uf5EDs#A$ zma`zjVTE_dR86A~xYV}lkcXJ_F;EQFU9^iV=TZ13SIr5BF-m-InTgo)A#bOt(Zraa zc8mY1Lj(2Mg9i5Z3yqr9l_0~@u4!E6-VCV>#^#}W=9j=)({e6EQHNM~b~{Sv!rpj1 zQ!soaJtRd$&cj~MR`%lCWEJLfHNxqWYqpP-;}Nq}+*gq>A|6bB>=b-9WA)z*-iCy{ z1jw=#aKLxPg5}0D$7XqOg{xM8WU{i}Rl=?z7Mm>432TmH#ye>$%FU*envnXX=VkFp zH~h=?La@(~j))FkYCM`JQDy8T0mJER7x=BK_BQ?}r0R@YNH~!dv+n29xIVWA#xZEg zVs3blXX{L|+SZW!nE6Fq^T8YqJA6xap6Z-FVU{Q`R#fm8m15ZqdT}oeyT<$X+X=#a zm_j{4o{C;HSXn|rCQ0nFBT?E}`g zn5GJH#nh^wG~2?Z5CMa1_UbZhmSn)ayJI2!}tvl~^LbOgd*(OGQ zkXz`Bylo5Hmcy zz(NvEKCl%Awkc0{(oRDKsKxi8H&N5O6-8Y7N+{YvoK-K9zQJa1qT3T VkpAOAaDrJ@fPVw{>yQ1-`wxE3uW1oy@vK+xb4+@*1C0tAN_cM=?e26y+M!6kT*Kydfq1VT7` z-+!^s&H2YZH?`KNsyVCHs!>S1gAk$r=F)IiD-a@2ozFW2gr%fo1P(|A$+OR0(mD$% zg4S~rcvzoj{@+aip*$~ZJKMpXpCgi`xw|g4j^O#LWhHi#} zi2;)}=@$;fJf~?9{=c?)6y{Ys>Ui7v%f9P?ZH0agCq@_6pZ~dkbdgV3-^e>wPAS)b z%0|(1uB;VJuTQ@2{(G{*)$OPA+xapLcp96W>?-7d=R7W(=5!kEbXPU|Jk^@0;Wz_p zorWtNM>wX__bomTfZ*f#(e*#9Tr>t&j#D&*YPdY3}Lqh$Zm*cm=R$Yj&^> zQq%BH|I$?GrcO#wo7bvS(NI311A`1MlxC(Fao0syLHL`bGGvrZ&}}O4%EFkXw40KP zl~gx=*mPg|gnYUs8D{fkC^7FY@*9UQG%h(0)&cE1{7&$NHJJ8%96E4~np)EB*TA=3 zSt>%!>rLXoE}Th`w0o$W41`u~;1bCW4AExds#FW3E3u$Q7!!QQN82-=jiQih&cf6; z9(j!K_I|jz7Mky#3GX~n@*OUP!*xt+s=Y^ZzFnV;32Nf5q^^0j3kB(E4)U}omREMo zGg`#_u*)_`6y~F3Ph?xN7L@{h)G0};W#nG>r(i&Lgr!`?kl&6fyEi)^u z47D%|+A@-H^pYE;7`gwRRnc&*`*8hqgQ5m+3-Hy{9Zq_;pEg)x#QkM)y>MZ0u*(@j zvP^a~JPnl%@pIY{0>mqr*+{C?=dSLCRp=QyZ54@3jlMM(%)|Ud@!PxEW8;)(iG2+l z?e8d5wKdh78FZE2Pp41Sf0&>iB7Uzt!ouPRK{KqWjO@{-#WLJd;x!mwI5l^SRQG8x zUQ0W9ymvn+p2pAL@ce7n{w1gPPE<9Ocp-tm+(50(&Ux|p5GYS>Q-sxIc$PQrJ4W>( zb?ih_5V54pyzGh6dr$PXB$0su_f6J7HM*|7U>KM48d&{xVxZ~8^`RK3{)X?TSi_Rz zLWSSQJr7bJ@fBIBEX}^*@S4@xmHz1N&Ie}U9H`>Z^%t9D1d3hDSzvWjJM*O65nI@6 z{m0*RYi(yXF2D&(S4PZA>f*w zcUG3I%_B<}lL=}E*_qNfudP%=AC6XVcm?#r9(PEWlKOCv)&Xm0G%`%xEA{p`= zG+U2;9>Cfseb0FDBbz1Ag~($QvtY^3|Mk}A+sJ5RP8Z3|_1iQANGJ!}VuNcXhLmLN_ic%L|7fNV6lK3^5BkQ88&|SJ7G&+e7I2Z9;kN!X zp0|L5-is32dCT!7c?RMEp=X^vo@&L8!1NAC)sT-78Gp`gG~VF@JY>C+Wx66F81pH%x136Yg{#B_yvlO?dN0-)O6YV_M(!)_7osye;f&LCy z4)$sfVyj}tpngYu)b4#uG%T^c>a(I}jo@xm8U2>-o=~O=?F@BlMk;de;xQ) zwMJJ>ywia6>iHdIdb~;rM=-0vc1eM6vFUQ~y#@O4X(3-;t0AI(4>8kxsXL*Ey8BJ zMnn~xCZ3SHBR7$OQL=*#JTrqdoMF;CuV z*YGDOFNWK+TX;0n{yTMbC5)}S=1wi-3}zkqw-6C|!dd<@bF0rM8(|G(tt+MZ9DOE< z(Rwe)L9_`CkP#-<;rh~L3Mt^qgi@t=TMD|7Nh19-1tsqV3uGOQ@Wl*akW(T=3xmE` zDf*aTU}Zu`<+$rUrd5Q8>=lo5!HT=>Xo+9Aqo?i0ykR zDY9%eR_h%X#)agefqD^`eZ32{)7?(+AlMh-_!Z7rdA2cVj1#gs?L4`ec8mqM>Dt*$ z<8^uiA))1|7`mXsy$=Iz0#WMgQpkv6+93m!>vhe?SgZLGjTopnG&7dG1(PaYv_f)b z1>UrVT)<5*m4^sPD**K#_hN}q%|2gSCdlfu(pLMfEFIe$IueP;F){bRC4G>{kEB4v z!5EtM^=Vy1PlkY;xJ+V4d`?dd&x!6+=gmVu2TDsK8zxtwg}uo4-Gq}ko(?tb<9uB* z+|zIxEo}>Gg>54x&pCNURt>%)wj&rO?5eW4V|lX%zEGP07@Kh}KFQi3T!MdrZ{IP0 z?R#|?SecMqsZ%cHBze(F`!`+*8u~em9;;~%fn6>ojAx2!`dzb4{q!mCMX47*N(7j`=_-hV_){TX6e_@Hyt|NdQsa+Y|7PVU-7=-nNvwa&SZ}c@2==lzq&aG79{B-4Ohit%@!rvt)KP60nICDNycCYkMd^ z0qp``Z-X18rR-OlkolaDB)5i~@hiugmxGuPhmFy^TbdVIj#va#88(&OL`_Zt<9GC8 zbY~ir849fpAYJ?F9RemgFG+=XlWA2LSbRAK8Cvz^Oc0ZXV-P!!+Gz!>IoD>uj3*Q2 zR3F367rjUbrBy{;dv|CK9f#N9FQ%gb#h1XW&-;uxdl>@~tgpg)`bj*}5iyt^C`^W- z(i$j6I3yGq2Q~>dJoY36Om@CfDkWZ?pD@z38saP+K3HH&?x55eZ4hLRAviKv{o=u? z7d{A0X_5_$uf3zsqrx;|Wbe4*S7O2#AML*PVs#^o;Lj zBc4Hn{@EOjf~HEf9VeZormst8fGd67t^mut!EP2O`-QX0P!P48hf5!OK}<@h;0$Cz z2@G8{N!c;@lwiXzs$J9TJa!uy9<&1d`D-)Qif4CL|0Y=9)Q5J-JDo7K%}QTSCmT&_ z?IY9yV|kaIJ_YqS+{A>rJQAgF8s*C}DgW#GUlAgw@~KQBDEU^C)pBSa6HOMhzj{LU zp!oN8+TWu5v0Q2i{bf^0KfRnOAH{_Q-BW_p2LmFFXh+R{+!rQ(Ug|+8EMsN^%0`5p z`#b%^5Fc!RyBD`mP1PP%EI?wxTh!R7zANxfJqj^3e9=E##cKiEC_dSQ6N!<7QeLvZ z4yZI0bqt8aWN{IODLBI$$HG-ruVeQPEeLotNv7%X7Ut+;lDQ-r15A8(ldh^4&(YJx zlcpkqf6nYnJPM2GqZt`;P^C=)YC#*fnsqE~JtHiA(u6Qn6@lt#`=ryoANus*A-nR~ zxQC8!4BdY`u5+H=`S!Uz)fvtm?8BF0Rwz{2xNPL=KhkKwhWF4d^a%hGN$FJ)gXfErX4AiNkuISlyU=K~j`MA{S?$;!0r!Db3*?<+H=Xr-qH)N9LE zd}Ns+4z&C$ep0xC4|-n#W~yn#7)k`Wc2|#60tabT7L3#B&M3(@2ty^Io&$Ckncr^8 zCj5)9v#?*5NM^^ME2!g?XiXWsmBvlvEX&1kC$CD5B<6VwldbSA*Z3Q-bMsB72DRve zo+{b{d#qce8g&+ENpRYg?}ZI>ZLK_N}KIPONl0pKlY$ape@g_DN7K!IL;b0Z6S%GrXgG+ zhJMeS@nPZgwsZ&s#4jxi`1gqV9+7tCf-a+H>YNVTvD12FGj?Io$jM0tTBA@tuw~qh z7E54ehFkM|P=rLW%GXx(#LtQg9Yn?oQI=@p6%toN-nxN5q;C|sF%fcv5c45XcF1fZ`KQo{pOhi~;S*+%U(a=e<%YxoRK_3mBfL(`l4BEwe2ss)OKc>+Bt^zUZB%!Elo% z9$?qxw?A%s3w-`-jVvU7q9jYZ8*ouUgHW|21gG$hSMlXDBvj!w@MGS|7Ocg>a-bGi zXgO1@FvRI~#MgO9{dGy9GUJm(B*E|or*Gj^M8`FxJe+~1O8Iq_OF?Dy@G^>-`5UP)h@zw9hZ_=z$}lPJ#1tNz}{0*_=IVO+mGjf zi1Sh;53*d(^|gfx@zUQfvN>Y67ef&H12Hf~XzXOtt9jD1eo&Jc)dK_m=I|C@I7FM|1&MZ+Ah_23E><2p##CHs) zmEJ?MJ*LFo+KlkfviieL{W0wwFb1Vy;pY35gLhdpnR-u*V}a&qumTOaOFk_#8f^5oSW?J( z4#=z|#fUg>fuVY?tv73*jtlG2n<-l*kmkswzYospq1QAxuXQ6cU{NSQYn3eOY5bId`VzHoI_7I^pXtu!NK0C7FmRx2HQicJ#RFg*?T2^vn!Z@ z32SOhBEQwddj#76X5t>ub1e3e-E~}6#dzIB`OH@3aXZg*06A9YCZHDkHL@*BHb94mRoI)!5u?LZ_pw|-%BF2+Z zoIPkiRt50z4&;NQF~=va=g&*t0nHEj1J#D!x)x&7qGPQD_2ugQfm=p&DAf2rO%L|i z5&leb^$A%8$xMw&z7VEt*2X8q`W6&1n@=zK#*L#G$J5KBsw>ykTaNe_yM$~C z**y!beCHSG%d=CE1l=eO;%py!nLeVUlo`B180Kj7r@65bx zJyMjSwmufw#y=B@OR+CA4FxxYuvF&NtZqNxw!pcB#FV?IlKEftQ;ERxEyEUQ0}0TM zoX++m^5+=!OHC`8CTBM$bEPLI9;ih=|GmE$$b9X8hu{>qQ$ve4Q_UDLC)mlo>Q#R& zzpwliXA#X_Cegxm;2>at4Ny#|Kt`7R;ag-R5MbgVpu2HdJ9(GqyJ1m$ub%qp)$!IF z%ms9DI>$*749+5zN+=)ayi4g|D};ABlJYaEPm544J=4yhCgP81=Q!l7+Oyx7zb@w? zibVAlJuOp4er;l9VvvQ`4Hi{oYU1hroe8F32e2*1=+80MrdtG&Dqx9Ek1%~>eVDc^ z1Mw5{Y;n6xK8n>B^90R?!{aNpvdOKtaU3OIB}+Npqz=K}*Ui8!U%*p%^_GaCqG+lt zW!4jnm_08NmDN{?UG`hI)bM3MwB-KX$2R9<-lumFOHU)-51X#1WGs`~%Bcx#?1;hL z8LOx#Dre2{>y(huMZn?Xh0MTqGED>C`J&-^%cweZ+;gZz#E7_T+-RV+!fOd(?3$qF zLk>w2j8Q0nBP_>9ydG{$CPBSOvth@LaIG~^ST{Ib!V4$KpeCTWFM4y*_JXa=i^z(X zjI3Fnu^$RSB_#=B)zc&~gAE?N7S=cpVvy*FPfVuWqLstr1{Ai4X!m%yJ}o`k?byK= zgmy=QV%f5ZvBonMlCkMrna5iH=EKcc1;|RF zWN~exUJtX^Hb2!FXAk7@lLLv9aa=1S)I1t+=&8^Rl5<~g4)^WF2z0A8u)MT$kfdFgD)n^COouG#@J zQp{~(bL#u12t>*R`|H2Gsae;TDqpg33hH&-83G}#Ft9OBs013 zY;`X2Kn5tWSTNk;?_=N>+ygi#q0t6b=|26+dBsR$QELM#OrhkqO> zuc%Z$W;ZZ3JS}Qp3>{*Zl(_CK2b^0{;E@MD73tUTG_Mhkb_OG&9;M7FD9r!5wIB$) z+HS4xGm7)Qhxa{P(ZIVEK}+sdK5P!7tnZ1TJ0tHSs_%=tjgTY!A;K_E~66e zb_*!0y&rg>&n<3eYPY{L>*XLplN;?V&e_r9Ng%j}XFxJQM1+~Xbfi_-oIRhnO_z`sYQD?TSa73{SQptZVr>4*v@;_$r}w{Q1{dglN~h%}Y{r=( zA`^L2Yden56BGSai8G6kx;w?@?az-lw|KzWwna_NT_Ia=h}+b!7k5wuM!rGx6I-78 z)wrQoZ!4;K8+hv-dO+C+Da<-j@lu`-UaV5eByVz$9%ak&+D>{s<;TyyP)@k0cGtoV z(mzWOIAD*oTe7=0^u#!n82r*Wg4muDoAGE7arCHe=}LtRgI@jgfxRTCO_HZmR-FK} zO@y>^R3CpwIR@134*eRZyH^N4U^==GuAo2mgl(|VQ5Afi zRh*MrDZD#Ibih$j6w%Rzb%&(UID`K*w{-37My#WX6isyyiWAYqN0~?j)ef7yd&;4+ zna#3;!C4;w{55jPAmwL>#|L@_kVz5DY>*Aae`%=BGZTceKcle!!V~`ip8s$<9dp;` z>={LFKBGt}gq$Qr7DWV&X9F(g|89Wvye#SL{X9O)dHydseUcn20x+M$PfxxBN>@C0$+3|8vd%13O0d{Qv*} diff --git a/test/machines.jl b/test/machines.jl index bd3493af..d0fcedfc 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -121,7 +121,7 @@ end pred = predict(mach, Xnew) MLJBase.save(io, mach; compression=:none) # commented out for travis testing: - # MLJBase.save(filename, mach) + #MLJBase.save(filename, mach) # test restoring data from filename: m = machine(filename) From 29b1fa2829524ec3847b72e8c34ce83a95f02a67 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 1 Oct 2020 12:33:12 +1300 Subject: [PATCH 3/5] truncate .jlso from filename --- src/machines.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/machines.jl b/src/machines.jl index 76d4257c..621beb30 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -598,7 +598,14 @@ report(mach::Machine) = mach.report # helper: _filename(file::IO) = string(rand(UInt)) -_filename(file::String) = file +function _filename(file::String) # truncates ".jlso" if present + m = match(r"(.*)\.jlso", file) + if m isa Nothing + return file + end + return first(m.captures) +end + # saving: """ @@ -615,7 +622,7 @@ supported). The format is JLSO (a wrapper for julia native or BSON serialization). For some model types, a custom serialization will be additionally performed. -### Keyword arguments +### Keyword arguments These keyword arguments are passed to the JLSO serializer: From 29a279798d7c1e2405e5f9e80417e88fc621d56d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 1 Oct 2020 13:50:53 +1300 Subject: [PATCH 4/5] no, truncate any final extension --- src/machines.jl | 4 ++-- test/machines.jl | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/machines.jl b/src/machines.jl index 621beb30..2b7e1b73 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -598,8 +598,8 @@ report(mach::Machine) = mach.report # helper: _filename(file::IO) = string(rand(UInt)) -function _filename(file::String) # truncates ".jlso" if present - m = match(r"(.*)\.jlso", file) +function _filename(file::String) # truncates extension if present + m = match(r"(.*)\..*", file) if m isa Nothing return file end diff --git a/test/machines.jl b/test/machines.jl index d0fcedfc..c31c4cc7 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -103,6 +103,11 @@ end end @testset "serialization" begin + + @test MLJBase._filename("mymodel.jlso") == "mymodel" + @test MLJBase._filename("mymodel.gz") == "mymodel" + @test MLJBase._filename("mymodel") == "mymodel" + model = @load DecisionTreeRegressor X = (a = Float64[98, 53, 93, 67, 90, 68], From 0e368ae7a2da26102b1e3677a13555679aaeb855 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 13 Oct 2020 12:31:27 +1300 Subject: [PATCH 5/5] use splitext --- src/machines.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/machines.jl b/src/machines.jl index 2b7e1b73..b18c6ccd 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -598,14 +598,7 @@ report(mach::Machine) = mach.report # helper: _filename(file::IO) = string(rand(UInt)) -function _filename(file::String) # truncates extension if present - m = match(r"(.*)\..*", file) - if m isa Nothing - return file - end - return first(m.captures) -end - +_filename(file::String) = splitext(file)[1] # saving: """