Skip to content

Store estimated models for nuisance parameters #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 10, 2022
Merged

Store estimated models for nuisance parameters #169

merged 12 commits into from
Nov 10, 2022

Conversation

MalteKurz
Copy link
Member

@MalteKurz MalteKurz commented Oct 10, 2022

Description

This PR implements the often requested feature to store the estimated models for nuisance parameters. To use it, call the method fit() with option store_models=True. Example:

library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
set.seed(2)
ml_g = lrn("regr.ranger", num.trees = 10, max.depth = 2)
ml_m = ml_g$clone()
obj_dml_data = make_plr_CCDDHNR2018(alpha = 0.5)
dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m)
dml_plr_obj$fit(store_models=TRUE)

The estimated models can then be found in the attribute dml_plr_obj$models:

dml_plr_obj$models
$ml_l
$ml_l$d
$ml_l$d[[1]]
$ml_l$d[[1]][[1]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

$ml_l$d[[1]][[2]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

$ml_l$d[[1]][[3]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

$ml_l$d[[1]][[4]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

$ml_l$d[[1]][[5]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights




$ml_m
$ml_m$d
$ml_m$d[[1]]
$ml_m$d[[1]][[1]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

$ml_m$d[[1]][[2]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

$ml_m$d[[1]][[3]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

$ml_m$d[[1]][[4]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

$ml_m$d[[1]][[5]]
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, num.trees=10, max.depth=2
* Packages: mlr3, mlr3learners, ranger
* Predict Types:  [response], se
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

Note that the number of fitted models depends on the settings and the considered model. The outer named list contains one entry for each nuisance part (here ml_l and ml_m). For each nuisance part there is a named list containing an entry for each treatment variable (here only 'd'). The next inner part is a list of length n_rep (repeated cross-fitting) and then a list of length n_folds (number of folds per repeated cross fit).

PR Checklist

  • The title of the pull request summarizes the changes made.
  • The PR contains a detailed description of all changes and additions.
  • The code passes R CMD check and all (unit) tests (see our contributing guidelines for details).
  • Enhancements or new feature are equipped with unit tests.
  • The changes adhere to the "mlr-style" standards (see our contributing guidelines for details).

Copy link
Member

@PhilippBach PhilippBach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @MalteKurz for preparing the PR. As you indicate in the check boxes of the PR, there are still the tests missing; Either I can do them in the next days myself or I'll add my review to them later (in case you're faster 😃 ). Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants