-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Persistence: Add support for LightGBM and XGBoost #244
Merged
adrinjalali
merged 15 commits into
skops-dev:main
from
BenjaminBossan:persistence-support-lightgbm-xgboost
Dec 14, 2022
+595
−167
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
6654d9c
Add support for LightGBM and XGBoost
BenjaminBossan 1c13827
Add entry to changes.rst
BenjaminBossan 3196f6a
Lower min version of XGBoost to 1.6
BenjaminBossan 63e702e
Remove gpu_hist from tested parameters
BenjaminBossan 9616b70
Trust np.int32 in lgbm test
BenjaminBossan 2a73e49
Remove example from docstring replicating xgb bug
BenjaminBossan 7fad01c
Use bytearray instead of xgboost custom format
BenjaminBossan ac2c1a5
Merge branch 'main' into persistence-support-lightgbm-xgboost
BenjaminBossan 9b4a9bf
Save bytes in file, not in schema
BenjaminBossan 3910a54
Add CatBoost to complete the holy trinity
BenjaminBossan 1b6aca4
Merge branch 'main' into persistence-support-lightgbm-xgboost
BenjaminBossan 4c73b80
Address reviewer comments
BenjaminBossan f988aa5
Merge branch 'persistence-support-lightgbm-xgboost' of github.com:Ben…
BenjaminBossan 32adb81
Fix missing argument in test function
BenjaminBossan 62fb8f2
Reviewer comment: improve comment
BenjaminBossan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import sys | ||
import warnings | ||
|
||
import numpy as np | ||
from scipy import sparse | ||
from sklearn.base import BaseEstimator | ||
from sklearn.utils._testing import assert_allclose_dense_sparse | ||
|
||
# TODO: Investigate why that seems to be an issue on MacOS (only observed with | ||
# Python 3.8) | ||
ATOL = 1e-6 if sys.platform == "darwin" else 1e-7 | ||
|
||
|
||
def _is_steps_like(obj): | ||
# helper function to check if an object is something like Pipeline.steps, | ||
# i.e. a list of tuples of names and estimators | ||
if not isinstance(obj, list): # must be a list | ||
return False | ||
|
||
if not obj: # must not be empty | ||
return False | ||
|
||
if not isinstance(obj[0], tuple): # must be list of tuples | ||
return False | ||
|
||
lens = set(map(len, obj)) | ||
if not lens == {2}: # all elements must be length 2 tuples | ||
return False | ||
|
||
keys, vals = list(zip(*obj)) | ||
|
||
if len(keys) != len(set(keys)): # keys must be unique | ||
return False | ||
|
||
if not all(map(lambda x: isinstance(x, (type(None), BaseEstimator)), vals)): | ||
# values must be BaseEstimators or None | ||
return False | ||
|
||
return True | ||
|
||
|
||
def _assert_generic_objects_equal(val1, val2): | ||
def _is_builtin(val): | ||
# Check if value is a builtin type | ||
return getattr(getattr(val, "__class__", {}), "__module__", None) == "builtins" | ||
|
||
if isinstance(val1, (list, tuple, np.ndarray)): | ||
assert len(val1) == len(val2) | ||
for subval1, subval2 in zip(val1, val2): | ||
_assert_generic_objects_equal(subval1, subval2) | ||
return | ||
|
||
assert type(val1) == type(val2) | ||
if hasattr(val1, "__dict__"): | ||
assert_params_equal(val1.__dict__, val2.__dict__) | ||
elif _is_builtin(val1): | ||
assert val1 == val2 | ||
else: | ||
# not a normal Python class, could be e.g. a Cython class | ||
assert val1.__reduce__() == val2.__reduce__() | ||
|
||
|
||
def _assert_tuples_equal(val1, val2): | ||
assert len(val1) == len(val2) | ||
for subval1, subval2 in zip(val1, val2): | ||
_assert_vals_equal(subval1, subval2) | ||
|
||
|
||
def _assert_vals_equal(val1, val2): | ||
if hasattr(val1, "__getstate__"): | ||
# This includes BaseEstimator since they implement __getstate__ and | ||
# that returns the parameters as well. | ||
# | ||
# Some objects return a tuple of parameters, others a dict. | ||
state1 = val1.__getstate__() | ||
state2 = val2.__getstate__() | ||
assert type(state1) == type(state2) | ||
if isinstance(state1, tuple): | ||
_assert_tuples_equal(state1, state2) | ||
else: | ||
assert_params_equal(val1.__getstate__(), val2.__getstate__()) | ||
elif sparse.issparse(val1): | ||
assert sparse.issparse(val2) and ((val1 - val2).nnz == 0) | ||
elif isinstance(val1, (np.ndarray, np.generic)): | ||
if len(val1.dtype) == 0: | ||
# for arrays with at least 2 dimensions, check that contiguity is | ||
# preserved | ||
if val1.squeeze().ndim > 1: | ||
assert val1.flags["C_CONTIGUOUS"] is val2.flags["C_CONTIGUOUS"] | ||
assert val1.flags["F_CONTIGUOUS"] is val2.flags["F_CONTIGUOUS"] | ||
if val1.dtype == object: | ||
assert val2.dtype == object | ||
assert val1.shape == val2.shape | ||
for subval1, subval2 in zip(val1, val2): | ||
_assert_generic_objects_equal(subval1, subval2) | ||
else: | ||
# simple comparison of arrays with simple dtypes, almost all | ||
# arrays are of this sort. | ||
np.testing.assert_array_equal(val1, val2) | ||
elif len(val1.shape) == 1: | ||
# comparing arrays with structured dtypes, but they have to be 1D | ||
# arrays. This is what we get from the Tree's state. | ||
assert np.all([x == y for x, y in zip(val1, val2)]) | ||
else: | ||
# we don't know what to do with these values, for now. | ||
assert False | ||
elif isinstance(val1, (tuple, list)): | ||
assert len(val1) == len(val2) | ||
for subval1, subval2 in zip(val1, val2): | ||
_assert_vals_equal(subval1, subval2) | ||
elif isinstance(val1, float) and np.isnan(val1): | ||
assert np.isnan(val2) | ||
elif isinstance(val1, dict): | ||
# dictionaries are compared by comparing their values recursively. | ||
assert set(val1.keys()) == set(val2.keys()) | ||
for key in val1: | ||
_assert_vals_equal(val1[key], val2[key]) | ||
elif hasattr(val1, "__dict__") and hasattr(val2, "__dict__"): | ||
_assert_vals_equal(val1.__dict__, val2.__dict__) | ||
elif isinstance(val1, np.ufunc): | ||
assert val1 == val2 | ||
elif val1.__class__.__module__ == "builtins": | ||
assert val1 == val2 | ||
else: | ||
_assert_generic_objects_equal(val1, val2) | ||
|
||
|
||
def assert_params_equal(params1, params2): | ||
# helper function to compare estimator dictionaries of parameters | ||
assert len(params1) == len(params2) | ||
assert set(params1.keys()) == set(params2.keys()) | ||
for key in params1: | ||
with warnings.catch_warnings(): | ||
# this is to silence the deprecation warning from _DictWithDeprecatedKeys | ||
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn") | ||
val1, val2 = params1[key], params2[key] | ||
assert type(val1) == type(val2) | ||
|
||
if _is_steps_like(val1): | ||
# Deal with Pipeline.steps, FeatureUnion.transformer_list, etc. | ||
assert _is_steps_like(val2) | ||
val1, val2 = dict(val1), dict(val2) | ||
|
||
if isinstance(val1, (tuple, list)): | ||
assert len(val1) == len(val2) | ||
for subval1, subval2 in zip(val1, val2): | ||
_assert_vals_equal(subval1, subval2) | ||
elif isinstance(val1, dict): | ||
assert_params_equal(val1, val2) | ||
else: | ||
_assert_vals_equal(val1, val2) | ||
|
||
|
||
def assert_method_outputs_equal(estimator, loaded, X): | ||
# helper function that checks the output of all supported methods | ||
for method in [ | ||
"predict", | ||
"predict_proba", | ||
"decision_function", | ||
"transform", | ||
"predict_log_proba", | ||
]: | ||
err_msg = ( | ||
f"{estimator.__class__.__name__}.{method}() doesn't produce the same" | ||
" results after loading the persisted model." | ||
) | ||
if hasattr(estimator, method): | ||
X_out1 = getattr(estimator, method)(X) | ||
X_out2 = getattr(loaded, method)(X) | ||
assert_allclose_dense_sparse(X_out1, X_out2, err_msg=err_msg, atol=ATOL) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything in here is unmodified, just moved around.