Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Persistence: Add support for LightGBM and XGBoost #244

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/changes.rst
Original file line number Diff line number Diff line change
@@ -19,7 +19,8 @@ v0.4
- Add `model_format` argument to :meth:`skops.hub_utils.init` to be stored in
`config.json` so that we know how to load a model from the repository.
:pr:`242` by `Merve Noyan`_.

- Persistence now supports bytes and bytearrays, added tests to verify that
LightGBM, XGBoost, and CatBoost work now. :pr:`244` by `Benjamin Bossan`_.

v0.3
----
20 changes: 20 additions & 0 deletions docs/persistence.rst
Original file line number Diff line number Diff line change
@@ -87,6 +87,26 @@ means if you have custom functions (say, a custom function to be used with
most ``numpy`` and ``scipy`` functions should work. Therefore, you can actually
save built-in functions like ``numpy.sqrt``.

Supported libraries
-------------------

Skops intends to support all of **scikit-learn**, that is, not only its
estimators, but also other classes like cross validation splitters. Furthermore,
most types from **numpy** and **scipy** should be supported, such as (sparse)
arrays, dtypes, random generators, and ufuncs.

Apart from this core, we plan to support machine learning libraries commonly
used be the community. So far, those are:

- `LightGBM <https://lightgbm.readthedocs.io/>`_ (scikit-learn API)
- `XGBoost <https://xgboost.readthedocs.io/en/stable/>`_ (scikit-learn API)
- `CatBoost <https://catboost.ai/en/docs/>`_

If you run into a problem using any of the mentioned libraries, this could mean
there is a bug in skops. Please open an issue on `our issue tracker
<https://github.com/skops-dev/skops/issues>`_ (but please check first if a
corresponding issue already exists).

Roadmap
-------

4 changes: 4 additions & 0 deletions skops/_min_dependencies.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,10 @@
"matplotlib": ("3.3", "docs, tests", None),
"pandas": ("1", "docs, tests", None),
"typing_extensions": ("3.7", "install", "python_full_version < '3.8'"),
# required for persistence tests of external libraries
"lightgbm": ("3", "tests", None),
"xgboost": ("1.6", "tests", None),
"catboost": ("1.0", "tests", None),
}


56 changes: 56 additions & 0 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import io
import json
import uuid
from functools import partial
from types import FunctionType, MethodType
from typing import Any, Sequence
@@ -475,12 +477,64 @@ def _construct(self):
return json.loads(self.content)


def bytes_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
f_name = f"{uuid.uuid4()}.bin"
save_context.zip_file.writestr(f_name, obj)
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "BytesNode",
"file": f_name,
}
return res


def bytearray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = bytes_get_state(obj, save_context)
res["__loader__"] = "BytearrayNode"
return res


class BytesNode(Node):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [bytes])
self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))}

def _construct(self):
content = self.children["content"].getvalue()
return content


class BytearrayNode(BytesNode):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [bytearray])

def _construct(self):
content_bytes = super()._construct()
content_bytearray = bytearray(list(content_bytes))
return content_bytearray


