Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Accelerated Failure Time loss for survival analysis task #4763

Merged
merged 84 commits into from
Mar 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
1a0d73d
[WIP] Add lower and upper bounds on the label for survival analysis
hcho3 Jul 9, 2019
41b3015
Update test MetaInfo.SaveLoadBinary to account for extra two fields
hcho3 Jul 9, 2019
063d915
Don't clear qids_ for version 2 of MetaInfo
hcho3 Jul 9, 2019
0ca21bd
Add SetInfo() and GetInfo() method for lower and upper bounds
hcho3 Jul 24, 2019
539edd1
changes to aft
avinashbarnwal Jul 28, 2019
5022b9b
Add parameter class for AFT; use enum's to represent distribution and…
hcho3 Jul 28, 2019
adb1107
Add AFT metric
hcho3 Jul 28, 2019
4652aca
changes to neg grad to grad
avinashbarnwal Jul 29, 2019
31aa14c
changes to binomial loss
avinashbarnwal Aug 2, 2019
8690763
changes to overflow
avinashbarnwal Aug 3, 2019
b648a81
changes to eps
avinashbarnwal Aug 6, 2019
725d309
changes to code refactoring
avinashbarnwal Aug 6, 2019
c026432
changes to code refactoring
avinashbarnwal Aug 6, 2019
ff1679f
changes to code refactoring
avinashbarnwal Aug 6, 2019
1c20533
Re-factor survival analysis
avinashbarnwal Aug 6, 2019
aac363a
Remove aft namespace
hcho3 Aug 7, 2019
a60483f
Move function bodies out of AFTNormal and AFTLogistic, to reduce clutter
hcho3 Aug 7, 2019
7090fe5
Move function bodies out of AFTLoss, to reduce clutter
hcho3 Aug 7, 2019
f774635
Use smart pointer to store AFTDistribution and AFTLoss
hcho3 Aug 7, 2019
3b80118
Rename AFTNoiseDistribution enum to AFTDistributionType for clarity
hcho3 Aug 7, 2019
1305472
Add AFTDistribution::Create() method for convenience
hcho3 Aug 7, 2019
7a0c097
changes to extreme distribution
avinashbarnwal Aug 8, 2019
5620ed9
changes to extreme distribution
avinashbarnwal Aug 8, 2019
2d57a5a
changes to extreme
avinashbarnwal Aug 8, 2019
4ec680f
changes to extreme distribution
avinashbarnwal Aug 8, 2019
72eae3a
changes to left censored
avinashbarnwal Aug 9, 2019
0623be4
deleted cout
avinashbarnwal Aug 9, 2019
b894c41
changes to x,mu and sd and code refactoring
avinashbarnwal Aug 11, 2019
9372ad6
changes to print
avinashbarnwal Aug 11, 2019
31755cf
changes to hessian formula in censored and uncensored
avinashbarnwal Aug 13, 2019
9d689f2
changes to variable names and pow
avinashbarnwal Aug 13, 2019
7ae3c98
changes to Logistic Pdf
avinashbarnwal Aug 13, 2019
69c95c5
changes to parameter
avinashbarnwal Aug 15, 2019
96d8360
Expose lower and upper bound labels to R package
hcho3 Aug 18, 2019
29550db
Use example weights; normalize log likelihood metric
hcho3 Aug 19, 2019
24635bc
changes to CHECK
avinashbarnwal Aug 20, 2019
01cb150
changes to logistic hessian to standard formula
avinashbarnwal Aug 25, 2019
4002399
changes to logistic formula
avinashbarnwal Aug 26, 2019
cf272f5
Comply with coding style guideline
hcho3 Sep 19, 2019
ab1271d
Revert back Rabit submodule
hcho3 Sep 19, 2019
23890f4
Revert dmlc-core submodule
hcho3 Sep 19, 2019
2b5654a
Comply with coding style guideline (clang-tidy)
hcho3 Sep 19, 2019
d59bd92
Fix an error in AFTLoss::Gradient()
hcho3 Sep 19, 2019
f1ee3a7
Add missing files to amalgamation
hcho3 Sep 20, 2019
9ede4b4
Address @RAMitchell's comment: minimize future change in MetaInfo int…
hcho3 Oct 3, 2019
56a0530
Fix lint
hcho3 Oct 3, 2019
bfa838c
Fix compilation error on 32-bit target, when size_t == bst_uint
hcho3 Oct 4, 2019
a668d7a
Allocate sufficient memory to hold extra label info
hcho3 Dec 2, 2019
7460dca
Use OpenMP to speed up
hcho3 Dec 19, 2019
5935a5a
Merge remote-tracking branch 'upstream/master' into survival_analysis1
hcho3 Feb 13, 2020
23bb304
Fix compilation on Windows
hcho3 Feb 13, 2020
6f31c09
Address reviewer's feedback
hcho3 Mar 15, 2020
fe8ff28
Add unit tests for probability distributions
hcho3 Mar 15, 2020
c2abd04
Make Metric subclass of Configurable
hcho3 Mar 15, 2020
f3d16ef
Address reviewer's feedback: Configure() AFT metric
hcho3 Mar 15, 2020
1914e46
Add a dummy test for AFT metric configuration
hcho3 Mar 15, 2020
346241a
Complete AFT configuration test; remove debugging print
hcho3 Mar 17, 2020
c9dd101
Rename AFT parameters
hcho3 Mar 17, 2020
dca593d
Clarify test comment
hcho3 Mar 17, 2020
d2c5c56
Add a dummy test for AFT loss for uncensored case
hcho3 Mar 17, 2020
e33fab1
Fix a bug in AFT loss for uncensored labels
hcho3 Mar 17, 2020
62d7e81
Complete unit test for AFT loss metric
hcho3 Mar 19, 2020
7afed61
Simplify unit tests for AFT metric
hcho3 Mar 19, 2020
fbaed76
Add unit test to verify aggregate output from AFT metric
hcho3 Mar 19, 2020
a5fd368
Use EXPECT_* instead of ASSERT_*, so that we run all unit tests
hcho3 Mar 19, 2020
1f6e9f7
Use aft_loss_param when serializing AFTObj
hcho3 Mar 19, 2020
9155b97
Add unit tests for AFT Objective
hcho3 Mar 20, 2020
b20d81b
Fix OpenMP bug; clarify semantics for shared variables used in OpenMP…
hcho3 Mar 20, 2020
330b92b
Add comments
hcho3 Mar 20, 2020
85938c5
Remove AFT prefix from probability distribution; put probability dist…
hcho3 Mar 20, 2020
156a5d2
Add comments
hcho3 Mar 20, 2020
e1b4d99
Define kPI and kEulerMascheroni in probability_distribution.h
hcho3 Mar 20, 2020
8d471b1
Add probability_distribution.cc to amalgamation
hcho3 Mar 20, 2020
58bbb12
Remove unnecessary diff
hcho3 Mar 20, 2020
67d7f96
Address reviewer's feedback: define variables where they're used
hcho3 Mar 20, 2020
cb768cf
Merge remote-tracking branch 'upstream/master' into survival_analysis1
hcho3 Mar 20, 2020
9236393
Eliminate all INFs and NANs from AFT loss and gradient
hcho3 Mar 20, 2020
4fbbff5
Add demo
hcho3 Mar 20, 2020
4df2656
Add tutorial
hcho3 Mar 21, 2020
895fd07
Fix lint
hcho3 Mar 21, 2020
5659d3d
Use 'survival:aft' to be consistent with 'survival:cox'
hcho3 Mar 25, 2020
e2bfcec
Move sample data to demo/data
hcho3 Mar 25, 2020
aea6496
Add visual demo with 1D toy data
hcho3 Mar 25, 2020
7a9c129
Add Python tests
hcho3 Mar 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
return(TRUE)
}
if (name == "label_lower_bound") {
if (length(info) != nrow(object))
stop("The length of lower-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
return(TRUE)
}
if (name == "label_upper_bound") {
if (length(info) != nrow(object))
stop("The length of upper-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
return(TRUE)
}
if (name == "weight") {
if (length(info) != nrow(object))
stop("The length of weights must equal to the number of rows in the input data")
Expand Down
6 changes: 5 additions & 1 deletion amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
#include "../src/metric/elementwise_metric.cc"
#include "../src/metric/multiclass_metric.cc"
#include "../src/metric/rank_metric.cc"
#include "../src/metric/survival_metric.cc"

// objectives
#include "../src/objective/objective.cc"
#include "../src/objective/regression_obj.cc"
#include "../src/objective/multiclass_obj.cc"
#include "../src/objective/rank_obj.cc"
#include "../src/objective/hinge.cc"
#include "../src/objective/aft_obj.cc"

// gbms
#include "../src/gbm/gbm.cc"
Expand All @@ -44,7 +46,7 @@
#include "../src/data/sparse_page_dmatrix.cc"
#endif

// tress
// trees
#include "../src/tree/param.cc"
#include "../src/tree/split_evaluator.cc"
#include "../src/tree/tree_model.cc"
Expand Down Expand Up @@ -72,6 +74,8 @@
#include "../src/common/hist_util.cc"
#include "../src/common/json.cc"
#include "../src/common/io.cc"
#include "../src/common/survival_util.cc"
#include "../src/common/probability_distribution.cc"
#include "../src/common/version.cc"

// c_api
Expand Down
54 changes: 54 additions & 0 deletions demo/aft_survival/aft_survival_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model
"""
from sklearn.model_selection import ShuffleSplit
import pandas as pd
import numpy as np
import xgboost as xgb

# The Veterans' Administration Lung Cancer Trial
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
df = pd.read_csv('../data/veterans_lung_cancer.csv')
print('Training data:')
print(df)

# Split features and labels
y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound']
X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1)

# Split data into training and validation sets
rs = ShuffleSplit(n_splits=2, test_size=.7, random_state=0)
train_index, valid_index = next(rs.split(X))
dtrain = xgb.DMatrix(X.values[train_index, :])
dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index])
dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index])
dvalid = xgb.DMatrix(X.values[valid_index, :])
dvalid.set_float_info('label_lower_bound', y_lower_bound[valid_index])
dvalid.set_float_info('label_upper_bound', y_upper_bound[valid_index])

# Train gradient boosted trees using AFT loss and metric
params = {'verbosity': 0,
'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'tree_method': 'hist',
'learning_rate': 0.05,
'aft_loss_distribution': 'normal',
'aft_loss_distribution_scale': 1.20,
'max_depth': 6,
'lambda': 0.01,
'alpha': 0.02}
bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50)

# Run prediction on the validation set
df = pd.DataFrame({'Label (lower bound)': y_lower_bound[valid_index],
'Label (upper bound)': y_upper_bound[valid_index],
'Predicted label': bst.predict(dvalid)})
print(df)
# Show only data points with right-censored labels
print(df[np.isinf(df['Label (upper bound)'])])

# Save trained model
bst.save_model('aft_model.json')
78 changes: 78 additions & 0 deletions demo/aft_survival/aft_survival_demo_with_optuna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna
to tune hyperparameters
"""
from sklearn.model_selection import ShuffleSplit
import pandas as pd
import numpy as np
import xgboost as xgb
import optuna

# The Veterans' Administration Lung Cancer Trial
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
df = pd.read_csv('../data/veterans_lung_cancer.csv')
print('Training data:')
print(df)

# Split features and labels
y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound']
X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1)

# Split data into training and validation sets
rs = ShuffleSplit(n_splits=2, test_size=.7, random_state=0)
train_index, valid_index = next(rs.split(X))
dtrain = xgb.DMatrix(X.values[train_index, :])
dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index])
dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index])
dvalid = xgb.DMatrix(X.values[valid_index, :])
dvalid.set_float_info('label_lower_bound', y_lower_bound[valid_index])
dvalid.set_float_info('label_upper_bound', y_upper_bound[valid_index])

