forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Set output margin to True for custom objective.
- Loading branch information
1 parent
d6d1035
commit c09fe63
Showing
10 changed files
with
257 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
'''Demo for creating customized multi-class objective function. This demo is | ||
only applicable after (excluding) XGBoost 1.0.0, as before this version XGBoost | ||
returns transformed prediction for multi-class objective function. More | ||
details in comments. | ||
''' | ||
|
||
import numpy as np | ||
import xgboost as xgb | ||
from matplotlib import pyplot as plt | ||
import argparse | ||
|
||
np.random.seed(1994) | ||
|
||
kRows = 100 | ||
kCols = 10 | ||
kClasses = 4 # number of classes | ||
|
||
kRounds = 10 # number of boosting rounds. | ||
|
||
# Generate some random data for demo. | ||
X = np.random.randn(kRows, kCols) | ||
y = np.random.randint(0, 4, size=kRows) | ||
|
||
m = xgb.DMatrix(X, y) | ||
|
||
|
||
def softmax(x): | ||
'''Softmax function with x as input vector.''' | ||
e = np.exp(x) | ||
return e / np.sum(e) | ||
|
||
|
||
def softprob_obj(predt: np.ndarray, data: xgb.DMatrix): | ||
'''Loss function. Computing the gradient and approximated hessian (diagonal). | ||
Reimplements the `multi:softprob` inside XGBoost. | ||
''' | ||
labels = data.get_label() | ||
if data.get_weight().size == 0: | ||
# Use 1 as weight if we don't have custom weight. | ||
weights = np.ones((kRows, 1), dtype=float) | ||
else: | ||
weights = data.get_weight() | ||
|
||
# The prediction is of shape (rows, classes), each element in a row | ||
# represents a raw prediction (leaf weight, hasn't gone through softmax | ||
# yet). In XGBoost 1.0.0, the prediction is transformed by a softmax | ||
# function, fixed in later versions. | ||
assert predt.shape == (kRows, kClasses) | ||
|
||
grad = np.zeros((kRows, kClasses), dtype=float) | ||
hess = np.zeros((kRows, kClasses), dtype=float) | ||
|
||
eps = 1e-6 | ||
|
||
# compute the gradient and hessian, slow iterations in Python, only | ||
# suitable for demo. Also the one in native XGBoost core is more robust to | ||
# numeric overflow as we don't do anything to mitigate the `exp` in | ||
# `softmax` here. | ||
for r in range(predt.shape[0]): | ||
target = labels[r] | ||
p = softmax(predt[r, :]) | ||
for c in range(predt.shape[1]): | ||
assert target >= 0 or target <= kClasses | ||
g = p[c] - 1.0 if c == target else p[c] | ||
g = g * weights[r] | ||
h = max((2.0 * p[c] * (1.0 - p[c]) * weights[r]).item(), eps) | ||
grad[r, c] = g | ||
hess[r, c] = h | ||
|
||
# Right now (XGBoost 1.0.0), reshaping is necessary | ||
grad = grad.reshape((kRows * kClasses, 1)) | ||
hess = hess.reshape((kRows * kClasses, 1)) | ||
return grad, hess | ||
|
||
|
||
def predict(booster, X): | ||
'''A customized prediction function that converts raw prediction to | ||
target class. | ||
''' | ||
# Output margin means we want to obtain the raw prediction obtained from | ||
# tree leaf weight. | ||
predt = booster.predict(X, output_margin=True) | ||
out = np.zeros(kRows) | ||
for r in range(predt.shape[0]): | ||
# the class with maximum prob (not strictly prob as it haven't gone | ||
# through softmax yet so it doesn't sum to 1, but result is the same | ||
# for argmax). | ||
i = np.argmax(predt[r]) | ||
out[r] = i | ||
return out | ||
|
||
|
||
def plot_history(custom_results, native_results): | ||
fig, axs = plt.subplots(2, 1) | ||
ax0 = axs[0] | ||
ax1 = axs[1] | ||
|
||
x = np.arange(0, kRounds, 1) | ||
ax0.plot(x, custom_results['train']['merror'], label='Custom objective') | ||
ax0.legend() | ||
ax1.plot(x, native_results['train']['merror'], label='multi:softmax') | ||
ax1.legend() | ||
|
||
plt.show() | ||
|
||
|
||
def main(args): | ||
custom_results = {} | ||
# Use our custom objective function | ||
booster_custom = xgb.train({'num_class': kClasses}, | ||
m, | ||
num_boost_round=kRounds, | ||
obj=softprob_obj, | ||
evals_result=custom_results, | ||
evals=[(m, 'train')]) | ||
|
||
predt_custom = predict(booster_custom, m) | ||
|
||
native_results = {} | ||
# Use the same objective function defined in XGBoost. | ||
booster_native = xgb.train({'num_class': kClasses}, | ||
m, | ||
num_boost_round=kRounds, | ||
evals_result=native_results, | ||
evals=[(m, 'train')]) | ||
predt_native = booster_native.predict(m) | ||
|
||
# We are reimplementing the loss function in XGBoost, so it should | ||
# be the same for normal cases. | ||
assert np.all(predt_custom == predt_native) | ||
|
||
if args.plot != 0: | ||
plot_history(custom_results, native_results) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Arguments for custom softmax objective function demo.') | ||
parser.add_argument( | ||
'--plot', | ||
type=int, | ||
default=1, | ||
help='Set to 0 to disable plotting the evaluation history.') | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
import subprocess | ||
|
||
|
||
CURRENT_DIR = os.path.dirname(__file__) | ||
ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR)) | ||
DEMO_DIR = os.path.join(ROOT_DIR, 'demo', 'guide-python') | ||
|
||
|
||
def test_basic_walkthrough(): | ||
script = os.path.join(DEMO_DIR, 'basic_walkthrough.py') | ||
cmd = ['python', script] | ||
subprocess.check_call(cmd) | ||
os.remove('dump.nice.txt') | ||
os.remove('dump.raw.txt') | ||
|
||
|
||
def test_custom_multiclass_objective(): | ||
script = os.path.join(DEMO_DIR, 'custom_softmax.py') | ||
cmd = ['python', script, '--plot=0'] | ||
subprocess.check_call(cmd) | ||
|
||
|
||
def test_custom_rmsle_objective(): | ||
script = os.path.join(DEMO_DIR, 'custom_rmsle.py') | ||
cmd = ['python', script, '--plot=0'] | ||
subprocess.check_call(cmd) |