diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index b96aeba088dd..4c125832d2ea 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -243,6 +243,18 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) { .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) return(TRUE) } + if (name == "label_lower_bound") { + if (length(info) != nrow(object)) + stop("The length of lower-bound labels must equal to the number of rows in the input data") + .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + return(TRUE) + } + if (name == "label_upper_bound") { + if (length(info) != nrow(object)) + stop("The length of upper-bound labels must equal to the number of rows in the input data") + .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + return(TRUE) + } if (name == "weight") { if (length(info) != nrow(object)) stop("The length of weights must equal to the number of rows in the input data") diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index f3885ea9ffbe..5f7e75fd9d7f 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -14,6 +14,7 @@ #include "../src/metric/elementwise_metric.cc" #include "../src/metric/multiclass_metric.cc" #include "../src/metric/rank_metric.cc" +#include "../src/metric/survival_metric.cc" // objectives #include "../src/objective/objective.cc" @@ -21,6 +22,7 @@ #include "../src/objective/multiclass_obj.cc" #include "../src/objective/rank_obj.cc" #include "../src/objective/hinge.cc" +#include "../src/objective/aft_obj.cc" // gbms #include "../src/gbm/gbm.cc" @@ -44,7 +46,7 @@ #include "../src/data/sparse_page_dmatrix.cc" #endif -// tress +// trees #include "../src/tree/param.cc" #include "../src/tree/split_evaluator.cc" #include "../src/tree/tree_model.cc" @@ -72,6 +74,8 @@ #include "../src/common/hist_util.cc" #include "../src/common/json.cc" #include "../src/common/io.cc" +#include "../src/common/survival_util.cc" +#include "../src/common/probability_distribution.cc" #include "../src/common/version.cc" // c_api diff --git a/demo/aft_survival/aft_survival_demo.py b/demo/aft_survival/aft_survival_demo.py new file mode 100644 index 000000000000..6b8181cf1060 --- /dev/null +++ b/demo/aft_survival/aft_survival_demo.py @@ -0,0 +1,54 @@ +""" +Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model +""" +from sklearn.model_selection import ShuffleSplit +import pandas as pd +import numpy as np +import xgboost as xgb + +# The Veterans' Administration Lung Cancer Trial +# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980) +df = pd.read_csv('../data/veterans_lung_cancer.csv') +print('Training data:') +print(df) + +# Split features and labels +y_lower_bound = df['Survival_label_lower_bound'] +y_upper_bound = df['Survival_label_upper_bound'] +X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1) + +# Split data into training and validation sets +rs = ShuffleSplit(n_splits=2, test_size=.7, random_state=0) +train_index, valid_index = next(rs.split(X)) +dtrain = xgb.DMatrix(X.values[train_index, :]) +dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index]) +dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index]) +dvalid = xgb.DMatrix(X.values[valid_index, :]) +dvalid.set_float_info('label_lower_bound', y_lower_bound[valid_index]) +dvalid.set_float_info('label_upper_bound', y_upper_bound[valid_index]) + +# Train gradient boosted trees using AFT loss and metric +params = {'verbosity': 0, + 'objective': 'survival:aft', + 'eval_metric': 'aft-nloglik', + 'tree_method': 'hist', + 'learning_rate': 0.05, + 'aft_loss_distribution': 'normal', + 'aft_loss_distribution_scale': 1.20, + 'max_depth': 6, + 'lambda': 0.01, + 'alpha': 0.02} +bst = xgb.train(params, dtrain, num_boost_round=10000, + evals=[(dtrain, 'train'), (dvalid, 'valid')], + early_stopping_rounds=50) + +# Run prediction on the validation set +df = pd.DataFrame({'Label (lower bound)': y_lower_bound[valid_index], + 'Label (upper bound)': y_upper_bound[valid_index], + 'Predicted label': bst.predict(dvalid)}) +print(df) +# Show only data points with right-censored labels +print(df[np.isinf(df['Label (upper bound)'])]) + +# Save trained model +bst.save_model('aft_model.json') \ No newline at end of file diff --git a/demo/aft_survival/aft_survival_demo_with_optuna.py b/demo/aft_survival/aft_survival_demo_with_optuna.py new file mode 100644 index 000000000000..998afc4816b6 --- /dev/null +++ b/demo/aft_survival/aft_survival_demo_with_optuna.py @@ -0,0 +1,78 @@ +""" +Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna +to tune hyperparameters +""" +from sklearn.model_selection import ShuffleSplit +import pandas as pd +import numpy as np +import xgboost as xgb +import optuna + +# The Veterans' Administration Lung Cancer Trial +# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980) +df = pd.read_csv('../data/veterans_lung_cancer.csv') +print('Training data:') +print(df) + +# Split features and labels +y_lower_bound = df['Survival_label_lower_bound'] +y_upper_bound = df['Survival_label_upper_bound'] +X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1) + +# Split data into training and validation sets +rs = ShuffleSplit(n_splits=2, test_size=.7, random_state=0) +train_index, valid_index = next(rs.split(X)) +dtrain = xgb.DMatrix(X.values[train_index, :]) +dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index]) +dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index]) +dvalid = xgb.DMatrix(X.values[valid_index, :]) +dvalid.set_float_info('label_lower_bound', y_lower_bound[valid_index]) +dvalid.set_float_info('label_upper_bound', y_upper_bound[valid_index]) + +# Define hyperparameter search space +base_params = {'verbosity': 0, + 'objective': 'survival:aft', + 'eval_metric': 'aft-nloglik', + 'tree_method': 'hist'} # Hyperparameters common to all trials +def objective(trial): + params = {'learning_rate': trial.suggest_loguniform('learning_rate', 0.01, 1.0), + 'aft_loss_distribution': trial.suggest_categorical('aft_loss_distribution', + ['normal', 'logistic', 'extreme']), + 'aft_loss_distribution_scale': trial.suggest_loguniform('aft_loss_distribution_scale', 0.1, 10.0), + 'max_depth': trial.suggest_int('max_depth', 3, 8), + 'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0), + 'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0)} # Search space + params.update(base_params) + pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik') + bst = xgb.train(params, dtrain, num_boost_round=10000, + evals=[(dtrain, 'train'), (dvalid, 'valid')], + early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback]) + if bst.best_iteration >= 25: + return bst.best_score + else: + return np.inf # Reject models with < 25 trees + +# Run hyperparameter search +study = optuna.create_study(direction='minimize') +study.optimize(objective, n_trials=200) +print('Completed hyperparameter tuning with best aft-nloglik = {}.'.format(study.best_trial.value)) +params = {} +params.update(base_params) +params.update(study.best_trial.params) + +# Re-run training with the best hyperparameter combination +print('Re-running the best trial... params = {}'.format(params)) +bst = xgb.train(params, dtrain, num_boost_round=10000, + evals=[(dtrain, 'train'), (dvalid, 'valid')], + early_stopping_rounds=50) + +# Run prediction on the validation set +df = pd.DataFrame({'Label (lower bound)': y_lower_bound[valid_index], + 'Label (upper bound)': y_upper_bound[valid_index], + 'Predicted label': bst.predict(dvalid)}) +print(df) +# Show only data points with right-censored labels +print(df[np.isinf(df['Label (upper bound)'])]) + +# Save trained model +bst.save_model('aft_best_model.json') \ No newline at end of file diff --git a/demo/aft_survival/aft_survival_viz_demo.py b/demo/aft_survival/aft_survival_viz_demo.py new file mode 100644 index 000000000000..fe622f9e23ec --- /dev/null +++ b/demo/aft_survival/aft_survival_viz_demo.py @@ -0,0 +1,97 @@ +""" +Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model. + +This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble model +starts out as a flat line and evolves into a step function in order to account for all ranged +labels. +""" +import numpy as np +import xgboost as xgb +import matplotlib.pyplot as plt + +plt.rcParams.update({'font.size': 13}) + +# Function to visualize censored labels +def plot_censored_labels(X, y_lower, y_upper): + def replace_inf(x, target_value): + x[np.isinf(x)] = target_value + return x + plt.plot(X, y_lower, 'o', label='y_lower', color='blue') + plt.plot(X, y_upper, 'o', label='y_upper', color='fuchsia') + plt.vlines(X, ymin=replace_inf(y_lower, 0.01), ymax=replace_inf(y_upper, 1000), + label='Range for y', color='gray') + +# Toy data +X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1)) +INF = np.inf +y_lower = np.array([ 10, 15, -INF, 30, 100]) +y_upper = np.array([INF, INF, 20, 50, INF]) + +# Visualize toy data +plt.figure(figsize=(5, 4)) +plot_censored_labels(X, y_lower, y_upper) +plt.ylim((6, 200)) +plt.legend(loc='lower right') +plt.title('Toy data') +plt.xlabel('Input feature') +plt.ylabel('Label') +plt.yscale('log') +plt.tight_layout() +plt.show(block=True) + +# Will be used to visualize XGBoost model +grid_pts = np.linspace(0.8, 5.2, 1000).reshape((-1, 1)) + +# Train AFT model using XGBoost +dmat = xgb.DMatrix(X) +dmat.set_float_info('label_lower_bound', y_lower) +dmat.set_float_info('label_upper_bound', y_upper) +params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0} + +accuracy_history = [] +def plot_intermediate_model_callback(env): + """Custom callback to plot intermediate models""" + # Compute y_pred = prediction using the intermediate model, at current boosting iteration + y_pred = env.model.predict(dmat) + # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes + # the corresponding predicted label (y_pred) + acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100) + accuracy_history.append(acc) + + # Plot ranged labels as well as predictions by the model + plt.subplot(5, 3, env.iteration + 1) + plot_censored_labels(X, y_lower, y_upper) + y_pred_grid_pts = env.model.predict(xgb.DMatrix(grid_pts)) + plt.plot(grid_pts, y_pred_grid_pts, 'r-', label='XGBoost AFT model', linewidth=4) + plt.title('Iteration {}'.format(env.iteration), x=0.5, y=0.8) + plt.xlim((0.8, 5.2)) + plt.ylim((1 if np.min(y_pred) < 6 else 6, 200)) + plt.yscale('log') + +res = {} +plt.figure(figsize=(12,13)) +bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=res, + callbacks=[plot_intermediate_model_callback]) +plt.tight_layout() +plt.legend(loc='lower center', ncol=4, + bbox_to_anchor=(0.5, 0), + bbox_transform=plt.gcf().transFigure) +plt.tight_layout() + +# Plot negative log likelihood over boosting iterations +plt.figure(figsize=(8,3)) +plt.subplot(1, 2, 1) +plt.plot(res['train']['aft-nloglik'], 'b-o', label='aft-nloglik') +plt.xlabel('# Boosting Iterations') +plt.legend(loc='best') + +# Plot "accuracy" over boosting iterations +# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes +# the corresponding predicted label (y_pred) +plt.subplot(1, 2, 2) +plt.plot(accuracy_history, 'r-o', label='Accuracy (%)') +plt.xlabel('# Boosting Iterations') +plt.legend(loc='best') +plt.tight_layout() + +plt.show() diff --git a/demo/data/veterans_lung_cancer.csv b/demo/data/veterans_lung_cancer.csv new file mode 100644 index 000000000000..24466b579157 --- /dev/null +++ b/demo/data/veterans_lung_cancer.csv @@ -0,0 +1,138 @@ +Survival_label_lower_bound,Survival_label_upper_bound,Age_in_years,Karnofsky_score,Months_from_Diagnosis,Celltype=adeno,Celltype=large,Celltype=smallcell,Celltype=squamous,Prior_therapy=no,Prior_therapy=yes,Treatment=standard,Treatment=test +72.0,72.0,69.0,60.0,7.0,0,0,0,1,1,0,1,0 +411.0,411.0,64.0,70.0,5.0,0,0,0,1,0,1,1,0 +228.0,228.0,38.0,60.0,3.0,0,0,0,1,1,0,1,0 +126.0,126.0,63.0,60.0,9.0,0,0,0,1,0,1,1,0 +118.0,118.0,65.0,70.0,11.0,0,0,0,1,0,1,1,0 +10.0,10.0,49.0,20.0,5.0,0,0,0,1,1,0,1,0 +82.0,82.0,69.0,40.0,10.0,0,0,0,1,0,1,1,0 +110.0,110.0,68.0,80.0,29.0,0,0,0,1,1,0,1,0 +314.0,314.0,43.0,50.0,18.0,0,0,0,1,1,0,1,0 +100.0,inf,70.0,70.0,6.0,0,0,0,1,1,0,1,0 +42.0,42.0,81.0,60.0,4.0,0,0,0,1,1,0,1,0 +8.0,8.0,63.0,40.0,58.0,0,0,0,1,0,1,1,0 +144.0,144.0,63.0,30.0,4.0,0,0,0,1,1,0,1,0 +25.0,inf,52.0,80.0,9.0,0,0,0,1,0,1,1,0 +11.0,11.0,48.0,70.0,11.0,0,0,0,1,0,1,1,0 +30.0,30.0,61.0,60.0,3.0,0,0,1,0,1,0,1,0 +384.0,384.0,42.0,60.0,9.0,0,0,1,0,1,0,1,0 +4.0,4.0,35.0,40.0,2.0,0,0,1,0,1,0,1,0 +54.0,54.0,63.0,80.0,4.0,0,0,1,0,0,1,1,0 +13.0,13.0,56.0,60.0,4.0,0,0,1,0,1,0,1,0 +123.0,inf,55.0,40.0,3.0,0,0,1,0,1,0,1,0 +97.0,inf,67.0,60.0,5.0,0,0,1,0,1,0,1,0 +153.0,153.0,63.0,60.0,14.0,0,0,1,0,0,1,1,0 +59.0,59.0,65.0,30.0,2.0,0,0,1,0,1,0,1,0 +117.0,117.0,46.0,80.0,3.0,0,0,1,0,1,0,1,0 +16.0,16.0,53.0,30.0,4.0,0,0,1,0,0,1,1,0 +151.0,151.0,69.0,50.0,12.0,0,0,1,0,1,0,1,0 +22.0,22.0,68.0,60.0,4.0,0,0,1,0,1,0,1,0 +56.0,56.0,43.0,80.0,12.0,0,0,1,0,0,1,1,0 +21.0,21.0,55.0,40.0,2.0,0,0,1,0,0,1,1,0 +18.0,18.0,42.0,20.0,15.0,0,0,1,0,1,0,1,0 +139.0,139.0,64.0,80.0,2.0,0,0,1,0,1,0,1,0 +20.0,20.0,65.0,30.0,5.0,0,0,1,0,1,0,1,0 +31.0,31.0,65.0,75.0,3.0,0,0,1,0,1,0,1,0 +52.0,52.0,55.0,70.0,2.0,0,0,1,0,1,0,1,0 +287.0,287.0,66.0,60.0,25.0,0,0,1,0,0,1,1,0 +18.0,18.0,60.0,30.0,4.0,0,0,1,0,1,0,1,0 +51.0,51.0,67.0,60.0,1.0,0,0,1,0,1,0,1,0 +122.0,122.0,53.0,80.0,28.0,0,0,1,0,1,0,1,0 +27.0,27.0,62.0,60.0,8.0,0,0,1,0,1,0,1,0 +54.0,54.0,67.0,70.0,1.0,0,0,1,0,1,0,1,0 +7.0,7.0,72.0,50.0,7.0,0,0,1,0,1,0,1,0 +63.0,63.0,48.0,50.0,11.0,0,0,1,0,1,0,1,0 +392.0,392.0,68.0,40.0,4.0,0,0,1,0,1,0,1,0 +10.0,10.0,67.0,40.0,23.0,0,0,1,0,0,1,1,0 +8.0,8.0,61.0,20.0,19.0,1,0,0,0,0,1,1,0 +92.0,92.0,60.0,70.0,10.0,1,0,0,0,1,0,1,0 +35.0,35.0,62.0,40.0,6.0,1,0,0,0,1,0,1,0 +117.0,117.0,38.0,80.0,2.0,1,0,0,0,1,0,1,0 +132.0,132.0,50.0,80.0,5.0,1,0,0,0,1,0,1,0 +12.0,12.0,63.0,50.0,4.0,1,0,0,0,0,1,1,0 +162.0,162.0,64.0,80.0,5.0,1,0,0,0,1,0,1,0 +3.0,3.0,43.0,30.0,3.0,1,0,0,0,1,0,1,0 +95.0,95.0,34.0,80.0,4.0,1,0,0,0,1,0,1,0 +177.0,177.0,66.0,50.0,16.0,0,1,0,0,0,1,1,0 +162.0,162.0,62.0,80.0,5.0,0,1,0,0,1,0,1,0 +216.0,216.0,52.0,50.0,15.0,0,1,0,0,1,0,1,0 +553.0,553.0,47.0,70.0,2.0,0,1,0,0,1,0,1,0 +278.0,278.0,63.0,60.0,12.0,0,1,0,0,1,0,1,0 +12.0,12.0,68.0,40.0,12.0,0,1,0,0,0,1,1,0 +260.0,260.0,45.0,80.0,5.0,0,1,0,0,1,0,1,0 +200.0,200.0,41.0,80.0,12.0,0,1,0,0,0,1,1,0 +156.0,156.0,66.0,70.0,2.0,0,1,0,0,1,0,1,0 +182.0,inf,62.0,90.0,2.0,0,1,0,0,1,0,1,0 +143.0,143.0,60.0,90.0,8.0,0,1,0,0,1,0,1,0 +105.0,105.0,66.0,80.0,11.0,0,1,0,0,1,0,1,0 +103.0,103.0,38.0,80.0,5.0,0,1,0,0,1,0,1,0 +250.0,250.0,53.0,70.0,8.0,0,1,0,0,0,1,1,0 +100.0,100.0,37.0,60.0,13.0,0,1,0,0,0,1,1,0 +999.0,999.0,54.0,90.0,12.0,0,0,0,1,0,1,0,1 +112.0,112.0,60.0,80.0,6.0,0,0,0,1,1,0,0,1 +87.0,inf,48.0,80.0,3.0,0,0,0,1,1,0,0,1 +231.0,inf,52.0,50.0,8.0,0,0,0,1,0,1,0,1 +242.0,242.0,70.0,50.0,1.0,0,0,0,1,1,0,0,1 +991.0,991.0,50.0,70.0,7.0,0,0,0,1,0,1,0,1 +111.0,111.0,62.0,70.0,3.0,0,0,0,1,1,0,0,1 +1.0,1.0,65.0,20.0,21.0,0,0,0,1,0,1,0,1 +587.0,587.0,58.0,60.0,3.0,0,0,0,1,1,0,0,1 +389.0,389.0,62.0,90.0,2.0,0,0,0,1,1,0,0,1 +33.0,33.0,64.0,30.0,6.0,0,0,0,1,1,0,0,1 +25.0,25.0,63.0,20.0,36.0,0,0,0,1,1,0,0,1 +357.0,357.0,58.0,70.0,13.0,0,0,0,1,1,0,0,1 +467.0,467.0,64.0,90.0,2.0,0,0,0,1,1,0,0,1 +201.0,201.0,52.0,80.0,28.0,0,0,0,1,0,1,0,1 +1.0,1.0,35.0,50.0,7.0,0,0,0,1,1,0,0,1 +30.0,30.0,63.0,70.0,11.0,0,0,0,1,1,0,0,1 +44.0,44.0,70.0,60.0,13.0,0,0,0,1,0,1,0,1 +283.0,283.0,51.0,90.0,2.0,0,0,0,1,1,0,0,1 +15.0,15.0,40.0,50.0,13.0,0,0,0,1,0,1,0,1 +25.0,25.0,69.0,30.0,2.0,0,0,1,0,1,0,0,1 +103.0,inf,36.0,70.0,22.0,0,0,1,0,0,1,0,1 +21.0,21.0,71.0,20.0,4.0,0,0,1,0,1,0,0,1 +13.0,13.0,62.0,30.0,2.0,0,0,1,0,1,0,0,1 +87.0,87.0,60.0,60.0,2.0,0,0,1,0,1,0,0,1 +2.0,2.0,44.0,40.0,36.0,0,0,1,0,0,1,0,1 +20.0,20.0,54.0,30.0,9.0,0,0,1,0,0,1,0,1 +7.0,7.0,66.0,20.0,11.0,0,0,1,0,1,0,0,1 +24.0,24.0,49.0,60.0,8.0,0,0,1,0,1,0,0,1 +99.0,99.0,72.0,70.0,3.0,0,0,1,0,1,0,0,1 +8.0,8.0,68.0,80.0,2.0,0,0,1,0,1,0,0,1 +99.0,99.0,62.0,85.0,4.0,0,0,1,0,1,0,0,1 +61.0,61.0,71.0,70.0,2.0,0,0,1,0,1,0,0,1 +25.0,25.0,70.0,70.0,2.0,0,0,1,0,1,0,0,1 +95.0,95.0,61.0,70.0,1.0,0,0,1,0,1,0,0,1 +80.0,80.0,71.0,50.0,17.0,0,0,1,0,1,0,0,1 +51.0,51.0,59.0,30.0,87.0,0,0,1,0,0,1,0,1 +29.0,29.0,67.0,40.0,8.0,0,0,1,0,1,0,0,1 +24.0,24.0,60.0,40.0,2.0,1,0,0,0,1,0,0,1 +18.0,18.0,69.0,40.0,5.0,1,0,0,0,0,1,0,1 +83.0,inf,57.0,99.0,3.0,1,0,0,0,1,0,0,1 +31.0,31.0,39.0,80.0,3.0,1,0,0,0,1,0,0,1 +51.0,51.0,62.0,60.0,5.0,1,0,0,0,1,0,0,1 +90.0,90.0,50.0,60.0,22.0,1,0,0,0,0,1,0,1 +52.0,52.0,43.0,60.0,3.0,1,0,0,0,1,0,0,1 +73.0,73.0,70.0,60.0,3.0,1,0,0,0,1,0,0,1 +8.0,8.0,66.0,50.0,5.0,1,0,0,0,1,0,0,1 +36.0,36.0,61.0,70.0,8.0,1,0,0,0,1,0,0,1 +48.0,48.0,81.0,10.0,4.0,1,0,0,0,1,0,0,1 +7.0,7.0,58.0,40.0,4.0,1,0,0,0,1,0,0,1 +140.0,140.0,63.0,70.0,3.0,1,0,0,0,1,0,0,1 +186.0,186.0,60.0,90.0,3.0,1,0,0,0,1,0,0,1 +84.0,84.0,62.0,80.0,4.0,1,0,0,0,0,1,0,1 +19.0,19.0,42.0,50.0,10.0,1,0,0,0,1,0,0,1 +45.0,45.0,69.0,40.0,3.0,1,0,0,0,1,0,0,1 +80.0,80.0,63.0,40.0,4.0,1,0,0,0,1,0,0,1 +52.0,52.0,45.0,60.0,4.0,0,1,0,0,1,0,0,1 +164.0,164.0,68.0,70.0,15.0,0,1,0,0,0,1,0,1 +19.0,19.0,39.0,30.0,4.0,0,1,0,0,0,1,0,1 +53.0,53.0,66.0,60.0,12.0,0,1,0,0,1,0,0,1 +15.0,15.0,63.0,30.0,5.0,0,1,0,0,1,0,0,1 +43.0,43.0,49.0,60.0,11.0,0,1,0,0,0,1,0,1 +340.0,340.0,64.0,80.0,10.0,0,1,0,0,0,1,0,1 +133.0,133.0,65.0,75.0,1.0,0,1,0,0,1,0,0,1 +111.0,111.0,64.0,60.0,5.0,0,1,0,0,1,0,0,1 +231.0,231.0,67.0,70.0,18.0,0,1,0,0,0,1,0,1 +378.0,378.0,65.0,80.0,4.0,0,1,0,0,1,0,0,1 +49.0,49.0,37.0,30.0,3.0,0,1,0,0,1,0,0,1 diff --git a/doc/tutorials/aft_survival_analysis.rst b/doc/tutorials/aft_survival_analysis.rst new file mode 100644 index 000000000000..4f06ce54c331 --- /dev/null +++ b/doc/tutorials/aft_survival_analysis.rst @@ -0,0 +1,135 @@ +############################################### +Survival Analysis with Accelerated Failure Time +############################################### + +.. contents:: + :local: + :backlinks: none + +************************** +What is survival analysis? +************************** + +**Survival analysis (regression)** models **time to an event of interest**. Survival analysis is a special kind of regression and differs from the conventional regression task as follows: + +* The label is always positive, since you cannot wait a negative amount of time until the event occurs. +* The label may not be fully known, or **censored**, because "it takes time to measure time." + +The second bullet point is crucial and we should dwell on it more. As you may have guessed from the name, one of the earliest applications of survival analysis is to model mortality of a given population. Let's take `NCCTG Lung Cancer Dataset `_ as an example. The first 8 columns represent features and the last column, Time to death, represents the label. + +==== === === ======= ======== ========= ======== ======= ======================== +Inst Age Sex ph.ecog ph.karno pat.karno meal.cal wt.loss **Time to death (days)** +==== === === ======= ======== ========= ======== ======= ======================== +3 74 1 1 90 100 1175 N/A 306 +3 68 1 0 90 90 1225 15 455 +3 56 1 0 90 90 N/A 15 :math:`[1010, +\infty)` +5 57 1 1 90 60 1150 11 210 +1 60 1 0 100 90 N/A 0 883 +12 74 1 1 50 80 513 0 :math:`[1022, +\infty)` +7 68 2 2 70 60 384 10 310 +==== === === ======= ======== ========= ======== ======= ======================== + +Take a close look at the label for the third patient. **His label is a range, not a single number.** The third patient's label is said to be **censored**, because for some reason the experimenters could not get a complete measurement for that label. One possible scenario: the patient survived the first 1010 days and walked out of the clinic on the 1011th day, so his death was not directly observed. Another possibility: The experiment was cut short (since you cannot run it forever) before his death could be observed. In any case, his label is :math:`[1010, +\infty)`, meaning his time to death can be any number that's higher than 1010, e.g. 2000, 3000, or 10000. + +There are four kinds of censoring: + +* **Uncensored**: the label is not censored and given as a single number. +* **Right-censored**: the label is of form :math:`[a, +\infty)`, where :math:`a` is the lower bound. +* **Left-censored**: the label is of form :math:`(-\infty, b]`, where :math:`b` is the upper bound. +* **Interval-censored**: the label is of form :math:`[a, b]`, where :math:`a` and :math:`b` are the lower and upper bounds, respectively. + +Right-censoring is the most commonly used. + +****************************** +Accelerated Failure Time model +****************************** +**Accelerated Failure Time (AFT)** model is one of the most commonly used models in survival analysis. The model is of the following form: + +.. math:: + + \ln{Y} = \langle \mathbf{w}, \mathbf{x} \rangle + \sigma Z + +where + +* :math:`\mathbf{x}` is a vector in :math:`\mathbb{R}^d` representing the features. +* :math:`\mathbf{w}` is a vector consisting of :math:`d` coefficients, each corresponding to a feature. +* :math:`\langle \cdot, \cdot \rangle` is the usual dot product in :math:`\mathbb{R}^d`. +* :math:`\ln{(\cdot)}` is the natural logarithm. +* :math:`Y` and :math:`Z` are random variables. + + - :math:`Y` is the output label. + - :math:`Z` is a random variable of a known probability distribution. Common choices are the normal distribution, the logistic distribution, and the extreme distribution. Intuitively, :math:`Z` represents the "noise" that pulls the prediction :math:`\langle \mathbf{w}, \mathbf{x} \rangle` away from the true log label :math:`\ln{Y}`. + +* :math:`\sigma` is a parameter that scales the size of :math:`Z`. + +Note that this model is a generalized form of a linear regression model :math:`Y = \langle \mathbf{w}, \mathbf{x} \rangle`. In order to make AFT work with gradient boosting, we revise the model as follows: + +.. math:: + + \ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z + +where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathbf{x}`. + +********** +How to use +********** +The first step is to express the labels in the form of a range, so that **every data point has two numbers associated with it, namely the lower and upper bounds for the label.** For uncensored labels, use a degenerate interval of form :math:`[a, a]`. + +.. |tick| unicode:: U+2714 +.. |cross| unicode:: U+2718 + +================= ==================== =================== =================== +Censoring type Interval form Lower bound finite? Upper bound finite? +================= ==================== =================== =================== +Uncensored :math:`[a, a]` |tick| |tick| +Right-censored :math:`[a, +\infty)` |tick| |cross| +Left-censored :math:`(-\infty, b]` |cross| |tick| +Interval-censored :math:`[a, b]` |tick| |tick| +================= ==================== =================== =================== + +Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) and the upper bound number in another array (call it ``y_upper_bound``). The ranged labels are associated with a data matrix object via calls to :meth:`xgboost.DMatrix.set_float_info`: + +.. code-block:: python + + import numpy as np + import xgboost as xgb + + # 4-by-2 Data matrix + X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]]) + dtrain = xgb.DMatrix(X) + + # Associate ranged labels with the data matrix. + # This example shows each kind of censored labels. + # uncensored right left interval + y_lower_bound = np.array([ 2.0, 3.0, -np.inf, 4.0]) + y_upper_bound = np.array([ 2.0, +np.inf, 4.0, 5.0]) + dtrain.set_float_info('label_lower_bound', y_lower_bound) + dtrain.set_float_info('label_upper_bound', y_upper_bound) + +Now we are ready to invoke the training API: + +.. code-block:: python + + params = {'objective': 'survival:aft', + 'eval_metric': 'aft-nloglik', + 'aft_loss_distribution': 'normal', + 'aft_loss_distribution_scale': 1.20, + 'tree_method': 'hist', 'learning_rate': 0.05, 'max_depth': 2} + bst = xgb.train(params, dtrain, num_boost_round=5, + evals=[(dtrain, 'train'), (dvalid, 'valid')]) + +We set ``objective`` parameter to ``survival:aft`` and ``eval_metric`` to ``aft-nloglik``, so that the log likelihood for the AFT model would be maximized. (XGBoost will actually minimize the negative log likelihood, hence the name ``aft-nloglik``.) + +The parameter ``aft_loss_distribution`` corresponds to the distribution of the :math:`Z` term in the AFT model, and ``aft_loss_distribution_scale`` corresponds to the scaling factor :math:`\sigma`. + +Currently, you can choose from three probability distributions for ``aft_loss_distribution``: + +========================= =========================================== +``aft_loss_distribution`` Probabilty Density Function (PDF) +========================= =========================================== +``normal`` :math:`\dfrac{\exp{(-z^2/2)}}{\sqrt{2\pi}}` +``logistic`` :math:`\dfrac{e^z}{(1+e^z)^2}` +``extreme`` :math:`e^z e^{-\exp{z}}` +========================= =========================================== + +Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`. \ No newline at end of file diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 0cf967cb03b1..f07dd11f1a34 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -39,7 +39,7 @@ enum class DataType : uint8_t { class MetaInfo { public: /*! \brief number of data fields in MetaInfo */ - static constexpr uint64_t kNumField = 7; + static constexpr uint64_t kNumField = 9; /*! \brief number of rows in the data */ uint64_t num_row_{0}; @@ -62,6 +62,14 @@ class MetaInfo { * can be used to specify initial prediction to boost from. */ HostDeviceVector base_margin_; + /*! + * \brief lower bound of the label, to be used for survival analysis (censored regression) + */ + HostDeviceVector labels_lower_bound_; + /*! + * \brief upper bound of the label, to be used for survival analysis (censored regression) + */ + HostDeviceVector labels_upper_bound_; /*! \brief default constructor */ MetaInfo() = default; diff --git a/include/xgboost/metric.h b/include/xgboost/metric.h index db32f719cbab..8ecc73c694b1 100644 --- a/include/xgboost/metric.h +++ b/include/xgboost/metric.h @@ -8,6 +8,7 @@ #define XGBOOST_METRIC_H_ #include +#include #include #include #include @@ -23,7 +24,7 @@ namespace xgboost { * \brief interface of evaluation metric used to evaluate model performance. * This has nothing to do with training, but merely act as evaluation purpose. */ -class Metric { +class Metric : public Configurable { protected: GenericParameter const* tparam_; @@ -34,6 +35,21 @@ class Metric { */ virtual void Configure( const std::vector >& args) {} + /*! + * \brief Load configuration from JSON object + * By default, metric has no internal configuration; + * override this function to maintain internal configuration + * \param in JSON object containing the configuration + */ + virtual void LoadConfig(Json const& in) {} + /*! + * \brief Save configuration to JSON object + * By default, metric has no internal configuration; + * override this function to maintain internal configuration + * \param out pointer to output JSON object + */ + virtual void SaveConfig(Json* out) const {} + /*! * \brief evaluate a specific metric * \param preds prediction diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ce03f7647bfe..a93ba897078c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -265,6 +265,10 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle, vec = &info.weights_.HostVector(); } else if (!std::strcmp(field, "base_margin")) { vec = &info.base_margin_.HostVector(); + } else if (!std::strcmp(field, "label_lower_bound")) { + vec = &info.labels_lower_bound_.HostVector(); + } else if (!std::strcmp(field, "label_upper_bound")) { + vec = &info.labels_upper_bound_.HostVector(); } else { LOG(FATAL) << "Unknown float field name " << field; } @@ -284,8 +288,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, if (!std::strcmp(field, "group_ptr")) { vec = &info.group_ptr_; } else { - LOG(FATAL) << "Unknown comp uint field name " << field - << " with comparison " << std::strcmp(field, "group_ptr"); + LOG(FATAL) << "Unknown uint field name " << field; } *out_len = static_cast(vec->size()); *out_dptr = dmlc::BeginPtr(*vec); diff --git a/src/common/probability_distribution.cc b/src/common/probability_distribution.cc new file mode 100644 index 000000000000..51bcc495b551 --- /dev/null +++ b/src/common/probability_distribution.cc @@ -0,0 +1,107 @@ +/*! + * Copyright 2020 by Contributors + * \file probability_distribution.cc + * \brief Implementation of a few useful probability distributions + * \author Avinash Barnwal and Hyunsu Cho + */ + +#include +#include +#include "probability_distribution.h" + +namespace xgboost { +namespace common { + +ProbabilityDistribution* ProbabilityDistribution::Create(ProbabilityDistributionType dist) { + switch (dist) { + case ProbabilityDistributionType::kNormal: + return new NormalDist; + case ProbabilityDistributionType::kLogistic: + return new LogisticDist; + case ProbabilityDistributionType::kExtreme: + return new ExtremeDist; + default: + LOG(FATAL) << "Unknown distribution"; + } + return nullptr; +} + +double NormalDist::PDF(double z) { + const double pdf = std::exp(-z * z / 2) / std::sqrt(2 * probability_constant::kPI); + return pdf; +} + +double NormalDist::CDF(double z) { + const double cdf = 0.5 * (1 + std::erf(z / std::sqrt(2))); + return cdf; +} + +double NormalDist::GradPDF(double z) { + const double pdf = this->PDF(z); + const double grad = -1 * z * pdf; + return grad; +} + +double NormalDist::HessPDF(double z) { + const double pdf = this->PDF(z); + const double hess = (z * z - 1) * pdf; + return hess; +} + +double LogisticDist::PDF(double z) { + const double w = std::exp(z); + const double sqrt_denominator = 1 + w; + const double pdf + = (std::isinf(w) || std::isinf(w * w)) ? 0.0 : (w / (sqrt_denominator * sqrt_denominator)); + return pdf; +} + +double LogisticDist::CDF(double z) { + const double w = std::exp(z); + const double cdf = std::isinf(w) ? 1.0 : (w / (1 + w)); + return cdf; +} + +double LogisticDist::GradPDF(double z) { + const double pdf = this->PDF(z); + const double w = std::exp(z); + const double grad = std::isinf(w) ? 0.0 : pdf * (1 - w) / (1 + w); + return grad; +} + +double LogisticDist::HessPDF(double z) { + const double pdf = this->PDF(z); + const double w = std::exp(z); + const double hess + = (std::isinf(w) || std::isinf(w * w)) ? 0.0 : pdf * (w * w - 4 * w + 1) / ((1 + w) * (1 + w)); + return hess; +} + +double ExtremeDist::PDF(double z) { + const double w = std::exp(z); + const double pdf = std::isinf(w) ? 0.0 : (w * std::exp(-w)); + return pdf; +} + +double ExtremeDist::CDF(double z) { + const double w = std::exp(z); + const double cdf = 1 - std::exp(-w); + return cdf; +} + +double ExtremeDist::GradPDF(double z) { + const double pdf = this->PDF(z); + const double w = std::exp(z); + const double grad = std::isinf(w) ? 0.0 : ((1 - w) * pdf); + return grad; +} + +double ExtremeDist::HessPDF(double z) { + const double pdf = this->PDF(z); + const double w = std::exp(z); + const double hess = (std::isinf(w) || std::isinf(w * w)) ? 0.0 : ((w * w - 3 * w + 1) * pdf); + return hess; +} + +} // namespace common +} // namespace xgboost diff --git a/src/common/probability_distribution.h b/src/common/probability_distribution.h new file mode 100644 index 000000000000..ccf3bb96c9eb --- /dev/null +++ b/src/common/probability_distribution.h @@ -0,0 +1,94 @@ +/*! + * Copyright 2020 by Contributors + * \file probability_distribution.h + * \brief Implementation of a few useful probability distributions + * \author Avinash Barnwal and Hyunsu Cho + */ + +#ifndef XGBOOST_COMMON_PROBABILITY_DISTRIBUTION_H_ +#define XGBOOST_COMMON_PROBABILITY_DISTRIBUTION_H_ + +namespace xgboost { +namespace common { + +namespace probability_constant { + +/*! \brief Constant PI */ +const double kPI = 3.14159265358979323846; +/*! \brief The Euler-Mascheroni_constant */ +const double kEulerMascheroni = 0.57721566490153286060651209008240243104215933593992; + +} // namespace probability_constant + +/*! \brief Enum encoding possible choices of probability distribution */ +enum class ProbabilityDistributionType : int { + kNormal = 0, kLogistic = 1, kExtreme = 2 +}; + +/*! \brief Interface for a probability distribution */ +class ProbabilityDistribution { + public: + /*! + * \brief Evaluate Probability Density Function (PDF) at a particular point + * \param z point at which to evaluate PDF + * \return Value of PDF evaluated + */ + virtual double PDF(double z) = 0; + /*! + * \brief Evaluate Cumulative Distribution Function (CDF) at a particular point + * \param z point at which to evaluate CDF + * \return Value of CDF evaluated + */ + virtual double CDF(double z) = 0; + /*! + * \brief Evaluate first derivative of PDF at a particular point + * \param z point at which to evaluate first derivative of PDF + * \return Value of first derivative of PDF evaluated + */ + virtual double GradPDF(double z) = 0; + /*! + * \brief Evaluate second derivative of PDF at a particular point + * \param z point at which to evaluate second derivative of PDF + * \return Value of second derivative of PDF evaluated + */ + virtual double HessPDF(double z) = 0; + + /*! + * \brief Factory function to instantiate a new probability distribution object + * \param dist kind of probability distribution + * \return Reference to the newly created probability distribution object + */ + static ProbabilityDistribution* Create(ProbabilityDistributionType dist); +}; + +/*! \brief The (standard) normal distribution */ +class NormalDist : public ProbabilityDistribution { + public: + double PDF(double z) override; + double CDF(double z) override; + double GradPDF(double z) override; + double HessPDF(double z) override; +}; + +/*! \brief The (standard) logistic distribution */ +class LogisticDist : public ProbabilityDistribution { + public: + double PDF(double z) override; + double CDF(double z) override; + double GradPDF(double z) override; + double HessPDF(double z) override; +}; + +/*! \brief The extreme distribution, also known as the Gumbel (minimum) distribution */ +class ExtremeDist : public ProbabilityDistribution { + public: + double PDF(double z) override; + double CDF(double z) override; + double GradPDF(double z) override; + double HessPDF(double z) override; +}; + +} // namespace common +} // namespace xgboost + +#endif // XGBOOST_COMMON_PROBABILITY_DISTRIBUTION_H_ diff --git a/src/common/survival_util.cc b/src/common/survival_util.cc new file mode 100644 index 000000000000..58c5a7946af7 --- /dev/null +++ b/src/common/survival_util.cc @@ -0,0 +1,146 @@ +/*! + * Copyright 2019 by Contributors + * \file survival_util.cc + * \brief Utility functions, useful for implementing objective and metric functions for survival + * analysis + * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking + */ + +#include +#include +#include +#include "survival_util.h" + +/* +- Formulas are motivated from document - + http://members.cbio.mines-paristech.fr/~thocking/survival.pdf +- Detailed Derivation of Loss/Gradient/Hessian - + https://github.com/avinashbarnwal/GSOC-2019/blob/master/doc/Accelerated_Failure_Time.pdf +*/ + +namespace xgboost { +namespace common { + +DMLC_REGISTER_PARAMETER(AFTParam); + +double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma) { + const double log_y_lower = std::log(y_lower); + const double log_y_upper = std::log(y_upper); + const double eps = 1e-12; + double cost; + + if (y_lower == y_upper) { // uncensored + const double z = (log_y_lower - y_pred) / sigma; + const double pdf = dist_->PDF(z); + // Regularize the denominator with eps, to avoid INF or NAN + cost = -std::log(std::max(pdf / (sigma * y_lower), eps)); + } else { // censored; now check what type of censorship we have + double z_u, z_l, cdf_u, cdf_l; + if (std::isinf(y_upper)) { // right-censored + cdf_u = 1; + } else { // left-censored or interval-censored + z_u = (log_y_upper - y_pred) / sigma; + cdf_u = dist_->CDF(z_u); + } + if (std::isinf(y_lower)) { // left-censored + cdf_l = 0; + } else { // right-censored or interval-censored + z_l = (log_y_lower - y_pred) / sigma; + cdf_l = dist_->CDF(z_l); + } + // Regularize the denominator with eps, to avoid INF or NAN + cost = -std::log(std::max(cdf_u - cdf_l, eps)); + } + + return cost; +} + +double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double sigma) { + const double log_y_lower = std::log(y_lower); + const double log_y_upper = std::log(y_upper); + double gradient; + const double eps = 1e-12; + + if (y_lower == y_upper) { // uncensored + const double z = (log_y_lower - y_pred) / sigma; + const double pdf = dist_->PDF(z); + const double grad_pdf = dist_->GradPDF(z); + // Regularize the denominator with eps, so that gradient doesn't get too big + gradient = grad_pdf / (sigma * std::max(pdf, eps)); + } else { // censored; now check what type of censorship we have + double z_u, z_l, pdf_u, pdf_l, cdf_u, cdf_l; + if (std::isinf(y_upper)) { // right-censored + pdf_u = 0; + cdf_u = 1; + } else { // interval-censored or left-censored + z_u = (log_y_upper - y_pred) / sigma; + pdf_u = dist_->PDF(z_u); + cdf_u = dist_->CDF(z_u); + } + if (std::isinf(y_lower)) { // left-censored + pdf_l = 0; + cdf_l = 0; + } else { // interval-censored or right-censored + z_l = (log_y_lower - y_pred) / sigma; + pdf_l = dist_->PDF(z_l); + cdf_l = dist_->CDF(z_l); + } + // Regularize the denominator with eps, so that gradient doesn't get too big + gradient = (pdf_u - pdf_l) / (sigma * std::max(cdf_u - cdf_l, eps)); + } + + return gradient; +} + +double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) { + const double log_y_lower = std::log(y_lower); + const double log_y_upper = std::log(y_upper); + const double eps = 1e-12; + double hessian; + + if (y_lower == y_upper) { // uncensored + const double z = (log_y_lower - y_pred) / sigma; + const double pdf = dist_->PDF(z); + const double grad_pdf = dist_->GradPDF(z); + const double hess_pdf = dist_->HessPDF(z); + // Regularize the denominator with eps, so that gradient doesn't get too big + hessian = -(pdf * hess_pdf - std::pow(grad_pdf, 2)) + / (std::pow(sigma, 2) * std::pow(std::max(pdf, eps), 2)); + } else { // censored; now check what type of censorship we have + double z_u, z_l, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l; + if (std::isinf(y_upper)) { // right-censored + pdf_u = 0; + cdf_u = 1; + grad_pdf_u = 0; + } else { // interval-censored or left-censored + z_u = (log_y_upper - y_pred) / sigma; + pdf_u = dist_->PDF(z_u); + cdf_u = dist_->CDF(z_u); + grad_pdf_u = dist_->GradPDF(z_u); + } + if (std::isinf(y_lower)) { // left-censored + pdf_l = 0; + cdf_l = 0; + grad_pdf_l = 0; + } else { // interval-censored or right-censored + z_l = (log_y_lower - y_pred) / sigma; + pdf_l = dist_->PDF(z_l); + cdf_l = dist_->CDF(z_l); + grad_pdf_l = dist_->GradPDF(z_l); + } + const double cdf_diff = cdf_u - cdf_l; + const double pdf_diff = pdf_u - pdf_l; + const double grad_diff = grad_pdf_u - grad_pdf_l; + // Regularize the denominator with eps, so that gradient doesn't get too big + const double cdf_diff_thresh = std::max(cdf_diff, eps); + const double numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff); + const double sqrt_denominator = sigma * cdf_diff_thresh; + const double denominator = sqrt_denominator * sqrt_denominator; + hessian = numerator / denominator; + } + + return hessian; +} + +} // namespace common +} // namespace xgboost diff --git a/src/common/survival_util.h b/src/common/survival_util.h new file mode 100644 index 000000000000..baae99b34e00 --- /dev/null +++ b/src/common/survival_util.h @@ -0,0 +1,85 @@ +/*! + * Copyright 2019 by Contributors + * \file survival_util.h + * \brief Utility functions, useful for implementing objective and metric functions for survival + * analysis + * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking + */ +#ifndef XGBOOST_COMMON_SURVIVAL_UTIL_H_ +#define XGBOOST_COMMON_SURVIVAL_UTIL_H_ + +#include +#include +#include "probability_distribution.h" + +DECLARE_FIELD_ENUM_CLASS(xgboost::common::ProbabilityDistributionType); + +namespace xgboost { +namespace common { + +/*! \brief Parameter structure for AFT loss and metric */ +struct AFTParam : public XGBoostParameter { + /*! \brief Choice of probability distribution for the noise term in AFT */ + ProbabilityDistributionType aft_loss_distribution; + /*! \brief Scaling factor to be applied to the distribution */ + float aft_loss_distribution_scale; + DMLC_DECLARE_PARAMETER(AFTParam) { + DMLC_DECLARE_FIELD(aft_loss_distribution) + .set_default(ProbabilityDistributionType::kNormal) + .add_enum("normal", ProbabilityDistributionType::kNormal) + .add_enum("logistic", ProbabilityDistributionType::kLogistic) + .add_enum("extreme", ProbabilityDistributionType::kExtreme) + .describe("Choice of distribution for the noise term in " + "Accelerated Failure Time model"); + DMLC_DECLARE_FIELD(aft_loss_distribution_scale) + .set_default(1.0f) + .describe("Scaling factor used to scale the distribution in " + "Accelerated Failure Time model"); + } +}; + +/*! \brief The AFT loss function */ +class AFTLoss { + private: + std::unique_ptr dist_; + + public: + /*! + * \brief Constructor for AFT loss function + * \param dist Choice of probability distribution for the noise term in AFT + */ + explicit AFTLoss(ProbabilityDistributionType dist) { + dist_.reset(ProbabilityDistribution::Create(dist)); + } + + public: + /*! + * \brief Compute the AFT loss + * \param y_lower Lower bound for the true label + * \param y_upper Upper bound for the true label + * \param y_pred Predicted label + * \param sigma Scaling factor to be applied to the distribution of the noise term + */ + double Loss(double y_lower, double y_upper, double y_pred, double sigma); + /*! + * \brief Compute the gradient of the AFT loss + * \param y_lower Lower bound for the true label + * \param y_upper Upper bound for the true label + * \param y_pred Predicted label + * \param sigma Scaling factor to be applied to the distribution of the noise term + */ + double Gradient(double y_lower, double y_upper, double y_pred, double sigma); + /*! + * \brief Compute the hessian of the AFT loss + * \param y_lower Lower bound for the true label + * \param y_upper Upper bound for the true label + * \param y_pred Predicted label + * \param sigma Scaling factor to be applied to the distribution of the noise term + */ + double Hessian(double y_lower, double y_upper, double y_pred, double sigma); +}; + +} // namespace common +} // namespace xgboost + +#endif // XGBOOST_COMMON_SURVIVAL_UTIL_H_ diff --git a/src/data/data.cc b/src/data/data.cc index 92b2840f4b8d..de7aa170e2ae 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -133,15 +133,17 @@ void MetaInfo::Clear() { /* * Binary serialization format for MetaInfo: * - * | name | type | is_scalar | num_row | num_col | value | - * |-------------+----------+-----------+---------+---------+-----------------| - * | num_row | kUInt64 | True | NA | NA | ${num_row_} | - * | num_col | kUInt64 | True | NA | NA | ${num_col_} | - * | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} | - * | labels | kFloat32 | False | ${size} | 1 | ${labels_} | - * | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} | - * | weights | kFloat32 | False | ${size} | 1 | ${weights_} | - * | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} | + * | name | type | is_scalar | num_row | num_col | value | + * |--------------------+----------+-----------+---------+---------+-------------------------| + * | num_row | kUInt64 | True | NA | NA | ${num_row_} | + * | num_col | kUInt64 | True | NA | NA | ${num_col_} | + * | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} | + * | labels | kFloat32 | False | ${size} | 1 | ${labels_} | + * | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} | + * | weights | kFloat32 | False | ${size} | 1 | ${weights_} | + * | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} | + * | labels_lower_bound | kFloat32 | False | ${size} | 1 | ${labels_lower_bound__} | + * | labels_upper_bound | kFloat32 | False | ${size} | 1 | ${labels_upper_bound__} | * * Note that the scalar fields (is_scalar=True) will have num_row and num_col missing. * Also notice the difference between the saved name and the name used in `SetInfo': @@ -164,6 +166,10 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { {weights_.Size(), 1}, weights_); ++field_cnt; SaveVectorField(fo, u8"base_margin", DataType::kFloat32, {base_margin_.Size(), 1}, base_margin_); ++field_cnt; + SaveVectorField(fo, u8"labels_lower_bound", DataType::kFloat32, + {labels_lower_bound_.Size(), 1}, labels_lower_bound_); ++field_cnt; + SaveVectorField(fo, u8"labels_upper_bound", DataType::kFloat32, + {labels_upper_bound_.Size(), 1}, labels_upper_bound_); ++field_cnt; CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields"; } @@ -195,6 +201,8 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_); LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_); LoadVectorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); + LoadVectorField(fi, u8"labels_lower_bound", DataType::kFloat32, &labels_lower_bound_); + LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_); } // try to load group information from file, if exists @@ -268,8 +276,18 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t for (size_t i = 1; i < group_ptr_.size(); ++i) { group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i]; } + } else if (!std::strcmp(key, "label_lower_bound")) { + auto& labels = labels_lower_bound_.HostVector(); + labels.resize(num); + DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, + std::copy(cast_dptr, cast_dptr + num, labels.begin())); + } else if (!std::strcmp(key, "label_upper_bound")) { + auto& labels = labels_upper_bound_.HostVector(); + labels.resize(num); + DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, + std::copy(cast_dptr, cast_dptr + num, labels.begin())); } else { - LOG(FATAL) << "Unknown metainfo: " << key; + LOG(FATAL) << "Unknown key for MetaInfo: " << key; } } diff --git a/src/metric/survival_metric.cc b/src/metric/survival_metric.cc new file mode 100644 index 000000000000..252ee40cd071 --- /dev/null +++ b/src/metric/survival_metric.cc @@ -0,0 +1,106 @@ +/*! + * Copyright 2019 by Contributors + * \file survival_metric.cc + * \brief Metrics for survival analysis + * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "xgboost/json.h" + +#include "../common/math.h" +#include "../common/survival_util.h" + +using AFTParam = xgboost::common::AFTParam; +using AFTLoss = xgboost::common::AFTLoss; + +namespace xgboost { +namespace metric { +// tag the this file, used by force static link later. +DMLC_REGISTRY_FILE_TAG(survival_metric); + +/*! \brief Negative log likelihood of Accelerated Failure Time model */ +struct EvalAFT : public Metric { + public: + explicit EvalAFT(const char* param) {} + + void Configure(const Args& args) override { + param_.UpdateAllowUnknown(args); + loss_.reset(new AFTLoss(param_.aft_loss_distribution)); + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(this->Name()); + out["aft_loss_param"] = toJson(param_); + } + + void LoadConfig(Json const& in) override { + fromJson(in["aft_loss_param"], ¶m_); + } + + bst_float Eval(const HostDeviceVector &preds, + const MetaInfo &info, + bool distributed) override { + CHECK_NE(info.labels_lower_bound_.Size(), 0U) + << "y_lower cannot be empty"; + CHECK_NE(info.labels_upper_bound_.Size(), 0U) + << "y_higher cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size()); + CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size()); + + /* Compute negative log likelihood for each data point and compute weighted average */ + const auto& yhat = preds.HostVector(); + const auto& y_lower = info.labels_lower_bound_.HostVector(); + const auto& y_upper = info.labels_upper_bound_.HostVector(); + const auto& weights = info.weights_.HostVector(); + const bool is_null_weight = weights.empty(); + const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale; + CHECK_LE(yhat.size(), static_cast(std::numeric_limits::max())) + << "yhat is too big"; + const omp_ulong nsize = static_cast(yhat.size()); + + double nloglik_sum = 0.0; + double weight_sum = 0.0; + #pragma omp parallel for default(none) \ + firstprivate(nsize, is_null_weight, aft_loss_distribution_scale) \ + shared(weights, y_lower, y_upper, yhat) reduction(+:nloglik_sum, weight_sum) + for (omp_ulong i = 0; i < nsize; ++i) { + // If weights are empty, data is unweighted so we use 1.0 everywhere + const double w = is_null_weight ? 1.0 : weights[i]; + const double loss + = loss_->Loss(y_lower[i], y_upper[i], yhat[i], aft_loss_distribution_scale); + nloglik_sum += loss; + weight_sum += w; + } + + double dat[2]{nloglik_sum, weight_sum}; + if (distributed) { + rabit::Allreduce(dat, 2); + } + return static_cast(dat[0] / dat[1]); + } + + const char* Name() const override { + return "aft-nloglik"; + } + + private: + AFTParam param_; + std::unique_ptr loss_; +}; + +XGBOOST_REGISTER_METRIC(AFT, "aft-nloglik") +.describe("Negative log likelihood of Accelerated Failure Time model.") +.set_body([](const char* param) { return new EvalAFT(param); }); + +} // namespace metric +} // namespace xgboost diff --git a/src/objective/aft_obj.cc b/src/objective/aft_obj.cc new file mode 100644 index 000000000000..1935d8b71f49 --- /dev/null +++ b/src/objective/aft_obj.cc @@ -0,0 +1,119 @@ +/*! + * Copyright 2015 by Contributors + * \file rank.cc + * \brief Definition of aft loss. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "xgboost/json.h" + +#include "../common/math.h" +#include "../common/random.h" +#include "../common/survival_util.h" + +using AFTParam = xgboost::common::AFTParam; +using AFTLoss = xgboost::common::AFTLoss; + +namespace xgboost { +namespace obj { + +DMLC_REGISTRY_FILE_TAG(aft_obj); + +class AFTObj : public ObjFunction { + public: + void Configure(const std::vector >& args) override { + param_.UpdateAllowUnknown(args); + loss_.reset(new AFTLoss(param_.aft_loss_distribution)); + } + + void GetGradient(const HostDeviceVector& preds, + const MetaInfo& info, + int iter, + HostDeviceVector* out_gpair) override { + /* Boilerplate */ + CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size()); + CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size()); + + const auto& yhat = preds.HostVector(); + const auto& y_lower = info.labels_lower_bound_.HostVector(); + const auto& y_upper = info.labels_upper_bound_.HostVector(); + const auto& weights = info.weights_.HostVector(); + const bool is_null_weight = weights.empty(); + + out_gpair->Resize(yhat.size()); + std::vector& gpair = out_gpair->HostVector(); + CHECK_LE(yhat.size(), static_cast(std::numeric_limits::max())) + << "yhat is too big"; + const omp_ulong nsize = static_cast(yhat.size()); + const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale; + + #pragma omp parallel for default(none) \ + firstprivate(nsize, is_null_weight, aft_loss_distribution_scale) \ + shared(weights, y_lower, y_upper, yhat, gpair) + for (omp_ulong i = 0; i < nsize; ++i) { + // If weights are empty, data is unweighted so we use 1.0 everywhere + const double w = is_null_weight ? 1.0 : weights[i]; + const double grad = loss_->Gradient(y_lower[i], y_upper[i], + yhat[i], aft_loss_distribution_scale); + const double hess = loss_->Hessian(y_lower[i], y_upper[i], + yhat[i], aft_loss_distribution_scale); + gpair[i] = GradientPair(grad * w, hess * w); + } + } + + void PredTransform(HostDeviceVector *io_preds) override { + // Trees give us a prediction in log scale, so exponentiate + std::vector &preds = io_preds->HostVector(); + const long ndata = static_cast(preds.size()); // NOLINT(*) + #pragma omp parallel for default(none) firstprivate(ndata) shared(preds) + for (long j = 0; j < ndata; ++j) { // NOLINT(*) + preds[j] = std::exp(preds[j]); + } + } + + void EvalTransform(HostDeviceVector *io_preds) override { + // do nothing here, since the AFT metric expects untransformed prediction score + } + + bst_float ProbToMargin(bst_float base_score) const override { + return std::log(base_score); + } + + const char* DefaultEvalMetric() const override { + return "aft-nloglik"; + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String("survival:aft"); + out["aft_loss_param"] = toJson(param_); + } + + void LoadConfig(Json const& in) override { + fromJson(in["aft_loss_param"], ¶m_); + loss_.reset(new AFTLoss(param_.aft_loss_distribution)); + } + + private: + AFTParam param_; + std::unique_ptr loss_; +}; + +// register the objective functions +XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft") +.describe("AFT loss function") +.set_body([]() { return new AFTObj(); }); + +} // namespace obj +} // namespace xgboost + + diff --git a/tests/cpp/common/test_probability_distribution.cc b/tests/cpp/common/test_probability_distribution.cc new file mode 100644 index 000000000000..4363be9dfdf6 --- /dev/null +++ b/tests/cpp/common/test_probability_distribution.cc @@ -0,0 +1,121 @@ +/*! + * Copyright (c) by Contributors 2020 + */ +#include +#include +#include + +#include "xgboost/logging.h" +#include "../../../src/common/probability_distribution.h" + +namespace xgboost { +namespace common { + +TEST(ProbabilityDistribution, DistributionGeneric) { + // Assert d/dx CDF = PDF, d/dx PDF = GradPDF, d/dx GradPDF = HessPDF + // Do this for every distribution type + for (auto type : {ProbabilityDistributionType::kNormal, ProbabilityDistributionType::kLogistic, + ProbabilityDistributionType::kExtreme}) { + std::unique_ptr dist{ ProbabilityDistribution::Create(type) }; + double integral_of_pdf = dist->CDF(-2.0); + double integral_of_grad_pdf = dist->PDF(-2.0); + double integral_of_hess_pdf = dist->GradPDF(-2.0); + // Perform numerical differentiation and integration + // Enumerate 4000 grid points in range [-2, 2] + for (int i = 0; i <= 4000; ++i) { + const double x = static_cast(i) / 1000.0 - 2.0; + // Numerical differentiation (p. 246, Numerical Analysis 2nd ed. by Timothy Sauer) + EXPECT_NEAR((dist->CDF(x + 1e-5) - dist->CDF(x - 1e-5)) / 2e-5, dist->PDF(x), 6e-11); + EXPECT_NEAR((dist->PDF(x + 1e-5) - dist->PDF(x - 1e-5)) / 2e-5, dist->GradPDF(x), 6e-11); + EXPECT_NEAR((dist->GradPDF(x + 1e-5) - dist->GradPDF(x - 1e-5)) / 2e-5, + dist->HessPDF(x), 6e-11); + // Numerical integration using Trapezoid Rule (p. 257, Sauer) + integral_of_pdf += 5e-4 * (dist->PDF(x - 1e-3) + dist->PDF(x)); + integral_of_grad_pdf += 5e-4 * (dist->GradPDF(x - 1e-3) + dist->GradPDF(x)); + integral_of_hess_pdf += 5e-4 * (dist->HessPDF(x - 1e-3) + dist->HessPDF(x)); + EXPECT_NEAR(integral_of_pdf, dist->CDF(x), 2e-4); + EXPECT_NEAR(integral_of_grad_pdf, dist->PDF(x), 2e-4); + EXPECT_NEAR(integral_of_hess_pdf, dist->GradPDF(x), 2e-4); + } + } +} + +TEST(ProbabilityDistribution, NormalDist) { + std::unique_ptr dist{ + ProbabilityDistribution::Create(ProbabilityDistributionType::kNormal) + }; + + // "Three-sigma rule" (https://en.wikipedia.org/wiki/68–95–99.7_rule) + // 68% of values are within 1 standard deviation away from the mean + // 95% of values are within 2 standard deviation away from the mean + // 99.7% of values are within 3 standard deviation away from the mean + EXPECT_NEAR(dist->CDF(0.5) - dist->CDF(-0.5), 0.3829, 0.00005); + EXPECT_NEAR(dist->CDF(1.0) - dist->CDF(-1.0), 0.6827, 0.00005); + EXPECT_NEAR(dist->CDF(1.5) - dist->CDF(-1.5), 0.8664, 0.00005); + EXPECT_NEAR(dist->CDF(2.0) - dist->CDF(-2.0), 0.9545, 0.00005); + EXPECT_NEAR(dist->CDF(2.5) - dist->CDF(-2.5), 0.9876, 0.00005); + EXPECT_NEAR(dist->CDF(3.0) - dist->CDF(-3.0), 0.9973, 0.00005); + EXPECT_NEAR(dist->CDF(3.5) - dist->CDF(-3.5), 0.9995, 0.00005); + EXPECT_NEAR(dist->CDF(4.0) - dist->CDF(-4.0), 0.9999, 0.00005); +} + +TEST(ProbabilityDistribution, LogisticDist) { + std::unique_ptr dist{ + ProbabilityDistribution::Create(ProbabilityDistributionType::kLogistic) + }; + + /** + * Enforce known properties of the logistic distribution. + * (https://en.wikipedia.org/wiki/Logistic_distribution) + **/ + + // Enumerate 4000 grid points in range [-2, 2] + for (int i = 0; i <= 4000; ++i) { + const double x = static_cast(i) / 1000.0 - 2.0; + // PDF = 1/4 * sech(x/2)**2 + const double sech_x = 1.0 / std::cosh(x * 0.5); // hyperbolic secant at x/2 + EXPECT_NEAR(0.25 * sech_x * sech_x, dist->PDF(x), 1e-15); + // CDF = 1/2 + 1/2 * tanh(x/2) + EXPECT_NEAR(0.5 + 0.5 * std::tanh(x * 0.5), dist->CDF(x), 1e-15); + } +} + +TEST(ProbabilityDistribution, ExtremeDist) { + std::unique_ptr dist{ + ProbabilityDistribution::Create(ProbabilityDistributionType::kExtreme) + }; + + /** + * Enforce known properties of the extreme distribution (also known as Gumbel distribution). + * The mean is the negative of the Euler-Mascheroni constant. + * The variance is 1/6 * pi**2. (https://mathworld.wolfram.com/GumbelDistribution.html) + **/ + + // Enumerate 25000 grid points in range [-20, 5]. + // Compute the mean (expected value) of the distribution using numerical integration. + // Nearly all mass of the extreme distribution is concentrated between -20 and 5, + // so numerically integrating x*PDF(x) over [-20, 5] gives good estimate of the mean. + double mean = 0.0; + for (int i = 0; i <= 25000; ++i) { + const double x = static_cast(i) / 1000.0 - 20.0; + // Numerical integration using Trapezoid Rule (p. 257, Sauer) + mean += 5e-4 * ((x - 1e-3) * dist->PDF(x - 1e-3) + x * dist->PDF(x)); + } + EXPECT_NEAR(mean, -probability_constant::kEulerMascheroni, 1e-7); + + // Enumerate 25000 grid points in range [-20, 5]. + // Compute the variance of the distribution using numerical integration. + // Nearly all mass of the extreme distribution is concentrated between -20 and 5, + // so numerically integrating (x-mean)*PDF(x) over [-20, 5] gives good estimate of the variance. + double variance = 0.0; + for (int i = 0; i <= 25000; ++i) { + const double x = static_cast(i) / 1000.0 - 20.0; + // Numerical integration using Trapezoid Rule (p. 257, Sauer) + variance += 5e-4 * ((x - 1e-3 - mean) * (x - 1e-3 - mean) * dist->PDF(x - 1e-3) + + (x - mean) * (x - mean) * dist->PDF(x)); + } + EXPECT_NEAR(variance, probability_constant::kPI * probability_constant::kPI / 6.0, 1e-6); +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/metric/test_survival_metric.cc b/tests/cpp/metric/test_survival_metric.cc new file mode 100644 index 000000000000..f2c447085b8b --- /dev/null +++ b/tests/cpp/metric/test_survival_metric.cc @@ -0,0 +1,169 @@ +/*! + * Copyright (c) by Contributors 2020 + */ +#include +#include +#include +#include +#include +#include + +#include "xgboost/metric.h" +#include "xgboost/logging.h" +#include "../helpers.h" +#include "../../../src/common/survival_util.h" + +namespace xgboost { +namespace common { + +/** + * Reference values obtained from + * https://github.com/avinashbarnwal/GSOC-2019/blob/master/AFT/R/combined_assignment.R + **/ + +TEST(Metric, AFTNegLogLik) { + auto lparam = CreateEmptyGenericParam(-1); // currently AFT metric is CPU only + + /** + * Test aggregate output from the AFT metric over a small test data set. + * This is unlike AFTLoss.* tests, which verify metric values over individual data points. + **/ + MetaInfo info; + info.num_row_ = 4; + info.labels_lower_bound_.HostVector() + = { 100.0f, -std::numeric_limits::infinity(), 60.0f, 16.0f }; + info.labels_upper_bound_.HostVector() + = { 100.0f, 20.0f, std::numeric_limits::infinity(), 200.0f }; + info.weights_.HostVector() = std::vector(); + HostDeviceVector preds(4, std::log(64)); + + struct TestCase { + std::string dist_type; + bst_float reference_value; + }; + for (const auto& test_case : std::vector{ {"normal", 2.1508f}, {"logistic", 2.1804f}, + {"extreme", 2.0706f} }) { + std::unique_ptr metric(Metric::Create("aft-nloglik", &lparam)); + metric->Configure({ {"aft_loss_distribution", test_case.dist_type}, + {"aft_loss_distribution_scale", "1.0"} }); + EXPECT_NEAR(metric->Eval(preds, info, false), test_case.reference_value, 1e-4); + } +} + +// Test configuration of AFT metric +TEST(AFTNegLogLikMetric, Configuration) { + auto lparam = CreateEmptyGenericParam(-1); // currently AFT metric is CPU only + std::unique_ptr metric(Metric::Create("aft-nloglik", &lparam)); + metric->Configure({{"aft_loss_distribution", "normal"}, {"aft_loss_distribution_scale", "10"}}); + + // Configuration round-trip test + Json j_obj{ Object() }; + metric->SaveConfig(&j_obj); + auto aft_param_json = j_obj["aft_loss_param"]; + EXPECT_EQ(get(aft_param_json["aft_loss_distribution"]), "normal"); + EXPECT_EQ(get(aft_param_json["aft_loss_distribution_scale"]), "10"); +} + +/** + * AFTLoss.* tests verify metric values over individual data points. + **/ + +// Generate prediction value ranging from 2**1 to 2**15, using grid points in log scale +// Then check prediction against the reference values +static inline void CheckLossOverGridPoints( + double true_label_lower_bound, + double true_label_upper_bound, + ProbabilityDistributionType dist_type, + const std::vector& reference_values) { + const int num_point = 20; + const double log_y_low = 1.0; + const double log_y_high = 15.0; + std::unique_ptr loss(new AFTLoss(dist_type)); + CHECK_EQ(num_point, reference_values.size()); + for (int i = 0; i < num_point; ++i) { + const double y_pred + = std::pow(2.0, i * (log_y_high - log_y_low) / (num_point - 1) + log_y_low); + const double loss_val + = loss->Loss(true_label_lower_bound, true_label_upper_bound, std::log(y_pred), 1.0); + EXPECT_NEAR(loss_val, reference_values[i], 1e-4); + } +} + +TEST(AFTLoss, Uncensored) { + // Given label 100, compute the AFT loss for various prediction values + const double true_label_lower_bound = 100.0; + const double true_label_upper_bound = true_label_lower_bound; + + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kNormal, + { 13.1761, 11.3085, 9.7017, 8.3558, 7.2708, 6.4466, 5.8833, 5.5808, 5.5392, 5.7585, 6.2386, + 6.9795, 7.9813, 9.2440, 10.7675, 12.5519, 14.5971, 16.9032, 19.4702, 22.2980 }); + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kLogistic, + { 8.5568, 8.0720, 7.6038, 7.1620, 6.7612, 6.4211, 6.1659, 6.0197, 5.9990, 6.1064, 6.3293, + 6.6450, 7.0289, 7.4594, 7.9205, 8.4008, 8.8930, 9.3926, 9.8966, 10.4033 }); + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kExtreme, + { 27.6310, 27.6310, 19.7177, 13.0281, 9.2183, 7.1365, 6.0916, 5.6688, 5.6195, 5.7941, 6.1031, + 6.4929, 6.9310, 7.3981, 7.8827, 8.3778, 8.8791, 9.3842, 9.8916, 10.40033 }); +} + +TEST(AFTLoss, LeftCensored) { + // Given label (-inf, 20], compute the AFT loss for various prediction values + const double true_label_lower_bound = -std::numeric_limits::infinity(); + const double true_label_upper_bound = 20.0; + + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kNormal, + { 0.0107, 0.0373, 0.1054, 0.2492, 0.5068, 0.9141, 1.5003, 2.2869, 3.2897, 4.5196, 5.9846, + 7.6902, 9.6405, 11.8385, 14.2867, 16.9867, 19.9399, 23.1475, 26.6103, 27.6310 }); + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kLogistic, + { 0.0953, 0.1541, 0.2451, 0.3804, 0.5717, 0.8266, 1.1449, 1.5195, 1.9387, 2.3902, 2.8636, + 3.3512, 3.8479, 4.3500, 4.8556, 5.3632, 5.8721, 6.3817, 6.8918, 7.4021 }); + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kExtreme, + { 0.0000, 0.0025, 0.0277, 0.1225, 0.3195, 0.6150, 0.9862, 1.4094, 1.8662, 2.3441, 2.8349, + 3.3337, 3.8372, 4.3436, 4.8517, 5.3609, 5.8707, 6.3808, 6.8912, 7.4018 }); +} + +TEST(AFTLoss, RightCensored) { + // Given label [60, +inf), compute the AFT loss for various prediction values + const double true_label_lower_bound = 60.0; + const double true_label_upper_bound = std::numeric_limits::infinity(); + + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kNormal, + { 8.0000, 6.2537, 4.7487, 3.4798, 2.4396, 1.6177, 0.9993, 0.5638, 0.2834, 0.1232, 0.0450, + 0.0134, 0.0032, 0.0006, 0.0001, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 }); + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kLogistic, + { 3.4340, 2.9445, 2.4683, 2.0125, 1.5871, 1.2041, 0.8756, 0.6099, 0.4083, 0.2643, 0.1668, + 0.1034, 0.0633, 0.0385, 0.0233, 0.0140, 0.0084, 0.0051, 0.0030, 0.0018 }); + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kExtreme, + { 27.6310, 18.0015, 10.8018, 6.4817, 3.8893, 2.3338, 1.4004, 0.8403, 0.5042, 0.3026, 0.1816, + 0.1089, 0.0654, 0.0392, 0.0235, 0.0141, 0.0085, 0.0051, 0.0031, 0.0018 }); +} + +TEST(AFTLoss, IntervalCensored) { + // Given label [16, 200], compute the AFT loss for various prediction values + const double true_label_lower_bound = 16.0; + const double true_label_upper_bound = 200.0; + + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kNormal, + { 3.9746, 2.8415, 1.9319, 1.2342, 0.7335, 0.4121, 0.2536, 0.2470, 0.3919, 0.6982, 1.1825, + 1.8622, 2.7526, 3.8656, 5.2102, 6.7928, 8.6183, 10.6901, 13.0108, 15.5826 }); + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kLogistic, + { 2.2906, 1.8578, 1.4667, 1.1324, 0.8692, 0.6882, 0.5948, 0.5909, 0.6764, 0.8499, 1.1061, + 1.4348, 1.8215, 2.2511, 2.7104, 3.1891, 3.6802, 4.1790, 4.6825, 5.1888 }); + CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound, + ProbabilityDistributionType::kExtreme, + { 8.0000, 4.8004, 2.8805, 1.7284, 1.0372, 0.6231, 0.3872, 0.3031, 0.3740, 0.5839, 0.8995, + 1.2878, 1.7231, 2.1878, 2.6707, 3.1647, 3.6653, 4.1699, 4.6770, 5.1856 }); +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/objective/test_aft_obj.cc b/tests/cpp/objective/test_aft_obj.cc new file mode 100644 index 000000000000..01e965df8d15 --- /dev/null +++ b/tests/cpp/objective/test_aft_obj.cc @@ -0,0 +1,174 @@ +/*! + * Copyright (c) by Contributors 2020 + */ +#include +#include +#include +#include +#include + +#include "xgboost/objective.h" +#include "xgboost/logging.h" +#include "../helpers.h" +#include "../../../src/common/survival_util.h" + +namespace xgboost { +namespace common { + +TEST(Objective, AFTObjConfiguration) { + auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only + std::unique_ptr objective(ObjFunction::Create("survival:aft", &lparam)); + objective->Configure({ {"aft_loss_distribution", "logistic"}, + {"aft_loss_distribution_scale", "5"} }); + + // Configuration round-trip test + Json j_obj{ Object() }; + objective->SaveConfig(&j_obj); + EXPECT_EQ(get(j_obj["name"]), "survival:aft"); + auto aft_param_json = j_obj["aft_loss_param"]; + EXPECT_EQ(get(aft_param_json["aft_loss_distribution"]), "logistic"); + EXPECT_EQ(get(aft_param_json["aft_loss_distribution_scale"]), "5"); +} + +/** + * Verify that gradient pair (gpair) is computed correctly for various prediction values. + * Reference values obtained from + * https://github.com/avinashbarnwal/GSOC-2019/blob/master/AFT/R/combined_assignment.R + **/ + +// Generate prediction value ranging from 2**1 to 2**15, using grid points in log scale +// Then check prediction against the reference values +static inline void CheckGPairOverGridPoints( + ObjFunction* obj, + bst_float true_label_lower_bound, + bst_float true_label_upper_bound, + const std::string& dist_type, + const std::vector& expected_grad, + const std::vector& expected_hess, + float ftol = 1e-4f) { + const int num_point = 20; + const double log_y_low = 1.0; + const double log_y_high = 15.0; + + obj->Configure({ {"aft_loss_distribution", dist_type}, + {"aft_loss_distribution_scale", "1"} }); + + MetaInfo info; + info.num_row_ = num_point; + info.labels_lower_bound_.HostVector() + = std::vector(num_point, true_label_lower_bound); + info.labels_upper_bound_.HostVector() + = std::vector(num_point, true_label_upper_bound); + info.weights_.HostVector() = std::vector(); + std::vector preds(num_point); + for (int i = 0; i < num_point; ++i) { + preds[i] = std::log(std::pow(2.0, i * (log_y_high - log_y_low) / (num_point - 1) + log_y_low)); + } + + HostDeviceVector out_gpair; + obj->GetGradient(HostDeviceVector(preds), info, 1, &out_gpair); + const auto& gpair = out_gpair.HostVector(); + CHECK_EQ(num_point, expected_grad.size()); + CHECK_EQ(num_point, expected_hess.size()); + for (int i = 0; i < num_point; ++i) { + EXPECT_NEAR(gpair[i].GetGrad(), expected_grad[i], ftol); + EXPECT_NEAR(gpair[i].GetHess(), expected_hess[i], ftol); + } +} + +TEST(Objective, AFTObjGPairUncensoredLabels) { + auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only + std::unique_ptr obj(ObjFunction::Create("survival:aft", &lparam)); + + CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "normal", + { -3.9120f, -3.4013f, -2.8905f, -2.3798f, -1.8691f, -1.3583f, -0.8476f, -0.3368f, 0.1739f, + 0.6846f, 1.1954f, 1.7061f, 2.2169f, 2.7276f, 3.2383f, 3.7491f, 4.2598f, 4.7706f, 5.2813f, + 5.7920f }, + { 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, + 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f }); + CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "logistic", + { -0.9608f, -0.9355f, -0.8948f, -0.8305f, -0.7327f, -0.5910f, -0.4001f, -0.1668f, 0.0867f, + 0.3295f, 0.5354f, 0.6927f, 0.8035f, 0.8773f, 0.9245f, 0.9540f, 0.9721f, 0.9832f, 0.9899f, + 0.9939f }, + { 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f, + 0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f }); + CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme", + { -0.0000f, -29.0026f, -17.0031f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f, + 0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f, + 0.9969f }, + { 0.0000f, 30.0026f, 18.0031f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f, + 0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f }); +} + +TEST(Objective, AFTObjGPairLeftCensoredLabels) { + auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only + std::unique_ptr obj(ObjFunction::Create("survival:aft", &lparam)); + + CheckGPairOverGridPoints(obj.get(), -std::numeric_limits::infinity(), 20.0f, "normal", + { 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f, + 3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 0.5072f }, + { 0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f, + 0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9812f, 0.0045f }, + 2e-4); + CheckGPairOverGridPoints(obj.get(), -std::numeric_limits::infinity(), 20.0f, "logistic", + { 0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f, + 0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f }, + { 0.0826f, 0.1224f, 0.1701f, 0.2163f, 0.2458f, 0.2461f, 0.2170f, 0.1709f, 0.1232f, 0.0832f, + 0.0538f, 0.0338f, 0.0209f, 0.0127f, 0.0077f, 0.0047f, 0.0028f, 0.0017f, 0.0010f, 0.0006f }); + CheckGPairOverGridPoints(obj.get(), -std::numeric_limits::infinity(), 20.0f, "extreme", + { 0.0005f, 0.0149f, 0.1011f, 0.2815f, 0.4881f, 0.6610f, 0.7847f, 0.8665f, 0.9183f, 0.9504f, + 0.9700f, 0.9820f, 0.9891f, 0.9935f, 0.9961f, 0.9976f, 0.9986f, 0.9992f, 0.9995f, 0.9997f }, + { 0.0041f, 0.0747f, 0.2731f, 0.4059f, 0.3829f, 0.2901f, 0.1973f, 0.1270f, 0.0793f, 0.0487f, + 0.0296f, 0.0179f, 0.0108f, 0.0065f, 0.0039f, 0.0024f, 0.0014f, 0.0008f, 0.0005f, 0.0003f }); +} + +TEST(Objective, AFTObjGPairRightCensoredLabels) { + auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only + std::unique_ptr obj(ObjFunction::Create("survival:aft", &lparam)); + + CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits::infinity(), "normal", + { -3.6583f, -3.1815f, -2.7135f, -2.2577f, -1.8190f, -1.4044f, -1.0239f, -0.6905f, -0.4190f, + -0.2209f, -0.0973f, -0.0346f, -0.0097f, -0.0021f, -0.0004f, -0.0000f, -0.0000f, -0.0000f, + -0.0000f, -0.0000f }, + { 0.9407f, 0.9259f, 0.9057f, 0.8776f, 0.8381f, 0.7821f, 0.7036f, 0.5970f, 0.4624f, 0.3128f, + 0.1756f, 0.0780f, 0.0265f, 0.0068f, 0.0013f, 0.0002f, 0.0000f, 0.0000f, 0.0000f, 0.0000f }); + CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits::infinity(), "logistic", + { -0.9677f, -0.9474f, -0.9153f, -0.8663f, -0.7955f, -0.7000f, -0.5834f, -0.4566f, -0.3352f, + -0.2323f, -0.1537f, -0.0982f, -0.0614f, -0.0377f, -0.0230f, -0.0139f, -0.0084f, -0.0051f, + -0.0030f, -0.0018f }, + { 0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f, + 0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f }); + CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits::infinity(), "extreme", + { -2.8073f, -18.0015f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f, + -0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, -0.0235f, -0.0141f, -0.0085f, -0.0051f, + -0.0031f, -0.0018f }, + { 0.2614f, 18.0015f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f, + 0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f }); +} + +TEST(Objective, AFTObjGPairIntervalCensoredLabels) { + auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only + std::unique_ptr obj(ObjFunction::Create("survival:aft", &lparam)); + + CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "normal", + { -2.4435f, -1.9965f, -1.5691f, -1.1679f, -0.7990f, -0.4649f, -0.1596f, 0.1336f, 0.4370f, + 0.7682f, 1.1340f, 1.5326f, 1.9579f, 2.4035f, 2.8639f, 3.3351f, 3.8143f, 4.2995f, 4.7891f, + 5.2822f }, + { 0.8909f, 0.8579f, 0.8134f, 0.7557f, 0.6880f, 0.6221f, 0.5789f, 0.5769f, 0.6171f, 0.6818f, + 0.7500f, 0.8088f, 0.8545f, 0.8884f, 0.9131f, 0.9312f, 0.9446f, 0.9547f, 0.9624f, 0.9684f }); + CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "logistic", + { -0.8790f, -0.8112f, -0.7153f, -0.5893f, -0.4375f, -0.2697f, -0.0955f, 0.0800f, 0.2545f, + 0.4232f, 0.5768f, 0.7054f, 0.8040f, 0.8740f, 0.9210f, 0.9513f, 0.9703f, 0.9820f, 0.9891f, + 0.9934f }, + { 0.1086f, 0.1588f, 0.2176f, 0.2745f, 0.3164f, 0.3374f, 0.3433f, 0.3434f, 0.3384f, 0.3191f, + 0.2789f, 0.2229f, 0.1637f, 0.1125f, 0.0737f, 0.0467f, 0.0290f, 0.0177f, 0.0108f, 0.0065f }); + CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "extreme", + { -8.0000f, -4.8004f, -2.8805f, -1.7284f, -1.0371f, -0.6168f, -0.3140f, -0.0121f, 0.2841f, + 0.5261f, 0.6989f, 0.8132f, 0.8857f, 0.9306f, 0.9581f, 0.9747f, 0.9848f, 0.9909f, 0.9945f, + 0.9967f }, + { 8.0000f, 4.8004f, 2.8805f, 1.7284f, 1.0380f, 0.6567f, 0.5727f, 0.6033f, 0.5384f, 0.4051f, + 0.2757f, 0.1776f, 0.1110f, 0.0682f, 0.0415f, 0.0251f, 0.0151f, 0.0091f, 0.0055f, 0.0033f }); +} + +} // namespace common +} // namespace xgboost diff --git a/tests/python/test_survival.py b/tests/python/test_survival.py new file mode 100644 index 000000000000..12c79ed4bb33 --- /dev/null +++ b/tests/python/test_survival.py @@ -0,0 +1,90 @@ +import testing as tm +import pytest +import numpy as np +import xgboost as xgb +import json +from pathlib import Path + +dpath = Path('demo/data') + +def test_aft_survival_toy_data(): + # See demo/aft_survival/aft_survival_viz_demo.py + X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1)) + INF = np.inf + y_lower = np.array([ 10, 15, -INF, 30, 100]) + y_upper = np.array([INF, INF, 20, 50, INF]) + + dmat = xgb.DMatrix(X) + dmat.set_float_info('label_lower_bound', y_lower) + dmat.set_float_info('label_upper_bound', y_upper) + + # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes + # the corresponding predicted label (y_pred) + acc_rec = [] + def my_callback(env): + y_pred = env.model.predict(dmat) + acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X)) + acc_rec.append(acc) + + evals_result = {} + params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0} + bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result, + callbacks=[my_callback]) + + nloglik_rec = evals_result['train']['aft-nloglik'] + # AFT metric (negative log likelihood) improve monotonically + assert all(p >= q for p, q in zip(nloglik_rec, nloglik_rec[:1])) + # "Accuracy" improve monotonically. + # Over time, XGBoost model makes predictions that fall within given label ranges. + assert all(p <= q for p, q in zip(acc_rec, acc_rec[1:])) + assert acc_rec[-1] == 1.0 + + def gather_split_thresholds(tree): + if 'split_condition' in tree: + return (gather_split_thresholds(tree['children'][0]) + | gather_split_thresholds(tree['children'][1]) + | {tree['split_condition']}) + return set() + + # Only 2.5, 3.5, and 4.5 are used as split thresholds. + model_json = [json.loads(e) for e in bst.get_dump(dump_format='json')] + for tree in model_json: + assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5}) + +@pytest.mark.skipif(**tm.no_pandas()) +def test_aft_survival_demo_data(): + import pandas as pd + df = pd.read_csv(dpath / 'veterans_lung_cancer.csv') + + y_lower_bound = df['Survival_label_lower_bound'] + y_upper_bound = df['Survival_label_upper_bound'] + X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1) + + dtrain = xgb.DMatrix(X) + dtrain.set_float_info('label_lower_bound', y_lower_bound) + dtrain.set_float_info('label_upper_bound', y_upper_bound) + + base_params = {'verbosity': 0, + 'objective': 'survival:aft', + 'eval_metric': 'aft-nloglik', + 'tree_method': 'hist', + 'learning_rate': 0.05, + 'aft_loss_distribution_scale': 1.20, + 'max_depth': 6, + 'lambda': 0.01, + 'alpha': 0.02} + nloglik_rec = {} + dists = ['normal', 'logistic', 'extreme'] + for dist in dists: + params = base_params + params.update({'aft_loss_distribution': dist}) + evals_result = {} + bst = xgb.train(params, dtrain, num_boost_round=500, evals=[(dtrain, 'train')], + evals_result=evals_result) + nloglik_rec[dist] = evals_result['train']['aft-nloglik'] + # AFT metric (negative log likelihood) improve monotonically + assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1])) + # For this data, normal distribution works the best + assert nloglik_rec['normal'][-1] < 5.0 + assert nloglik_rec['logistic'][-1] > 5.0 + assert nloglik_rec['extreme'][-1] > 5.0