Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent undefined behavior when passing handle from Treelite to cuML FIL #5849

Merged
merged 9 commits into from
Apr 20, 2024
2 changes: 1 addition & 1 deletion ci/test_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ rapids-logger "pytest cuml single GPU"
./ci/run_cuml_singlegpu_pytests.sh \
--numprocesses=8 \
--dist=worksteal \
-k 'not test_sparse_pca_inputs and not test_fil_skl_classification' \
-k 'not test_sparse_pca_inputs' \
--junitxml="${RAPIDS_TESTS_DIR}/junit-cuml.xml"

# Run test_sparse_pca_inputs separately
Expand Down
29 changes: 28 additions & 1 deletion python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ cdef extern from "treelite/c_api.h":
TreeliteModelHandle* out) except +
cdef int TreeliteSerializeModelToFile(TreeliteModelHandle handle,
const char* filename) except +
cdef int TreeliteDeserializeModelFromBytes(const char* bytes_seq, size_t len,
TreeliteModelHandle* out) except +
cdef int TreeliteGetHeaderField(
TreeliteModelHandle model, const char * name, TreelitePyBufferFrame* out_frame) except +
cdef const char* TreeliteGetLastError()
Expand Down Expand Up @@ -164,6 +166,26 @@ cdef class TreeliteModel():
cdef uintptr_t model_ptr = <uintptr_t>model_handle
TreeliteFreeModel(<TreeliteModelHandle> model_ptr)

@classmethod
def from_treelite_bytes(cls, bytes bytes_seq):
"""
Returns a TreeliteModel object loaded from bytes representing a
serialized Treelite model object.

Parameters
----------
bytes_seq: bytes
bytes representing a serialized Treelite model
"""
cdef TreeliteModelHandle handle
res = TreeliteDeserializeModelFromBytes(bytes_seq, len(bytes_seq), &handle)
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
if res < 0:
err = TreeliteGetLastError()
raise RuntimeError("Failed to load Treelite model from bytes (%s)" % (err))
model = TreeliteModel()
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
model.set_handle(handle)
return model

@classmethod
def from_filename(cls, filename, model_type="xgboost"):
"""
Expand Down Expand Up @@ -882,8 +904,13 @@ class ForestInference(Base,
" parameters. Accuracy may degrade slightly relative to"
" native sklearn invocation.")
tl_model = tl_skl.import_model(skl_model)
# Serialize Treelite model object and de-serialize again,
# to get around C++ ABI incompatibilities (due to different compilers
# being used to build cuML pip wheel vs. Treelite pip wheel)
cdef bytes bytes_seq = tl_model.serialize_bytes()
tl_model2 = TreeliteModel.from_treelite_bytes(bytes_seq)
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
cuml_fm.load_from_treelite_model(
model=tl_model,
model=tl_model2,
output_class=output_class,
threshold=threshold,
algo=algo,
Expand Down
Loading