Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement serialization methods save and restore to address new MLJ interface changes #14

Merged
merged 3 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.3'
- '1.6'
- '1'
os:
- ubuntu-latest
Expand Down
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJXGBoostInterface"
uuid = "54119dfa-1dab-4055-a167-80440f4f7a91"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.1.5"
version = "0.2.0"

[deps]
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand All @@ -12,14 +12,13 @@ XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
MLJModelInterface = "0.3.5, 0.4, 1"
Tables = "1.0.5"
XGBoost = "1.1.1"
julia = "1.3"
julia = "1.6"

[extras]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJSerialization = "17bed46d-0ab5-4cd4-b792-a5c4b8547c6d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Distributions", "MLJBase", "MLJSerialization", "StableRNGs", "Test"]
test = ["Distributions", "MLJBase", "StableRNGs", "Test"]
103 changes: 61 additions & 42 deletions src/MLJXGBoostInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -676,62 +676,81 @@ function MMI.predict(model::XGBoostClassifier
end


## SERIALIZATION - REGRESSOR AND COUNT
# # SERIALIZATION

const XGBoostInfinite = Union{XGBoostRegressor,XGBoostCount}

function MLJModelInterface.save(filename,
::XGBoostInfinite,
fitresult; # `XGBoost.Booster` object
kwargs...)
xgb_filename = string(filename, ".xgboost.model")
XGBoost.save(fitresult, xgb_filename)
serializable_fitresult = read(xgb_filename)
@info "Additional XGBoost serialization file \"$xgb_filename\" generated. "
return serializable_fitresult
# ## Helpers

"""
persistent(booster)

Private method.

Return a persistent (ie, Julia-serializable) representation of the
XGBoost.jl model `booster`.

Restore the model with [`booster`](@ref)

"""
function persistent(booster)

# this implemenation is not ideal; see
# https://github.com/dmlc/XGBoost.jl/issues/103

xgb_file, io = mktemp()
close(io)

XGBoost.save(booster, xgb_file)
persistent_booster = read(xgb_file)
rm(xgb_file)
return persistent_booster
end

function MLJModelInterface.restore(filename,
::XGBoostInfinite,
serializable_fitresult)
xgb_filename = string(filename, ".tmp")
open(xgb_filename, "w") do file
write(file, serializable_fitresult)
end
fitresult = XGBoost.Booster(model_file=xgb_filename)
rm(xgb_filename)
return fitresult
"""
booster(persistent)

Private method.

Return the XGBoost.jl model which has `persistent` as its persistent
(Julia-serializable) representation. See [`persistent`](@ref) method.

"""
function booster(persistent)

xgb_file, io = mktemp()
write(io, persistent)
close(io)
booster = XGBoost.Booster(model_file=xgb_file)
rm(xgb_file)

return booster
end


## SERIALIZATION - CLASSIFIER
# ## Regressor and Count

const XGBoostInfinite = Union{XGBoostRegressor,XGBoostCount}

function MLJModelInterface.save(filename,
::XGBoostClassifier,
MLJModelInterface.save(::XGBoostInfinite, fitresult; kwargs...) =
persistent(fitresult)

MLJModelInterface.restore(::XGBoostInfinite, serializable_fitresult) =
booster(serializable_fitresult)


# ## Classifier

function MLJModelInterface.save(::XGBoostClassifier,
fitresult;
kwargs...)
booster, a_target_element = fitresult

xgb_filename = string(filename, ".xgboost.model")
XGBoost.save(booster, xgb_filename)
persistent_booster = read(xgb_filename)
@info "Additional XGBoost serialization file \"$xgb_filename\" generated. "
return (persistent_booster, a_target_element)
return (persistent(booster), a_target_element)
end

function MLJModelInterface.restore(filename,
::XGBoostClassifier,
function MLJModelInterface.restore(::XGBoostClassifier,
serializable_fitresult)
persistent_booster, a_target_element = serializable_fitresult

xgb_filename = string(filename, ".tmp")
open(xgb_filename, "w") do file
write(file, persistent_booster)
end
booster = XGBoost.Booster(model_file=xgb_filename)
rm(xgb_filename)
fitresult = (booster, a_target_element)
return fitresult
persistent, a_target_element = serializable_fitresult
return (booster(persistent), a_target_element)
end


Expand Down
31 changes: 15 additions & 16 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using MLJBase
using MLJSerialization
using Test
import XGBoost
using MLJXGBoostInterface
Expand All @@ -24,7 +23,7 @@ rng = StableRNGs.StableRNG(123)

#fitresult, cache, report = MLJBase.fit(plain_classifier, 1, X, ycat;)

m = machine(plain_classifier,X,ycat)
m = machine(plain_classifier, X, ycat)
fit!(m,verbosity = 0)
end

Expand All @@ -35,22 +34,21 @@ features = rand(rng, n,m);
weights = rand(rng, -1:1,m);
labels = features * weights;
features = MLJBase.table(features)
fitresultR, cacheR, reportR = MLJBase.fit(plain_regressor, 1, features, labels);
fitresultR, cacheR, reportR = MLJBase.fit(plain_regressor, 0, features, labels);
rpred = predict(plain_regressor, fitresultR, features);

plain_regressor.objective = "gamma"
labels = abs.(labels)
fitresultR, cacheR, reportR = MLJBase.fit(plain_regressor, 1, features, labels);
fitresultR, cacheR, reportR = MLJBase.fit(plain_regressor, 0, features, labels);
rpred = predict(plain_regressor, fitresultR, features);

importances = reportR.feature_importances

# serialization:
serializable_fitresult =
MLJBase.save("mymodel", plain_regressor, fitresultR)
MLJBase.save(plain_regressor, fitresultR);

restored_fitresult = MLJBase.restore("mymodel",
plain_regressor,
restored_fitresult = MLJBase.restore(plain_regressor,
serializable_fitresult)
@test predict(plain_regressor, restored_fitresult, features) ≈ rpred

Expand All @@ -67,7 +65,7 @@ Xtable = table(X)
λ = exp.(α .+ X * β)
ycount = [rand(rng, Poisson(λᵢ)) for λᵢ ∈ λ]

fitresultC, cacheC, reportC = MLJBase.fit(count_regressor, 1, Xtable, ycount);
fitresultC, cacheC, reportC = MLJBase.fit(count_regressor, 0, Xtable, ycount);
cpred = predict(count_regressor, fitresultC, Xtable);

importances = reportC.feature_importances
Expand All @@ -84,7 +82,7 @@ ycat = map(X.x1) do x
end |> categorical
y = identity.(ycat) # make plain Vector with categ. elements
train, test = partition(eachindex(y), 0.6)
fitresult, cache, report = MLJBase.fit(plain_classifier, 1,
fitresult, cache, report = MLJBase.fit(plain_classifier, 0,
selectrows(X, train), y[train];)
yhat = mode.(predict(plain_classifier, fitresult, selectrows(X, test)))
misclassification_rate = sum(yhat .!= y[test])/length(test)
Expand All @@ -102,7 +100,7 @@ end |> categorical
y = identity.(ycat) # make plain Vector with categ. elements

train, test = partition(eachindex(y), 0.6)
fitresult, cache, report = MLJBase.fit(plain_classifier, 1,
fitresult, cache, report = MLJBase.fit(plain_classifier, 0,
selectrows(X, train), y[train];)
yhat = mode.(predict(plain_classifier, fitresult, selectrows(X, test)))
misclassification_rate = sum(yhat .!= y[test])/length(test)
Expand All @@ -115,17 +113,16 @@ y = identity.(ycat)
train, test = partition(eachindex(y), 0.5)
@test length(unique(y[train])) == 2
@test length(unique(y[test])) == 1
fitresult, cache, report = MLJBase.fit(plain_classifier, 1,
fitresult, cache, report = MLJBase.fit(plain_classifier, 0,
selectrows(X, train), y[train];)
yhat = predict_mode(plain_classifier, fitresult, selectrows(X, test))
@test Set(MLJBase.classes(yhat[1])) == Set(MLJBase.classes(y[train][1]))

# serialization:
serializable_fitresult =
MLJBase.save("mymodel", plain_classifier, fitresult)
MLJBase.save(plain_classifier, fitresult)

restored_fitresult = MLJBase.restore("mymodel",
plain_classifier,
restored_fitresult = MLJBase.restore(plain_classifier,
serializable_fitresult)

@test predict_mode(plain_classifier, restored_fitresult, selectrows(X, test)) ==
Expand All @@ -137,7 +134,8 @@ restored_fitresult = MLJBase.restore("mymodel",
# count regressor (`count_regressor`, `Xtable` and `ycount`
# defined above):

mach = machine(count_regressor, Xtable, ycount) |> fit!
mach = machine(count_regressor, Xtable, ycount)
fit!(mach, verbosity=0)
yhat = predict(mach, Xtable)

# serialize:
Expand All @@ -154,7 +152,8 @@ close(io)


# classifier
mach = machine(plain_classifier, X, y) |> fit!
mach = machine(plain_classifier, X, y)
fit!(mach, verbosity=0)
yhat = predict_mode(mach, X);

# serialize:
Expand Down