Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

review of two-phase sampling functionality #49

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: medoutcon
Title: Efficient Natural and Interventional Causal Mediation Analysis
Version: 0.2.2
Version: 0.2.3
Authors@R: c(
person("Nima", "Hejazi", email = "nh@nimahejazi.org",
role = c("aut", "cre", "cph"),
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
YEAR: 2020-2022
YEAR: 2020-2024
COPYRIGHT HOLDER: Nima S. Hejazi
18 changes: 18 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# medoutcon 0.2.3

* Added a new named argument `cv_stratify` to `est_onestep()` and `est_tml()`
and to the `estimator_args` list-argument in `medoutcon()`, which allows for
stratified folds to be generated for cross-fitting (by passing these to the
`strata_ids` argument of `make_folds()` from the `origami` package). This is
also triggered by an override in `est_onestep()` and `est_tml()` when the
proportion of detected cases is less than 0.1, a heuristic for rare outcomes.
* Increased the default number of folds for cross-fitting from 5 to 10, setting
`cv_folds = 10L` in named arguments to `est_onestep()` and `est_tml()` and to
the `estimator_args` list-argument in `medoutcon()`.
* Changed default propensity score truncation bounds specified in `g_bounds` to
`c(0.005, 0.995)` from `c(0.01, 0.99)` (in v0.22), based on sanity checks and
manual experimentation.
* Wrapped instances of `sl3_Task()` in which `outcome_type = "continuous"` is
specified in `suppressWarnings()` to sink warnings when the outcome variable
for a given nuisance estimation task fails `sl3`'s check for continuous-ness.

# medoutcon 0.2.2

* Change iterative targeting procedures in `est_tml()` to use `glm2::glm2`
Expand Down
104 changes: 81 additions & 23 deletions R/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
effect_type = c("interventional", "natural"),
w_names,
m_names,
g_bounds = c(0.01, 0.99)) {
g_bounds = c(0.005, 0.995)) {
# make training and validation data
train_data <- origami::training(data_in)
valid_data <- origami::validation(data_in)
Expand Down Expand Up @@ -160,9 +160,11 @@
g_star <- g_out$treat_est_valid$treat_pred_A_star[valid_data$R == 1]
h_prime <- h_out$treat_est_valid$treat_pred_A_prime
g_prime <- g_out$treat_est_valid$treat_pred_A_prime[valid_data$R == 1]
q_prime_Z_one <- q_out$moc_est_valid_Z_one$moc_pred_A_prime[valid_data$R == 1]
q_prime_Z_one <-
q_out$moc_est_valid_Z_one$moc_pred_A_prime[valid_data$R == 1]
r_prime_Z_one <- r_out$moc_est_valid_Z_one$moc_pred_A_prime
q_prime_Z_natural <- q_out$moc_est_valid_Z_natural$moc_pred_A_prime[valid_data$R == 1]
q_prime_Z_natural <-
q_out$moc_est_valid_Z_natural$moc_pred_A_prime[valid_data$R == 1]
r_prime_Z_natural <- r_out$moc_est_valid_Z_natural$moc_pred_A_prime

# need pseudo-outcome regressions with intervention set to a contrast
Expand Down Expand Up @@ -208,12 +210,14 @@

# predict u(z, a', w) using intervened data with treatment set A = a'
# NOTE: here, obs_weights should not include two_phase_weights (?)
u_task_valid_z_interv <- sl3::sl3_Task$new(
data = valid_data_z_interv,
weights = "obs_weights",
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
suppressWarnings(
u_task_valid_z_interv <- sl3::sl3_Task$new(
data = valid_data_z_interv,
weights = "obs_weights",
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
)
)

# return partial pseudo-outcome for v nuisance regression
Expand Down Expand Up @@ -352,7 +356,7 @@
# for each index in R with R == 0, add a zero at the same index in eif
new_eif <- rep(NA, length(R))
eif_idx <- 1
for (idx in seq_len(length(R))) {
for (idx in seq_along(R)) {
if (R[idx] == 1) {
new_eif[idx] <- eif[eif_idx]
eif_idx <- eif_idx + 1
Expand Down Expand Up @@ -443,6 +447,16 @@
#' conditions on the one-step estimator to be relaxed. For compatibility with
#' \code{\link[origami]{make_folds}}, this value specified must be greater
#' than or equal to 2; the default is to create 5 folds.
#' @param cv_strat A \code{logical} atomic vector indicating whether V-fold
#' cross-validation should stratify the folds based on the outcome variable.
#' If \code{TRUE}, the folds are stratified by passing the outcome variable to
#' the \code{strata_ids} argument of \code{\link[origami]{make_folds}}. While
#' the default is \code{FALSE}, an override is triggered when the incidence of
#' the binary outcome variable falls below the tolerance in \code{strat_pmin}.
#' @param strat_pmin A \code{numeric} atomic vector indicating a tolerance for
#' the minimum proportion of cases (for a binary outcome variable) below which
#' stratified V-fold cross-validation is invoked if \code{cv_strat} is set to
#' \code{TRUE} (default is \code{FALSE}). The default tolerance is 0.1.
#'
#' @importFrom assertthat assert_that
#' @importFrom stats var weighted.mean
Expand All @@ -462,18 +476,35 @@
w_names,
m_names,
y_bounds,
g_bounds = c(0.01, 0.99),
g_bounds = c(0.005, 0.995),
effect_type = c("interventional", "natural"),
svy_weights = NULL,
cv_folds = 5L) {
cv_folds = 10L,
cv_strat = FALSE,
strat_pmin = 0.1) {
# make sure that more than one fold is specified
assertthat::assert_that(cv_folds > 1L)

# create cross-validation folds
folds <- origami::make_folds(data,
fold_fun = origami::folds_vfold,
V = cv_folds
)
if (cv_strat && data[, mean(Y) <= strat_pmin]) {
# check that outcome is binary for stratified V-fold cross-validation
assertthat::assert_that(data[, all(unique(Y) %in% c(0, 1))])

Check warning on line 491 in R/estimators.R

View check run for this annotation

Codecov / codecov/patch

R/estimators.R#L491

Added line #L491 was not covered by tests

# if outcome is binary and rare, use stratified V-fold cross-validation
folds <- origami::make_folds(
data,
fold_fun = origami::folds_vfold,
V = cv_folds,
strata_ids = data$Y

Check warning on line 498 in R/estimators.R

View check run for this annotation

Codecov / codecov/patch

R/estimators.R#L494-L498

Added lines #L494 - L498 were not covered by tests
)
} else {
# just use standard V-fold cross-validation
folds <- origami::make_folds(
data,
fold_fun = origami::folds_vfold,
V = cv_folds
)
}

# estimate the EIF on a per-fold basis
cv_eif_results <- origami::cross_validate(
Expand Down Expand Up @@ -599,6 +630,16 @@
#' conditions on the TML estimator to be relaxed. Note: for compatibility with
#' \code{\link[origami]{make_folds}}, this value must be greater than or
#' equal to 2; the default is to create 10 folds.
#' @param cv_strat A \code{logical} atomic vector indicating whether V-fold
#' cross-validation should stratify the folds based on the outcome variable.
#' If \code{TRUE}, the folds are stratified by passing the outcome variable to
#' the \code{strata_ids} argument of \code{\link[origami]{make_folds}}. While
#' the default is \code{FALSE}, an override is triggered when the incidence of
#' the binary outcome variable falls below the tolerance in \code{strat_pmin}.
#' @param strat_pmin A \code{numeric} atomic vector indicating a tolerance for
#' the minimum proportion of cases (for a binary outcome variable) below which
#' stratified V-fold cross-validation is invoked if \code{cv_strat} is set to
#' \code{TRUE} (default is \code{FALSE}). The default tolerance is 0.1.
#' @param max_iter A \code{numeric} integer giving the maximum number of steps
#' to be taken for the iterative procedure to construct a TML estimator.
#' @param tiltmod_tol A \code{numeric} indicating the maximum step size to be
Expand Down Expand Up @@ -626,20 +667,37 @@
w_names,
m_names,
y_bounds,
g_bounds = c(0.01, 0.99),
g_bounds = c(0.005, 0.95),
effect_type = c("interventional", "natural"),
svy_weights = NULL,
cv_folds = 5L,
max_iter = 5L,
cv_folds = 10L,
cv_strat = FALSE,
strat_pmin = 0.1,
max_iter = 10L,
tiltmod_tol = 5) {
# make sure that more than one fold is specified
assertthat::assert_that(cv_folds > 1L)

# create cross-validation folds
folds <- origami::make_folds(data,
fold_fun = origami::folds_vfold,
V = cv_folds
)
if (cv_strat && data[, mean(Y) <= strat_pmin]) {
# check that outcome is binary for stratified V-fold cross-validation
assertthat::assert_that(data[, all(unique(Y) %in% c(0, 1))])

Check warning on line 684 in R/estimators.R

View check run for this annotation

Codecov / codecov/patch

R/estimators.R#L684

Added line #L684 was not covered by tests

# if outcome is binary and rare, use stratified V-fold cross-validation
folds <- origami::make_folds(
data,
fold_fun = origami::folds_vfold,
V = cv_folds,
strata_ids = data$Y

Check warning on line 691 in R/estimators.R

View check run for this annotation

Codecov / codecov/patch

R/estimators.R#L687-L691

Added lines #L687 - L691 were not covered by tests
)
} else {
# just use standard V-fold cross-validation
folds <- origami::make_folds(
data,
fold_fun = origami::folds_vfold,
V = cv_folds
)
}

# perform the cv_eif procedure on a per-fold basis
cv_eif_results <- origami::cross_validate(
Expand Down
99 changes: 58 additions & 41 deletions R/fit_mechanisms.R
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ fit_nuisance_u <- function(train_data,
g_star <- g_out$treat_est_train$treat_pred_A_star[train_data$R == 1]
h_prime <- h_out$treat_est_train$treat_pred_A_prime
g_prime <- g_out$treat_est_train$treat_pred_A_prime[train_data$R == 1]
q_prime_Z_natural <- q_out$moc_est_train_Z_natural$moc_pred_A_prime[train_data$R == 1]
q_prime_Z_natural <-
q_out$moc_est_train_Z_natural$moc_pred_A_prime[train_data$R == 1]
r_prime_Z_natural <- r_out$moc_est_train_Z_natural$moc_pred_A_prime

# remove observations that were not sampled in second stage
Expand Down Expand Up @@ -618,12 +619,14 @@ fit_nuisance_u <- function(train_data,
w_names, "A", "Z", "U_pseudo",
"obs_weights"
))
u_task_train <- sl3::sl3_Task$new(
data = u_data_train,
weights = "obs_weights",
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
suppressWarnings(
u_task_train <- sl3::sl3_Task$new(
data = u_data_train,
weights = "obs_weights",
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
)
)

## fit model for nuisance parameter regression on training data
Expand All @@ -640,12 +643,14 @@ fit_nuisance_u <- function(train_data,
w_names, "A", "Z", "U_pseudo",
"obs_weights"
))
u_task_valid <- sl3::sl3_Task$new(
data = u_data_valid,
weights = "obs_weights",
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
suppressWarnings(
u_task_valid <- sl3::sl3_Task$new(
data = u_data_valid,
weights = "obs_weights",
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
)
)

## predict from nuisance parameter regression on validation and training data
Expand Down Expand Up @@ -702,8 +707,10 @@ fit_nuisance_v <- function(train_data,
m_names,
w_names) {
## extract nuisance estimates necessary for this routrine
q_train_prime_Z_one <- q_out$moc_est_train_Z_one$moc_pred_A_prime[train_data$R == 1]
q_valid_prime_Z_one <- q_out$moc_est_valid_Z_one$moc_pred_A_prime[valid_data$R == 1]
q_train_prime_Z_one <-
q_out$moc_est_train_Z_one$moc_pred_A_prime[train_data$R == 1]
q_valid_prime_Z_one <-
q_out$moc_est_valid_Z_one$moc_pred_A_prime[valid_data$R == 1]

# remove observations that were not sampled in second stage
train_data <- train_data[R == 1, ]
Expand Down Expand Up @@ -799,24 +806,28 @@ fit_nuisance_v <- function(train_data,

## build regression tasks for training and validation sets
train_data[, V_pseudo := v_pseudo_train]
v_task_train <- sl3::sl3_Task$new(
data = train_data,
weights = "obs_weights", # NOTE: should not include two_phase_weights
covariates = c("A", w_names),
outcome = "V_pseudo",
outcome_type = "continuous"
suppressWarnings(
v_task_train <- sl3::sl3_Task$new(
data = train_data,
weights = "obs_weights", # NOTE: should not include two_phase_weights
covariates = c("A", w_names),
outcome = "V_pseudo",
outcome_type = "continuous"
)
)
# NOTE: independent implementation from ID sets A to a* as done below
valid_data[, `:=`(
V_pseudo = v_pseudo_valid,
A = contrast[2]
)]
v_task_valid <- sl3::sl3_Task$new(
data = valid_data,
weights = "obs_weights", # NOTE: should not include two_phase_weights
covariates = c("A", w_names),
outcome = "V_pseudo",
outcome_type = "continuous"
suppressWarnings(
v_task_valid <- sl3::sl3_Task$new(
data = valid_data,
weights = "obs_weights", # NOTE: should not include two_phase_weights
covariates = c("A", w_names),
outcome = "V_pseudo",
outcome_type = "continuous"
)
)

## fit regression model for v on training task, get predictions on validation
Expand Down Expand Up @@ -911,8 +922,10 @@ fit_nuisance_d <- function(train_data,
g_prime <- g_out$treat_est_train$treat_pred_A_prime[train_data$R == 1]
u_prime <- u_out$u_train_pred
v_star <- v_out$v_train_pred
q_prime_Z_one <- q_out$moc_est_train_Z_one$moc_pred_A_prime[train_data$R == 1]
q_prime_Z_natural <- q_out$moc_est_train_Z_natural$moc_pred_A_prime[train_data$R == 1]
q_prime_Z_one <-
q_out$moc_est_train_Z_one$moc_pred_A_prime[train_data$R == 1]
q_prime_Z_natural <-
q_out$moc_est_train_Z_natural$moc_pred_A_prime[train_data$R == 1]
r_prime_Z_natural <- r_out$moc_est_train_Z_natural$moc_pred_A_prime

# NOTE: assuming Z in {0,1}; other cases not supported yet
Expand All @@ -926,12 +939,14 @@ fit_nuisance_d <- function(train_data,
)]

# predict u(z, a', w) using intervened data with treatment set A = a'
u_task_train_z_interv <- sl3::sl3_Task$new(
data = train_data_z_interv,
weights = "obs_weights", # NOTE: should not include two_phase_weights
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
suppressWarnings(
u_task_train_z_interv <- sl3::sl3_Task$new(
data = train_data_z_interv,
weights = "obs_weights", # NOTE: should not include two_phase_weights
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
)
)

# return partial pseudo-outcome for v nuisance regression
Expand Down Expand Up @@ -966,12 +981,14 @@ fit_nuisance_d <- function(train_data,

# generate the sl3 task
# NOTE: Purposefully not adding two-phase sampling weights
d_task_train <- sl3::sl3_Task$new(
data = eif_data_train,
weights = "obs_weights", # NOTE: should not include two_phase_weights
covariates = c(w_names, "A", "Z", "Y"),
outcome = "eif",
outcome_type = "continuous"
suppressWarnings(
d_task_train <- sl3::sl3_Task$new(
data = eif_data_train,
weights = "obs_weights", # NOTE: should not include two_phase_weights
covariates = c(w_names, "A", "Z", "Y"),
outcome = "eif",
outcome_type = "continuous"
)
)

## fit model for nuisance parameter regression on training data
Expand Down
Loading