-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathREADME.Rmd
139 lines (101 loc) · 4.83 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
cache = FALSE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
library(mlr3)
library(mlr3misc)
set.seed(1)
lgr::get_logger("mlr3")$set_threshold("warn")
```
# mlr3inferr <a href="https://mlr3inferr.mlr-org.com"><img src="man/figures/logo.png" align="right" height="138" alt="mlr3inferr website" /></a>
Methods for statistical **inf**erence on the generalization **err**or.
Package website: [release](https://mlr3inferr.mlr-org.com/) | [dev](https://mlr3inferr.mlr-org.com/dev/)
<!-- badges: start -->
[![RCMD Check](https://github.com/mlr-org/mlr3inferr/actions/workflows/r-cmd-check.yml/badge.svg)](https://github.com/mlr-org/mlr3inferr/actions/workflows/r-cmd-check.yml)
[![CRAN status](https://www.r-pkg.org/badges/version/mlr3inferr)](https://CRAN.R-project.org/package=mlr3inferr)
[![StackOverflow](https://img.shields.io/badge/stackoverflow-mlr3-orange.svg)](https://stackoverflow.com/questions/tagged/mlr3)
[![Mattermost](https://img.shields.io/badge/chat-mattermost-orange.svg)](https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/)
<!-- badges: end -->
## Installation
```{r, eval = FALSE}
# Install release from CRAN
install.packages("mlr3inferr")
# Install development version from GitHub
pak::pkg_install("mlr-org/mlr3inferr")
```
## What is `mlr3inferr`?
The main purpose of the package is to allow to obtain confidence intervals for the generalization error for a number of resampling methods.
Below, we evaluate a decision tree on the sonar task using a holdout resampling and obtain a confidence interval for the generalization error.
This is achieved using the `msr("ci.holdout")` measure, to which we pass another `mlr3::Measure` that determines the loss function.
```{r}
library(mlr3inferr)
rr = resample(tsk("sonar"), lrn("classif.rpart"), rsmp("holdout"))
# 0.05 is also the default
ci = msr("ci.holdout", "classif.acc", alpha = 0.05)
rr$aggregate(ci)
```
It is also possible to select the default inference method for a certain `Resampling` method using `msr("ci")`
```{r}
ci_default = msr("ci", "classif.acc")
rr$aggregate(ci_default)
```
With [`mlr3viz`](https://mlr3viz.mlr-org.com), it is also possible to visualize multiple confidence intervals.
Below, we compare a random forest with a decision tree and a featureless learner:
```{r, dpi = 300, out.width = "70%", fig.align = "center"}
library(mlr3learners)
library(mlr3viz)
bmr = benchmark(benchmark_grid(
tsks(c("sonar", "german_credit")),
lrns(c("classif.rpart", "classif.ranger", "classif.featureless")),
rsmp("subsampling")
))
autoplot(bmr, "ci", msr("ci", "classif.ce"))
```
Note that:
* Some methods require pointwise loss functions, i.e. have an `$obs_loss` field.
* Not for every resampling method exists an inference method.
* There are combinations of datasets and learners, where inference methods can fail.
## Features
* Additional Resampling Methods
* Confidence Intervals for the Generalization Error for some resampling methods
## Inference Methods
```{r, echo = FALSE}
content = as.data.table(mlr3::mlr_measures, objects = TRUE)[startsWith(get("key"), "ci."),]
content$resamplings = map(content$object, function(x) paste0(gsub("Resampling", "", x$resamplings), collapse = ", "))
content[["only pointwise loss"]] = map_chr(content$object, function(object) {
if (get_private(object)$.requires_obs_loss) "yes" else "false"
})
content = content[, c("key", "label", "resamplings", "only pointwise loss")]
knitr::kable(content, format = "markdown", col.names = tools::toTitleCase(names(content)))
```
## Citing mlr3
If we use mlr3inferr, please cite our paper:
@misc{kuempelfischer2024ciforge,
title={Constructing Confidence Intervals for 'the' Generalization Error -- a Comprehensive Benchmark Study},
author={Hannah Schulz-Kümpel and Sebastian Fischer and Thomas Nagler and Anne-Laure Boulesteix and Bernd Bischl and Roman Hornung},
year={2024},
eprint={2409.18836},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={https://arxiv.org/abs/2409.18836},
}
## Bugs, Questions, Feedback
*mlr3inferr* is a free and open source software project that
encourages participation and feedback. If you have any issues,
questions, suggestions or feedback, please do not hesitate to open an
“issue” about it on the GitHub page\!
In case of problems / bugs, it is often helpful if you provide a
“minimum working example” that showcases the behaviour (but don’t
worry about this if the bug is obvious).
Please understand that the resources of the project are limited:
response may sometimes be delayed by a few days, and some feature
suggestions may be rejected if they are deemed too tangential to the
vision behind the project.