Skip to content

Commit

Permalink
Add R code to AFT tutorial [skip ci] (#5486)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 authored Apr 4, 2020
1 parent 1580010 commit 30e94dd
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions doc/tutorials/aft_survival_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Interval-censored :math:`[a, b]` |tick| |tick|
Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) and the upper bound number in another array (call it ``y_upper_bound``). The ranged labels are associated with a data matrix object via calls to :meth:`xgboost.DMatrix.set_float_info`:

.. code-block:: python
:caption: Python
import numpy as np
import xgboost as xgb
Expand All @@ -105,10 +106,29 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a
y_upper_bound = np.array([ 2.0, +np.inf, 4.0, 5.0])
dtrain.set_float_info('label_lower_bound', y_lower_bound)
dtrain.set_float_info('label_upper_bound', y_upper_bound)
.. code-block:: r
:caption: R
library(xgboost)
# 4-by-2 Data matrix
X <- matrix(c(1., -1., -1., 1., 0., 1., 1., 0.),
nrow=4, ncol=2, byrow=TRUE)
dtrain <- xgb.DMatrix(X)
# Associate ranged labels with the data matrix.
# This example shows each kind of censored labels.
# uncensored right left interval
y_lower_bound <- c( 2., 3., -Inf, 4.)
y_upper_bound <- c( 2., +Inf, 4., 5.)
setinfo(dtrain, 'label_lower_bound', y_lower_bound)
setinfo(dtrain, 'label_upper_bound', y_upper_bound)
Now we are ready to invoke the training API:

.. code-block:: python
:caption: Python
params = {'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
Expand All @@ -118,6 +138,19 @@ Now we are ready to invoke the training API:
bst = xgb.train(params, dtrain, num_boost_round=5,
evals=[(dtrain, 'train'), (dvalid, 'valid')])
.. code-block:: r
:caption: R
params <- list(objective='survival:aft',
eval_metric='aft-nloglik',
aft_loss_distribution='normal',
aft_loss_distribution_scale=1.20,
tree_method='hist',
learning_rate=0.05,
max_depth=2)
watchlist <- list(train = dtrain)
bst <- xgb.train(params, dtrain, nrounds=5, watchlist)
We set ``objective`` parameter to ``survival:aft`` and ``eval_metric`` to ``aft-nloglik``, so that the log likelihood for the AFT model would be maximized. (XGBoost will actually minimize the negative log likelihood, hence the name ``aft-nloglik``.)

The parameter ``aft_loss_distribution`` corresponds to the distribution of the :math:`Z` term in the AFT model, and ``aft_loss_distribution_scale`` corresponds to the scaling factor :math:`\sigma`.
Expand Down

0 comments on commit 30e94dd

Please sign in to comment.