Skip to content

Commit

Permalink
Update TIGER
Browse files Browse the repository at this point in the history
  • Loading branch information
cchen22 committed Sep 17, 2023
1 parent c1123dc commit 0432ddb
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 61 deletions.
185 changes: 125 additions & 60 deletions R/TIGER.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#' @param prior A prior regulatory network in adjacency matrix format. Rows are TFs
#' and columns target genes.
#' @param method Method used for Bayesian inference. "VB" or "MCMC". Defaults to "VB".
#' @param TFexpressed TF mRNA needs to be expressed or not. Defaults to TRUE.
#' @param signed Prior network is signed or not. Defaults to TRUE.
#' @param baseline Include baseline or not. Defaults to TRUE.
#' @param psis_loo Use pareto smoothed importance sampling leave-one-out cross
Expand All @@ -39,6 +40,7 @@
#' @param b_alpha Hyperparameter of edge weight W. Default = 1.
#' @param sigmaZ Standard deviation of TF activity Z. Default = 10.
#' @param sigmaB Standard deviation of baseline term. Default = 1.
#' @param tol Convergence tolerance on ELBO.. Default = 0.005.
#'
#' @return A TIGER list object.
#' * W is the estimated regulatory network, but different from prior network,
Expand All @@ -54,22 +56,35 @@
#' data(TIGER_expr)
#' data(TIGER_prior)
#' TIGER(TIGER_expr,TIGER_prior)
TIGER = function(expr,prior,method="VB",
signed=TRUE,baseline=TRUE,psis_loo = FALSE,
seed=123,out_path=NULL,out_size = 300,
a_sigma=1,b_sigma=1,a_alpha=1,b_alpha=1,sigmaZ=10,sigmaB=1){
TIGER = function(expr,prior,method="VB",TFexpressed = TRUE,
signed=TRUE,baseline=TRUE,psis_loo = FALSE,
seed=123,out_path=NULL,out_size = 300,
a_sigma=1,b_sigma=1,a_alpha=1,b_alpha=1,
sigmaZ=10,sigmaB=1,tol = 0.005){
# check data
sample.name = colnames(expr)
TF.name = intersect(rownames(prior),rownames(expr)) # TF needs to express
TG.name = intersect(rownames(expr),colnames(prior))
if (TFexpressed){
TF.name = sort(intersect(rownames(prior),rownames(expr))) # TF needs to express
}else{
TF.name = sort(rownames(prior))
}
TG.name = sort(intersect(rownames(expr),colnames(prior)))
if (length(TG.name)==0 | length(TF.name)==0){
stop("No matched gene names in the two inputs...")
}

#0. prepare stan input
if (signed){
prior = prior.pp(prior[TF.name,TG.name],expr)
P = prior
prior2 = prior.pp(prior[TF.name,TG.name],expr)
if (nrow(prior2)!=length(TF.name)){
TFnotExp = setdiff(TF.name,rownames(prior2))
TFnotExpEdge = prior[TFnotExp,colnames(prior2),drop=F]
TFnotExpEdge[TFnotExpEdge==1] = 1e-6
prior2 = rbind(prior2,TFnotExpEdge)
prior2 = prior2[order(rownames(prior2)),]
prior2 = prior2[rowSums(prior2!=0)>0,] # remove all zero TFs
}
P = prior2
TF.name = rownames(P)
TG.name = colnames(P)
}else{
Expand All @@ -79,7 +94,7 @@ TIGER = function(expr,prior,method="VB",
n_genes = dim(X)[1]
n_samples = dim(X)[2]
n_TFs = dim(P)[1]

P = as.vector(t(P)) ## row=TG, col=TF
P_zero = as.array(which(P==0))
P_ones = as.array(which(P!=0))
Expand All @@ -95,58 +110,82 @@ TIGER = function(expr,prior,method="VB",
sign = as.integer(signed)
baseline = as.integer(baseline)
psis_loo = as.integer(psis_loo)

data_to_model = list(n_genes = n_genes, n_samples = n_samples, n_TFs = n_TFs,X = as.matrix(X), P = P,
P_zero = P_zero, P_ones = P_ones,P_negs = P_negs, P_poss = P_poss, P_blur = P_blur,
n_zero = n_zero,n_ones = n_ones,n_negs = n_negs, n_poss = n_poss, n_blur = n_blur,
n_all = n_all,sign = sign,baseline = baseline,psis_loo = psis_loo,
sigmaZ = sigmaZ, sigmaB = sigmaB,
a_sigma = a_sigma,b_sigma = b_sigma,a_alpha = a_alpha,b_alpha = b_alpha)

#1. compile stan model, only once
f = cmdstanr::write_stan_file(TIGER_C) # save to .stan file in root folder
#mod = cmdstanr::cmdstan_model(f,cpp_options = list(stan_threads = TRUE)) # compile stan program, allow within-chain parallel
mod = cmdstanr::cmdstan_model(f)

#2. run VB or MCMC
if (method=="VB"){
fit <- mod$variational(data = data_to_model, algorithm = "meanfield",seed = seed,
iter = 50000, tol_rel_obj = 0.005,output_samples = out_size)
iter = 50000, tol_rel_obj = tol,output_samples = out_size)
}else if (method=="MCMC"){
fit <- mod$sample(data = data_to_model,chains=1,seed = seed,max_treedepth=10,
iter_warmup = 1000,iter_sampling=out_size,adapt_delta=0.99)
}

## optional: save stan object
if (!is.null(out_path)){
fit$save_object(paste0(out_path,"fit_",seed,".rds"))
fit$save_output_files(out_path)
}

#3. posterior distributions

## point summary of W
W_sample = fit$draws("W",format = "draws_matrix") ## matrix, each row is a sample of vectorized matrix
one_ind = which(colSums(W_sample==0)<nrow(W_sample)) ## index for all-zero columns
W_pos = colMeans(W_sample[,one_ind]) ## average W
W_pos2 = rep(0,ncol(W_sample))
W_pos2[one_ind] = W_pos
W_pos = matrix(W_pos2,nrow = n_genes,ncol = n_TFs) ## convert to matrix genes*TFs


## point summary of W non-zero elements
print("Draw sample from W matrix...")
W_pos = rep(0,n_all)
if (signed){
W_negs = fit$summary("W_negs","mean")$mean
W_pos[P_negs] = W_negs
rm("W_negs")
gc()

W_poss = fit$summary("W_poss","mean")$mean
W_pos[P_poss] = W_poss
rm("W_poss")
gc()

W_blur = fit$summary("W_blur","mean")$mean
W_pos[P_blur] = W_blur
rm("W_blur")
gc()

}else{
W_ones = fit$summary("W_ones","mean")$mean
gc()
W_pos[P_ones] = W_ones
rm(list = c("W_ones"))
gc()
}
W_pos = matrix(W_pos,nrow = n_genes,ncol = n_TFs)
gc()

## point summary of Z
Z_sample = fit$draws("Z",format = "draws_matrix")
Z_pos = colMeans(Z_sample) ## average Z
print("Draw sample from Z matrix...")
Z_pos = fit$summary("Z","mean")$mean
gc()
Z_pos = matrix(Z_pos,nrow = n_TFs,ncol = n_samples) ## convert to matrix TFs*samples

gc()

## rescale
IZ = Z_pos*(apply(abs(W_pos),2,sum)/apply(W_pos!=0,2,sum))
IW = t(t(W_pos)*apply(Z_pos,1,sum)/n_samples)

## output
rownames(IW) = TG.name
colnames(IW) = TF.name
rownames(IZ) = TF.name
colnames(IZ) = sample.name

# check model fitting
if (psis_loo){
message("Pareto Smooth Importance Sampling...")
Expand All @@ -159,12 +198,13 @@ TIGER = function(expr,prior,method="VB",
loocv = NA
elpd_loo = NA
}

# output
tiger_fit = list(W = IW, Z = IZ,
TF.name = TF.name, TG.name = TG.name, sample.name = sample.name,
TF.name = TF.name, TG.name = TG.name,
sample.name = sample.name,
loocv = loocv, elpd_loo=elpd_loo)

return(tiger_fit)
}

Expand Down Expand Up @@ -225,7 +265,7 @@ el2regulon = function(el) {
tfmode = stats::setNames(regulon$weight, regulon$to)
list(tfmode = tfmode, likelihood = rep(1, length(tfmode)))
})

return(viper_regulons)
}

Expand Down Expand Up @@ -253,33 +293,33 @@ adj2regulon = function(adj){
#' @export
#'
prior.pp = function(prior,expr){

# filter tfs and tgs
tf = intersect(rownames(prior),rownames(expr)) ## TF needs to express
tg = intersect(colnames(prior),rownames(expr))
all.gene = unique(c(tf,tg))

# create coexp net
coexp = GeneNet::ggm.estimate.pcor(t(expr[all.gene,]), method = "static")
diag(coexp)= 0

# prior and coexp nets
P_ij = prior[tf,tg] ## prior ij
C_ij = coexp[tf,tg]*abs(P_ij) ## coexpression ij

# signs
sign_P = sign(P_ij) ## signs in prior
sign_C = sign(C_ij) ## signs in coexp

# blurred edge index
blurs = which((sign_P*sign_C)<0,arr.ind = T) ## inconsistent edges
P_ij[blurs] = 1e-6

# remove all zero TFs (in case prior has all zero TFs)
A_ij = P_ij
A_ij = A_ij[rowSums(A_ij!=0)>0,]
A_ij = A_ij[,colSums(A_ij!=0)>0]

return(A_ij)
}

Expand Down Expand Up @@ -309,7 +349,7 @@ prior.pp = function(prior,expr){

# stan model, conditional likelihood
TIGER_C <-
'
'
data {
int<lower=0> n_genes; // Number of genes
int<lower=0> n_samples; // Number of samples
Expand Down Expand Up @@ -358,35 +398,42 @@ parameters {
}
transformed parameters {
matrix[n_genes, n_TFs] W; // Regulatory netwrok W
vector[n_all] W_vec; // Regulatory vector W_vec
vector[sign ? n_negs : 0] W_negs;
vector[sign ? n_poss : 0] W_poss;
vector[sign ? n_blur : 0] W_blur;
vector[sign ? 0 : n_ones] W_ones;
W_vec[P_zero]=rep_vector(0,n_zero);
if (sign) {
vector[n_negs] W_negs = beta3.*sqrt(alpha3); // Regulatory network negative edge weight
vector[n_poss] W_poss = beta2.*sqrt(alpha2); // Regulatory network positive edge weight
vector[n_blur] W_blur = beta0.*sqrt(alpha0); // Regulatory network blurred edge weight
W_negs = beta3.*sqrt(alpha3); // Regulatory network negative edge weight
W_poss = beta2.*sqrt(alpha2); // Regulatory network positive edge weight
W_blur = beta0.*sqrt(alpha0); // Regulatory network blurred edge weight
}else{
W_ones = beta1.*sqrt(alpha1); // Regulatory network non-zero edge weight
}
}
model {
// local parameters
vector[n_all] W_vec; // Regulatory vector W_vec
W_vec[P_zero]=rep_vector(0,n_zero);
if (sign){
W_vec[P_negs]=W_negs;
W_vec[P_poss]=W_poss;
W_vec[P_blur]=W_blur;
}else{
vector[n_ones] W_ones = beta1.*sqrt(alpha1); // Regulatory network non-zero edge weight
W_vec[P_ones]=W_ones;
}
W = to_matrix(W_vec,n_genes,n_TFs); // by column
matrix[n_genes,n_samples] mu = W*Z; // mu for gene expression X
matrix[n_genes, n_TFs] W=to_matrix(W_vec,n_genes,n_TFs); // by column
matrix[n_genes,n_samples] mu=W*Z; // mu for gene expression X
if (baseline){
matrix[n_genes,n_samples] mu0; // baseline
mu0 = rep_matrix(b0,n_samples);
mu = mu + mu0;
matrix[n_genes,n_samples] mu0=rep_matrix(b0,n_samples);
mu=mu + mu0;
}
vector[n_genes*n_samples] X_mu = to_vector(mu);
vector[n_genes*n_samples] X_sigma = to_vector(rep_matrix(sqrt(sigma2),n_samples));
}
model {
// priors
sigma2 ~ inv_gamma(a_sigma,b_sigma);
Expand Down Expand Up @@ -417,10 +464,28 @@ model {
}
generated quantities {
vector[n_genes*n_samples] log_lik;
vector[psis_loo ? n_genes*n_samples : 0] log_lik;
if (psis_loo){
// redefine X_mu, X_sigma; this is ugly because X_mu, X_sigma are temp variables
vector[n_all] W_vec; // Regulatory vector W_vec
W_vec[P_zero]=rep_vector(0,n_zero);
if (sign){
W_vec[P_negs]=W_negs;
W_vec[P_poss]=W_poss;
W_vec[P_blur]=W_blur;
}else{
W_vec[P_ones]=W_ones;
}
matrix[n_genes, n_TFs] W=to_matrix(W_vec,n_genes,n_TFs); // by column
matrix[n_genes,n_samples] mu=W*Z; // mu for gene expression X
if (baseline){
matrix[n_genes,n_samples] mu0=rep_matrix(b0,n_samples);
mu=mu + mu0;
}
vector[n_genes*n_samples] X_mu = to_vector(mu);
vector[n_genes*n_samples] X_sigma = to_vector(rep_matrix(sqrt(sigma2),n_samples));
// leave one element out
for (i in 1:n_genes*n_samples){
log_lik[i] = normal_lpdf(X_vec[i]|X_mu[i],X_sigma[i]);
Expand Down
8 changes: 7 additions & 1 deletion man/TIGER.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0432ddb

Please sign in to comment.