Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Sandwich variance - Issue #140 #141

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
matrix:
config:
- {os: macOS-latest, r: 'release'}
- {os: macOS-latest, r: 'oldrel-1'}
- {os: windows-latest, r: 'release'}
- {os: windows-latest, r: 'oldrel-1'}
- {os: ubuntu-latest, r: 'release'}
Expand Down
10 changes: 6 additions & 4 deletions R/PLNfit-S3methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ standard_error <- function(object, type = c("variational", "jackknife", "sandwic

#' @describeIn standard_error Component-wise standard errors of B in [`PLNfit`]
#' @export
standard_error.PLNfit <- function(object, type = c("variational", "jackknife", "bootstrap", "sandwich"), parameter = c("B", "Omega")) {
standard_error.PLNfit <- function(object, type = c("variational", "sandwich", "jackknife", "bootstrap"), parameter = c("B", "Omega")) {
type <- match.arg(type)
par <- match.arg(parameter)
if (type == "variational" & is.null(attr(object$model_par$B, "variance_variational")))
Expand All @@ -184,14 +184,14 @@ standard_error.PLNfit <- function(object, type = c("variational", "jackknife", "
stop("Jackknife estimation not available: rerun by setting `jackknife = TRUE` in the control list.")
if (type == "bootstrap" & is.null(attr(object$model_par$B, "variance_bootstrap")))
stop("Bootstrap estimation not available: rerun by setting `bootstrap > 0` in the control list.")
if (type == "sandwich")
stop("Sandwich estimator is only available for fixed covariance / precision matrix.")
if (type == "sandwich" & is.null(attr(object$model_par$B, "variance_sandwich")))
stop("Sandwich estimator not available: rerun by setting `sandwich_var = TRUE` in the control list.")
attr(object$model_par[[par]], paste0("variance_", type)) %>% sqrt()
}

#' @describeIn standard_error Component-wise standard errors of B in [`PLNfit_fixedcov`]
#' @export
standard_error.PLNfit_fixedcov <- function(object, type = c("variational", "jackknife", "bootstrap", "sandwich"), parameter = c("B", "Omega")) {
standard_error.PLNfit_fixedcov <- function(object, type = c("variational", "sandwich", "jackknife", "bootstrap", "sandwich"), parameter = c("B", "Omega")) {
type <- match.arg(type)
par <- match.arg(parameter)
if (par == "Omega")
Expand All @@ -202,5 +202,7 @@ standard_error.PLNfit_fixedcov <- function(object, type = c("variational", "jack
stop("Jackknife estimation not available: rerun by setting `jackknife = TRUE` in the control list.")
if (type == "bootstrap" & is.null(attr(object$model_par$B, "variance_bootstrap")))
stop("Bootstrap estimation not available: rerun by setting `bootstrap > 0` in the control list.")
if (type == "sandwich" & is.null(attr(object$model_par$B, "variance_sandwich")))
stop("Sandwich estimator not available: rerun by setting `sandwich_var = TRUE` in the control list.")
attr(object$model_par[[par]], paste0("variance_", type)) %>% sqrt()
}
95 changes: 21 additions & 74 deletions R/PLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ PLNfit <- R6Class(
variance_variational = function(X, config = config_default_nlopt) {
## Variance of B for n data points
fisher <- Matrix::bdiag(lapply(1:self$p, function(j) {
crossprod(X, private$A[, j] * X) # t(X) %*% diag(A[, i]) %*% X
crossprod(X, private$A[, j] * X) # t(X) %*% diag(A[, j]) %*% X
}))
vcov_B <- tryCatch(Matrix::solve(fisher), error = function(e) {e})
if (is(vcov_B, "error")) {
Expand Down Expand Up @@ -220,22 +220,11 @@ PLNfit <- R6Class(

compute_vcov_from_resamples = function(resamples){
B_list = resamples %>% map("B")
#print (B_list)
vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){
param_ests_for_col = B_list %>% map(~.x[, B_col])
param_ests_for_col = do.call(rbind, param_ests_for_col)
#print (param_ests_for_col)
row_vcov = cov(param_ests_for_col)
})
#print ("vcov blocks")
#print (vcov_B)

#B_vcov <- resamples %>% map("B") %>% map(~( . )) %>% reduce(cov)

#var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>%
# `dimnames<-`(dimnames(private$B))
#B_hat <- private$B[,] ## strips attributes while preserving names

vcov_B = Matrix::bdiag(vcov_B) %>% as.matrix()

rownames(vcov_B) <- colnames(vcov_B) <-
Expand All @@ -244,33 +233,32 @@ PLNfit <- R6Class(
## Hack to make sure that species is first and varies slowest
apply(1, paste0, collapse = "_")

#print (pheatmap::pheatmap(vcov_B, cluster_rows=FALSE, cluster_cols=FALSE))


#names = lapply(bootstrapped_df$cov_mat, function(m){ colnames(m)}) %>% unlist()
#rownames(bootstrapped_vhat) = names
#colnames(bootstrapped_vhat) = names

vcov_B = methods::as(vcov_B, "dgCMatrix")

return(vcov_B)
},

vcov_sandwich_B = function(Y, X) {
vcov_sand <- get_sandwich_variance_B(Y, X, private$A,
private$S, private$Sigma, diag(private$Omega)
)
attr(private$B, "vcov_sandwich") <- vcov_sand
attr(private$B, "variance_sandwich") <- matrix(diag(vcov_sand), nrow = self$d, ncol = self$p,
dimnames = dimnames(private$B))
},

variance_jackknife = function(Y, X, O, w, config = config_default_nlopt) {
jacks <- future.apply::future_lapply(seq_len(self$n), function(i) {
data <- list(Y = Y[-i, , drop = FALSE],
X = X[-i, , drop = FALSE],
O = O[-i, , drop = FALSE],
w = w[-i])
args <- list(data = data,
# params = list(B = private$B,
# M = matrix(0, self$n-1, self$p),
# S = private$S[-i, , drop = FALSE]),
params = do.call(compute_PLN_starting_point, data),
config = config)
optim_out <- do.call(private$optimizer$main, args)
optim_out[c("B", "Omega")]
})
}, future.seed = TRUE, future.scheduling = structure(TRUE, ordering = "random"))

B_jack <- jacks %>% map("B") %>% reduce(`+`) / self$n
var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>%
Expand Down Expand Up @@ -298,21 +286,17 @@ PLNfit <- R6Class(
O = O[resample, , drop = FALSE],
w = w[resample])
if (config$backend == "torch") # Convert data to torch tensors
data <- lapply(data, torch_tensor, device = config$device) # list with Y, X, O, w

#print (data$Y$device)
data <- lapply(data, torch_tensor, device = config$device)

args <- list(data = data,
# params = list(B = private$B, M = matrix(0,self$n,self$p), S = private$S[resample, ]),
params = do.call(compute_PLN_starting_point, data),
config = config)
if (config$backend == "torch") # Convert data to torch tensors
args$params <- lapply(args$params, torch_tensor, requires_grad = TRUE, device = config$device) # list with B, M, S

optim_out <- do.call(private$optimizer$main, args)
#print (optim_out)
optim_out[c("B", "Omega", "monitoring")]
})
}, future.seed = TRUE, future.scheduling = structure(TRUE, ordering = "random"))

B_boots <- boots %>% map("B") %>% reduce(`+`) / n_resamples
attr(private$B, "variance_bootstrap") <-
Expand Down Expand Up @@ -455,7 +439,7 @@ PLNfit <- R6Class(
#' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE.
#' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated).
#' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE.
#' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE
#' * sandwich_var boolean indicating whether sandwich estimator should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE.
#' * trace integer for verbosity. should be > 1 to see output in post-treatments
postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) {
## PARAMATERS DIMNAMES
Expand Down Expand Up @@ -496,6 +480,11 @@ PLNfit <- R6Class(
}
private$variance_bootstrap(responses, covariates, offsets, weights, n_resamples=config_post$bootstrap, config = config_optim)
}
## 5. compute and store matrix of standard variances for B with sandwich correction approximation
if (config_post$sandwich_var) {
if(config_post$trace > 1) cat("\n\tComputing sandwich estimator of the variance...")
private$vcov_sandwich_B(responses, covariates)
}
},

#' @description Predict position, scores or observations of new data.
Expand Down Expand Up @@ -920,25 +909,8 @@ PLNfit_fixedcov <- R6Class(
optim_out <- do.call(private$optimizer$main, args)
do.call(self$update, optim_out)
private$Sigma <- solve(optim_out$Omega)
},

#' @description Update R2, fisher and std_err fields after optimization
#' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details
#' @param config_optim a list for controlling the optimization parameter. See details
#' @details The list of parameters `config` controls the post-treatment processing, with the following entries:
#' * trace integer for verbosity. should be > 1 to see output in post-treatments
#' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE.
#' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated).
#' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE.
#' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE
postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) {
super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel)
## 6. compute and store matrix of standard variances for B with sandwich correction approximation
if (config_post$sandwich_var) {
if(config_post$trace > 1) cat("\n\tComputing sandwich estimator of the variance...")
private$vcov_sandwich_B(responses, covariates)
}
}

),
private = list(
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Expand Down Expand Up @@ -976,39 +948,14 @@ PLNfit_fixedcov <- R6Class(
config = config)
optim_out <- do.call(private$optimizer$main, args)
optim_out[c("B", "Omega")]
}, future.seed = TRUE)
}, future.seed = TRUE, future.scheduling = structure(TRUE, ordering = "random"))

