Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for XGBoost UBJSON in FIL #6009

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,4 @@ dependencies:
- pandas
- *scikit_learn
- seaborn
- xgboost
4 changes: 2 additions & 2 deletions notebooks/forest_inference_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@
" algo='BATCH_TREE_REORG',\n",
" output_class=True,\n",
" threshold=0.50,\n",
" model_type='xgboost'\n",
" model_type='xgboost_ubj'\n",
")"
]
},
Expand Down Expand Up @@ -507,7 +507,7 @@
" algo='BATCH_TREE_REORG',\n",
" output_class=True,\n",
" threshold=0.50,\n",
" model_type='xgboost'\n",
" model_type='xgboost_ubj'\n",
" )"
]
},
Expand Down
18 changes: 13 additions & 5 deletions python/cuml/cuml/experimental/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ class ForestInference(UniversalBase, CMajorInputTagMixin):
only for models trained and double precision and when exact
conformance between results from FIL and the original training
framework is of paramount importance.
model_type : {'xgboost', 'xgboost_json', 'lightgbm',
model_type : {'xgboost_ubj', 'xgboost_json', 'xgboost', 'lightgbm',
'treelite_checkpoint', None }, default=None
The serialization format for the model file. If None, a best-effort
guess will be made based on the file extension.
Expand Down Expand Up @@ -841,18 +841,26 @@ class ForestInference(UniversalBase, CMajorInputTagMixin):
extension = pathlib.Path(path).suffix
if extension == '.json':
model_type = 'xgboost_json'
elif extension == '.ubj':
model_type = 'xgboost_ubj'
elif extension == '.model':
model_type = 'xgboost'
elif extension == '.txt':
model_type = 'lightgbm'
else:
model_type = 'treelite_checkpoint'
if model_type == 'treelite_checkpoint':
if model_type == "treelite_checkpoint":
tl_model = treelite.frontend.Model.deserialize(path)
elif model_type == "xgboost_ubj":
tl_model = treelite.frontend.load_xgboost_model(path, format_choice="ubjson")
elif model_type == "xgboost_json":
tl_model = treelite.frontend.load_xgboost_model(path, format_choice="json")
elif model_type == "xgboost":
tl_model = treelite.frontend.load_xgboost_model_legacy_binary(path)
elif model_type == "lightgbm":
tl_model = treelite.frontend.load_lightgbm_model(path)
else:
tl_model = treelite.frontend.Model.load(
path, model_type
)
raise ValueError(f"Unknown model type: {model_type}")
if default_chunk_size is None:
default_chunk_size = threads_per_tree
return cls(
Expand Down
29 changes: 21 additions & 8 deletions python/cuml/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ cdef extern from "treelite/c_api.h":
size_t nitem
ctypedef void* TreeliteModelHandle
ctypedef void* TreeliteGTILConfigHandle
cdef int TreeliteLoadXGBoostModelUBJSON(const char* filename,
const char* config_json,
TreeliteModelHandle* out) except +
cdef int TreeliteLoadXGBoostModelLegacyBinary(const char* filename,
const char* config_json,
TreeliteModelHandle* out) except +
Expand Down Expand Up @@ -188,7 +191,7 @@ cdef class TreeliteModel():
return model

@classmethod
def from_filename(cls, filename, model_type="xgboost"):
def from_filename(cls, filename, model_type="xgboost_ubj"):
"""
Returns a TreeliteModel object loaded from `filename`

Expand All @@ -198,15 +201,15 @@ cdef class TreeliteModel():
Path to treelite model file to load

model_type : string
Type of model: 'xgboost', 'xgboost_json', or 'lightgbm'
Type of model: 'xgboost_ubj', 'xgboost_json', 'xgboost' or 'lightgbm'
"""
cdef bytes filename_bytes = filename.encode("UTF-8")
cdef bytes config_bytes = b"{}"
cdef TreeliteModelHandle handle
cdef int res
cdef str err_msg
if model_type == "xgboost":
res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle)
if model_type == "xgboost_ubj":
res = TreeliteLoadXGBoostModelUBJSON(filename_bytes, config_bytes, &handle)
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
Expand All @@ -215,6 +218,11 @@ cdef class TreeliteModel():
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
elif model_type == "xgboost":
res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle)
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
elif model_type == "lightgbm":
logger.warn("Treelite currently does not support float64 model"
" parameters. Accuracy may degrade slightly relative"
Expand Down Expand Up @@ -953,7 +961,7 @@ class ForestInference(Base,
n_items=0,
compute_shape_str=False,
precision='native',
model_type="xgboost",
model_type="xgboost_ubj",
handle=None):
"""
Returns a FIL instance containing the forest saved in `filename`
Expand Down Expand Up @@ -1018,9 +1026,14 @@ class ForestInference(Base,
thresholds
- ``'float64'``: always load in float64

model_type : string (default="xgboost")
Format of the saved treelite model to be load.
It can be 'xgboost', 'xgboost_json', 'lightgbm'.
model_type : string (default="xgboost_ubj")
Format of the saved tree model to be load.
It can be one of the following:

- ``'xgboost_ubj'``: XGBoost model, using the UBJSON format (default in XGBoost 2.1+)
- ``'xgboost_json'``: XGBoost model, using the JSON format
- ``'xgboost'``: XGBoost model, using the legacy binary format
- ``'lightgbm'``: LightGBM model

Returns
-------
Expand Down
Loading