# Define hyperparameter search space
base_params = {'verbosity': 0,
'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'tree_method': 'hist'} # Hyperparameters common to all trials
def objective(trial):
params = {'learning_rate': trial.suggest_loguniform('learning_rate', 0.01, 1.0),
'aft_loss_distribution': trial.suggest_categorical('aft_loss_distribution',
['normal', 'logistic', 'extreme']),
'aft_loss_distribution_scale': trial.suggest_loguniform('aft_loss_distribution_scale', 0.1, 10.0),
'max_depth': trial.suggest_int('max_depth', 3, 8),
'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0),
'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0)} # Search space
params.update(base_params)
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik')
bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback])
if bst.best_iteration >= 25:
return bst.best_score
else:
return np.inf # Reject models with < 25 trees

# Run hyperparameter search
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=200)
print('Completed hyperparameter tuning with best aft-nloglik = {}.'.format(study.best_trial.value))
params = {}
params.update(base_params)
params.update(study.best_trial.params)

# Re-run training with the best hyperparameter combination
print('Re-running the best trial... params = {}'.format(params))
bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50)

# Run prediction on the validation set
df = pd.DataFrame({'Label (lower bound)': y_lower_bound[valid_index],
'Label (upper bound)': y_upper_bound[valid_index],
'Predicted label': bst.predict(dvalid)})
print(df)
# Show only data points with right-censored labels
print(df[np.isinf(df['Label (upper bound)'])])

