Skip to content

Commit

Permalink
Merge pull request #9 from mlr-org/non-decomposable
Browse files Browse the repository at this point in the history
Support CIs for non-decomposable measures
  • Loading branch information
sebffischer authored Jan 8, 2025
2 parents fd823c6 + 177eddf commit 15c7771
Show file tree
Hide file tree
Showing 25 changed files with 131 additions and 74 deletions.
9 changes: 5 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ Imports:
R6,
withr
Suggests:
testthat (>= 3.0.0)
testthat (>= 3.0.0),
rpart
Remotes:
mlr-org/mlr3
Config/testthat/edition: 3
Expand All @@ -38,12 +39,12 @@ RoxygenNote: 7.3.2
Collate:
'MeasureAbstractCi.R'
'aaa.R'
'MeasureCI.R'
'MeasureCIConZ.R'
'MeasureCICorT.R'
'MeasureCIHoldout.R'
'MeasureCINaiveCV.R'
'MeasureCi.R'
'MeasureCiNestedCV.R'
'MeasureCINestedCV.R'
'MeasureCIWaldCV.R'
'ResamplingNestedCV.R'
'ResamplingPairedSubsampling.R'
'bibentries.R'
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ export(MeasureCi)
export(MeasureCiConZ)
export(MeasureCiCorrectedT)
export(MeasureCiHoldout)
export(MeasureCiNaiveCV)
export(MeasureCiNestedCV)
export(MeasureCiWaldCV)
export(ResamplingNestedCV)
export(ResamplingPairedSubsampling)
import(checkmate)
Expand Down
39 changes: 28 additions & 11 deletions R/MeasureAbstractCi.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#' The measure for which to calculate a confidence interval. Must have `$obs_loss`.
#' @param resamplings (`character()`)\cr
#' To which resampling classes this measure can be applied.
#' @param requires_obs_loss (`logical(1)`)\cr
#' Whether the inference method requires a pointwise loss function.
#' @template param_param_set
#' @template param_packages
#' @template param_label
Expand All @@ -28,7 +30,8 @@
#' @section Inheriting:
#' To define a new CI method, inherit from the abstract base class and implement the private method:
#' `ci: function(tbl: data.table, rr: ResampleResult, param_vals: named `list()`) -> numeric(3)`
#' Here, `tbl` contains the columns `loss`, `row_id` and `iteration`, which are the pointwise loss,
#' If `requires_obs_loss` is set to `TRUE`, `tbl` contains the columns `loss`, `row_id` and `iteration`, which are the pointwise loss,
#' Otherwise, `tbl` contains the result of `rr$score()` with the name of the loss column set to `"loss"`.
#' the identifier of the observation and the resampling iteration.
#' It should return a vector containing the `estimate`, `lower` and `upper` boundary in that order.
#'
Expand All @@ -49,19 +52,28 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
measure = NULL,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(measure = NULL, param_set = ps(), packages = character(), resamplings, label, delta_method = FALSE) {
initialize = function(measure = NULL, param_set = ps(), packages = character(), resamplings, label, delta_method = FALSE,
requires_obs_loss = TRUE) { # nolint
private$.delta_method = assert_flag(delta_method, na.ok = TRUE)
self$measure = if (test_string(measure)) {
msr(measure)
} else {
private$.requires_obs_loss = assert_flag(requires_obs_loss)
if (test_string(measure)) measure = msr(measure)
self$measure = measure

if (private$.requires_obs_loss) {
assert(
check_class(measure, "Measure"),
check_false(inherits(measure, "MeasureCi")),
check_function(measure$obs_loss),
combine = "and",
.var.name = "Argument measure must be a scalar Measure with a pointwise loss function (has $obs_loss field)"
)
measure
} else {
assert(
check_class(measure, "Measure"),
check_false(inherits(measure, "MeasureCi")),
combine = "and",
.var.name = "Argument measure must be a scalar Measure."
)
}

param_set = c(param_set,
Expand Down Expand Up @@ -108,10 +120,15 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
}

param_vals = self$param_set$get_values()
tbl = rr$obs_loss(self$measure)
names(tbl)[names(tbl) == self$measure$id] = "loss"
tbl = if (private$.requires_obs_loss) {
rr$obs_loss(self$measure)
} else {
rr$score(self$measure)
}
setnames(tbl, self$measure$id, "loss")

ci = private$.ci(tbl, rr, param_vals)
if (!is.null(self$measure$trafo)) {
if (!is.null(self$measure$trafo) && private$.requires_obs_loss) {
ci = private$.trafo(ci)
}
if (param_vals$within_range) {
Expand All @@ -121,15 +138,15 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
}
),
private = list(
.requires_obs_loss = NULL,
.delta_method = FALSE,
.trafo = function(ci) {
if (!private$.delta_method) {
stopf("Measure '%s' has a trafo, but the CI does handle it", self$measure$id)
stopf("Measure '%s' has a trafo, but the CI does not handle it", self$measure$id)
}
measure = self$measure
# delta-rule
multiplier = measure$trafo$deriv(ci[[1]])
ci[[1]] = measure$trafo$fn(ci[[1]])
halfwidth = (ci[[3]] - ci[[1]])
est_t = measure$trafo$fn(ci[[1]])
ci_t = c(est_t, est_t - halfwidth * multiplier, est_t + halfwidth * multiplier)
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion R/MeasureCIConZ.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' @description
#' The conservative-z confidence intervals based on the [`ResamplingPairedSubsampling`].
#' Because the variance estimate is obtained using only `n / 2` observations, it tends to be conservative.
#' This inference method can also be applied to non-decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand All @@ -22,6 +23,7 @@ MeasureCiConZ = R6Class("MeasureCiConZ",
measure = measure,
resamplings = "ResamplingPairedSubsampling",
label = "Conservative-Z CI",
requires_obs_loss = FALSE,
delta_method = TRUE
)
}
Expand All @@ -30,7 +32,6 @@ MeasureCiConZ = R6Class("MeasureCiConZ",
.ci = function(tbl, rr, param_vals) {
repeats_in = rr$resampling$param_set$values$repeats_in
repeats_out = rr$resampling$param_set$values$repeats_out
tbl = tbl[, list(loss = mean(get("loss"))), by = "iteration"]

estimate = tbl[get("iteration") <= repeats_in, mean(get("loss"))]

Expand Down
4 changes: 3 additions & 1 deletion R/MeasureCICorT.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' Corrected-T confidence intervals based on [`ResamplingSubsampling`][mlr3::ResamplingSubsampling].
#' A heuristic factor is applied to correct for the dependence between the iterations.
#' The confidence intervals tend to be liberal.
#' This inference method can also be applied to non-decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand All @@ -29,6 +30,7 @@ MeasureCiCorrectedT = R6Class("MeasureCiCorrectedT",
measure = measure,
resamplings = "ResamplingSubsampling",
label = "Corrected-T CI",
requires_obs_loss = FALSE,
delta_method = TRUE
)
}
Expand All @@ -45,7 +47,7 @@ MeasureCiCorrectedT = R6Class("MeasureCiCorrectedT",
n2 = n - n1

# the different mu in the rows are the mu_j
mus = tbl[, list(estimate = mean(get("loss"))), by = "iteration"]$estimate
mus = tbl$loss
# the global estimator
estimate = mean(mus)
# The naive SD estimate (does not take correlation between folds into account)
Expand Down
1 change: 1 addition & 0 deletions R/MeasureCIHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @name mlr_measures_ci_holdout
#' @description
#' Standard holdout CI.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand Down
15 changes: 8 additions & 7 deletions R/MeasureCINaiveCV.R → R/MeasureCIWaldCV.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#' @title Naive Cross-Validation CI
#' @name mlr_measures_ci_naive_cv
#' @title Cross-Validation CI
#' @name mlr_measures_ci_wald_cv
#' @description
#' Confidence intervals for cross-validation.
#' The method is asymptotically exact for the so called *Test Error* as defined by Bayle et al. (2020).
#' For the (expected) risk, the confidence intervals tend to be too liberal.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Those from [`MeasureAbstractCi`], as well as:
#' * `variance` :: `"all-pairs"` or `"within-fold"`\cr
Expand All @@ -13,11 +14,11 @@
#' `r format_bib("bayle2020cross")`
#' @export
#' @examples
#' m_naivecv = msr("ci.naive_cv", "classif.ce")
#' m_naivecv
#' m_waldcv = msr("ci.wald_cv", "classif.ce")
#' m_waldcv
#' rr = resample(tsk("sonar"), lrn("classif.featureless"), rsmp("cv"))
#' rr$aggregate(m_naivecv)
MeasureCiNaiveCV = R6Class("MeasureCiNaiveCV",
#' rr$aggregate(m_waldcv)
MeasureCiWaldCV = R6Class("MeasureCiWaldCV",
inherit = MeasureAbstractCi,
public = list(
#' @description
Expand Down Expand Up @@ -60,4 +61,4 @@ MeasureCiNaiveCV = R6Class("MeasureCiNaiveCV",
)

#' @include aaa.R
measures[["ci.naive_cv"]] = list(MeasureCiNaiveCV, .prototype_args = list(measure = "classif.acc"))
measures[["ci.wald_cv"]] = list(MeasureCiWaldCV, .prototype_args = list(measure = "classif.acc"))
1 change: 1 addition & 0 deletions R/MeasureCiNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @name mlr_measures_ci_ncv
#' @description
#' Confidence Intervals based on [`ResamplingNestedCV`][ResamplingNestedCV], including bias-correction.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Those from [`MeasureAbstractCi`], as well as:
#' * `bias` :: `logical(1)`\cr
Expand Down
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ register_mlr3 = function(...) {
mlr_reflections = mlr3::mlr_reflections
mlr_reflections$default_ci_methods = list(
ResamplingHoldout = "ci.holdout",
ResamplingCV = "ci.naive_cv",
ResamplingCV = "ci.wald_cv",
ResamplingSubsampling = "ci.cor_t",
ResamplingPairedSubsampling = "ci.con_z",
ResamplingNestedCV = "ci.ncv"
Expand Down
11 changes: 7 additions & 4 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,25 @@ autoplot(bmr, "ci", msr("ci", "classif.ce"))

Note that:

* Confidence Intervals can only be obtained for measures that are based on pointwise loss functions, i.e. have an `$obs_loss` field.
* Some methods require pointwise loss functions, i.e. have an `$obs_loss` field.
* Not for every resampling method exists an inference method.
* There are combinations of datasets and learners, where inference methods can fail.

## Features

* Additional Resampling Methods
* Confidence Intervals for the Generalization Error for some resampling methods
* Confidence Intervals for the Generalization Error for some resampling methods


## Inference Methods

```{r, echo = FALSE}
content = as.data.table(mlr3::mlr_measures, objects = TRUE)[startsWith(get("key"), "ci."),]
content$resamplings = map(content$object, "resamplings")
content = content[, c("key", "label", "resamplings")]
content$resamplings = map(content$object, function(x) paste0(gsub("Resampling", "", x$resamplings), collapse = ", "))
content[["only pointwise loss"]] = map_chr(content$object, function(object) {
if (get_private(object)$.requires_obs_loss) "yes" else "false"
})
content = content[, c("key", "label", "resamplings", "only pointwise loss")]
knitr::kable(content, format = "markdown", col.names = tools::toTitleCase(names(content)))
```

Expand Down
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ autoplot(bmr, "ci", msr("ci", "classif.ce"))

Note that:

- Confidence Intervals can only be obtained for measures that are based
on pointwise loss functions, i.e. have an `$obs_loss` field.
- Some methods require pointwise loss functions, i.e. have an
`$obs_loss` field.
- Not for every resampling method exists an inference method.
- There are combinations of datasets and learners, where inference
methods can fail.
Expand All @@ -89,13 +89,23 @@ Note that:

## Inference Methods

| Key | Label | Resamplings |
|:------------|:------------------|:-----------------------------|
| ci.con_z | Conservative-Z CI | ResamplingPairedSubsampling |
| ci.cor_t | Corrected-T CI | ResamplingSubsampling |
| ci.holdout | Holdout CI | ResamplingHoldout |
| ci.naive_cv | Naive CV CI | ResamplingCV , ResamplingLOO |
| ci.ncv | Nested CV CI | ResamplingNestedCV |
``` r
content = as.data.table(mlr3::mlr_measures, objects = TRUE)[startsWith(get("key"), "ci."),]
content$resamplings = map(content$object, function(x) paste0(gsub("Resampling", "", x$resamplings), collapse = ", "))
content[["only pointwise loss"]] = map_chr(content$object, function(object) {
if (get_private(object)$.requires_obs_loss) "yes" else "false"
})
content = content[, c("key", "label", "resamplings", "only pointwise loss")]
knitr::kable(content, format = "markdown", col.names = tools::toTitleCase(names(content)))
```

| Key | Label | Resamplings | Only Pointwise Loss |
|:------------|:------------------|:------------------|:--------------------|
| ci.con_z | Conservative-Z CI | PairedSubsampling | false |
| ci.cor_t | Corrected-T CI | Subsampling | false |
| ci.holdout | Holdout CI | Holdout | yes |
| ci.wald_cv | Naive CV CI | CV, LOO | yes |
| ci.ncv | Nested CV CI | NestedCV | yes |

## Bugs, Questions, Feedback

Expand Down
9 changes: 7 additions & 2 deletions man/mlr_measures_abstract_ci.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_measures_ci.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_con_z.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_cor_t.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_holdout.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/mlr_measures_ci_ncv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 15c7771

Please sign in to comment.