Skip to content

Commit

Permalink
[python] raise an informative error instead of segfaulting when custo…
Browse files Browse the repository at this point in the history
…m objective produces incorrect output (#4815)

* fix for bad grads causing segfault

* adjust checking criteria to properly reflect reality of multi-class classifiers

* fix styling

* Line break before operator

* Update python-package/lightgbm/basic.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update python-package/lightgbm/basic.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* add a note to the C-API docs

* rearrange text s;ightly

* add some tests to python package

* Update include/LightGBM/c_api.h

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* PR comments

* match argument is a regex and our expression has brackets ..

* rework tests

* isorting imports

* updating test to relfect that the python APi does not take pres/labels as a fobj function

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
yaxxie and StrikerRUS authored Dec 30, 2021
1 parent a55ff18 commit af5b40e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
3 changes: 3 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle,
/*!
* \brief Update the model by specifying gradient and Hessian directly
* (this can be used to support customized loss functions).
* \note
* The length of the arrays referenced by ``grad`` and ``hess`` must be equal to
* ``num_class * num_train_data``, this is not verified by the library, the caller must ensure this.
* \param handle Handle of booster
* \param grad The first order derivative (gradient) statistics
* \param hess The second order derivative (Hessian) statistics
Expand Down
10 changes: 9 additions & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3031,7 +3031,15 @@ def __boost(self, grad, hess):
assert grad.flags.c_contiguous
assert hess.flags.c_contiguous
if len(grad) != len(hess):
raise ValueError(f"Lengths of gradient({len(grad)}) and hessian({len(hess)}) don't match")
raise ValueError(f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) don't match")
num_train_data = self.train_set.num_data()
num_models = self.__num_class
if len(grad) != num_train_data * num_models:
raise ValueError(
f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) "
f"don't match training data length ({num_train_data}) * "
f"number of models per one iteration ({num_models})"
)
is_finished = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom(
self.handle,
Expand Down
30 changes: 30 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8
import filecmp
import numbers
import re
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -579,3 +580,32 @@ def test_param_aliases():
assert all(len(i) >= 1 for i in aliases.values())
assert all(k in v for k, v in aliases.items())
assert lgb.basic._ConfigAliases.get('config', 'task') == {'config', 'config_file', 'task', 'task_type'}


def _bad_gradients(preds, _):
return np.random.randn(len(preds) + 1), np.random.rand(len(preds) + 1)


def _good_gradients(preds, _):
return np.random.randn(len(preds)), np.random.rand(len(preds))


def test_custom_objective_safety():
nrows = 100
X = np.random.randn(nrows, 5)
y_binary = np.arange(nrows) % 2
classes = [0, 1, 2]
nclass = len(classes)
y_multiclass = np.arange(nrows) % nclass
ds_binary = lgb.Dataset(X, y_binary).construct()
ds_multiclass = lgb.Dataset(X, y_multiclass).construct()
bad_bst_binary = lgb.Booster({'objective': "none"}, ds_binary)
good_bst_binary = lgb.Booster({'objective': "none"}, ds_binary)
bad_bst_multi = lgb.Booster({'objective': "none", "num_class": nclass}, ds_multiclass)
good_bst_multi = lgb.Booster({'objective': "none", "num_class": nclass}, ds_multiclass)
good_bst_binary.update(fobj=_good_gradients)
with pytest.raises(ValueError, match=re.escape("number of models per one iteration (1)")):
bad_bst_binary.update(fobj=_bad_gradients)
good_bst_multi.update(fobj=_good_gradients)
with pytest.raises(ValueError, match=re.escape(f"number of models per one iteration ({nclass})")):
bad_bst_multi.update(fobj=_bad_gradients)

0 comments on commit af5b40e

Please sign in to comment.