Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Accelerated Failure Time loss for survival analysis task #4763

Merged
merged 84 commits into from
Mar 25, 2020

Conversation

avinashbarnwal
Copy link
Contributor

@avinashbarnwal avinashbarnwal commented Aug 12, 2019

Hi,

Please find the Accelerated Failure time loss for Survival Modeling.

Survival analysis is a "censored regression" where the goal is to learn time-to-event function. This is similar to the common regression analysis where data-points are uncensored. Time-to-event modeling is critical for understanding users/companies behaviors not limited to credit, cancer, and attrition risks.

Supports

  • 4 kinds of datasets - Left, Right, Interval Censored and Uncensored.
  • Normal, Logistic and Extreme Distributions for underlying error distribution.

This project is part of the Google Summer of Code - 2019. AFT-Xgboost

Compact summary of AFT loss formula
Cap 2020-03-14 15-46-55-981
Cap 2020-03-14 15-47-43-993

Relevant Documents

Example in Python to run -


res    = {}
dtrain = xgboost.DMatrix(X)
dtrain.set_float_info("label_lower_bound",y_lower)
dtrain.set_float_info("label_upper_bound",y_higher)

dtest  = xgboost.DMatrix(X_val)
dtest.set_float_info("label_lower_bound",y_lower_val)
dtest.set_float_info("label_upper_bound",y_higher_val)

params = {'learning_rate':0.1, 'aft_loss_distribution' : 'normal', 'aft_loss_distribution_scale': 1.0,'eval_metric':'aft-nloglik','objective':"survival:aft"}

bst    = xgboost.train(params,dtrain,num_boost_round=100,evals=[(dtrain,"train"),(dtest,"test")],evals_result=res)

For more details - avinashbarnwal#1.

Note: as part of this PR, the Metric class became a subclass of the Configurable interface.

src/common/survival_util.cc Outdated Show resolved Hide resolved
src/common/survival_util.cc Outdated Show resolved Hide resolved
@hcho3 hcho3 changed the title Survival analysis1 Add Accelerated Failure Time loss for survival analysis task Aug 12, 2019
@hcho3 hcho3 self-assigned this Aug 12, 2019
@hcho3 hcho3 requested review from trivialfis and hcho3 and removed request for trivialfis August 12, 2019 18:38
src/common/survival_util.cc Outdated Show resolved Hide resolved
@hcho3
Copy link
Collaborator

hcho3 commented Aug 12, 2019

@avinashbarnwal In the PR description, can you add a short one-paragraph description of what survival analysis is? Something like:

survival analysis is a new kind of learning task where we would like to predict a time to certain event. The time-to-event labels are often censored, i.e. we only know which intervals the label falls in and do not know its exact value. See https://eng.uber.com/modeling-censored-time-to-event-data-using-pyro/ for a real-world example.

src/common/survival_util.cc Outdated Show resolved Hide resolved
@tdhock
Copy link

tdhock commented Aug 12, 2019

survival analysis is a new kind of learning task where we would like to predict a time to certain event

I would describe it as "censored regression" or more specifically "regression with censored outputs" because the goal is still to learn a (real-valued) regression function; this emphasizes the similarity with usual regression, where all outputs are un-censored.

@hcho3
Copy link
Collaborator

hcho3 commented Aug 12, 2019

Also add a short Python example to the description:

dtrain = xgboost.DMatrix(X)
dtrain.set_float_info("label_lower_bound", y_lower)
dtrain.set_float_info("label_upper_bound", y_higher)
    
dtest = xgboost.DMatrix(X_test)
dtest.set_float_info("label_lower_bound", y_lower_test)
dtest.set_float_info("label_upper_bound", y_higher_test)
    
bst = xgboost.train(params, dtrain, num_boost_round=100,
                    evals=[(dtrain,"train"), (dtest,"test")])

@hcho3
Copy link
Collaborator

hcho3 commented Aug 12, 2019

@tdhock Thanks for your suggestion. Yes, "censored regression" sounds reasonable.

src/common/survival_util.cc Outdated Show resolved Hide resolved
@trivialfis
Copy link
Member

trivialfis commented Aug 14, 2019

I'm not familiar with survival models, just skimmed through the survey. Are there other recommended materials concentrating on theoretical part? ;-)

@avinashbarnwal
Copy link
Contributor Author

Hi @trivialfis,

Please find the good lecture notes for learning survival modeling - https://www4.stat.ncsu.edu/~dzhang2/st745/index.html.

One of the motivating books- https://www.amazon.com/Applied-Survival-Analysis-Time-Event/dp/0471754994.

Prof. @tdhock and @hcho3 might give a better reference for understanding theoretical survival modeling.

@tdhock
Copy link

tdhock commented Aug 14, 2019

would be good if @avinashbarnwal could write a latex/PDF vignette in the xgboost R pkg describing the loss functions that he implemented

@tdhock
Copy link

tdhock commented Aug 14, 2019

they are the same as in R's survival::survreg, there are some docs on that man page, but the math formulas come from http://members.cbio.mines-paristech.fr/~thocking/survival.pdf

@trivialfis
Copy link
Member

@avinashbarnwal @tdhock Thanks for the good references. Will try to catch up.

@avinashbarnwal
Copy link
Contributor Author

Hi Prof. @tdhock and @hcho3,

I will start writing loss functions in latex/PDF vignette for the xgboost R pkg.

@avinashbarnwal
Copy link
Contributor Author

Hi Prof. @tdhock,

Please let me know if it is fine to make the vignette-like this https://cran.r-project.org/web/packages/xgboost/vignettes/xgboostfromJSON.html.

@tdhock
Copy link

tdhock commented Aug 15, 2019

typically for vignettes with lots of math I prefer writing Rnw source which is rendered to tex / pdf. It is possible to include simple math in Rmd which is rendered on a web page using mathjax, but in my experience complex equations (e.g. optimization problems) do not render well on web pages.

Examples of both are here: https://github.com/tdhock/PeakSegDisk/tree/master/vignettes

include/xgboost/data.h Outdated Show resolved Hide resolved
@avinashbarnwal
Copy link
Contributor Author

Hi Prof. @tdhock and @hcho3,

Please find R-vignette below and let me know your thoughts.
http://rpubs.com/avinashbarnwal123/aft

@hcho3
Copy link
Collaborator

hcho3 commented Aug 26, 2019

@tdhock Do the datasets follow log-normal AFT distribution? The errors are not decreasing when we choose log-logistic and log-weibull. See http://rpubs.com/avinashbarnwal123/aft

@tdhock
Copy link

tdhock commented Aug 26, 2019 via email

@avinashbarnwal
Copy link
Contributor Author

avinashbarnwal commented Aug 26, 2019

Hi Prof. @tdhock and @hcho3,

I have updated the vignette - http://rpubs.com/avinashbarnwal123/aft. It works for the last dataset -
H3K36me3_AM_immune. Please check last fold. This might be not clear because of the scale. It works for both Logistic and Extreme. I think we need datasets like that where it works.

@avinashbarnwal
Copy link
Contributor Author

@avinashbarnwal I fixed the bug. See commit e33fab1. Also note that you don't need to add @10,normal suffix to the metric name aft-nloglik.

Thanks. I will change the code accordingly for our paper.

@hcho3
Copy link
Collaborator

hcho3 commented Mar 21, 2020

@trivialfis I added a demo, as you requested. A tutorial is available. Feel free to try it out.

@hcho3
Copy link
Collaborator

hcho3 commented Mar 21, 2020

Rendered output of the tutorial:
Cap 2020-03-20 17-31-49-229

{ 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f,
0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f });
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme",
{ -0.0000f, -29.0026f, -17.0031f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f,
Copy link
Collaborator

@hcho3 hcho3 Mar 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avinashbarnwal FYI, I applied the regularization scheme to the uncensored case as well, and now I'm getting a zero gradient here, where previously we'd get something like -50.0. I'm still looking at ways to avoid INF and NAN (in general) without strange behavior like this. For this example, clamping the gradient to a reasonable quantity like -30.0 would be a lot better than giving 0.0. I'll come back to this soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Copy link
Collaborator

@hcho3 hcho3 Mar 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve merged this PR for now. I’ll file a follow-up PR to make AFT more robust in edge cases like this.

@hcho3 hcho3 removed the status: WIP label Mar 25, 2020
@hcho3
Copy link
Collaborator

hcho3 commented Mar 25, 2020

I added a toy example to visualize how XGBoost responds to censored labels:

aft_viz_demo

Copy link
Member

@trivialfis trivialfis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! This is exciting.

@hcho3 hcho3 merged commit dcf4399 into dmlc:master Mar 25, 2020
@hcho3
Copy link
Collaborator

hcho3 commented Mar 25, 2020

Merged. Thanks everyone!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants