Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweaks to the model API for supporting machine serialization #429

Merged
merged 6 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/hyperparam/one_dimensional_ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,7 @@ end
r = range(model, :hyper; values=nothing)

Define a one-dimensional `NominalRange` object for a field `hyper` of
`model`. Note that `r` is not directly iterable but `iterator(r)`
is.

By default, the behaviour of range methods depends on the type of the value of the
hyperparameter `:hyper` at `model` during range construction.

To override this behaviour (for instance if `model` is not available) specify a type
in place of `model` so the behaviour depends on the value of the specified type.
`model`. Note that `r` is not directly iterable but `iterator(r)` is.

A nested hyperparameter is specified using dot notation. For example,
`:(atom.max_depth)` specifies the `max_depth` hyperparameter of
Expand All @@ -60,13 +53,22 @@ the submodel `model.atom`.
r = range(model, :hyper; upper=nothing, lower=nothing,
scale=nothing, values=nothing)

Assuming `values` is not specified, defines a one-dimensional
Assuming `values` is not specified, define a one-dimensional
`NumericRange` object for a `Real` field `hyper` of `model`. Note
that `r` is not directly iteratable but `iterator(r, n)`is an iterator
of length `n`. To generate random elements from `r`, instead apply
`rand` methods to `sampler(r)`. The supported scales are `:linear`,`
:log`, `:logminus`, `:log10`, `:log2`, or a callable object.

Note that `r` is not directly iterable, but `iterator(r, n)` is, for
given resolution (length) `n`.

By default, the behaviour of the constructed object depends on the
type of the value of the hyperparameter `:hyper` at `model` *at the
time of construction.* To override this behaviour (for instance if
`model` is not available) specify a type in place of `model` so the
behaviour is determined by the value of the specified type.

A nested hyperparameter is specified using dot notation (see above).

If `scale` is unspecified, it is set to `:linear`, `:log`,
Expand All @@ -84,6 +86,11 @@ See also: [`iterator`](@ref), [`sampler`](@ref)
function Base.range(model::Union{Model, Type}, field::Union{Symbol,Expr};
values=nothing, lower=nothing, upper=nothing,
origin=nothing, unit=nothing, scale::D=nothing) where D
all(==(nothing), [values, lower, upper, origin, unit]) &&
throw(ArgumentError("You must specify at least one of these: "*
"values=..., lower=..., upper=..., origin=..., "*
"unit=..."))

if model isa Model
value = recursive_getproperty(model, field)
T = typeof(value)
Expand Down Expand Up @@ -172,13 +179,13 @@ function nominal_range(::Type{T}, field, values::AbstractVector{T}) where T
end

#specific def for T<:AbstractFloat(Allows conversion btw AbstractFloats and Signed types)
function nominal_range(::Type{T}, field,
function nominal_range(::Type{T}, field,
values::AbstractVector{<:Union{AbstractFloat,Signed}}) where T<: AbstractFloat
return NominalRange{T,length(values)}(field, Tuple(values))
end

#specific def for T<:Signed (Allows conversion btw Signed types)
function nominal_range(::Type{T}, field,
function nominal_range(::Type{T}, field,
values::AbstractVector{<:Signed}) where T<: Signed
return NominalRange{T,length(values)}(field, Tuple(values))
end
13 changes: 3 additions & 10 deletions src/interface/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,9 @@ MLJModelInterface.implemented_methods(::FI, M::Type{<:MLJType}) =
getfield.(methodswith(M), :name) |> unique

# serialization fallbacks:
# Here `file` can be `String` or `IO` (eg, `file=IOBuffer()`).
MLJModelInterface.save(file, model, fitresult, report; kwargs...) =
JLSO.save(file,
:model => model,
:fitresult => fitresult,
:report => report; kwargs...)
function MLJModelInterface.restore(file; kwargs...)
dict = JLSO.load(file)
return dict[:model], dict[:fitresult], dict[:report]
end
MLJModelInterface.save(filename, model, fitresult; kwargs...) = fitresult

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just to be clear, this is fallback to an asumed serialisable function and if this doesnt work the method implementer overrides this to produce a serialisable fitresult?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, although I would say seriasable object not function. It may not be a function (and usually isn't).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, thanks for clarification

MLJModelInterface.restore(filename, model, serializable_fitresult) =
serializable_fitresult

# to suppress inclusion of abstract types in the model registry.
for T in (:Supervised, :Unsupervised,
Expand Down
51 changes: 43 additions & 8 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,17 @@ report(mach::Machine) = mach.report

## SERIALIZATION

# helper:
_filename(file::IO) = string(rand(UInt))
function _filename(file::String) # truncates extension if present
ablaom marked this conversation as resolved.
Show resolved Hide resolved
m = match(r"(.*)\..*", file)
if m isa Nothing
return file
end
return first(m.captures)
end


# saving:
"""
MLJ.save(filename, mach::Machine; kwargs...)
Expand All @@ -608,10 +619,12 @@ Serialize the machine `mach` to a file with path `filename`, or to an
input/output stream `io` (at least `IOBuffer` instances are
supported).

The format is JLSO (a wrapper for julia native or BSON serialization)
unless a custom format has been implemented for the model type of
`mach.model`. The keyword arguments `kwargs` are passed to
the format-specific serializer, which in the JSLO case include these:
The format is JLSO (a wrapper for julia native or BSON serialization).
For some model types, a custom serialization will be additionally performed.

### Keyword arguments

These keyword arguments are passed to the JLSO serializer:

keyword | values | default
---------------|-------------------------------|-------------------------
Expand All @@ -622,6 +635,9 @@ See (see
[https://github.com/invenia/JLSO.jl](https://github.com/invenia/JLSO.jl)
for details.

Any additional keyword arguments are passed to model-specific
serializers.

Machines are de-serialized using the `machine` constructor as shown in
the example below. Data (or nodes) may be optionally passed to the
constructor for retraining on new data using the saved model.
Expand Down Expand Up @@ -660,15 +676,34 @@ constructor for retraining on new data using the saved model.
horse](https://en.wikipedia.org/wiki/Trojan_horse_(computing)).

"""
function MMI.save(file, mach::Machine; verbosity=1, kwargs...)
function MMI.save(file::Union{String,IO},
mach::Machine;
verbosity=1,
format=:julia_serialize,
compression=:none,
kwargs...)
isdefined(mach, :fitresult) ||
error("Cannot save an untrained machine. ")
MMI.save(file, mach.model, mach.fitresult, mach.report; kwargs...)

# fallback `save` method returns `mach.fitresult` and saves nothing:
serializable_fitresult =
save(_filename(file), mach.model, mach.fitresult; kwargs...)

JLSO.save(file,
:model => mach.model,
:fitresult => serializable_fitresult,
:report => mach.report;
format=format,
compression=compression)
end

# restoring:
# deserializing:
function machine(file::Union{String,IO}, args...; kwargs...)
model, fitresult, report = MMI.restore(file; kwargs...)
dict = JLSO.load(file)
model = dict[:model]
serializable_fitresult = dict[:fitresult]
report = dict[:report]
fitresult = restore(_filename(file), model, serializable_fitresult)
if isempty(args)
mach = Machine(model)
else
Expand Down
1 change: 0 additions & 1 deletion test/hyperparam/one_dimensional_ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ super_model = SuperModel(0.5, dummy1, dummy2)

@test_throws DomainError range(dummy_model, :K, origin=2)
@test_throws DomainError range(dummy_model, :K, unit=1)
@test_throws DomainError range(dummy_model, :K)

@test_throws ArgumentError range(dummy_model, :kernel)

Expand Down
37 changes: 1 addition & 36 deletions test/interface/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import MLJModelInterface
using ..Models
using Distributions
using StableRNGs
using JLSO

rng = StableRNG(661)

Expand All @@ -27,41 +28,5 @@ rng = StableRNG(661)
@test_throws ArgumentError predict_mode(rgs, fitresult, X)
end

@testset "serialization" begin

# train a model on some data:
model = @load KNNRegressor
X = (a = Float64[98, 53, 93, 67, 90, 68],
b = Float64[64, 43, 66, 47, 16, 66],)
Xnew = (a = Float64[82, 49, 16],
b = Float64[36, 13, 36],)
y = [59.1, 28.6, 96.6, 83.3, 59.1, 48.0]
fitresult, cache, report = MLJBase.fit(model, 0, X, y)
pred = predict(model, fitresult, Xnew)
filename = joinpath(@__DIR__, "test.jlso")

# save to file:
# To avoid complications to travis tests (ie, writing to file) the
# next line was run once and then commented out:
# save(filename, model, fitresult, report)

# save to buffer:
io = IOBuffer()
MLJBase.save(io, model, fitresult, report, compression=:none)
seekstart(io)

# test restoring data:
for input in [filename, io]
eval(quote
m, f, r = MLJBase.restore($input)
p = predict(m, f, $Xnew)
@test m == $model
@test r == $report
@test p ≈ $pred
end)
end

end

end
true
Binary file modified test/machine.jlso
Binary file not shown.
7 changes: 6 additions & 1 deletion test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ end
end

@testset "serialization" begin

@test MLJBase._filename("mymodel.jlso") == "mymodel"
@test MLJBase._filename("mymodel.gz") == "mymodel"
@test MLJBase._filename("mymodel") == "mymodel"

model = @load DecisionTreeRegressor

X = (a = Float64[98, 53, 93, 67, 90, 68],
Expand All @@ -121,7 +126,7 @@ end
pred = predict(mach, Xnew)
MLJBase.save(io, mach; compression=:none)
# commented out for travis testing:
# MLJBase.save(filename, mach)
#MLJBase.save(filename, mach)

# test restoring data from filename:
m = machine(filename)
Expand Down