Skip to content

Commit

Permalink
[python] handle arbitrary length feature names in Python-package (#4293)
Browse files Browse the repository at this point in the history
* handle arbitrary length feature names in Python-package

* added tests
  • Loading branch information
StrikerRUS authored May 21, 2021
1 parent 41a1a24 commit 237ac29
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 19 deletions.
65 changes: 47 additions & 18 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,7 +1875,7 @@ def get_feature_name(self):
tmp_out_len = ctypes.c_int(0)
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range(num_feature)]
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_DatasetGetFeatureNames(
self.handle,
Expand All @@ -1886,11 +1886,18 @@ def get_feature_name(self):
ptr_string_buffers))
if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
f"Allocated feature name buffer size ({reserved_string_buffer_size}) was"
f"inferior to the needed size ({required_string_buffer_size.value})."
)
actual_string_buffer_size = required_string_buffer_size.value
# if buffer length is not long enough, reallocate buffers
if reserved_string_buffer_size < actual_string_buffer_size:
string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_DatasetGetFeatureNames(
self.handle,
ctypes.c_int(num_feature),
ctypes.byref(tmp_out_len),
ctypes.c_size_t(actual_string_buffer_size),
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]

def get_label(self):
Expand Down Expand Up @@ -3249,7 +3256,7 @@ def feature_name(self):
tmp_out_len = ctypes.c_int(0)
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range(num_feature)]
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetFeatureNames(
self.handle,
Expand All @@ -3260,9 +3267,18 @@ def feature_name(self):
ptr_string_buffers))
if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
f"Allocated feature name buffer size ({reserved_string_buffer_size}) was inferior to the needed size ({required_string_buffer_size.value}).")
actual_string_buffer_size = required_string_buffer_size.value
# if buffer length is not long enough, reallocate buffers
if reserved_string_buffer_size < actual_string_buffer_size:
string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetFeatureNames(
self.handle,
ctypes.c_int(num_feature),
ctypes.byref(tmp_out_len),
ctypes.c_size_t(actual_string_buffer_size),
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]

def feature_importance(self, importance_type='split', iteration=None):
Expand Down Expand Up @@ -3445,7 +3461,7 @@ def __get_eval_info(self):
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [
ctypes.create_string_buffer(reserved_string_buffer_size) for i in range(self.__num_inner_eval)
ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(self.__num_inner_eval)
]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
Expand All @@ -3457,13 +3473,26 @@ def __get_eval_info(self):
ptr_string_buffers))
if self.__num_inner_eval != tmp_out_len.value:
raise ValueError("Length of eval names doesn't equal with num_evals")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
f"Allocated eval name buffer size ({reserved_string_buffer_size}) was inferior to the needed size ({required_string_vuffer_size.value}).")
self.__name_inner_eval = \
[string_buffers[i].value.decode('utf-8') for i in range(self.__num_inner_eval)]
self.__higher_better_inner_eval = \
[name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval]
actual_string_buffer_size = required_string_buffer_size.value
# if buffer length is not long enough, reallocate buffers
if reserved_string_buffer_size < actual_string_buffer_size:
string_buffers = [
ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(self.__num_inner_eval)
]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
self.handle,
ctypes.c_int(self.__num_inner_eval),
ctypes.byref(tmp_out_len),
ctypes.c_size_t(actual_string_buffer_size),
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
self.__name_inner_eval = [
string_buffers[i].value.decode('utf-8') for i in range(self.__num_inner_eval)
]
self.__higher_better_inner_eval = [
name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval
]

def attr(self, key):
"""Get attribute string from the Booster.
Expand Down
7 changes: 6 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
def test_basic(tmp_path):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True),
test_size=0.1, random_state=2)
train_data = lgb.Dataset(X_train, label=y_train)
feature_names = [f"Column_{i}" for i in range(X_train.shape[1])]
feature_names[1] = "a" * 1000 # set one name to a value longer than default buffer size
train_data = lgb.Dataset(X_train, label=y_train, feature_name=feature_names)
valid_data = train_data.create_valid(X_test, label=y_test)

params = {
Expand All @@ -37,6 +39,8 @@ def test_basic(tmp_path):
if i % 10 == 0:
print(bst.eval_train(), bst.eval_valid())

assert train_data.get_feature_name() == feature_names

assert bst.current_iteration() == 20
assert bst.num_trees() == 20
assert bst.num_model_per_iteration() == 1
Expand All @@ -55,6 +59,7 @@ def test_basic(tmp_path):

# check saved model persistence
bst = lgb.Booster(params, model_file=model_file)
assert bst.feature_name() == feature_names
pred_from_model_file = bst.predict(X_test)
# we need to check the consistency of model file here, so test for exact equal
np.testing.assert_array_equal(pred_from_matr, pred_from_model_file)
Expand Down

0 comments on commit 237ac29

Please sign in to comment.