# Save trained model
bst.save_model('aft_best_model.json')
97 changes: 97 additions & 0 deletions demo/aft_survival/aft_survival_viz_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model.

This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble model
starts out as a flat line and evolves into a step function in order to account for all ranged
labels.
"""
import numpy as np
import xgboost as xgb
import matplotlib.pyplot as plt

plt.rcParams.update({'font.size': 13})

# Function to visualize censored labels
def plot_censored_labels(X, y_lower, y_upper):
def replace_inf(x, target_value):
x[np.isinf(x)] = target_value
return x
plt.plot(X, y_lower, 'o', label='y_lower', color='blue')
plt.plot(X, y_upper, 'o', label='y_upper', color='fuchsia')
plt.vlines(X, ymin=replace_inf(y_lower, 0.01), ymax=replace_inf(y_upper, 1000),
label='Range for y', color='gray')

# Toy data
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
INF = np.inf
y_lower = np.array([ 10, 15, -INF, 30, 100])
y_upper = np.array([INF, INF, 20, 50, INF])

# Visualize toy data
plt.figure(figsize=(5, 4))
plot_censored_labels(X, y_lower, y_upper)
plt.ylim((6, 200))
plt.legend(loc='lower right')
plt.title('Toy data')
plt.xlabel('Input feature')
plt.ylabel('Label')
plt.yscale('log')
plt.tight_layout()
plt.show(block=True)

# Will be used to visualize XGBoost model
grid_pts = np.linspace(0.8, 5.2, 1000).reshape((-1, 1))

# Train AFT model using XGBoost
dmat = xgb.DMatrix(X)
dmat.set_float_info('label_lower_bound', y_lower)
dmat.set_float_info('label_upper_bound', y_upper)
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}

accuracy_history = []
def plot_intermediate_model_callback(env):
"""Custom callback to plot intermediate models"""
# Compute y_pred = prediction using the intermediate model, at current boosting iteration
y_pred = env.model.predict(dmat)
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred)
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100)
accuracy_history.append(acc)

# Plot ranged labels as well as predictions by the model
plt.subplot(5, 3, env.iteration + 1)
plot_censored_labels(X, y_lower, y_upper)
y_pred_grid_pts = env.model.predict(xgb.DMatrix(grid_pts))
plt.plot(grid_pts, y_pred_grid_pts, 'r-', label='XGBoost AFT model', linewidth=4)
plt.title('Iteration {}'.format(env.iteration), x=0.5, y=0.8)
plt.xlim((0.8, 5.2))
plt.ylim((1 if np.min(y_pred) < 6 else 6, 200))
plt.yscale('log')

res = {}
plt.figure(figsize=(12,13))
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=res,
callbacks=[plot_intermediate_model_callback])
plt.tight_layout()
plt.legend(loc='lower center', ncol=4,
bbox_to_anchor=(0.5, 0),
bbox_transform=plt.gcf().transFigure)
plt.tight_layout()

# Plot negative log likelihood over boosting iterations
plt.figure(figsize=(8,3))
plt.subplot(1, 2, 1)
plt.plot(res['train']['aft-nloglik'], 'b-o', label='aft-nloglik')
plt.xlabel('# Boosting Iterations')
plt.legend(loc='best')

# Plot "accuracy" over boosting iterations
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred)
plt.subplot(1, 2, 2)
plt.plot(accuracy_history, 'r-o', label='Accuracy (%)')
plt.xlabel('# Boosting Iterations')
plt.legend(loc='best')
plt.tight_layout()

plt.show()
Loading