# tuples of type and function that gets the state of that type
GET_STATE_DISPATCH_FUNCTIONS = [
(dict, dict_get_state),
(list, list_get_state),
(set, set_get_state),
(tuple, tuple_get_state),
(bytes, bytes_get_state),
(bytearray, bytearray_get_state),
(slice, slice_get_state),
(FunctionType, function_get_state),
(MethodType, method_get_state),
@@ -494,6 +548,8 @@ def _construct(self):
"ListNode": ListNode,
"SetNode": SetNode,
"TupleNode": TupleNode,
"BytesNode": BytesNode,
"BytearrayNode": BytearrayNode,
"SliceNode": SliceNode,
"FunctionNode": FunctionNode,
"MethodNode": MethodNode,
170 changes: 170 additions & 0 deletions skops/io/tests/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import sys
import warnings
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything in here is unmodified, just moved around.


import numpy as np
from scipy import sparse
from sklearn.base import BaseEstimator
from sklearn.utils._testing import assert_allclose_dense_sparse

# TODO: Investigate why that seems to be an issue on MacOS (only observed with
# Python 3.8)
ATOL = 1e-6 if sys.platform == "darwin" else 1e-7


def _is_steps_like(obj):
# helper function to check if an object is something like Pipeline.steps,
# i.e. a list of tuples of names and estimators
if not isinstance(obj, list): # must be a list
return False

if not obj: # must not be empty
return False

if not isinstance(obj[0], tuple): # must be list of tuples
return False

lens = set(map(len, obj))
if not lens == {2}: # all elements must be length 2 tuples
return False

keys, vals = list(zip(*obj))

if len(keys) != len(set(keys)): # keys must be unique
return False

if not all(map(lambda x: isinstance(x, (type(None), BaseEstimator)), vals)):
# values must be BaseEstimators or None
return False

return True


def _assert_generic_objects_equal(val1, val2):
def _is_builtin(val):
# Check if value is a builtin type
return getattr(getattr(val, "__class__", {}), "__module__", None) == "builtins"

if isinstance(val1, (list, tuple, np.ndarray)):
assert len(val1) == len(val2)
for subval1, subval2 in zip(val1, val2):
_assert_generic_objects_equal(subval1, subval2)
return

assert type(val1) == type(val2)
if hasattr(val1, "__dict__"):
assert_params_equal(val1.__dict__, val2.__dict__)
elif _is_builtin(val1):
assert val1 == val2
else:
# not a normal Python class, could be e.g. a Cython class
assert val1.__reduce__() == val2.__reduce__()


def _assert_tuples_equal(val1, val2):
assert len(val1) == len(val2)
for subval1, subval2 in zip(val1, val2):
_assert_vals_equal(subval1, subval2)


def _assert_vals_equal(val1, val2):
if hasattr(val1, "__getstate__"):
# This includes BaseEstimator since they implement __getstate__ and
# that returns the parameters as well.
#
# Some objects return a tuple of parameters, others a dict.
state1 = val1.__getstate__()
state2 = val2.__getstate__()
assert type(state1) == type(state2)
if isinstance(state1, tuple):
_assert_tuples_equal(state1, state2)
else:
assert_params_equal(val1.__getstate__(), val2.__getstate__())
elif sparse.issparse(val1):
assert sparse.issparse(val2) and ((val1 - val2).nnz == 0)
elif isinstance(val1, (np.ndarray, np.generic)):
if len(val1.dtype) == 0:
# for arrays with at least 2 dimensions, check that contiguity is
# preserved
if val1.squeeze().ndim > 1:
assert val1.flags["C_CONTIGUOUS"] is val2.flags["C_CONTIGUOUS"]
assert val1.flags["F_CONTIGUOUS"] is val2.flags["F_CONTIGUOUS"]
if val1.dtype == object:
assert val2.dtype == object
assert val1.shape == val2.shape
for subval1, subval2 in zip(val1, val2):
_assert_generic_objects_equal(subval1, subval2)
else:
# simple comparison of arrays with simple dtypes, almost all
# arrays are of this sort.
np.testing.assert_array_equal(val1, val2)
elif len(val1.shape) == 1:
# comparing arrays with structured dtypes, but they have to be 1D
# arrays. This is what we get from the Tree's state.
assert np.all([x == y for x, y in zip(val1, val2)])
else:
# we don't know what to do with these values, for now.
assert False
elif isinstance(val1, (tuple, list)):
assert len(val1) == len(val2)
for subval1, subval2 in zip(val1, val2):
_assert_vals_equal(subval1, subval2)
elif isinstance(val1, float) and np.isnan(val1):
assert np.isnan(val2)
elif isinstance(val1, dict):
# dictionaries are compared by comparing their values recursively.
assert set(val1.keys()) == set(val2.keys())
for key in val1:
_assert_vals_equal(val1[key], val2[key])
elif hasattr(val1, "__dict__") and hasattr(val2, "__dict__"):
_assert_vals_equal(val1.__dict__, val2.__dict__)
elif isinstance(val1, np.ufunc):
assert val1 == val2
elif val1.__class__.__module__ == "builtins":
assert val1 == val2
else:
_assert_generic_objects_equal(val1, val2)


def assert_params_equal(params1, params2):
# helper function to compare estimator dictionaries of parameters
assert len(params1) == len(params2)
assert set(params1.keys()) == set(params2.keys())
for key in params1:
with warnings.catch_warnings():
# this is to silence the deprecation warning from _DictWithDeprecatedKeys
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
val1, val2 = params1[key], params2[key]
assert type(val1) == type(val2)

if _is_steps_like(val1):
# Deal with Pipeline.steps, FeatureUnion.transformer_list, etc.
assert _is_steps_like(val2)
val1, val2 = dict(val1), dict(val2)

if isinstance(val1, (tuple, list)):
assert len(val1) == len(val2)
for subval1, subval2 in zip(val1, val2):
_assert_vals_equal(subval1, subval2)
elif isinstance(val1, dict):
assert_params_equal(val1, val2)
else:
_assert_vals_equal(val1, val2)


def assert_method_outputs_equal(estimator, loaded, X):
# helper function that checks the output of all supported methods
for method in [
"predict",
"predict_proba",
"decision_function",
"transform",
"predict_log_proba",
]:
err_msg = (
f"{estimator.__class__.__name__}.{method}() doesn't produce the same"
" results after loading the persisted model."
)
if hasattr(estimator, method):
X_out1 = getattr(estimator, method)(X)
X_out2 = getattr(loaded, method)(X)
assert_allclose_dense_sparse(X_out1, X_out2, err_msg=err_msg, atol=ATOL)
Loading