Skip to content

Commit

Permalink
Merge pull request #25 from hheiling/M_Step
Browse files Browse the repository at this point in the history
M step
  • Loading branch information
hheiling authored Mar 15, 2023
2 parents b198354 + c33194b commit 345a8c2
Show file tree
Hide file tree
Showing 25 changed files with 369 additions and 105 deletions.
16 changes: 9 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ Package: glmmPen
Type: Package
Title: High Dimensional Penalized Generalized Linear Mixed Models
(pGLMM)
Version: 1.5.2.11
Date: 2022-12-12
Version: 1.5.3.0
Date: 2023-03-14
Authors@R: c(
person("Hillary", "Heiling", email = "hheiling@live.unc.edu", role = c("aut", "cre")),
person("Naim", "Rashid", email = "naim@unc.edu", role = c("aut")),
person("Quefeng", "Li", email = "quefeng@email.unc.edu", role = c("aut")))
person("Naim", "Rashid", email = "nur2@email.unc.edu", role = c("aut")),
person("Quefeng", "Li", email = "quefeng@email.unc.edu", role = c("aut")),
person("Joseph", "Ibrahim", email = "ibrahim@bios.unc.edu", role = c("aut")))
Maintainer: Hillary Heiling <hheiling@live.unc.edu>
Description: Fits high dimensional penalized generalized linear
mixed models using
Expand All @@ -33,10 +34,10 @@ Imports:
ncvreg,
reshape2,
rstan (>= 2.18.1),
rstantools (>= 2.0.0),
stringr,
mvtnorm,
MASS
MASS,
coxme
Depends:
lme4,
bigmemory,
Expand All @@ -55,7 +56,8 @@ NeedsCompilation: yes
Packaged: 2019-01-25 20:03:59 UTC; hheiling
Author: Hillary Heiling [aut, cre],
Naim Rashid [aut],
Quefeng Li [aut]
Quefeng Li [aut],
Joseph Ibrahim [aut]
Suggests:
testthat,
knitr,
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ S3method(sigma,pglmmObj)
S3method(summary,pglmmObj)
export(LambdaSeq)
export(adaptControl)
export(coxphControl)
export(glmm)
export(glmmPen)
export(glmmPen_FA)
Expand All @@ -43,6 +44,8 @@ importFrom(bigmemory,big.matrix)
importFrom(bigmemory,describe)
importFrom(bigmemory,read.big.matrix)
importFrom(bigmemory,write.big.matrix)
importFrom(coxme,VarCorr)
importFrom(coxme,coxme)
importFrom(lme4,VarCorr)
importFrom(lme4,factorize)
importFrom(lme4,findbars)
Expand Down
158 changes: 141 additions & 17 deletions R/E_step.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# only use fixed effects
# ranef_idx: index of which random effects are non-zero as of latest M-step (will not
# sample from random effects that have been penalized out)
# y: response
# y: response. If coxph, y = event indicator
# y_times: If coxph, value of observed times (NULL if not coxph family)
# X: fixed effects covariates
# Znew2: For each group k (k = 1,...,d), Znew2 = Z * Gamma (Gamma = t(chol(sigma)))
# group: group index
Expand All @@ -18,14 +19,22 @@
# d: total number groups
# uold: posterior sample to use for initialization of E-step
# proposal_SD, batch, batch_length, offset_increment: arguments for adaptive random walk
# coxph_options: for Cox Proportional Hazards model, additional parameters of interest


#' @importFrom bigmemory attach.big.matrix describe big.matrix
#' @importFrom rstan sampling extract
E_step = function(coef, ranef_idx, y, X, Znew2, group, offset_fit,
nMC, nMC_burnin, family, link, phi, sig_g,
E_step = function(coef, ranef_idx, y, y_times = NULL, X, Znew2, group, offset_fit,
nMC, nMC_burnin, family, link, phi = 0.0, sig_g = 1.0,
sampler, d, uold, proposal_SD, batch, batch_length,
offset_increment, trace){
offset_increment, trace, coxph_options = NULL){

if((family == "coxph") & !(sampler == "stan")){
stop("'coxph' family currently only supports the 'stan' sampler")
}
if((family == "coxph") & (is.null(coxph_options))){
stop("coxph_options must be of class 'coxphControl' (see coxphControl() documentation) for the 'coxph' family")
}

gibbs_accept_rate = matrix(NA, nrow = d, ncol = nrow(Znew2)/d)

Expand Down Expand Up @@ -82,7 +91,7 @@ E_step = function(coef, ranef_idx, y, X, Znew2, group, offset_fit,
print(gibbs_accept_rate)
}

}else if(sampler == "stan"){
}else if((sampler == "stan") & (family != "coxph")){

u0 = big.matrix(nrow = nMC, ncol = ncol(Znew2), init=0) # use ', init = 0' for sampling within EM algorithm

Expand All @@ -99,10 +108,6 @@ E_step = function(coef, ranef_idx, y, X, Znew2, group, offset_fit,
# Number draws to extract in each chain (after burn-in)
nMC_chain = nMC

# Record last elements of each chain for initialization of next E step
## Restriction: single chain for each E-step
last_draws = matrix(0, nrow = 1, ncol = ncol(Znew2))

# For each group, sample from posterior distribution (sample the alpha values)
for(k in 1:d){
idx_k = which(group == k)
Expand All @@ -121,16 +126,15 @@ E_step = function(coef, ranef_idx, y, X, Znew2, group, offset_fit,
q = length(ranef_idx), # number random effects (or common factors)
eta_fef = as.array(as.numeric(X_k %*% matrix(coef[1:ncol(X)], ncol = 1)) + offset_fit[idx_k]), # fixed effects componenet of linear predictor
y = as.array(y_k), # outcomes for group k
Z = Z_k) # portion of Z matrix corresonding to group k
Z = Z_k) # portion of Z matrix corresponding to group k
}else{ # length(idx_k) > 1
dat_list = list(N = length(idx_k), # Number individuals in group k
q = length(ranef_idx), # number random effects
eta_fef = as.numeric(X_k %*% matrix(coef[1:ncol(X)], ncol = 1) + offset_fit[idx_k]), # fixed effects componenet of linear predictor
y = y_k, # outcomes for group k
Z = Z_k) # portion of Z matrix corresonding to group k
Z = Z_k) # portion of Z matrix corresponding to group k
}


if(family == "gaussian"){

dat_list$sigma = sig_g # standard deviation of normal dist of y
Expand All @@ -150,7 +154,8 @@ E_step = function(coef, ranef_idx, y, X, Znew2, group, offset_fit,
init_lst[[1]] = list(alpha = uold[cols_use])
}

# Avoid excessive warnings when nMC_chain is low in early EM iterations
# Sampling step
# suppressWarnings(): Avoid excessive warnings when nMC_chain is low in early EM iterations
stan_fit = suppressWarnings(rstan::sampling(stan_file, data = dat_list, init = init_lst,
iter = nMC_chain + nMC_burnin,
warmup = nMC_burnin, show_messages = FALSE, refresh = 0,
Expand All @@ -165,10 +170,6 @@ E_step = function(coef, ranef_idx, y, X, Znew2, group, offset_fit,
draws_mat = matrix(stan_out[,1], ncol = 1)
}


# Find last elements for each chain for initialization of next E step
last_draws[1,cols_use] = draws_mat[nMC_chain,]

if(nrow(draws_mat) == nMC){
u0[,cols_use] = draws_mat
}else{ # nrow(draws_mat) > nMC due to ceiling function in 'iter' specification
Expand All @@ -180,6 +181,129 @@ E_step = function(coef, ranef_idx, y, X, Znew2, group, offset_fit,

} # End k for loop

}else if((sampler == "stan") & (family == "coxph")){

# Alternative approach: https://rpubs.com/kaz_yos/surv_stan_piecewise1

# Cox Proportional Hazards family: Calculate cut-points to use for time intervals
## Divide the timeline into J = cut_num intervals such that there are an equal
## (or approximately equal) number of events in each interval
## Note: must have at least one event in each interval (preferably > 2) to be identifiable
cut_num = coxph_options$cut_num
event_total = sum(y)
event_idx = which(y == 1)
# Determine number of events per time interval, event_j
if((event_total %% cut_num) == 0){ # event_total is a factor of cut_num
event_cuts = rep(event_total / cut_num, times = cut_num)
}else{
tmp = event_total %/% cut_num
event_cuts = rep(tmp, times = cut_num)
for(j in 1:(event_total - tmp*cut_num)){
event_cuts[j] = event_cuts[j] + 1
}
}

# warning if only 1 event for an interval, stop if 0 events for an interval
if(any(event_cuts == 1)){
warning("at least one time interval for the piecewise exponential hazard model has only 1 event, ",
"please see the coxphControl() documentation for details and tips on how to fix the issue",
immediate. = TRUE)
}else if(any(event_cuts == 0)){
stop("at least one time interval for the piecewise exponential hazard model has 0 events, ",
"please see the coxphControl() documentation for details and tips on how to fix the issue")
}

cut_pts_idx = numeric(cut_num)
for(j in 1:cut_num){
cut_pts_idx[j] = event_idx[sum(event_cuts[1:j])]
}

cut_points = numeric(cut_num)
for(j in 1:(cut_num-1)){
cut_points[j] = mean(y_times[cut_pts_idx[j]], y_times[cut_pts_idx[j]+1])
}
cut_points[cut_num] = max(y_times) + 1
# cut_points = y_times[cut_pts_idx]

u0 = big.matrix(nrow = nMC, ncol = ncol(Znew2) + length(cut_points), init=0) # use ', init = 0' for sampling within EM algorithm

stan_file = stanmodels$coxph_piecewise_exp_model

# Number draws to extract in each chain (after burn-in)
nMC_chain = nMC

# If necessary, restrict columns of Znew2 matrix to columns associated with non-zero
# latent variables (random effects / latent common factors)
# Also determine relevant rows of u0 matrix to save alpha samples
cols_analyze = NULL
for(k in 1:d){
cols_k = seq(from = k, to = ncol(Znew2), by = d)
cols_analyze = c(cols_analyze,cols_k[ranef_idx])
}
cols_analyze = cols_analyze[order(cols_analyze)]

# Indicator matrix:
## For subject i, determine which columns of the Znew2 matrix are relevant for analyses
## In other words, if subject i in group k, indicate which rows of Znew2 matrix associated with group k
I_mat = matrix(0, nrow = nrow(Znew2), ncol = ncol(Znew2))
for(k in 1:d){
idx_k = which(group == k)
cols_k = seq(from = k, to = ncol(Znew2), by = d)
I_mat[idx_k,cols_k] = 1
}

# Sample the random effects / latent common factors 'alpha': group-specific values needed
# Also sample log-hazard values 'lhaz' for each time interval
## As opposed to other families, sample all (d*q) random effects / (d*r) latent common factors
## together instead of sampling by group. Reasoning: want log-hazard values to be
## consistent regardless of group identity
dat_list = list(N = length(y), # Number of observations
NT = length(cut_points), # Number of time intervals
H = length(ranef_idx)*d, # Number groups times number latent variables (latent random effects or latent common factors)
eta_fef = as.numeric(X %*% matrix(coef[1:ncol(X)], ncol = 1) + offset_fit), # Fixed effects portion of linear predictor
y = y, # event indicator (1 = event, 0 = censor)
obs_t = y_times, # observed times
Z = Znew2[,cols_analyze], # Z * Gamma or Z * B matrix, see calculation in fit_dat_coxph
cutpt = c(0, cut_points), # Time interval boundaries, including 0 as lower bound of first interval
I = I_mat[,cols_analyze], # Indicator matrix, see above calculation
lhaz_prior = coxph_options$lhaz_prior) # Specifies standard deviation of normal prior

# initialize posterior random draws
alpha_idx = cols_analyze
lhaz_idx = (ncol(Znew2)+1):length(uold)
# init: See "rstan::stan" documentation
## Set initial values by providing a list equal in length to the number of chains (1).
## The elements of this list should themselves be named lists, where each of these
## named lists has the name of a parameter and is use to specify the initial values for
# that parameter for the corresponding chain
init_lst = list()
init_lst[[1]] = list(alpha = uold[alpha_idx],
lhaz = uold[lhaz_idx])

# Sampling step
# suppressWarnings(): Avoid excessive warnings when nMC_chain is low in early EM iterations
stan_fit = suppressWarnings(rstan::sampling(stan_file, data = dat_list, init = init_lst,
iter = nMC_chain + nMC_burnin,
warmup = nMC_burnin, show_messages = FALSE, refresh = 0,
chains = 1, cores = 1))

stan_out = as.matrix(stan_fit)
# Check organization of samples
# print(colnames(stan_out)) # first alpha samples, then lhaz samples, then lp__ value
# Exclude lp__ column of output (log density up to a constant)
samp_idx = 1:(length(cols_analyze) + length(cut_points))
draws_mat = stan_out[,samp_idx]
# Specify column locations of u0 matrix to save samples from stan_fit object
u0_idx = c(cols_analyze, ((1:length(cut_points))+ncol(Znew2)))

if(nrow(draws_mat) == nMC){
u0[,u0_idx] = draws_mat
}else{ # nrow(draws_mat) > nMC due to ceiling function in 'iter' specification
start_row = nrow(draws_mat) - nMC + 1
rows_seq = start_row:nrow(draws_mat)
u0[,u0_idx] = draws_mat[rows_seq,]
}

} # End if-else sampler

return(list(u0 = describe(u0), proposal_SD = proposal_SD, gibbs_accept_rate = gibbs_accept_rate,
Expand Down
2 changes: 1 addition & 1 deletion R/M_step.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ M_step = function(y, X, Z, u_address, M, J, group, family, link_int, coef, offse
K = numeric(J_XZ)
for(j in unique(XZ_group)){
idx = which(XZ_group == j)
K[j+1] = length(idx)
K[j+1] = length(idx) # Add 1 because smallest XZ_group value is 0
}

# Number of groups wrt observations
Expand Down
66 changes: 66 additions & 0 deletions R/control_options.R
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,72 @@ adaptControl = function(batch_length = 100, offset = 0){
class = "adaptControl")
}

#' @title Control of Cox Proportional Hazards Model Fitting
#'
#' @description Constructs the control structure for additional parameters needed for
#' the sampling and optimization routines involving the Cox Proportional Hazards model fit algorithm
#'
#' @param cut_num positive integer specifying the number of time intervals to include in
#' the piecewise exponential hazard model assumptions for the sampling step. Default is 8.
#' General recommendation: use between 5 and 10 intervals. See the Details section for
#' additional information.
#' @param lhaz_prior positive numeric value specifying the standard deviation of the
#' multivariate normal prior for the log of the baseline hazard values for each time interval.
#' Default is 3. If encounter convergence issues, the user can consider
#' increasing or decreasing this value (e.g. increase to 4 or decrease to 2 ...).
#'
#' @return Function returns a list inheriting from class \code{optimControl}
#' containing fit and optimization criteria values used in optimization routine.
#'
#' @details In the piecewise exponential hazard model assumption---which is assumed in the
#' sampling step (E-step) for the Cox PH family---there is an assumption that the
#' time line of the data can be cut into \code{cut_num}
#' time intervals and the baseline hazard is constant within
#' each of these time intervals. In the sampling step, we need to estimate
#' these baseline hazard values (specifically, we estimate the log of the baseline
#' hazard values). We determine cut points by specifying the total number of cuts
#' to make (\code{cut_num}) and then specifying time values for cut points such
#' that each time interval has an equal number (or approximately equal number)
#' of events. Each time interval must have at least one event for the model
#' to be identifiable, but more events per time interval is better.
#' Consequently, having too many cut points could result in (i) not having enough
#' events for each time interval and/or (ii) significantly slowing down the
#' sampling step due to requiring the estimation of many log baseline hazard values.
#' Additionally, data with few events could result too few events per time interval
#' even for a small number of cut points. We generally recommend having
#' 8 total time intervals (more broadly, between 5 and 10). Warnings or errors
#' will occur for cases when there are 1 or 0 events for a time interval.
#' If this happens, either adjust the \code{cut_num} value appropriately,
#' or in the case when the data simply has a very small number of events,
#' consider not using this software for your estimation purposes.
#'
#' @export
coxphControl = function(cut_num = 8, lhaz_prior = 3){

#########################################################################################
# Input checks and restrictions
#########################################################################################

# cut_num
if((floor(cut_num) != cut_num) | (cut_num < 1)){
stop("cut_num must be a positive integer")
}

if((cut_num < 5) | (cut_num > 10)){
warning("the glmmPen team recommends that you keep cut_num between 5 and 10; 8 is typically a good cut_num value", immediate. = TRUE)
}

# lhaz_prior
if(lhaz_prior <= 0){
stop("lhaz_prior must be a positive numeric value")
}

# output object
structure(list(cut_num = cut_num, lhaz_prior = lhaz_prior),
class = "coxphControl")

}

#' @title Control of Penalized Generalized Linear Mixed Model Fitting
#'
#' @description Constructs the control structure for the optimization of the penalized mixed model fit algorithm.
Expand Down
10 changes: 6 additions & 4 deletions R/fit_dat.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ fit_dat = function(dat, lambda0 = 0, lambda1 = 0,
}else{
vars = var_start
cov = var = matrix(vars, ncol = 1)
gamma = matrix(sqrt(var), ncol = 1)
gamma = matrix(sqrt(vars), ncol = 1)
}

if(trace >= 1){
Expand Down Expand Up @@ -569,8 +569,10 @@ fit_dat = function(dat, lambda0 = 0, lambda1 = 0,
out$warnings = "Error in M step: coefficient values diverged"
}
}else if(randInt_issue == 1){
warning("Error in model fit: Random intercept variance is too small, indicating that this model \n
should be fit using traditional generalized linear model techniques.", immediate. = TRUE)
warning("Error in model fit: Random intercept variance is too small, indicating either that
there are high correlations among the covariates (if so, consider reducing these correlations
or changing the Elastic Net alpha value) or that this model should be fit
using traditional generalized linear model techniques.", immediate. = TRUE)
out$warnings = "Error in model fit: random intercept variance becomes too small, model should be fit using regular generalized linear model techniques"
}

Expand Down Expand Up @@ -661,7 +663,7 @@ fit_dat = function(dat, lambda0 = 0, lambda1 = 0,
Znew2[group == j,seq(j, ncol(Z), by = d)] = Z[group == j,seq(j, ncol(Z), by = d)]%*%gamma
}

# Initial points for Metropolis within Gibbs E step algorithms
# Initial points for E step sampling algorithms
uold = as.numeric(u0[nrow(u0),])
# if random effect penalized out in past model / in previous M-step, do not
# collect posterior samples for this random effect
Expand Down
Loading

0 comments on commit 345a8c2

Please sign in to comment.