Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting survival probabilities for specific timepoints and generating survival curves #273

Closed
bblodfon opened this issue Apr 7, 2022 · 8 comments

Comments

@bblodfon
Copy link
Collaborator

bblodfon commented Apr 7, 2022

Hi Raphael,

Could you have a look at the code below? I looked at some examples from the mlr3 book and the documentation and I am sure you have implemented what I want to get. I just wanna make sure what I am doing is correct and if there is a better way to do some of these things (tips are welcome :) I may have found some potential issues along the way, so here it is:

library(mlr3verse)
#> Loading required package: mlr3

task = tsk('rats')
task$select(c("litter", "rx")) # leave out factor column `sex`
task_test = task$clone(deep = TRUE)
task$filter(1:295)
task_test$filter(296:300) # I did this task split because I can't specify
# `row_ids` in the pipe operator below when training and predicting, but let me
# know if you know of an easier way

# `glmnet` does not compute `distr` on its own so we use the `distrcompositor`
glmnet_pipe = ppl(
  "distrcompositor",
  learner = lrn("surv.glmnet"),
  estimator = "kaplan",
  form = "aft"
)

# Example from the book (https://mlr3book.mlr-org.com/special-tasks.html#composition)
# doesn't work => `graph_learner = TRUE` is the cause (maybe not a bug)
glm.distr = ppl("distrcompositor", learner = lrn("surv.glmnet"),
  estimator = "kaplan", form = "ph", overwrite = FALSE, graph_learner = TRUE)
#> Error: Learner 'surv.glmnet' does not support predict type 'distr'

# can't do this!
glmnet_pipe$train(task)$predict(task_test)
#> Error in eval(expr, envir, enclos): attempt to apply non-function

# so we split it
glmnet_pipe$train(task) # output may seem worrying...
#> $compose_distr.output
#> NULL
pred = glmnet_pipe$predict(task_test)
#> Warning: Multiple lambdas have been fit. Lambda will be set to 0.01 (see parameter 's').
#> This happened PipeOp surv.glmnet's $predict()

# There is no documentation for this I think, but the following matrix has
# for every unique time point in the training dataset (columns), the survival
# probability for each of the test rats (rows), right?
pred$compose_distr.output$data$distr[,1:16]
#>      23 32 34 39 40 45 49 50 51 53 54 55       61       62       63        64
#> [1,]  1  1  1  1  1  1  1  1  1  1  1  1 0.996587 0.996587 0.996587 0.9931741
#> [2,]  1  1  1  1  1  1  1  1  1  1  1  1 0.996587 0.996587 0.996587 0.9931741
#> [3,]  1  1  1  1  1  1  1  1  1  1  1  1 1.000000 1.000000 1.000000 1.0000000
#> [4,]  1  1  1  1  1  1  1  1  1  1  1  1 0.996587 0.996587 0.996587 0.9931741
#> [5,]  1  1  1  1  1  1  1  1  1  1  1  1 0.996587 0.996587 0.996587 0.9931741
all(task$unique_times() == colnames(pred$compose_distr.output$data$distr))
#> [1] TRUE

# So, if I am correct, the above result is only for the times in the train
# dataset, not in between or later timepoints that I might want. So, let's
# work with the following:
pred$compose_distr.output$distr
#> Matdist()

# Get the mean survival probabilities for each test rat? (a (5x1) result)
pred$compose_distr.output$distr$mean() # Inf?!
#> [1] Inf Inf Inf Inf Inf

# specific timepoints for which I want survival probability for the test rats
times = c(1,10,42,70,120)
pred$compose_distr.output$distr$mean(times) # Inf!? I was expecting a (5x5) result?
#> [1] Inf Inf Inf Inf Inf
pred$compose_distr.output$distr$median() # NA!?
#> [1] NA NA NA NA NA
pred$compose_distr.output$distr$cumHazard(times) # seems okay
#>      1 10 42         70         120
#> [1,] 0  0  0 0.01030358 0.031693993
#> [2,] 0  0  0 0.01030358 0.031693993
#> [3,] 0  0  0 0.00000000 0.003418807
#> [4,] 0  0  0 0.01030358 0.027969594
#> [5,] 0  0  0 0.01030358 0.027969594

# Are the following the survival probabilities at the specific timepoints?
pred$compose_distr.output$distr$survival(times) # seems okay
#>      1 10 42        70       120
#> [1,] 1  1  1 0.9897493 0.9688030
#> [2,] 1  1  1 0.9897493 0.9688030
#> [3,] 1  1  1 1.0000000 0.9965870
#> [4,] 1  1  1 0.9897493 0.9724179
#> [5,] 1  1  1 0.9897493 0.9724179

?Matdist # doesn't include any documentation on survival* and *hazard* methods! (and thus my questions)

##########################################################
# Survival curves for the predicted `task_test` (5 rats) #
##########################################################

