From c953f32199d26be74e36973a9a14bf1dd9b13d4c Mon Sep 17 00:00:00 2001 From: RaphaelS1 Date: Wed, 1 Nov 2023 13:53:15 +0000 Subject: [PATCH 01/17] fix measure bottlenecks --- DESCRIPTION | 4 ++-- NEWS.md | 4 ++++ R/MeasureSurvDCalibration.R | 16 +++++++++++++--- R/MeasureSurvRCLL.R | 37 +++++++++++++++++++++++++++---------- 4 files changed, 46 insertions(+), 15 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 3fa5a5b71..dbd8d014c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mlr3proba Title: Probabilistic Supervised Learning for 'mlr3' -Version: 0.5.3 +Version: 0.5.4 Authors@R: c(person(given = "Raphael", family = "Sonabend", @@ -43,7 +43,7 @@ Depends: Imports: checkmate, data.table, - distr6 (>= 1.8.3), + distr6 (>= 1.8.4), ggplot2, mlr3misc (>= 0.7.0), mlr3viz, diff --git a/NEWS.md b/NEWS.md index 245a29e93..fccf182aa 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# mlr3proba 0.5.4 + +* Fix bottlenecks in Dcalib and RCLL + # mlr3proba 0.5.3 * Add support for learners that can predict multiple posterior distributions by using `distr6::Arrdist` diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index 954548e29..2fab4ef59 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -64,11 +64,21 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", ps = self$param_set$values # initialize buckets bj = numeric(ps$B) + true_times = prediction$truth[, 1L] # predict individual probability of death at observed event time - if (inherits(prediction$distr, "VectorDistribution")) { - si = as.numeric(prediction$distr$survival(data = matrix(prediction$truth[, 1L], nrow = 1L))) + # bypass distr6 construction if possible + if (inherits(prediction$data$distr, "array")) { + si = diag(t(distr6:::C_Vec_WeightedDiscreteCdf(true_times, + as.numeric(colnames(prediction$data$distr)), + t(1 - prediction$data$distr), FALSE, FALSE + ))) } else { - si = diag(prediction$distr$survival(prediction$truth[, 1L])) + distr = prediction$distr + if (inherits(distr, "VectorDistribution")) { + si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L))) + } else { + si = diag(distr$survival(true_times)) + } } # remove zeros si = map_dbl(si, function(.x) max(.x, 1e-5)) diff --git a/R/MeasureSurvRCLL.R b/R/MeasureSurvRCLL.R index b78225317..b7014c1b2 100644 --- a/R/MeasureSurvRCLL.R +++ b/R/MeasureSurvRCLL.R @@ -73,17 +73,34 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL", event = truth[, 2] == 1 event_times = truth[event, 1] cens_times = truth[!event, 1] - distr = prediction$distr - if (!any(event)) { # all censored - # survival at outcome time (survived *at least* this long) - out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) - } else if (all(event)) { # all uncensored - # pdf at outcome time (survived *this* long) - out[event] = diag(as.matrix(distr[event]$pdf(event_times))) - } else { # mix - out[event] = diag(as.matrix(distr[event]$pdf(event_times))) - out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) + # Bypass distr6 construction if underlying distr represented by array + if (inherits(prediction$data$distr, "array")) { + surv = prediction$data$distr + times = as.numeric(colnames(surv)) + pdf = distr6:::cdfpdf(1 - surv) + if (any(!event)) { + out[!event] = diag(t(distr6:::C_Vec_WeightedDiscreteCdf( + cens_times, times, t(1 - surv), FALSE, FALSE))) + } + if (any(event)) { + out[event] = diag(t(distr6:::C_Vec_WeightedDiscretePdf( + event_times, times, t(pdf)))) + } + } else { + distr = prediction$distr + + # Splitting in this way bypasses unnecessary distr extraction + if (!any(event)) { # all censored + # survival at outcome time (survived *at least* this long) + out = diag(as.matrix(distr$survival(cens_times))) + } else if (all(event)) { # all uncensored + # pdf at outcome time (survived *this* long) + out = diag(as.matrix(distr$pdf(event_times))) + } else { # mix + out[event] = diag(as.matrix(distr[event]$pdf(event_times))) + out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) + } } stopifnot(!any(out == -99L)) # safety check From 54561e277e5abe79ea2441e17ae20112e6593e5b Mon Sep 17 00:00:00 2001 From: john Date: Sat, 11 Nov 2023 00:58:11 +0100 Subject: [PATCH 02/17] filter distribution prediction objects --- R/PredictionDataSurv.R | 15 ++++++++++----- tests/testthat/test_PredictionSurv.R | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/R/PredictionDataSurv.R b/R/PredictionDataSurv.R index b355ec609..0c098e5b9 100644 --- a/R/PredictionDataSurv.R +++ b/R/PredictionDataSurv.R @@ -129,12 +129,17 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) { } if (!is.null(pdata$distr)) { - if (inherits(pdata$distr, "matrix")) { - pdata$distr = pdata$distr[keep, , drop = FALSE] - } else { # array - pdata$distr = pdata$distr[keep, , , drop = FALSE] - } + distr = pdata$distr + if (testDistribution(distr)) { # distribution + pdata$distr = distr[keep] + } else { + if (length(dim(distr)) == 2) { # 2d matrix + pdata$distr = distr[keep, , drop = FALSE] + } else { # 3d array + pdata$distr = distr[keep, , , drop = FALSE] + } + } } pdata diff --git a/tests/testthat/test_PredictionSurv.R b/tests/testthat/test_PredictionSurv.R index 20872bf08..af75db1e8 100644 --- a/tests/testthat/test_PredictionSurv.R +++ b/tests/testthat/test_PredictionSurv.R @@ -178,19 +178,35 @@ test_that("as_prediction_surv", { test_that("filtering", { p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task)) p2 = reshape_distr_to_3d(p) # survival array distr + p3 = p$clone() + p4 = p2$clone() + p3$data$distr = p3$distr # Matdist + p4$data$distr = p4$distr # Arrdist p$filter(c(20, 37, 42)) p2$filter(c(20, 37, 42)) + p3$filter(c(20, 37, 42)) + p4$filter(c(20, 37, 42)) expect_prediction_surv(p) expect_prediction_surv(p2) + expect_prediction_surv(p3) + expect_prediction_surv(p4) expect_set_equal(p$data$row_ids, c(20, 37, 42)) expect_set_equal(p2$data$row_ids, c(20, 37, 42)) + expect_set_equal(p3$data$row_ids, c(20, 37, 42)) + expect_set_equal(p4$data$row_ids, c(20, 37, 42)) expect_numeric(p$data$crank, any.missing = FALSE, len = 3) expect_numeric(p2$data$crank, any.missing = FALSE, len = 3) + expect_numeric(p3$data$crank, any.missing = FALSE, len = 3) + expect_numeric(p4$data$crank, any.missing = FALSE, len = 3) expect_numeric(p$data$lp, any.missing = FALSE, len = 3) expect_numeric(p2$data$lp, any.missing = FALSE, len = 3) + expect_numeric(p3$data$lp, any.missing = FALSE, len = 3) + expect_numeric(p4$data$lp, any.missing = FALSE, len = 3) expect_matrix(p$data$distr, nrows = 3) expect_array(p2$data$distr, d = 3) expect_equal(nrow(p2$data$distr), 3) + expect_true(inherits(p3$data$distr, "Matdist")) + expect_true(inherits(p4$data$distr, "Arrdist")) }) From f0e6dfb11e639522d97841ebf809b110e34fecef Mon Sep 17 00:00:00 2001 From: john Date: Sat, 11 Nov 2023 23:02:54 +0100 Subject: [PATCH 03/17] fix small issues when filtering predictions + add more tests --- R/PredictionDataSurv.R | 8 ++++++- R/PredictionSurv.R | 4 +++- inst/testthat/helper_expectations.R | 4 ++-- tests/testthat/test_PredictionSurv.R | 34 ++++++++++++++++++++++++++-- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/R/PredictionDataSurv.R b/R/PredictionDataSurv.R index 0c098e5b9..fbb54c390 100644 --- a/R/PredictionDataSurv.R +++ b/R/PredictionDataSurv.R @@ -132,7 +132,13 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) { distr = pdata$distr if (testDistribution(distr)) { # distribution - pdata$distr = distr[keep] + ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) && + length(keep) > 1 # edge case: Arrdist(1xYxZ) and keep = FALSE + if (ok) { + pdata$distr = distr[keep] # we can subset row/samples like this + } else { + pdata$distr = base::switch(keep, distr) # one distribution only + } } else { if (length(dim(distr)) == 2) { # 2d matrix pdata$distr = distr[keep, , drop = FALSE] diff --git a/R/PredictionSurv.R b/R/PredictionSurv.R index 37c50b1ac..205a0300d 100644 --- a/R/PredictionSurv.R +++ b/R/PredictionSurv.R @@ -171,10 +171,12 @@ PredictionSurv = R6Class("PredictionSurv", } }, .distrify_survarray = function(x) { - if (inherits(x, "array")) { # can be matrix as well + if (inherits(x, "array") && nrow(x) > 0) { # can be matrix as well # create Matdist or Arrdist (default => median curve) distr6::as.Distribution(1 - x, fun = "cdf", decorators = c("CoreStatistics", "ExoticStatistics")) + } else { + NULL } } ) diff --git a/inst/testthat/helper_expectations.R b/inst/testthat/helper_expectations.R index 65177d1bb..ea2ece8a8 100644 --- a/inst/testthat/helper_expectations.R +++ b/inst/testthat/helper_expectations.R @@ -30,8 +30,8 @@ expect_prediction_surv = function(p) { "response", "distr", "lp", "crank")) checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids)) checkmate::expect_atomic_vector(p$missing) - if ("distr" %in% p$predict_types) { - expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist")) + if ("distr" %in% p$predict_types && !is.null(p$distr)) { + expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist", "WeightedDiscrete")) } expect_true(inherits(p, "PredictionSurv")) } diff --git a/tests/testthat/test_PredictionSurv.R b/tests/testthat/test_PredictionSurv.R index af75db1e8..0e12e137d 100644 --- a/tests/testthat/test_PredictionSurv.R +++ b/tests/testthat/test_PredictionSurv.R @@ -176,8 +176,8 @@ test_that("as_prediction_surv", { }) test_that("filtering", { - p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task)) - p2 = reshape_distr_to_3d(p) # survival array distr + p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task)) # survival matrix + p2 = reshape_distr_to_3d(p) # survival array p3 = p$clone() p4 = p2$clone() p3$data$distr = p3$distr # Matdist @@ -209,4 +209,34 @@ test_that("filtering", { expect_equal(nrow(p2$data$distr), 3) expect_true(inherits(p3$data$distr, "Matdist")) expect_true(inherits(p4$data$distr, "Arrdist")) + + # edge case: filter to 1 observation + p$filter(20) + p2$filter(20) + p3$filter(20) + p4$filter(20) + expect_prediction_surv(p) + expect_prediction_surv(p2) + expect_prediction_surv(p3) + expect_prediction_surv(p4) + expect_matrix(p$data$distr, nrows = 1) + expect_array(p2$data$distr, d = 3) + expect_equal(nrow(p2$data$distr), 1) + expect_true(inherits(p3$data$distr, "WeightedDiscrete")) # from Matdist! + expect_true(inherits(p4$data$distr, "Arrdist")) # remains an Arrdist! + + # filter to 0 observations using non-existent (positive) id + p$filter(42) + p2$filter(42) + p3$filter(42) + p4$filter(42) + + expect_prediction_surv(p) + expect_prediction_surv(p2) + expect_prediction_surv(p3) + expect_prediction_surv(p4) + expect_null(p$distr) + expect_null(p2$distr) + expect_null(p3$distr) + expect_null(p4$distr) }) From c3ad10219955782cf902cb9d2d508e4874b5fcf3 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 11 Nov 2023 23:16:27 +0100 Subject: [PATCH 04/17] small fix --- R/PredictionDataSurv.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/PredictionDataSurv.R b/R/PredictionDataSurv.R index fbb54c390..acb7ec95c 100644 --- a/R/PredictionDataSurv.R +++ b/R/PredictionDataSurv.R @@ -133,7 +133,7 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) { if (testDistribution(distr)) { # distribution ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) && - length(keep) > 1 # edge case: Arrdist(1xYxZ) and keep = FALSE + length(keep) > 1 # e.g.: Arrdist(1xYxZ) and keep = FALSE if (ok) { pdata$distr = distr[keep] # we can subset row/samples like this } else { From 1634967ef5a7ff9a830288275c56ec96b80191b4 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 11 Nov 2023 23:28:25 +0100 Subject: [PATCH 05/17] fix bugs in RCLL when distr is of array type --- R/MeasureSurvRCLL.R | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/R/MeasureSurvRCLL.R b/R/MeasureSurvRCLL.R index b7014c1b2..c0d529e9b 100644 --- a/R/MeasureSurvRCLL.R +++ b/R/MeasureSurvRCLL.R @@ -77,15 +77,34 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL", # Bypass distr6 construction if underlying distr represented by array if (inherits(prediction$data$distr, "array")) { surv = prediction$data$distr + if (length(dim(surv)) == 3) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } times = as.numeric(colnames(surv)) - pdf = distr6:::cdfpdf(1 - surv) + if (any(!event)) { - out[!event] = diag(t(distr6:::C_Vec_WeightedDiscreteCdf( - cens_times, times, t(1 - surv), FALSE, FALSE))) + if (sum(!event) == 1) { # fix subsetting issue in case of 1 censored + cdf = t(1 - surv) + } else { + cdf = t(1 - surv[!event, ]) + } + + out[!event] = diag( + distr6:::C_Vec_WeightedDiscreteCdf(cens_times, times, cdf = cdf, FALSE, FALSE) + ) } if (any(event)) { - out[event] = diag(t(distr6:::C_Vec_WeightedDiscretePdf( - event_times, times, t(pdf)))) + pdf = distr6:::cdfpdf(1 - surv) + if (sum(event) == 1) { # fix subsetting issue in case of 1 event + pdf = t(pdf) + } else { + pdf = t(pdf[event, ]) + } + + out[event] = diag( + distr6:::C_Vec_WeightedDiscretePdf(event_times, times, pdf = pdf) + ) } } else { distr = prediction$distr From ca01473462dd29572ba43d622f04a5034f4c051e Mon Sep 17 00:00:00 2001 From: john Date: Sat, 11 Nov 2023 23:29:01 +0100 Subject: [PATCH 06/17] update RCLL tests --- tests/testthat/test_mlr_measures.R | 44 +++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R index 9503a1bb7..8661b79fd 100644 --- a/tests/testthat/test_mlr_measures.R +++ b/tests/testthat/test_mlr_measures.R @@ -194,11 +194,15 @@ test_that("rcll works", { t = tsk("rats")$filter(sample(1:300, 50)) l = lrn("surv.kaplan") p = l$train(t)$predict(t) + p2 = p$clone() + p2$data$distr = p2$distr # hack: test score via distribution m = msr("surv.rcll") expect_true(m$minimize) expect_equal(m$range, c(0, Inf)) KMscore = p$score(m) expect_numeric(KMscore) + KMscore2 = p2$score(m) + expect_equal(KMscore, KMscore2) status = t$truth()[,2] row_ids = t$row_ids @@ -207,18 +211,44 @@ test_that("rcll works", { # only censored rats in test set p = l$predict(t, row_ids = cens_ids) - expect_numeric(p$score(m)) - expect_numeric(p$filter(row_ids = cens_ids[1])$score(m)) # 1 test rat + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr + score2 = p2$score(m) + expect_equal(score, score2) + + # 1 censored test rat + p = p$filter(row_ids = cens_ids[1]) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr + score2 = p2$score(m) + expect_equal(score, score2) # only dead rats in test set p = l$predict(t, row_ids = event_ids) - expect_numeric(p$score(m)) - expect_numeric(p$filter(row_ids = event_ids[1])$score(m)) # 1 test rat + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr # Matdist(1xY) + score2 = p2$score(m) + expect_equal(score, score2) + + # 1 dead rat + p = p$filter(row_ids = event_ids[1]) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr[1] # WeightDisc + score2 = p2$score(m) + expect_equal(score, score2) # Cox is better than baseline (Kaplan-Meier) - l = lrn("surv.coxph") - p = suppressWarnings(l$train(t)$predict(t)) - expect_true(p$score(m) < KMscore) + l2 = lrn("surv.coxph") + p2 = suppressWarnings(l2$train(t)$predict(t)) + expect_true(p2$score(m) < KMscore) }) test_that("distr measures work with 3d survival array", { From 97e1c79e9050fc1478b3777aae5a3e891aee44d5 Mon Sep 17 00:00:00 2001 From: john Date: Sun, 12 Nov 2023 18:03:59 +0100 Subject: [PATCH 07/17] fixing subsetting issue (another edge case) --- R/MeasureSurvRCLL.R | 4 ++-- tests/testthat/test_mlr_measures.R | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/R/MeasureSurvRCLL.R b/R/MeasureSurvRCLL.R index c0d529e9b..0a4ffe051 100644 --- a/R/MeasureSurvRCLL.R +++ b/R/MeasureSurvRCLL.R @@ -85,7 +85,7 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL", if (any(!event)) { if (sum(!event) == 1) { # fix subsetting issue in case of 1 censored - cdf = t(1 - surv) + cdf = as.matrix(1 - surv[!event, ]) } else { cdf = t(1 - surv[!event, ]) } @@ -97,7 +97,7 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL", if (any(event)) { pdf = distr6:::cdfpdf(1 - surv) if (sum(event) == 1) { # fix subsetting issue in case of 1 event - pdf = t(pdf) + pdf = as.matrix(pdf[event, ]) } else { pdf = t(pdf[event, ]) } diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R index 8661b79fd..5fdf26c74 100644 --- a/tests/testthat/test_mlr_measures.R +++ b/tests/testthat/test_mlr_measures.R @@ -249,6 +249,14 @@ test_that("rcll works", { l2 = lrn("surv.coxph") p2 = suppressWarnings(l2$train(t)$predict(t)) expect_true(p2$score(m) < KMscore) + + # Another edge case: some dead rats and 1 only censored + p3 = p2$filter(row_ids = c(event_ids, cens_ids[1])) + score = p3$score(m) + expect_numeric(score) + p3$data$distr = p3$distr + score2 = p3$score(m) + expect_equal(score, score2) }) test_that("distr measures work with 3d survival array", { From 3f52b61529e8be8ad8d4c98f86ebc06267984b28 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 16 Nov 2023 13:22:02 +0100 Subject: [PATCH 08/17] small doc fix --- R/MeasureSurvDCalibration.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index 2fab4ef59..c93f3d59a 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -16,7 +16,7 @@ #' @details #' This measure can either return the test statistic or the p-value from the `chisq.test`. #' The former is useful for model comparison whereas the latter is useful for determining if a model -#' is well-calibration. If `chisq = FALSE` and `m` is the predicted value then you can manually +#' is well-calibrated. If `chisq = FALSE` and `m` is the predicted value then you can manually #' compute the p.value with `pchisq(m, B - 1, lower.tail = FALSE)`. #' #' NOTE: This measure is still experimental both theoretically and in implementation. Results From c2d31f8b348b981aeccba9ef7086221d635cb7c9 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 16 Nov 2023 13:29:13 +0100 Subject: [PATCH 09/17] small refactoring --- R/MeasureSurvDCalibration.R | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index c93f3d59a..2b5ec3214 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -62,8 +62,10 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", private = list( .score = function(prediction, ...) { ps = self$param_set$values + B = ps$B + # initialize buckets - bj = numeric(ps$B) + bj = numeric(B) true_times = prediction$truth[, 1L] # predict individual probability of death at observed event time # bypass distr6 construction if possible @@ -83,7 +85,7 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", # remove zeros si = map_dbl(si, function(.x) max(.x, 1e-5)) # index of associated bucket - js = ceiling(ps$B * si) + js = ceiling(B * si) # could remove loop for dead observations but needed for censored ones and minimal overhead # in combining both @@ -95,16 +97,16 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", } else { # uncensored observations spread across buckets with most weighting on penultimate for (k in seq.int(ji - 1)) { - bj[k] = bj[k] + 1 / (ps$B * si[[i]]) + bj[k] = bj[k] + 1 / (B * si[[i]]) } - bj[ji] = bj[ji] + (1 - (ji - 1) / (ps$B * si[[i]])) + bj[ji] = bj[ji] + (1 - (ji - 1) / (B * si[[i]])) } } if (ps$chisq) { return(stats::chisq.test(bj)$p.value) } else { - return((ps$B / length(si)) * sum((bj - length(si) / ps$B)^2)) + return((B / length(si)) * sum((bj - length(si) / B)^2)) } } ) From 9423416083c3b240eb574a8bdac972c9497b1d91 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 16 Nov 2023 14:20:35 +0100 Subject: [PATCH 10/17] small comment fix --- R/MeasureSurvDCalibration.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index 2b5ec3214..11b671105 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -67,6 +67,7 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", # initialize buckets bj = numeric(B) true_times = prediction$truth[, 1L] + # predict individual probability of death at observed event time # bypass distr6 construction if possible if (inherits(prediction$data$distr, "array")) { @@ -95,7 +96,7 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", # dead observations contribute 1 to their index bj[ji] = bj[ji] + 1 } else { - # uncensored observations spread across buckets with most weighting on penultimate + # censored observations spread across buckets with most weighting on penultimate for (k in seq.int(ji - 1)) { bj[k] = bj[k] + 1 / (B * si[[i]]) } From 4171edd569aff2d364f551047b09b24abd2e25c3 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 16 Nov 2023 16:54:48 +0100 Subject: [PATCH 11/17] better doc for dcalib --- R/MeasureSurvDCalibration.R | 17 ++++++++++------- man/mlr_measures_surv.dcalib.Rd | 19 +++++++++++-------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index 11b671105..d6812380e 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -3,15 +3,18 @@ #' @templateVar fullname MeasureSurvDCalibration #' #' @description -#' This calibration method is defined by calculating +#' This calibration method is defined by calculating the following statistic: #' \deqn{s = B/n \sum_i (P_i - n/B)^2} -#' where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions, -#' and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval -#' [0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1). +#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals), +#' \eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion +#' of observations in the \eqn{i}th interval. An observation is assigned to the +#' \eqn{i}th bucket, if its predicted survival probability at the time of event +#' falls within the corresponding interval. +#' This statistic assumes that censoring time is independent of death time. #' -#' A model is well-calibrated if `s ~ Unif(B)`, tested with `chisq.test` -#' (`p > 0.05` if well-calibrated). -#' Model `i` is better calibrated than model `j` if `s_i < s_j`. +#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test` +#' (\eqn{p > 0.05} if well-calibrated). +#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}. #' #' @details #' This measure can either return the test statistic or the p-value from the `chisq.test`. diff --git a/man/mlr_measures_surv.dcalib.Rd b/man/mlr_measures_surv.dcalib.Rd index 411931e45..6ed0e928f 100644 --- a/man/mlr_measures_surv.dcalib.Rd +++ b/man/mlr_measures_surv.dcalib.Rd @@ -5,20 +5,23 @@ \alias{MeasureSurvDCalibration} \title{D-Calibration Survival Measure} \description{ -This calibration method is defined by calculating +This calibration method is defined by calculating the following statistic: \deqn{s = B/n \sum_i (P_i - n/B)^2} -where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions, -and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval -[0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1). +where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals), +\eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion +of observations in the \eqn{i}th interval. An observation is assigned to the +\eqn{i}th bucket, if its predicted survival probability at the time of event +falls within the corresponding interval. +This statistic assumes that censoring time is independent of death time. -A model is well-calibrated if \code{s ~ Unif(B)}, tested with \code{chisq.test} -(\code{p > 0.05} if well-calibrated). -Model \code{i} is better calibrated than model \code{j} if \code{s_i < s_j}. +A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with \code{chisq.test} +(\eqn{p > 0.05} if well-calibrated). +Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}. } \details{ This measure can either return the test statistic or the p-value from the \code{chisq.test}. The former is useful for model comparison whereas the latter is useful for determining if a model -is well-calibration. If \code{chisq = FALSE} and \code{m} is the predicted value then you can manually +is well-calibrated. If \code{chisq = FALSE} and \code{m} is the predicted value then you can manually compute the p.value with \code{pchisq(m, B - 1, lower.tail = FALSE)}. NOTE: This measure is still experimental both theoretically and in implementation. Results From 5a720ed1dd44364fce6bebbe71748a75b105cc50 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 16 Nov 2023 17:08:54 +0100 Subject: [PATCH 12/17] remove unnecessary transpose --- R/MeasureSurvDCalibration.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index d6812380e..e248b091e 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -72,12 +72,12 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", true_times = prediction$truth[, 1L] # predict individual probability of death at observed event time - # bypass distr6 construction if possible + # bypass distr6 construction if possible if (inherits(prediction$data$distr, "array")) { - si = diag(t(distr6:::C_Vec_WeightedDiscreteCdf(true_times, + si = diag(distr6:::C_Vec_WeightedDiscreteCdf(true_times, as.numeric(colnames(prediction$data$distr)), t(1 - prediction$data$distr), FALSE, FALSE - ))) + )) } else { distr = prediction$distr if (inherits(distr, "VectorDistribution")) { From b61ded81b5bdd67bffb917f73fedd639c930916e Mon Sep 17 00:00:00 2001 From: john Date: Thu, 16 Nov 2023 17:43:00 +0100 Subject: [PATCH 13/17] add tests for dcalib --- tests/testthat/test_mlr_measures.R | 65 ++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R index 5fdf26c74..617fc8131 100644 --- a/tests/testthat/test_mlr_measures.R +++ b/tests/testthat/test_mlr_measures.R @@ -259,6 +259,71 @@ test_that("rcll works", { expect_equal(score, score2) }) +test_that("dcal works", { + set.seed(1) + t = tsk("rats")$filter(sample(1:300, 50)) + l = lrn("surv.coxph") + p = suppressWarnings(l$train(t)$predict(t)) + p2 = p$clone() + p2$data$distr = p2$distr # hack: test score via distribution + m = msr("surv.dcalib") + expect_true(m$minimize) + expect_equal(m$range, c(0, Inf)) + KMscore = p$score(m) + expect_numeric(KMscore) + KMscore2 = p2$score(m) + expect_equal(KMscore, KMscore2) + + status = t$truth()[,2] + row_ids = t$row_ids + cens_ids = row_ids[status == 0] + event_ids = row_ids[status == 1] + + # only censored rats in test set + p = l$predict(t, row_ids = cens_ids) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr + score2 = p2$score(m) + expect_equal(score, score2) + + # 1 censored test rat + p = p$filter(row_ids = cens_ids[1]) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr + score2 = p2$score(m) + expect_equal(score, score2) + + # only dead rats in test set + p = l$predict(t, row_ids = event_ids) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr # Matdist(1xY) + score2 = p2$score(m) + expect_equal(score, score2) + + # 1 dead rat + p = p$filter(row_ids = event_ids[1]) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr[1] # WeightDisc + score2 = p2$score(m) + expect_equal(score, score2) + + # Another edge case: some dead rats and 1 only censored + p3 = p2$filter(row_ids = c(event_ids, cens_ids[1])) + score = p3$score(m) + expect_numeric(score) + p3$data$distr = p3$distr + score2 = p3$score(m) + expect_equal(score, score2) +}) + test_that("distr measures work with 3d survival array", { learner = lrn("surv.kaplan")$train(task) p = learner$predict(task) From 3288df53500e56a30c417c91d6d2ac35cf3b7e82 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 16 Nov 2023 17:47:15 +0100 Subject: [PATCH 14/17] bug fixes - 3d survival array needs conversion to 2d - 1 observation distribution resulted in NaN score --- R/MeasureSurvDCalibration.R | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index e248b091e..cfbbd1e48 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -74,16 +74,21 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", # predict individual probability of death at observed event time # bypass distr6 construction if possible if (inherits(prediction$data$distr, "array")) { - si = diag(distr6:::C_Vec_WeightedDiscreteCdf(true_times, - as.numeric(colnames(prediction$data$distr)), - t(1 - prediction$data$distr), FALSE, FALSE - )) + surv = prediction$data$distr + if (length(dim(surv)) == 3) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } + times = as.numeric(colnames(surv)) + + si = diag(distr6:::C_Vec_WeightedDiscreteCdf(true_times, times, + cdf = t(1 - surv), FALSE, FALSE)) } else { distr = prediction$distr - if (inherits(distr, "VectorDistribution")) { - si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L))) - } else { + if (inherits(distr, c("Matdist", "Arrdist"))) { si = diag(distr$survival(true_times)) + } else { # VectorDistribution or single Distribution, e.g. WeightDisc() + si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L))) } } # remove zeros From a14aff2e9936962401700b82db406b243a2e24ec Mon Sep 17 00:00:00 2001 From: john Date: Sat, 18 Nov 2023 02:53:48 +0100 Subject: [PATCH 15/17] add truncate parameter --- R/MeasureSurvDCalibration.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index cfbbd1e48..54ec4aa9d 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -46,9 +46,10 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", initialize = function() { ps = ps( B = p_int(1, default = 10), - chisq = p_lgl(default = FALSE) + chisq = p_lgl(default = FALSE), + truncate = p_dbl(lower = 0, upper = Inf, default = 10) ) - ps$values = list(B = 10L, chisq = FALSE) + ps$values = list(B = 10L, chisq = FALSE, truncate = 10) super$initialize( id = "surv.dcalib", @@ -115,7 +116,7 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", if (ps$chisq) { return(stats::chisq.test(bj)$p.value) } else { - return((B / length(si)) * sum((bj - length(si) / B)^2)) + return(min(ps$truncate, (B / length(si)) * sum((bj - length(si) / B)^2))) } } ) From 48b29ea650cade3d9ecacbfc3ebf2823483ba5d4 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 18 Nov 2023 02:54:11 +0100 Subject: [PATCH 16/17] update doc --- R/MeasureSurvDCalibration.R | 27 +++++++++++++++++++-------- man/mlr_measures_surv.dcalib.Rd | 28 ++++++++++++++++++++-------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index 54ec4aa9d..e7d2d5c99 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -14,13 +14,14 @@ #' #' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test` #' (\eqn{p > 0.05} if well-calibrated). -#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}. +#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}, +#' meaning that *lower values* of this measure are preferred. #' #' @details #' This measure can either return the test statistic or the p-value from the `chisq.test`. #' The former is useful for model comparison whereas the latter is useful for determining if a model -#' is well-calibrated. If `chisq = FALSE` and `m` is the predicted value then you can manually -#' compute the p.value with `pchisq(m, B - 1, lower.tail = FALSE)`. +#' is well-calibrated. If `chisq = FALSE` and `s` is the predicted value then you can manually +#' compute the p.value with `pchisq(s, B - 1, lower.tail = FALSE)`. #' #' NOTE: This measure is still experimental both theoretically and in implementation. Results #' should therefore only be taken as an indicator of performance and not for @@ -37,12 +38,22 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", public = list( #' @description Creates a new instance of this [R6][R6::R6Class] class. #' @param B (`integer(1)`) \cr - #' Number of buckets to test for uniform predictions over. Default of `10` is recommended by - #' Haider et al. (2020). + #' Number of buckets to test for uniform predictions over. + #' Default of `10` is recommended by Haider et al. (2020). + #' Changing this parameter affects `truncate`. #' @param chisq (`logical(1)`) \cr - #' If `TRUE` returns the p.value of the corresponding chisq.test instead of the measure. - #' Otherwise this can be performed manually with `pchisq(m, B - 1, lower.tail = FALSE)`. - #' `p > 0.05` indicates well-calibrated. + #' If `TRUE` returns the p-value of the corresponding chisq.test instead of the measure. + #' Default is `FALSE` and returns the statistic `s`. + #' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`. + #' `p > 0.05` indicates a well-calibrated model. + #' @param truncate (`double(1)`) \cr + #' This parameter controls the upper bound of the output statistic, + #' when `chisq` is `FALSE`. The default `truncate` value of \eqn{10} + #' corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets. + #' Values \eqn{>10} translate to even lower p-values and thus less calibrated + #' models. If the number of buckets \eqn{B} changes, you probably will want to + #' change the `truncate` value as well to correspond to the same p-value significance. + #' Initialize with `truncate = Inf` if no truncation is desired. initialize = function() { ps = ps( B = p_int(1, default = 10), diff --git a/man/mlr_measures_surv.dcalib.Rd b/man/mlr_measures_surv.dcalib.Rd index 6ed0e928f..6694eed71 100644 --- a/man/mlr_measures_surv.dcalib.Rd +++ b/man/mlr_measures_surv.dcalib.Rd @@ -16,13 +16,14 @@ This statistic assumes that censoring time is independent of death time. A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with \code{chisq.test} (\eqn{p > 0.05} if well-calibrated). -Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}. +Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}, +meaning that \emph{lower values} of this measure are preferred. } \details{ This measure can either return the test statistic or the p-value from the \code{chisq.test}. The former is useful for model comparison whereas the latter is useful for determining if a model -is well-calibrated. If \code{chisq = FALSE} and \code{m} is the predicted value then you can manually -compute the p.value with \code{pchisq(m, B - 1, lower.tail = FALSE)}. +is well-calibrated. If \code{chisq = FALSE} and \code{s} is the predicted value then you can manually +compute the p.value with \code{pchisq(s, B - 1, lower.tail = FALSE)}. NOTE: This measure is still experimental both theoretically and in implementation. Results should therefore only be taken as an indicator of performance and not for @@ -129,13 +130,24 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \if{html}{\out{
}} \describe{ \item{\code{B}}{(\code{integer(1)}) \cr -Number of buckets to test for uniform predictions over. Default of \code{10} is recommended by -Haider et al. (2020).} +Number of buckets to test for uniform predictions over. +Default of \code{10} is recommended by Haider et al. (2020). +Changing this parameter affects \code{truncate}.} \item{\code{chisq}}{(\code{logical(1)}) \cr -If \code{TRUE} returns the p.value of the corresponding chisq.test instead of the measure. -Otherwise this can be performed manually with \code{pchisq(m, B - 1, lower.tail = FALSE)}. -\code{p > 0.05} indicates well-calibrated.} +If \code{TRUE} returns the p-value of the corresponding chisq.test instead of the measure. +Default is \code{FALSE} and returns the statistic \code{s}. +You can manually get the p-value by executing \code{pchisq(s, B - 1, lower.tail = FALSE)}. +\code{p > 0.05} indicates a well-calibrated model.} + +\item{\code{truncate}}{(\code{double(1)}) \cr +This parameter controls the upper bound of the output statistic, +when \code{chisq} is \code{FALSE}. The default \code{truncate} value of \eqn{10} +corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets. +Values \eqn{>10} translate to even lower p-values and thus less calibrated +models. If the number of buckets \eqn{B} changes, you probably will want to +change the \code{truncate} value as well to correspond to the same p-value significance. +Initialize with \code{truncate = Inf} if no truncation is desired.} } \if{html}{\out{
}} } From 148387fff226d6f6aa0d538cbed844891d1ed314 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 18 Nov 2023 02:56:33 +0100 Subject: [PATCH 17/17] fix dcalib test + add truncate test --- tests/testthat/test_mlr_measures.R | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R index 617fc8131..5f766256c 100644 --- a/tests/testthat/test_mlr_measures.R +++ b/tests/testthat/test_mlr_measures.R @@ -266,9 +266,12 @@ test_that("dcal works", { p = suppressWarnings(l$train(t)$predict(t)) p2 = p$clone() p2$data$distr = p2$distr # hack: test score via distribution - m = msr("surv.dcalib") + m = msr("surv.dcalib", truncate = 20) expect_true(m$minimize) expect_equal(m$range, c(0, Inf)) + expect_equal(m$param_set$values$B, 10) + expect_equal(m$param_set$values$chisq, FALSE) + expect_equal(m$param_set$values$truncate, 20) KMscore = p$score(m) expect_numeric(KMscore) KMscore2 = p2$score(m) @@ -316,12 +319,20 @@ test_that("dcal works", { expect_equal(score, score2) # Another edge case: some dead rats and 1 only censored - p3 = p2$filter(row_ids = c(event_ids, cens_ids[1])) - score = p3$score(m) + p = l$predict(t, row_ids = c(event_ids, cens_ids[1])) + score = p$score(m) expect_numeric(score) - p3$data$distr = p3$distr - score2 = p3$score(m) + p$data$distr = p$distr + score2 = p$score(m) expect_equal(score, score2) + expect_true(score > 10) + + score3 = p$score(msr("surv.dcalib")) # default truncate = 10 + expect_equal(unname(score3), 10) + score4 = p$score(msr("surv.dcalib", truncate = 5)) + expect_equal(unname(score4), 5) + score5 = p$score(msr("surv.dcalib", truncate = Inf, B = 20)) # B affects truncate + expect_true(score5 > score) }) test_that("distr measures work with 3d survival array", {