Skip to content

Commit

Permalink
[PYTHON] Refactor trainnig API to use callback
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 20, 2016
1 parent 03996dd commit 868c6a9
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 185 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ rcpplint:
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src

lint: rcpplint
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin python-package

clean:
$(RM) -rf build build_plugin lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost
Expand Down
4 changes: 3 additions & 1 deletion python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=unused-import, invalid-name, wrong-import-position
# pylint: disable= invalid-name, unused-import
"""For compatibility"""

from __future__ import absolute_import
Expand All @@ -14,12 +14,14 @@
STRING_TYPES = str,

def py_str(x):
"""convert c string back to python string"""
return x.decode('utf-8')
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,

def py_str(x):
"""convert c string back to python string"""
return x

try:
Expand Down
38 changes: 28 additions & 10 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-branches, too-many-lines, W0141
"""Core XGBoost Library."""
from __future__ import absolute_import

Expand All @@ -21,6 +22,19 @@ class XGBoostError(Exception):
"""Error throwed by xgboost trainer."""
pass

class EarlyStopException(Exception):
"""Exception to signal early stopping."""

# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
"XGBoostCallbackEnv",
["booster",
"iteration",
"begin_iteration",
"end_iteration",
"rank",
"evaluation_result_list"])


def from_pystr_to_cstr(data):
"""Convert a list of Python str to C pointer
Expand Down Expand Up @@ -657,7 +671,7 @@ def __setstate__(self, state):
def __copy__(self):
return self.__deepcopy__(None)

def __deepcopy__(self, memo):
def __deepcopy__(self, _):
return Booster(model_file=self.save_raw())

def copy(self):
Expand Down Expand Up @@ -975,7 +989,6 @@ def load_model(self, fname):
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))

def dump_model(self, fout, fmap='', with_stats=False):
# pylint: disable=consider-using-enumerate
"""
Dump model into a text file.
Expand Down Expand Up @@ -1143,10 +1156,12 @@ def _validate_features(self, data):
msg = 'feature_names mismatch: {0} {1}'

if dat_missing:
msg += '\nexpected ' + ', '.join(str(s) for s in dat_missing) + ' in input data'
msg += ('\nexpected ' + ', '.join(str(s) for s in dat_missing) +
' in input data')

if my_missing:
msg += '\ntraining data did not have the following fields: ' + ', '.join(str(s) for s in my_missing)
msg += ('\ntraining data did not have the following fields: ' +
', '.join(str(s) for s in my_missing))

raise ValueError(msg.format(self.feature_names,
data.feature_names))
Expand All @@ -1161,23 +1176,25 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True)
The name of feature map file.
bin: int, default None
The maximum number of bins.
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
Number of bins equals number of unique split values n_unique,
if bins == None or bins > n_unique.
as_pandas : bool, default True
Return pd.DataFrame when pandas is installed.
If False or pandas is not installed, return numpy ndarray.
Returns
-------
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
a histogram of used splitting values for the specified feature
either as numpy array or pandas DataFrame.
"""
xgdump = self.get_dump(fmap=fmap)
values = []
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
for i in range(len(xgdump)):
m = re.findall(regexp, xgdump[i])
values.extend(map(float, m))

n_unique = np.unique(values).shape[0]
n_unique = len(np.unique(values))
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)

nph = np.histogram(values, bins=bins)
Expand All @@ -1187,7 +1204,8 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True)
if as_pandas and PANDAS_INSTALLED:
return DataFrame(nph, columns=['SplitValue', 'Count'])
elif as_pandas and not PANDAS_INSTALLED:
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
sys.stderr.write(
"Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
return nph
else:
return nph
5 changes: 4 additions & 1 deletion python-package/xgboost/rabit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# coding: utf-8
# pylint: disable= invalid-name

"""Distributed XGBoost Rabit related API."""
from __future__ import absolute_import
import sys
Expand Down Expand Up @@ -179,7 +182,7 @@ def allreduce(data, op, prepare_fun=None):
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)

def pfunc(args):
def pfunc(_):
"""prepare function."""
prepare_fun(data)
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
Expand Down
7 changes: 4 additions & 3 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912
"""Scikit-Learn Wrapper interface for XGBoost."""
from __future__ import absolute_import

Expand Down Expand Up @@ -42,6 +42,7 @@ def _objective_decorator(func):
``dmatrix.get_label()``
"""
def inner(preds, dmatrix):
"""internal function"""
labels = dmatrix.get_label()
return func(labels, preds)
return inner
Expand Down Expand Up @@ -183,7 +184,7 @@ def get_xgb_params(self):

def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init, redefined-variable-type
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
Expand Down Expand Up @@ -351,7 +352,7 @@ def __init__(self, max_depth=3, learning_rate=0.1,

def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
# pylint: disable = attribute-defined-outside-init,arguments-differ, redefined-variable-type
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
Expand Down
Loading

0 comments on commit 868c6a9

Please sign in to comment.