-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XGBoostClassifier
can't be serialised. Add custom serialisation?
#512
Comments
cc @aviatesk . The culprit here is the |
@OkonSamuel Thanks for investigating!
This is possible. One overloads the fallbacks recalled below, somewhere in the code where XGBoost has its implementation of the MLJ model API: MLJModelInterface.save(file, model, fitresult, report; kwargs...) =
JLSO.save(file,
:model => model,
:fitresult => fitresult,
:report => report; kwargs...)
function MLJModelInterface.restore(file; kwargs...)
dict = JLSO.load(file)
return dict[:model], dict[:fitresult], dict[:report]
end The tricky part is that, for the classifier, the MLJ |
XGBoostClassifier
can't predict
XGBoostClassifier
can't be serialised. Add custom serialisation?
I'm going to try this for LightGBM models -- I anticipated this issue and put some code in to deal with it but I never tested it in context of MLJ serialisation. If it works I'll come back and point at the code I have implemented within LightGBM.jl which the XGB maintainers could probably also make use of. |
Do you have any suggestions about how we combine the two bits of information - the XGBoost learned parameters, and the complete pool of classes (eg, the |
Okay, I have discovered some flaws in the API on the deserializing restoring side. Working on a PR for that. |
So for the LGBM I was only referring to the bit which causes the error as reported by user about a disposed model -- I made internal structs carry around serialised string version of model so even after its deepcopied elsewhere and the ptr is null then it can try to reload the string serialised model into memory before continuing upon attempting operations such as predict, etc. I didn't get around to trying this yet but I can point you at the code Here's our test that this runs correctly after deepcopying: Here's the bit where we make sure we still have a valid pointer after a deepcopy: And finally, for cases which aren't covered by deepcopies somehow, we do this for safety and in other places where we think we might continue after a SERde we use this: |
After JuliaRegistries/General#23206 is merged, update MLJModels and serialisation should now work. |
Describe the bug
deserialized
XGBoostClassifier
can'tpredict
.This only happens for MLJ machine interface:
XGBoost.save(xgboost::Booser, fname)
->predict(Booster(fname), X)
works as expected./cc @aviks I would like to ping in a case you're interested in this
To Reproduce
the final line results in the following error:
Versions
The text was updated successfully, but these errors were encountered: