Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean codes for python-package; dump model to JSON #97

Merged
merged 2 commits into from
Dec 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ For more details, please refer to [Features](https://github.com/Microsoft/LightG
News
----

12/02/2012 : Release [python-package](https://github.com/Microsoft/LightGBM/tree/master/python-package) beta version, welcome to have a try and provide issues and feedback.
12/02/2016 : Release [python-package](https://github.com/Microsoft/LightGBM/tree/master/python-package) beta version, welcome to have a try and provide issues and feedback.

Get Started
------------
Expand Down
18 changes: 18 additions & 0 deletions examples/python-guide/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Python Package Example
=====================
Here is an example for LightGBM to use python package.

***You should install lightgbm (both c++ and python verion) first.***

For the installation, check the wiki [here](https://github.com/Microsoft/LightGBM/wiki/Installation-Guide).

You also need scikit-learn and pandas to run the examples, but they are not required for the package itself. You can install them with pip:
```
pip install -U scikit-learn
pip install -U pandas
```

Now you can run examples in this folder, for example:
```
python simple_example.py
```
77 changes: 62 additions & 15 deletions examples/python-guide/simple_example.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,64 @@
import numpy as np
import random
# coding: utf-8
# pylint: disable = invalid-name, C0111
import json
import lightgbm as lgb
from sklearn import datasets, metrics, model_selection

rng = np.random.RandomState(2016)

X, y = datasets.make_classification(n_samples=10000, n_features=100)
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1)
lgb_model = lgb.LGBMClassifier(n_estimators=100).fit(x_train, y_train, [(x_test, y_test)], eval_metric="auc")
lgb_model.predict(x_test)
# save model
lgb_model.booster().save_model('model.txt')
# load model
booster = lgb.Booster(model_file='model.txt')
import pandas as pd
from sklearn.metrics import mean_squared_error

# load or create your dataset
df_train = pd.read_csv('../regression/regression.train', header=None, sep='\t')
df_test = pd.read_csv('../regression/regression.test', header=None, sep='\t')

y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)

# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
# or you can simply use a tuple of length=2 here
lgb_train = (X_train, y_train)
lgb_eval = (X_test, y_test)

# specify your configurations as a dict
params = {
'task' : 'train',
'boosting_type' : 'gbdt',
'objective' : 'regression',
'metric' : 'l2',
'num_leaves' : 31,
'learning_rate' : 0.05,
'feature_fraction' : 0.9,
'bagging_fraction' : 0.8,
'bagging_freq': 5,
# 'ndcg_eval_at' : [1, 3, 5, 10],
# this metric is not needed in this task, show as an example
'verbose' : 0
}

# train
gbm = lgb.train(params,
lgb_train,
num_boost_round=100,
valid_datas=lgb_eval,
# you can use a list to represent multiple valid_datas/valid_names
# don't use tuple, tuple is used to represent one dataset
early_stopping_rounds=10)

# save model to file
gbm.save_model('model.txt')

# load model from file
gbm = lgb.Booster(model_file='model.txt')

# predict
print(booster.predict(x_test))
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
# eval
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)

# dump model to json (and save to file)
model_json = gbm.dump_model()

with open('model.json', 'w+') as f:
json.dump(model_json, f, indent=4)
28 changes: 28 additions & 0 deletions examples/python-guide/sklearn_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# coding: utf-8
# pylint: disable = invalid-name, C0111
import lightgbm as lgb
import pandas as pd
from sklearn.metrics import mean_squared_error

# load or create your dataset
df_train = pd.read_csv('../regression/regression.train', header=None, sep='\t')
df_test = pd.read_csv('../regression/regression.test', header=None, sep='\t')

y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)

# train
gbm = lgb.LGBMRegressor(objective='regression',
num_leaves=31,
learning_rate=0.05,
n_estimators=100)
gbm.fit(X_train, y_train,
eval_set=[(X_test, y_test)],
early_stopping_rounds=10)

# predict
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
# eval
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)
13 changes: 10 additions & 3 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,16 @@ class Boosting {
const double* feature_values) const = 0;

/*!
* \brief save model to file
* \param num_iterations Iterations that want to save, -1 means save all
* \param filename filename that want to save to
* \brief Dump model to json format string
* \return Json format string of model
*/
virtual std::string DumpModel() const = 0;

/*!
* \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
*/
virtual void SaveModelToFile(int num_iterations, const char* filename) const = 0;

Expand Down
13 changes: 12 additions & 1 deletion include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,18 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_iteration,
const char* filename);


/*!
* \brief dump model to json
* \param handle handle
* \param buffer_len string buffer length, if buffer_len < out_len, re-allocate buffer
* \param out_len actual output length
* \param out_str json format string of model
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
int buffer_len,
int64_t* out_len,
char** out_str);

// some help functions used to convert data

Expand Down
12 changes: 9 additions & 3 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ class Tree {
}
}

/*! \brief Serialize this object by string*/
/*! \brief Serialize this object to string*/
std::string ToString();

/*! \brief Serialize this object to json*/
std::string ToJSON();

private:
/*!
* \brief Find leaf index of which record belongs by data
Expand All @@ -118,6 +121,9 @@ class Tree {
*/
inline int GetLeaf(const double* feature_values) const;

/*! \brief Serialize one node to json*/
inline std::string NodeToJSON(int index);

/*! \brief Number of max leaves*/
int max_leaves_;
/*! \brief Number of current levas*/
Expand All @@ -137,13 +143,13 @@ class Tree {
std::vector<double> threshold_;
/*! \brief A non-leaf node's split gain */
std::vector<double> split_gain_;
/*! \brief Output of internal nodes(save internal output for per inference feature importance calc) */
std::vector<double> internal_value_;
// used for leaf node
/*! \brief The parent of leaf */
std::vector<int> leaf_parent_;
/*! \brief Output of leaves */
std::vector<double> leaf_value_;
/*! \brief Output of internal nodes(save internal output for per inference feature importance calc) */
std::vector<double> internal_value_;
/*! \brief Depth for leaves */
std::vector<int> leaf_depth_;
};
Expand Down
3 changes: 2 additions & 1 deletion python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@

__all__ = ['Dataset', 'Booster',
'train', 'cv',
'LGBMModel','LGBMRegressor', 'LGBMClassifier', 'LGBMRanker']
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker']

68 changes: 53 additions & 15 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
# coding: utf-8
# pylint: disable = invalid-name, C0111, R0912, R0913, R0914, W0105
"""Wrapper c_api of LightGBM"""
from __future__ import absolute_import

import sys
import os
import ctypes
import tempfile
import json

import numpy as np
import scipy.sparse

from .libpath import find_lib_path

# pandas
try:
from pandas import Series, DataFrame
IS_PANDAS_INSTALLED = True
except ImportError:
IS_PANDAS_INSTALLED = False
class Series(object):
pass
class DataFrame(object):
pass

IS_PY3 = (sys.version_info[0] == 3)

def _load_lib():
Expand Down Expand Up @@ -69,6 +83,8 @@ def list_to_1d_numpy(data, dtype):
return data.astype(dtype=dtype, copy=False)
elif is_1d_list(data):
return np.array(data, dtype=dtype, copy=False)
elif IS_PANDAS_INSTALLED and isinstance(data, Series):
return data.astype(dtype).values
else:
raise TypeError("Unknow type({})".format(type(data).__name__))

Expand Down Expand Up @@ -110,7 +126,7 @@ def param_dict_to_str(data):
elif isinstance(val, (int, float, bool)):
pairs.append(str(key)+'='+str(val))
else:
raise TypeError('unknow type of parameter:%s , got:%s'
raise TypeError('unknow type of parameter:%s , got:%s'
% (key, type(val).__name__))
return ' '.join(pairs)
"""marco definition of data type in c_api of LightGBM"""
Expand Down Expand Up @@ -183,7 +199,7 @@ def __init__(self, model_file=None, booster_handle=None, is_manage_handle=True):
"""Prediction task"""
out_num_iterations = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(model_file),
c_str(model_file),
ctypes.byref(out_num_iterations),
ctypes.byref(self.handle)))
out_num_class = ctypes.c_int64(0)
Expand Down Expand Up @@ -357,7 +373,7 @@ def __pred_for_csr(self, csr, num_iteration, predict_type):
type_ptr_data,
len(csr.indptr),
len(csr.data),
csr.shape[1],
csr.shape[1],
predict_type,
num_iteration,
ctypes.byref(out_num_preds),
Expand All @@ -367,13 +383,6 @@ def __pred_for_csr(self, csr, num_iteration, predict_type):
raise ValueError("incorrect number for predict result")
return preds, nrow

# pandas
try:
from pandas import DataFrame
except ImportError:
class DataFrame(object):
pass

PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int',
'int64': 'int', 'uint8': 'int', 'uint16': 'int',
'uint32': 'int', 'uint64': 'int', 'float16': 'float',
Expand Down Expand Up @@ -467,8 +476,8 @@ def __init__(self, data, label=None, max_bin=255, reference=None,
self.data_has_header = True
self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_DatasetCreateFromFile(
c_str(data),
c_str(params_str),
c_str(data),
c_str(params_str),
ref_dataset,
ctypes.byref(self.handle)))
elif isinstance(data, scipy.sparse.csr_matrix):
Expand Down Expand Up @@ -830,6 +839,7 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
self.__is_manage_handle = True
self.__train_data_name = "training"
self.__attr = {}
self.best_iteration = -1
params = {} if params is None else params
if silent:
params["verbose"] = 0
Expand Down Expand Up @@ -1018,7 +1028,7 @@ def current_iteration(self):
self.handle,
ctypes.byref(out_cur_iter)))
return out_cur_iter.value

def eval(self, data, name, feval=None):
"""Evaluate for data

Expand Down Expand Up @@ -1098,6 +1108,34 @@ def save_model(self, filename, num_iteration=-1):
num_iteration,
c_str(filename)))

def dump_model(self):
"""
Dump model to json format

Returns
-------
Json format of model
"""
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle,
buffer_len,
ctypes.byref(tmp_out_len),
ctypes.byref(ptr_string_buffer)))
actual_len = tmp_out_len.value
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle,
actual_len,
ctypes.byref(tmp_out_len),
ctypes.byref(ptr_string_buffer)))
return json.loads(string_buffer.value.decode())

def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True):
"""
Predict logic
Expand Down Expand Up @@ -1147,7 +1185,7 @@ def __inner_eval(self, data_name, data_idx, feval=None):
_safe_call(_LIB.LGBM_BoosterGetEval(
self.handle,
data_idx,
ctypes.byref(tmp_out_len),
ctypes.byref(tmp_out_len),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_float))))
if tmp_out_len.value != self.__num_inner_eval:
raise ValueError("incorrect number of eval results")
Expand Down Expand Up @@ -1190,7 +1228,7 @@ def __inner_predict(self, data_idx):
ctypes.byref(tmp_out_len),
data_ptr))
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):
raise ValueError("incorrect number of predict results for data %d" % (data_idx) )
raise ValueError("incorrect number of predict results for data %d" % (data_idx))
self.__is_predicted_cur_iter[data_idx] = True
return self.__inner_predict_buffer[data_idx]

Expand Down
Loading