B_jack <- jacks %>% map("B") %>% reduce(`+`) / self$n
var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>%
`dimnames<-`(dimnames(private$B))
B_hat <- private$B[,] ## strips attributes while preserving names
attr(private$B, "bias") <- (self$n - 1) * (B_jack - B_hat)
attr(private$B, "variance_jackknife") <- (self$n - 1) / self$n * var_jack
},

vcov_sandwich_B = function(Y, X) {
getMat_iCnB <- function(i) {
a_i <- as.numeric(private$A[i, ])
s2_i <- as.numeric(private$S[i, ]**2)
# omega <- as.numeric(1/diag(private$Sigma))
# diag_mat_i <- diag(1/a_i + s2_i^2 / (1 + s2_i * (a_i + omega)))
diag_mat_i <- diag(1/a_i + .5 * s2_i^2)
solve(private$Sigma + diag_mat_i)
}
YmA <- Y - private$A
Dn <- matrix(0, self$d*self$p, self$d*self$p)
Cn <- matrix(0, self$d*self$p, self$d*self$p)
for (i in 1:self$n) {
xxt_i <- tcrossprod(X[i, ])
Cn <- Cn - kronecker(getMat_iCnB(i) , xxt_i) / (self$n)
Dn <- Dn + kronecker(tcrossprod(YmA[i,]), xxt_i) / (self$n)
}
Cn_inv <- solve(Cn)
dim_names <- dimnames(attr(private$B, "vcov_variational"))
vcov_sand <- ((Cn_inv %*% Dn %*% Cn_inv) / self$n) %>% `dimnames<-`(dim_names)
attr(private$B, "vcov_sandwich") <- vcov_sand
attr(private$B, "variance_sandwich") <- matrix(diag(vcov_sand), nrow = self$d, ncol = self$p,
dimnames = dimnames(private$B))
}
),
active = list(
Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ cpp_test_packing <- function() {
.Call('_PLNmodels_cpp_test_packing', PACKAGE = 'PLNmodels')
}

get_sandwich_variance_B <- function(Y, X, A, S, Sigma, Diag_Omega) {
.Call('_PLNmodels_get_sandwich_variance_B', PACKAGE = 'PLNmodels', Y, X, A, S, Sigma, Diag_Omega)
}

2 changes: 1 addition & 1 deletion R/utils-zipln.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ extract_model_zi <- function(call, envir) {
call_args <- c(as.list(call_args), list(xlev = attr(call$formula, "xlevels"), na.action = NULL))

## Extract terms for ZI and PLN components
terms <- .extract_terms_zi(as.formula(eval(call$formula, env = envir)))
terms <- .extract_terms_zi(as.formula(eval(call$formula, envir = envir)))
## eval the call in the parent environment with adjustement due to ZI terms
call_args$formula <- terms$formula
frame <- do.call(stats::model.frame, call_args, envir = envir)
Expand Down
9 changes: 6 additions & 3 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,26 @@ config_post_default_PLNLDA <-
jackknife = FALSE,
bootstrap = 0L,
rsquared = TRUE,
variational_var = FALSE
variational_var = FALSE,
sandwich_var = FALSE
)

config_post_default_PLNPCA <-
list(
jackknife = FALSE,
bootstrap = 0L,
rsquared = TRUE,
variational_var = FALSE
variational_var = FALSE,
sandwich_var = FALSE
)

config_post_default_PLNmixture <-
list(
jackknife = FALSE,
bootstrap = 0L,
rsquared = TRUE,
variational_var = FALSE
variational_var = FALSE,
sandwich_var = FALSE
)

status_to_message <- function(status) {
Expand Down
51 changes: 37 additions & 14 deletions inst/check/variance_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,55 @@ library(tidyverse)
library(PLNmodels)
set.seed(1234)

nb_cores <- 10
options(future.fork.enable = TRUE)
rmse <- function(theta_hat, theta_star) {
sqrt(mean((theta_hat - theta_star)^2))
}

params <- PLNmodels:::create_parameters(n = 50, p = 10, d = 1, depths = 1e3)
params <- PLNmodels:::create_parameters(n = 100, p = 10, d = 1, depths = 1e3)
X <- params$X
B <- params$B
Y <- rPLN(n = nrow(X), mu = X %*% B, Sigma = params$Sigma, depths = params$depths)
conf <- list(variational_var = TRUE, jackknife = TRUE, bootstrap = FALSE, sandwich_var = TRUE)

data <- prepare_data(Y, X, offset = "none")
logO <- attr(Y, "offsets")

conf <- list(variational_var = TRUE, jackknife = TRUE, bootstrap = nrow(Y))
future::plan("multicore", workers = nb_cores)
model <- PLN(Abundance ~ 0 + . + offset(logO), data = data, control = PLN_param(config_post = conf))
future::plan("sequential")
one_simu <- function(s) {

Y <- rPLN(n = nrow(X), mu = X %*% B, Sigma = params$Sigma, depths = params$depths)
data <- prepare_data(Y, X, offset = "none")
logO <- attr(Y, "offsets")
model <- PLN(Abundance ~ 0 + . + offset(logO), data = data, control = PLN_param(trace = FALSE, config_post = conf))

B_hat <- coef(model)
vcov_sandwich <- attr(coef(model), "vcov_sandwich")
vcov_jackknife <- attr(coef(model), "vcov_sandwich")
vcov_variational <- attr(coef(model), "vcov_variational")

data.frame(rmse = rmse(B_hat, B),
cover_sandwich = mean(abs(as.numeric(B_hat - B) %*% solve(chol(vcov_sandwich))) < 1.96),
cover_jackknife = mean(abs(as.numeric(B_hat - B) %*% solve(chol(vcov_jackknife))) < 1.96),
cover_variational = mean(abs(as.numeric(B_hat - B) %*% solve(chol(vcov_variational))) < 1.96),
simu = s)
}

res <- do.call(rbind, lapply(1:50, one_simu))

boxplot(res$cover_sandwich, res$cover_jackknife, res$cover_variational)

### Single test

B_hat <- coef(model)
B_se_var <- standard_error(model, "variational")
B_se_jk <- standard_error(model, "jackknife")
B_se_bt <- standard_error(model, "bootstrap")
B_se_sw <- standard_error(model, "sandwich")

Y <- rPLN(n = nrow(X), mu = X %*% B, Sigma = params$Sigma, depths = params$depths)
data <- prepare_data(Y, X, offset = "none")
logO <- attr(Y, "offsets")
model <- PLN(Abundance ~ 0 + . + offset(logO), data = data, control = PLN_param(config_post = conf))

data.frame(
B = rep(c(B), 3),
B_hat = rep(c(B_hat), 3),
se = c(B_se_var, B_se_jk, B_se_bt),
method = rep(c("variational", "jackknife", "bootstrap"), each = length(c(B))) ) %>%
se = c(B_se_var, B_se_jk, B_se_sw),
method = rep(c("variational", "jackknife", "sandwich"), each = length(c(B))) ) %>%
ggplot(aes(x = B, y = B_hat)) +
geom_errorbar(aes(ymin = B_hat - 2 * se,
ymax = B_hat + 2 * se), color = "blue") + facet_wrap(~ method) +
Expand Down
2 changes: 1 addition & 1 deletion man/PLNfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading