Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add p_max parameter for cindex measure #384

Merged
merged 5 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.1
Version: 0.6.2
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# mlr3proba 0.6.2

* Updates in `surv.cindex` measure
* added `p_max` (same as `surv.graf`)
* refactor `cutoff` to `t_max`

# mlr3proba 0.6.1

* Compatibility with upcoming 'paradox' release.
Expand Down
77 changes: 62 additions & 15 deletions R/MeasureSurvCindex.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@
#' @template param_eps
#'
#' @description
#' Calculates weighted concordance statistics, which, depending on the chosen weighting method
#' and tied times solution, are equivalent to several proposed methods.
#' Calculates weighted concordance statistics, which, depending on the chosen
#' weighting method (`weight_meth`) and tied times parameter (`tiex`), are
#' equivalent to several proposed methods.
#' By default, no weighting is applied and this is equivalent to Harrell's C-index.
#'
#' @details
#' For the Kaplan-Meier estimate of the training survival distribution, `S`, and
#' the Kaplan-Meier estimate of the training censoring distribution, `G`:
#' For the Kaplan-Meier estimate of the **training survival** distribution (\eqn{S}),
#' and the Kaplan-Meier estimate of the **training censoring** distribution (\eqn{G}),
#' we have the following options for time-independent concordance statistics
#' (C-indexes) given the weighted method:
#'
#' `weight_meth`:
#'
#' - `"I"` = No weighting. (Harrell)
#' - `"GH"` = Gonen and Heller's Concordance Index
#' - `"G"` = Weights concordance by G^-1.
#' - `"G2"` = Weights concordance by G^-2. (Uno et al.)
#' - `"SG"` = Weights concordance by S/G (Shemper et al.)
#' - `"S"` = Weights concordance by S (Peto and Peto)
#' - `"G"` = Weights concordance by \eqn{1/G}.
#' - `"G2"` = Weights concordance by \eqn{1/G^2}. (Uno et al.)
#' - `"SG"` = Weights concordance by \eqn{S/G} (Shemper et al.)
#' - `"S"` = Weights concordance by \eqn{S} (Peto and Peto)
#'
#' The last three require training data. `"GH"` is only applicable to [LearnerSurvCoxPH].
#'
Expand All @@ -29,8 +33,11 @@
#' computed on the same testing data.
#'
#' @section Parameter details:
#' - `cutoff` (`numeric(1)`)\cr
#' Cut-off time to evaluate concordance up to.
#' - `t_max` (`numeric(1)`)\cr
#' Cutoff time (i.e. time horizon) to evaluate concordance up to.
#' - `p_max` (`numeric(1)`)\cr
#' The proportion of censoring to evaluate concordance up to in the given dataset.
#' When `t_max` is specified, this parameter is ignored.
#' - `weight_meth` (`character(1)`)\cr
#' Method for weighting concordance. Default `"I"` is Harrell's C. See details.
#' - `tiex` (`numeric(1)`)\cr
Expand All @@ -44,14 +51,36 @@
#' @template param_packages
#' @template param_predict_type
#' @template param_measure_properties
#'
#' @examples
#' library(mlr3)
#' task = tsk("rats")
#' learner = lrn("surv.coxph")
#' part = partition(task) # train/test split, stratified on `status` by default
#' learner$train(task, part$train)
#' p = learner$predict(task, part$test)
#'
#' # Harrell's C-index
#' p$score(msr("surv.cindex")) # same as `p$score()`
#'
#' # Uno's C-index
#' p$score(msr("surv.cindex", weight_meth = "G2"),
#' task = task, train_set = part$train)
#'
#' # Harrell's C-index evaluated up to a specific time horizon
#' p$score(msr("surv.cindex", t_max = 97))
#' # Harrell's C-index evaluated up to the time corresponding to 30% of censoring
#' p$score(msr("surv.cindex", p_max = 0.3))
#'
#' @export
MeasureSurvCindex = R6Class("MeasureSurvCindex",
inherit = MeasureSurv,
public = list(
#' @description This is an abstract class that should not be constructed directly.
initialize = function() {
ps = ps(
cutoff = p_dbl(),
t_max = p_dbl(lower = 0),
p_max = p_dbl(0, 1),
weight_meth = p_fct(levels = c("I", "G", "G2", "SG", "S", "GH"), default = "I"),
tiex = p_dbl(0, 1, default = 0.5),
eps = p_dbl(0, 1, default = 1e-3)
Expand All @@ -77,16 +106,34 @@ MeasureSurvCindex = R6Class("MeasureSurvCindex",
private = list(
.score = function(prediction, task, train_set, ...) {
ps = self$param_set$values

# calculate t_max (cutoff time horizon)
if (is.null(ps$t_max) && !is.null(ps$p_max)) {
truth = prediction$truth
unique_times = unique(sort(truth[,"time"]))
surv = survival::survfit(truth ~ 1)
indx = which(1 - (surv$n.risk / surv$n) > ps$p_max)
if (length(indx) == 0) {
t_max = NULL # t_max calculated in `cindex()`
} else {
# first time point that surpasses the specified
# `p_max` proportion of censoring
t_max = surv$time[indx[1]]
}
} else {
t_max = ps$t_max
}

if (ps$weight_meth == "GH") {
return(gonen(prediction$crank, ps$tiex))
} else if (ps$weight_meth == "I") {
return(cindex(prediction$truth, prediction$crank, ps$cutoff, ps$weight_meth, ps$tiex))
return(cindex(prediction$truth, prediction$crank, t_max, ps$weight_meth, ps$tiex))
} else {
if (is.null(task) | is.null(train_set)) {
stop("'task' and 'train_set' required for all weighted c-index (except GH).")
stop("'task' and 'train_set' required for all weighted C-indexes (except GH).")
}
return(cindex(prediction$truth, prediction$crank, ps$cutoff, ps$weight_meth,
ps$tiex, task$truth(train_set), ps$eps))
return(cindex(prediction$truth, prediction$crank, t_max, ps$weight_meth,
ps$tiex, task$truth(train_set), ps$eps))
}
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ c_score_graf_schmid <- function(truth, unique_times, cdf, power = 2L) {
.Call(`_mlr3proba_c_weight_survival_score`, score, truth, unique_times, cens, proper, eps)
}

c_concordance <- function(time, status, crank, cutoff, weight_meth, cens, surv, tiex) {
.Call(`_mlr3proba_c_concordance`, time, status, crank, cutoff, weight_meth, cens, surv, tiex)
c_concordance <- function(time, status, crank, t_max, weight_meth, cens, surv, tiex) {
.Call(`_mlr3proba_c_concordance`, time, status, crank, t_max, weight_meth, cens, surv, tiex)
}

c_gonen <- function(crank, tiex) {
Expand Down
8 changes: 4 additions & 4 deletions R/cindex.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cindex = function(truth, crank, cutoff = NULL,
cindex = function(truth, crank, t_max = NULL,
weight_meth = c("I", "G", "G2", "SG", "S"),
tiex = 0.5, train = NULL, eps = 1e-3) {

Expand Down Expand Up @@ -32,12 +32,12 @@ cindex = function(truth, crank, cutoff = NULL,
surv = matrix(ncol = 2)
}

if (is.null(cutoff)) {
cutoff = max(time) + 1
if (is.null(t_max)) {
t_max = max(time) + 1
}

cens[cens[, 2] == 0, 2] = eps
surv[surv[, 2] == 0, 2] = eps

c_concordance(time, status, crank[ord], cutoff, weight_meth, cens, surv, tiex)
c_concordance(time, status, crank[ord], t_max, weight_meth, cens, surv, tiex)
}
1 change: 0 additions & 1 deletion R/integrated_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ score_graf_schmid = function(true_times, unique_times, cdf, power = 2) {
c_score_graf_schmid(true_times, unique_times, cdf, power)
}


weighted_survival_score = function(loss, truth, distribution, times = NULL,
t_max = NULL, p_max = NULL, proper, train = NULL, eps, ...) {
assert_surv(truth)
Expand Down
51 changes: 40 additions & 11 deletions man/mlr_measures_surv.cindex.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,20 @@ BEGIN_RCPP
END_RCPP
}
// c_concordance
float c_concordance(NumericVector time, NumericVector status, NumericVector crank, double cutoff, std::string weight_meth, NumericMatrix cens, NumericMatrix surv, float tiex);
RcppExport SEXP _mlr3proba_c_concordance(SEXP timeSEXP, SEXP statusSEXP, SEXP crankSEXP, SEXP cutoffSEXP, SEXP weight_methSEXP, SEXP censSEXP, SEXP survSEXP, SEXP tiexSEXP) {
float c_concordance(NumericVector time, NumericVector status, NumericVector crank, double t_max, std::string weight_meth, NumericMatrix cens, NumericMatrix surv, float tiex);
RcppExport SEXP _mlr3proba_c_concordance(SEXP timeSEXP, SEXP statusSEXP, SEXP crankSEXP, SEXP t_maxSEXP, SEXP weight_methSEXP, SEXP censSEXP, SEXP survSEXP, SEXP tiexSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type time(timeSEXP);
Rcpp::traits::input_parameter< NumericVector >::type status(statusSEXP);
Rcpp::traits::input_parameter< NumericVector >::type crank(crankSEXP);
Rcpp::traits::input_parameter< double >::type cutoff(cutoffSEXP);
Rcpp::traits::input_parameter< double >::type t_max(t_maxSEXP);
Rcpp::traits::input_parameter< std::string >::type weight_meth(weight_methSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cens(censSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type surv(survSEXP);
Rcpp::traits::input_parameter< float >::type tiex(tiexSEXP);
rcpp_result_gen = Rcpp::wrap(c_concordance(time, status, crank, cutoff, weight_meth, cens, surv, tiex));
rcpp_result_gen = Rcpp::wrap(c_concordance(time, status, crank, t_max, weight_meth, cens, surv, tiex));
return rcpp_result_gen;
END_RCPP
}
Expand Down
4 changes: 2 additions & 2 deletions src/survival_scores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ NumericMatrix c_weight_survival_score(NumericMatrix score, NumericMatrix truth,

// [[Rcpp::export]]
float c_concordance(NumericVector time, NumericVector status, NumericVector crank,
double cutoff, std::string weight_meth, NumericMatrix cens,
double t_max, std::string weight_meth, NumericMatrix cens,
NumericMatrix surv, float tiex) {
double num = 0;
double den = 0;
Expand Down Expand Up @@ -178,7 +178,7 @@ float c_concordance(NumericVector time, NumericVector status, NumericVector cran
weight = -1;
if(status[i] == 1) {
for (int j = i + 1; j < time.length(); j++) {
if (time[i] < time[j] && time[i] < cutoff) {
if (time[i] < time[j] && time[i] < t_max) {
if (weight == -1) {
if (weight_meth == "I") {
weight = 1;
Expand Down
16 changes: 12 additions & 4 deletions tests/testthat/test_mlr_measures.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,24 @@ test_that("t_max, p_max", {
m2 = p$score(msr("surv.graf", t_max = 100))
expect_equal(m1, m2)

s = t$kaplan()

t_max = s$time[which(1 - s$n.risk / s$n > 0.3)[1]]
s = t$kaplan() # KM
t_max = s$time[which(1 - s$n.risk / s$n > 0.3)[1]] # t_max for up to 30% cens

# graf score: t_max and p_max are the same
m1 = p$score(msr("surv.graf", t_max = t_max))
m2 = p$score(msr("surv.graf", p_max = 0.3))
m3 = p$score(msr("surv.graf", p_max = 0.5))
expect_equal(m1, m2)
expect_true(m1 != m3)

p_cox = suppressWarnings(lrn("surv.coxph")$train(t)$predict(t))
c1 = p_cox$score(msr("surv.cindex", t_max = t_max))
c2 = p_cox$score(msr("surv.cindex", p_max = 0.3))
c3 = p_cox$score(msr("surv.cindex", p_max = 0.5))
expect_equal(c1, c2)
expect_true(c1 != c3)
})


test_that("ERV works as expected", {
set.seed(1)
t = tsk("rats")$filter(sample(1:300, 50))
Expand Down