Skip to content

Commit

Permalink
[python] add parameter object_hook to method dump_model (#4533)
Browse files Browse the repository at this point in the history
* add parameter object_hook to function dump_model (python API)

* eol

* fix syntax

* lint

* better documentation

* Update python-package/lightgbm/basic.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

Co-authored-by: xavier dupré <xavier.dupre@gmail.com>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
3 people authored Aug 23, 2021
1 parent 4db10d8 commit 11d7608
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3342,7 +3342,7 @@ def model_to_string(self, num_iteration=None, start_iteration=0, importance_type
ret += _dump_pandas_categorical(self.pandas_categorical)
return ret

def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split'):
def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split', object_hook=None):
"""Dump Booster to JSON format.
Parameters
Expand All @@ -3357,6 +3357,15 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl
What type of feature importance should be dumped.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
object_hook : callable or None, optional (default=None)
If not None, ``object_hook`` is a function called while parsing the json
string returned by the C API. It may be used to alter the json, to store
specific values while building the json structure. It avoids
walking through the structure again. It saves a significant amount
of time if the number of trees is huge.
Signature is ``def object_hook(node: dict) -> dict``.
None is equivalent to ``lambda node: node``.
See documentation of ``json.loads()`` for further details.
Returns
-------
Expand Down Expand Up @@ -3391,7 +3400,7 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
ret = json.loads(string_buffer.value.decode('utf-8'))
ret = json.loads(string_buffer.value.decode('utf-8'), object_hook=object_hook)
ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical,
default=json_default_with_numpy))
return ret
Expand Down
20 changes: 20 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2846,3 +2846,23 @@ def test_dump_model():
assert "leaf_const" in dumped_model_str
assert "leaf_value" in dumped_model_str
assert "leaf_count" in dumped_model_str


def test_dump_model_hook():

def hook(obj):
if 'leaf_value' in obj:
obj['LV'] = obj['leaf_value']
del obj['leaf_value']
return obj

X, y = load_breast_cancer(return_X_y=True)
train_data = lgb.Dataset(X, label=y)
params = {
"objective": "binary",
"verbose": -1
}
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook))
assert "leaf_value" not in dumped_model_str
assert "LV" in dumped_model_str

0 comments on commit 11d7608

Please sign in to comment.