Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reset_learning_rate->reset_parameter #131

Merged
merged 4 commits into from
Dec 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 88 additions & 5 deletions docs/Python-API.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
- [LGBMClassifier](Python-API.md#lgbmclassifier)
- [LGBMRegressor](Python-API.md#lgbmregressor)
- [LGBMRanker](Python-API.md#lgbmranker)

* [Callbacks](Python-API.md#callbacks)
- [Before iteration](Python-API.md#before-iteration)
+ [reset_parameter](Python-API.md#reset_parameterkwargs)
- [After iteration](Python-API.md#after-iteration)
+ [print_evaluation](Python-API.md#print_evaluationperiod1-show_stdvtrue)
+ [record_evaluation](Python-API.md#record_evaluationeval_result)
+ [early_stopping](Python-API.md#early_stoppingstopping_rounds-verbosetrue)

The methods of each Class is in alphabetical order.

Expand Down Expand Up @@ -496,12 +504,10 @@ The methods of each Class is in alphabetical order.
an evaluation metric is printed every 4 (instead of 1) boosting stages.
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates learning_rate in terms of
current number of round (and the total number of boosting round)
(e.g. yields learning rate decay)
or a customized function that calculates learning_rate
in terms of current number of round (e.g. yields learning rate decay)
- list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round, total_boost_round)
or learning_rate = f(current_round)
- function f: learning_rate = f(current_round)
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.

Expand Down Expand Up @@ -805,3 +811,80 @@ The methods of each Class is in alphabetical order.
eval_at : list of int
The evaulation positions of NDCG

## Callbacks

###Before iteration

####reset_parameter(**kwargs)

Reset parameter after first iteration

NOTE: the initial parameter will still take in-effect on first iteration.

Parameters
----------
**kwargs: value should be list or function
List of parameters for each boosting round
or a customized function that calculates learning_rate in terms of
current number of round (e.g. yields learning rate decay)
- list l: parameter = l[current_round]
- function f: parameter = f(current_round)
Returns
-------
callback : function
The requested callback function.

###After iteration

####print_evaluation(period=1, show_stdv=True)

Create a callback that print evaluation result.
(Same function as `verbose_eval` in lightgbm.train())

Parameters
----------
period : int
The period to log the evaluation results

show_stdv : bool, optional
Whether show standard deviation if provided

Returns
-------
callback : function
A callback that prints evaluation every period iterations.

####record_evaluation(eval_result)

Create a call back that records the evaluation history into eval_result.
(Same function as `evals_result` in lightgbm.train())

Parameters
----------
eval_result : dict
A dictionary to store the evaluation results.

Returns
-------
callback : function
The requested callback function.

####early_stopping(stopping_rounds, verbose=True)

Create a callback that activates early stopping.
To activates early stopping, at least one validation data and one metric is required.
If there's more than one, all of them will be checked.
(Same function as `early_stopping_rounds` in lightgbm.train())

Parameters
----------
stopping_rounds : int
The stopping rounds before the trend occur.

verbose : optional, bool
Whether to print message about early stopping information.

Returns
-------
callback : function
The requested callback function.
2 changes: 1 addition & 1 deletion examples/python-guide/sklearn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
# feature importances
print('Feature importances:', list(gbm.feature_importance()))

# other scikit-learn built-in module
# other scikit-learn modules
estimator = lgb.LGBMRegressor(num_leaves=31)

param_grid = {
Expand Down
4 changes: 3 additions & 1 deletion python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .basic import Dataset, Booster
from .engine import train, cv
from .callback import print_evaluation, record_evaluation, reset_parameter, early_stopping
try:
from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier, LGBMRanker
except ImportError:
Expand All @@ -20,5 +21,6 @@

__all__ = ['Dataset', 'Booster',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker']
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping']

41 changes: 16 additions & 25 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# pylint: disable = invalid-name, W0105, C0301
from __future__ import absolute_import
import collections
import inspect

class EarlyStopException(Exception):
"""Exception of early stopping.
Expand Down Expand Up @@ -98,47 +97,39 @@ def callback(env):
return callback


def reset_learning_rate(learning_rates):
"""Reset learning rate after first iteration
def reset_parameter(**kwargs):
"""Reset parameter after first iteration

NOTE: the initial learning rate will still take in-effect on first iteration.
NOTE: the initial parameter will still take in-effect on first iteration.

Parameters
----------
learning_rates: list or function
List of learning rate for each boosting round \
or a customized function that calculates learning_rate in terms of \
current number of round and the total number of boosting round \
(e.g. yields learning rate decay)
- list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round, total_boost_round) \
or learning_rate = f(current_round)
**kwargs: value should be list or function
List of parameters for each boosting round
or a customized function that calculates learning_rate in terms of
current number of round (e.g. yields learning rate decay)
- list l: parameter = l[current_round]
- function f: parameter = f(current_round)
Returns
-------
callback : function
The requested callback function.
"""
def callback(env):
"""internal function"""
if isinstance(learning_rates, list):
if len(learning_rates) != env.end_iteration - env.begin_iteration:
raise ValueError("Length of list 'learning_rates' has to equal to 'num_boost_round'.")
env.model.reset_parameter({'learning_rate':learning_rates[env.iteration]})
else:
argc = len(inspect.getargspec(learning_rates).args)
if argc is 1:
env.model.reset_parameter({"learning_rate": learning_rates(env.iteration - env.begin_iteration)})
elif argc is 2:
env.model.reset_parameter({"learning_rate": \
learning_rates(env.iteration - env.begin_iteration, env.end_iteration - env.begin_iteration)})
for key, value in kwargs.items():
if isinstance(value, list):
if len(value) != env.end_iteration - env.begin_iteration:
raise ValueError("Length of list {} has to equal to 'num_boost_round'.".format(repr(key)))
env.model.reset_parameter({key: value[env.iteration - env.begin_iteration]})
else:
raise ValueError("Self-defined function 'learning_rates' should have 1 or 2 arguments, got %d" %(argc))
env.model.reset_parameter({key: value(env.iteration - env.begin_iteration)})
callback.before_iteration = True
callback.order = 10
return callback


def early_stop(stopping_rounds, verbose=True):
def early_stopping(stopping_rounds, verbose=True):
"""Create a callback that activates early stopping.
Activates early stopping.
Requires at least one validation data and one metric
Expand Down
15 changes: 6 additions & 9 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,10 @@ def train(params, train_set, num_boost_round=100,
an evaluation metric is printed every 4 (instead of 1) boosting stages.
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates learning_rate in terms of
current number of round (and the total number of boosting round)
(e.g. yields learning rate decay)
or a customized function that calculates learning_rate
in terms of current number of round (e.g. yields learning rate decay)
- list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round, total_boost_round)
or learning_rate = f(current_round)
- function f: learning_rate = f(current_round)
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.

Expand Down Expand Up @@ -138,11 +136,10 @@ def train(params, train_set, num_boost_round=100,
callbacks.add(callback.print_evaluation(verbose_eval))

if early_stopping_rounds is not None:
callbacks.add(callback.early_stop(early_stopping_rounds,
verbose=bool(verbose_eval)))
callbacks.add(callback.early_stopping(early_stopping_rounds, verbose=bool(verbose_eval)))

if learning_rates is not None:
callbacks.add(callback.reset_learning_rate(learning_rates))
callbacks.add(callback.reset_parameter(learning_rate=learning_rates))

if evals_result is not None:
callbacks.add(callback.record_evaluation(evals_result))
Expand Down Expand Up @@ -355,7 +352,7 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks)
if early_stopping_rounds is not None:
callbacks.add(callback.early_stop(early_stopping_rounds, verbose=False))
callbacks.add(callback.early_stopping(early_stopping_rounds, verbose=False))
if verbose_eval is True:
callbacks.add(callback.print_evaluation(show_stdv=show_stdv))
elif isinstance(verbose_eval, int):
Expand Down