-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tweaks to the model API for supporting machine serialization #429
Conversation
Codecov Report
@@ Coverage Diff @@
## master #429 +/- ##
==========================================
+ Coverage 81.83% 81.86% +0.03%
==========================================
Files 38 38
Lines 2703 2713 +10
==========================================
+ Hits 2212 2221 +9
- Misses 491 492 +1
Continue to review full report at Codecov.
|
dict = JLSO.load(file) | ||
return dict[:model], dict[:fitresult], dict[:report] | ||
end | ||
MLJModelInterface.save(filename, model, fitresult; kwargs...) = fitresult |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So just to be clear, this is fallback to an asumed serialisable function and if this doesnt work the method implementer overrides this to produce a serialisable fitresult?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, although I would say seriasable object not function. It may not be a function (and usually isn't).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course, thanks for clarification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clever.
LGTM
Codecov Report
@@ Coverage Diff @@
## master #429 +/- ##
==========================================
+ Coverage 81.83% 81.84% +0.01%
==========================================
Files 38 38
Lines 2703 2710 +7
==========================================
+ Hits 2212 2218 +6
- Misses 491 492 +1
Continue to review full report at Codecov.
|
@OkonSamuel @yalwan-iqvia Thanks for this feedback. |
This PR makes tweaks to the (experimental and unpublicised) API for models with special serizialization requirements. It is not breaking for users. I have made it in response to failures of the existing API to handle the XGBoost serialization issue.
The proposed API is outlined in the draft documentation given below, for adding to "Adding New Models for General Use". As a POC I have implemented and locally tested the new API at a forthcoming external glue code package for XGBoost.
Very roughly, the PR enables: (i) automatic generation of model-specific serialization formats in addition to a self-contained MLJ/JLSO format; (ii) a workaround for "non-persistent" representations of learned parameters, such as pointers generated by wrapped C code, provided the model-providing package already implements some kind of serialization.
@yalwan-iqvia @OkonSamuel Be great to get your feedback.
I will leave this open for comment until Monday 11th.
Draft documentation
Serialization
The MLJ user can serialize and deserialize a machine, which means
serializing/deserializing:
Model
object (storing hyperparameters)fitresult
(learned parameters)report
generating during trainingThese are bundled into a single file or
IO
stream specified by theuser using the package
JLSO
. There are two scenarios in which a newMLJ model API implementation will want to overload two additional
methods
save
andrestore
to support serialization:The algorithm-providing package already has it's own serialization
format for learned parameters and/or hyper-parameters, which users
may want to access. In that case the implementation overloads
save
.The
fitresult
is not a sufficiently persistent object; forexample, it is a pointer passed from wrapped C code. In that case
the implementation overloads
save
andrestore
.In case 2, 1 presumably holds also, for otherwise MLJ serialization is
probably not going to be possible without changes to the
algorithm-providing package. An example is given below.
Note that in case 1, MLJ will continue to create it's own
self-contained serialization of the machine. Below
filename
refersto the corresponding serialization file name, as specified by the
user, but with any final extension (e.g., ".jlso", ".gz") removed. If
the user has alternatively specified an
IO
object for serialization,then
filename
is a randomly generated numeric string.The save method
Implement this method to serialize using a format specific to models
of type
SomeModel
. Thefitresult
is the first return value ofMMI.fit
for such model types;kwargs
is a list of keywordarguments specified by the user and understood to relate to a some
model-specific serialization (cannot be
format=...
orcompression=...
). The value ofserializable_fitresult
should be apersistent representation of
fitresult
, from which a correct andvalid
fitresult
can be reconstructed usingrestore
(seebelow).
The fallback of
save
performs no action and returnsfitresult
.The restore method
Implement this method to reconstruct a
fitresult
(as returned byMMI.fit
) from a persistent representation constructed usingMMI.save
as described above.The fallback of
restore
returnsserializable_fitresult
.Example
Below is an example drawn from MLJ's XGBoost wrapper. In this example
the
fitresult
returned byMMI.fit
is a tuple(booster, a_target_element)
wherebooster
is theXGBoost.jl
object storingthe learned parameters (essentially a pointer to some object created
by C code) and
a_target_element
is an ordinaryCategoricalValue
used to track the target classes (a persistent object, requiring no
special treatment).