learner = glmnet_pipe$pipeops$surv.glmnet$learner_model
class(learner)
#> [1] "LearnerSurvGlmnet" "LearnerSurv"       "Learner"          
#> [4] "R6"
# doesn't work?
plot(x = learner, task, fun = "survival", newdata = task_test$data(), ylim = c(0, 1), xlim = c(0,120))
#> Warning: Multiple lambdas have been fit. Lambda will be set to 0.01 (see
#> parameter 's').
#> Error in UseMethod("as.Distribution"): no applicable method for 'as.Distribution' applied to an object of class "c('double', 'numeric')"

# This seem to work but why it goes only until time =~ 50?
plot(pred$compose_distr.output$distr, fun = "survival")

# let's try another learner
my_learner = lrn("surv.coxph")
my_learner$train(task)
# `ind` is not used anymore?
plot(my_learner, task, fun = "survival", ind = 10)
#> Warning in plot.window(...): "ind" is not a graphical parameter
#> Warning in plot.xy(xy, type, ...): "ind" is not a graphical parameter
#> Warning in axis(side = side, at = at, labels = labels, ...): "ind" is not a
#> graphical parameter

#> Warning in axis(side = side, at = at, labels = labels, ...): "ind" is not a
#> graphical parameter
#> Warning in box(...): "ind" is not a graphical parameter
#> Warning in title(...): "ind" is not a graphical parameter

plot(my_learner, task, fun = "survival", newdata = task_test$data(),
  xlim = c(0,80), ylim = c(0, 1)) # again only until ~50 but it seems to work

Created on 2022-04-07 by the reprex package (v2.0.1)

For the survival curves, as you mentioned in #253 , it would be really cool if we could pipe the distr prediction output to something ggplot2-compliant.

@RaphaelS1
Copy link
Collaborator

Hi I appreciate there a lot of questions above and I'm happy to answer them all but currently it's too difficult to read. Can you please list all your questions one-by-one and then I can more clearly answer them?

@bblodfon
Copy link
Collaborator Author

bblodfon commented Apr 8, 2022

Of course, the question became too large in the end! Let me summarize the questions/issues based on the code above so that is easier for you and others to see them:

  1. Can we use row_ids to train and predict different rows from the same task when the learner is part of a pipeline, like glmnet_pipe is? Would you did somewhat differently than cloning the task to task_test and selecting the rows manually?
  2. Example from survival section in mlr3 book doesn't work (glm.distr = ppl(... ,graph_learner = TRUE)?
  3. Why this doesn't work glmnet_pipe$train(task)$predict(task_test) and I have to split the train and predict calls?
  4. Is the pred$compose_distr.output$data$distr the matrix of predicted survival probabilities (columns as the unique timepoints in the train dataset, rows as the test samples)?
  5. How to get the mean probability of survival for every test sample? pred$compose_distr.output$distr$mean() outputs Inf? (but it might be not the correct way to do this, same applies to $median())
  6. The main question :) => How to get for specific timepoints (even outside the range of the training timepoints) the survival probability of the test samples? I tried pred$compose_distr.output$distr$survival(times), is that it?
  7. Consider adding documentation on the methods of Matdist class that match the patterns survival* and *hazard* - I could autocomplete them when typing $ after pred$compose_distr.output$distr so I know they are implemented, just don't know if they are exactly what I need!
  8. ?plot.LearnerSurv accepts a trained LearnerSurv and learner = glmnet_pipe$pipeops$surv.glmnet$learner_model is such an object, but plot(learner, ...) fails with Error in UseMethod("as.Distribution"): no applicable method for 'as.Distribution' applied to an object of class "c('double', 'numeric')"? With a surv.coxph learner it works fine.
  9. Why only up to time=~50 is plotted on the survival surves of the 5 test samples with the following command? plot(pred$compose_distr.output$distr, fun = "survival")
  10. Parameter ind doesn't seem to work anymore? plot(my_learner, task, fun = "survival", ind = 10) (from examples in ?plot.LearnerSurv)

@RaphaelS1
Copy link
Collaborator

Can we use row_ids to train and predict different rows from the same task when the learner is part of a pipeline, like glmnet_pipe is? Would you did somewhat differently than cloning the task to task_test and selecting the rows manually?

Yup row_ids should work.

Example from survival section in mlr3 book doesn't work (glm.distr = ppl(... ,graph_learner = TRUE)?

Thanks will fix this.

Why this doesn't work glmnet_pipe$train(task)$predict(task_test) and I have to split the train and predict calls?

Because you've constructed a list I believe, you can fix this by wrapping first with as_learner(glmnet_pipe)$train...

Is the pred$compose_distr.output$data$distr the matrix of predicted survival probabilities (columns as the unique timepoints in the train dataset, rows as the test samples)?

Yes but I've just pushed a change that makes this easier to use (remotes::install_github('mlr-org/mlr3proba')). So I wouldn't recommend working with data$distr dirctly but instead $distr. Then you can use functions like $survival to compute the survival probabilities, or if you then want to extract the matrix you could use $distr$getParameterValue("cdf")

How to get the mean probability of survival for every test sample? pred$compose_distr.output$distr$mean() outputs Inf? (but it might be not the correct way to do this, same applies to $median())

Inf is correct. The prediction is improper, i.e. the final value in your prediction does not equal 0 (or even come close to it)

The main question :) => How to get for specific timepoints (even outside the range of the training timepoints) the survival probability of the test samples? I tried pred$compose_distr.output$distr$survival(times), is that it?

Yes that should work, even outside the test samples.

Consider adding documentation on the methods of Matdist class that match the patterns survival* and hazard - I could autocomplete them when typing $ after pred$compose_distr.output$distr so I know they are implemented, just don't know if they are exactly what I need!

Yeah so this was a design choice in distr6 so that all these methods are implemented in a separate decorator class, which means documentation doesn't show up. The documentation is identical to cdf, pdf, etc. Autocomplete should work though...

?plot.LearnerSurv accepts a trained LearnerSurv and learner = glmnet_pipe$pipeops$surv.glmnet$learner_model is such an object, but plot(learner, ...) fails with Error in UseMethod("as.Distribution"): no applicable method for 'as.Distribution' applied to an object of class "c('double', 'numeric')"? With a surv.coxph learner it works fine.

Thanks will look into this properly.

Why only up to time=~50 is plotted on the survival surves of the 5 test samples with the following command? plot(pred$compose_distr.output$distr, fun = "survival")

Looks like a plotting bug, so everything is being plotted but the x-axis labels are wrong. I can fix this.

Parameter ind doesn't seem to work anymore? plot(my_learner, task, fun = "survival", ind = 10) (from examples in ?plot.LearnerSurv)

Thanks, this is because of an update I made to the package so I need to fix that in the documentation. Basically it now plots much more neatly with matplot. So now I'll just update the code so you pass in row_ids which will have a similar effect to ind

@bblodfon
Copy link
Collaborator Author

Thanks so much for the answers! I just provide a summary for everyone interested:

  1. Wrapping the pipeline as_learner is very useful (can subset by row_ids, can train and predict in one line and the resulting predictions are more easily accesible):
library(mlr3verse)
#> Loading required package: mlr3
library(dplyr, warn.conflicts = FALSE)

task = tsk('rats')
task$select(c("litter", "rx")) # leave out factor column `sex`

glmnet_lrn = ppl(
  "distrcompositor",
  learner = lrn("surv.glmnet"),
  estimator = "kaplan",
  form = "aft"
) %>% as_learner()

# Now we can do the following:
pred = glmnet_lrn$train(task, row_ids = 1:295)$predict(task, row_ids = 296:300)
#> Warning: Multiple lambdas have been fit. Lambda will be set to 0.01 (see parameter 's').
#> This happened PipeOp surv.glmnet's $predict()
pred
#> <PredictionSurv> for 5 observations:
#>  row_ids time status   crank.1      lp.1     distr
#>      296  104  FALSE 0.4837924 0.4837924 <list[1]>
#>      297   79   TRUE 0.4837924 0.4837924 <list[1]>
#>      298   92  FALSE 1.0932839 1.0932839 <list[1]>
#>      299  104  FALSE 0.4886792 0.4886792 <list[1]>
#>      300  102  FALSE 0.4886792 0.4886792 <list[1]>

Created on 2022-04-11 by the reprex package (v2.0.1)

  1. The following two are the same (matrix of survival probabilities for test set on the train set's timepoints) so just use the pred$distr result:
all(1 - pred$distr$getParameterValue('cdf') == pred$data$distr)
  1. pred$distr is decorated with ExoticStatistics :) which implements methods such as survival, cumHazard, etc. See documentation:
?distr6::ExoticStatistics
  1. Getting survival probabilities or the cumulative hazard on specific timepoints is as easy as:
times = c(1,10,100,150)
pred$distr$survival(times)
pred$distr$cumHazard(times)

@RaphaelS1
Copy link
Collaborator

@bblodfon I am considering removing the decorators from distr6 and just including the survival and cumHazard (and other) functions directly in the distribution, would you have found this easier to work with?

@bblodfon
Copy link
Collaborator Author

I don't think that it practically makes any difference since with the dollar sign ($) the methods are accessible. What was difficult was to find the documentation for these and see if they are what I thought they were.

A first step is now done with this issue I think. Maybe it would be a nice idea to add an example on the documentation of distrcompositor so that users know that the returned distribution is decorated with ExoticStatistics and they can search there for the extra methods, that would be super helpful.

@RaphaelS1
Copy link
Collaborator

Great idea, will do

@bblodfon
Copy link
Collaborator Author

Hi @RaphaelS1,

I quickly re-checked this issue since you closed it - there were some things that I think still require your attention (or maybe you already solved them?) - up to you of course, just wanted to let you know: see this list, numbers 2 (mlr book code issue) and 8-10 (plotting stuff).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants