Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweaks to the model API for supporting machine serialization #429

Merged
merged 6 commits into from
Oct 13, 2020
Merged

Conversation

ablaom
Copy link
Member

@ablaom ablaom commented Oct 1, 2020

This PR makes tweaks to the (experimental and unpublicised) API for models with special serizialization requirements. It is not breaking for users. I have made it in response to failures of the existing API to handle the XGBoost serialization issue.

The proposed API is outlined in the draft documentation given below, for adding to "Adding New Models for General Use". As a POC I have implemented and locally tested the new API at a forthcoming external glue code package for XGBoost.

Very roughly, the PR enables: (i) automatic generation of model-specific serialization formats in addition to a self-contained MLJ/JLSO format; (ii) a workaround for "non-persistent" representations of learned parameters, such as pointers generated by wrapped C code, provided the model-providing package already implements some kind of serialization.

@yalwan-iqvia @OkonSamuel Be great to get your feedback.

I will leave this open for comment until Monday 11th.


Draft documentation

Serialization

The MLJ user can serialize and deserialize a machine, which means
serializing/deserializing:

  • the associated Model object (storing hyperparameters)
  • the fitresult (learned parameters)
  • the report generating during training

These are bundled into a single file or IO stream specified by the
user using the package JLSO. There are two scenarios in which a new
MLJ model API implementation will want to overload two additional
methods save and restore to support serialization:

  1. The algorithm-providing package already has it's own serialization
    format for learned parameters and/or hyper-parameters, which users
    may want to access. In that case the implementation overloads save.

  2. The fitresult is not a sufficiently persistent object; for
    example, it is a pointer passed from wrapped C code. In that case
    the implementation overloads save and restore.

In case 2, 1 presumably holds also, for otherwise MLJ serialization is
probably not going to be possible without changes to the
algorithm-providing package. An example is given below.

Note that in case 1, MLJ will continue to create it's own
self-contained serialization of the machine. Below filename refers
to the corresponding serialization file name, as specified by the
user, but with any final extension (e.g., ".jlso", ".gz") removed. If
the user has alternatively specified an IO object for serialization,
then filename is a randomly generated numeric string.

The save method

MMI.save(filename, model::SomeModel, fitresult; kwargs...) -> serializable_fitresult

Implement this method to serialize using a format specific to models
of type SomeModel. The fitresult is the first return value of
MMI.fit for such model types; kwargs is a list of keyword
arguments specified by the user and understood to relate to a some
model-specific serialization (cannot be format=... or
compression=...). The value of serializable_fitresult should be a
persistent representation of fitresult, from which a correct and
valid fitresult can be reconstructed using restore (see
below).

The fallback of save performs no action and returns fitresult.

The restore method

MMI.restore(filename, model::SomeModel, serializable_fitresult) -> fitresult

Implement this method to reconstruct a fitresult (as returned by
MMI.fit) from a persistent representation constructed using
MMI.save as described above.

The fallback of restore returns serializable_fitresult.

Example

Below is an example drawn from MLJ's XGBoost wrapper. In this example
the fitresult returned by MMI.fit is a tuple (booster, a_target_element) where booster is the XGBoost.jl object storing
the learned parameters (essentially a pointer to some object created
by C code) and a_target_element is an ordinary CategoricalValue
used to track the target classes (a persistent object, requiring no
special treatment).

function MLJModelInterface.save(filename,
                                ::XGBoostClassifier,
                                fitresult;
                                kwargs...)
    booster, a_target_element = fitresult

    xgb_filename = string(filename, ".xgboost.model")
    XGBoost.save(booster, xgb_filename)
    persistent_booster = read(xgb_filename)
    @info "Additional XGBoost serialization file \"$xgb_filename\" generated. "
    return (persistent_booster, a_target_element)
end

function MLJModelInterface.restore(filename,
                                   ::XGBoostClassifier,
                                   serializable_fitresult)
    persistent_booster, a_target_element = serializable_fitresult

    xgb_filename = string(filename, ".tmp")
    open(xgb_filename, "w") do file
        write(file, persistent_booster)
    end
    booster = XGBoost.Booster(model_file=xgb_filename)
    rm(xgb_filename)
    fitresult = (booster, a_target_element)
    return fitresult
end

@ablaom ablaom added the API label Oct 1, 2020
@codecov-commenter
Copy link

codecov-commenter commented Oct 1, 2020

Codecov Report

Merging #429 into master will increase coverage by 0.03%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #429      +/-   ##
==========================================
+ Coverage   81.83%   81.86%   +0.03%     
==========================================
  Files          38       38              
  Lines        2703     2713      +10     
==========================================
+ Hits         2212     2221       +9     
- Misses        491      492       +1     
Impacted Files Coverage Δ
src/hyperparam/one_dimensional_ranges.jl 80.76% <100.00%> (-1.59%) ⬇️
src/interface/model_api.jl 92.30% <100.00%> (-0.55%) ⬇️
src/machines.jl 84.35% <100.00%> (+1.14%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7a4920c...29a2797. Read the comment docs.

@ablaom ablaom marked this pull request as draft October 1, 2020 01:42
src/machines.jl Outdated Show resolved Hide resolved
dict = JLSO.load(file)
return dict[:model], dict[:fitresult], dict[:report]
end
MLJModelInterface.save(filename, model, fitresult; kwargs...) = fitresult

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just to be clear, this is fallback to an asumed serialisable function and if this doesnt work the method implementer overrides this to produce a serialisable fitresult?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, although I would say seriasable object not function. It may not be a function (and usually isn't).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, thanks for clarification

Copy link
Member

@OkonSamuel OkonSamuel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever.
LGTM

@ablaom ablaom marked this pull request as ready for review October 12, 2020 23:30
@codecov-io
Copy link

codecov-io commented Oct 12, 2020

Codecov Report

Merging #429 into master will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #429      +/-   ##
==========================================
+ Coverage   81.83%   81.84%   +0.01%     
==========================================
  Files          38       38              
  Lines        2703     2710       +7     
==========================================
+ Hits         2212     2218       +6     
- Misses        491      492       +1     
Impacted Files Coverage Δ
src/hyperparam/one_dimensional_ranges.jl 80.76% <100.00%> (-1.59%) ⬇️
src/interface/model_api.jl 92.30% <100.00%> (-0.55%) ⬇️
src/machines.jl 84.02% <100.00%> (+0.81%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7a4920c...0e368ae. Read the comment docs.

@ablaom
Copy link
Member Author

ablaom commented Oct 13, 2020

@OkonSamuel @yalwan-iqvia Thanks for this feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants