diff --git a/R/ComBat.R b/R/ComBat.R index 9d6e100..aed307c 100644 --- a/R/ComBat.R +++ b/R/ComBat.R @@ -7,7 +7,9 @@ #' #' @param dat Genomic measure matrix (dimensions probe x sample) - for example, expression matrix #' @param batch {Batch covariate (only one batch allowed)} -#' @param mod Model matrix for outcome of interest and other covariates besides batch +#' @param dat_test (Optional) Independent test data of the same nature than dat -e.g. expression matrix- and with the same rows. Useful for machine learning applications. +#' @param batch_test (Optional) Batch covariate of dat_test. +#' @param mod (Optional) Model matrix for outcome of interest and other covariates besides batch #' @param par.prior (Optional) TRUE indicates parametric adjustments will be used, FALSE indicates non-parametric adjustments will be used #' @param prior.plots (Optional) TRUE give prior plots with black as a kernel estimate of the empirical batch effect density and red as the parametric #' @param mean.only (Optional) FALSE If TRUE ComBat only corrects the mean of the batch effect (no scale adjustment) @@ -40,267 +42,339 @@ #' # reference-batch version, with covariates #' combat_edata3 = ComBat(dat=edata, batch=batch, mod=mod, par.prior=TRUE, ref.batch=3) #' +#' # Training and test +#' te_idx <- c(1:4,9:11,18:22,47:50) #Test indices +#' combat_edata4 = ComBat(dat=edata[,-te_idx], batch=batch[-te_idx],dat_test=edata[,te_idx], batch_test=batch[te_idx]) +#' #' @export #' -ComBat <- function(dat, batch, mod = NULL, par.prior = TRUE, prior.plots = FALSE, - mean.only = FALSE, ref.batch = NULL, BPPARAM = bpparam("SerialParam")) { - if(length(dim(batch))>1){ - stop("This version of ComBat only allows one batch variable") - } ## to be updated soon! - - ## coerce dat into a matrix - dat <- as.matrix(dat) - - ## find genes with zero variance in any of the batches - batch <- as.factor(batch) - zero.rows.lst <- lapply(levels(batch), function(batch_level){ - if(sum(batch==batch_level)>1){ - return(which(apply(dat[, batch==batch_level], 1, function(x){var(x)==0}))) - }else{ - return(which(rep(1,3)==2)) - } - }) - zero.rows <- Reduce(union, zero.rows.lst) - keep.rows <- setdiff(1:nrow(dat), zero.rows) - - if (length(zero.rows) > 0) { - cat(sprintf("Found %d genes with uniform expression within a single batch (all zeros); these will not be adjusted for batch.\n", length(zero.rows))) - # keep a copy of the original data matrix and remove zero var rows - dat.orig <- dat - dat <- dat[keep.rows, ] - } + + + +ComBat <- function(dat, batch, dat_test=NULL, batch_test=NULL, + mod = NULL, par.prior = TRUE, prior.plots = FALSE, mean.only = FALSE, + ref.batch = NULL, BPPARAM = bpparam("SerialParam")) { - ## make batch a factor and make a set of indicators for batch - if(any(table(batch)==1)){mean.only=TRUE} - if(mean.only==TRUE){ - message("Using the 'mean only' version of ComBat") + + if(length(dim(batch))>1||length(dim(batch_test))>1){ + stop("This version of ComBat only allows one batch variable") + } ## to be updated soon! + + + ## coerce dat into a matrix + dat <- as.matrix(dat) + + ## find genes with zero variance in any of the batches + batch <- as.factor(batch) + zero.rows.lst <- lapply(levels(batch), function(batch_level){ + if(sum(batch==batch_level)>1){ + return(which(apply(dat[, batch==batch_level], 1, function(x){var(x)==0}))) + }else{ + return(which(rep(1,3)==2)) } - - batchmod <- model.matrix(~-1+batch) - if (!is.null(ref.batch)){ - ## check for reference batch, check value, and make appropriate changes - if (!(ref.batch%in%levels(batch))) { - stop("reference level ref.batch is not one of the levels of the batch variable") - } - message("Using batch =",ref.batch, "as a reference batch (this batch won't change)") - ref <- which(levels(as.factor(batch))==ref.batch) # find the reference - batchmod[,ref] <- 1 - } else { - ref <- NULL + }) + zero.rows <- Reduce(union, zero.rows.lst) + keep.rows <- setdiff(1:nrow(dat), zero.rows) + + if (length(zero.rows) > 0) { + cat(sprintf("Found %d genes with uniform expression within a single batch (all zeros); these will not be adjusted for batch.\n", length(zero.rows))) + # keep a copy of the original data matrix and remove zero var rows + dat.orig <- dat + dat <- dat[keep.rows, ] + } + + ## make batch a factor and make a set of indicators for batch + if(any(table(batch)==1)){mean.only=TRUE} + if(mean.only==TRUE){ + message("Using the 'mean only' version of ComBat") + } + + batchmod <- model.matrix(~-1+batch) + if (!is.null(ref.batch)){ + ## check for reference batch, check value, and make appropriate changes + if (!(ref.batch%in%levels(batch))) { + stop("reference level ref.batch is not one of the levels of the batch variable") } - message("Found", nlevels(batch), "batches") + message("Using batch =",ref.batch, "as a reference batch (this batch won't change)") + ref <- which(levels(as.factor(batch))==ref.batch) # find the reference + batchmod[,ref] <- 1 + } else { + ref <- NULL + } + message("Found", nlevels(batch), "batches") - ## A few other characteristics on the batches - n.batch <- nlevels(batch) - batches <- list() - for (i in 1:n.batch) { - batches[[i]] <- which(batch == levels(batch)[i]) - } # list of samples in each batch - n.batches <- sapply(batches, length) - if(any(n.batches==1)){ - mean.only=TRUE - message("Note: one batch has only one sample, setting mean.only=TRUE") + ## A few other characteristics on the batches + n.batch <- nlevels(batch) + batches <- list() + for (i in 1:n.batch) batches[[i]] <- which(batch == levels(batch)[i]) + # list of samples in each batch + n.batches <- sapply(batches, length) + if(any(n.batches==1)){ + mean.only=TRUE + message("Note: one batch has only one sample, setting mean.only=TRUE") + } + n.array <- sum(n.batches) + ## combine batch variable and covariates + design <- cbind(batchmod,mod) + + ## check for intercept in covariates, and drop if present + check <- apply(design, 2, function(x) all(x == 1)) + if(!is.null(ref)){ + check[ref] <- FALSE + } ## except don't throw away the reference batch indicator + design <- as.matrix(design[,!check]) + + ## Number of covariates or covariate levels + message("Adjusting for", ncol(design)-ncol(batchmod), 'covariate(s) or covariate level(s)') + + ## Check if the design is confounded + if(qr(design)$rank < ncol(design)) { + ## if(ncol(design)<=(n.batch)){stop("Batch variables are redundant! Remove one or more of the batch variables so they are no longer confounded")} + if(ncol(design)==(n.batch+1)) { + stop("The covariate is confounded with batch! Remove the covariate and rerun ComBat") } - n.array <- sum(n.batches) - ## combine batch variable and covariates - design <- cbind(batchmod,mod) - - ## check for intercept in covariates, and drop if present - check <- apply(design, 2, function(x) all(x == 1)) - if(!is.null(ref)){ - check[ref] <- FALSE - } ## except don't throw away the reference batch indicator - design <- as.matrix(design[,!check]) - - ## Number of covariates or covariate levels - message("Adjusting for", ncol(design)-ncol(batchmod), 'covariate(s) or covariate level(s)') - - ## Check if the design is confounded - if(qr(design)$rank < ncol(design)) { - ## if(ncol(design)<=(n.batch)){stop("Batch variables are redundant! Remove one or more of the batch variables so they are no longer confounded")} - if(ncol(design)==(n.batch+1)) { - stop("The covariate is confounded with batch! Remove the covariate and rerun ComBat") - } - if(ncol(design)>(n.batch+1)) { - if((qr(design[,-c(1:n.batch)])$rank(n.batch+1)) { + if((qr(design[,-c(1:n.batch)])$rank 0) { + dat.orig[keep.rows, ] <- bayesdata + bayesdata <- dat.orig + } + + if(!is.null(dat_test)) { + + dat2 <- as.matrix(dat_test) + if( (nlevels(as.factor(batch_test)) > nlevels (batch)) | ( length(intersect(levels(as.factor(batch_test)), levels(batch))) <2 )) { + stop("batch_test should contain at least two batches in common with batch") } - if(!is.null(ref.batch)){ - gamma.star[ref,] <- 0 ## set reference batch mean equal to 0 - delta.star[ref,] <- 1 ## set reference batch variance equal to 1 + batch2 <- factor(batch_test,levels=levels(batch)) + + if (length(zero.rows) > 0) { + # keep a copy of the original data matrix and remove zero var rows + dat.orig2 <- dat2 + dat2 <- dat2[keep.rows, ] } + batchmod2 <- model.matrix(~-1+batch2) + ## A few other characteristics on the batches + batches2 <- list() + for (i in 1:n.batch) batches2[[i]] <- which(batch2 == levels(batch2)[i]) + n.batches2 <- sapply(batches2, length) + n.array2 <- sum(n.batches2) + ## combine batch variable and covariates + design2 <- cbind(batchmod2,mod) + design2 <- as.matrix(design2[,!check]) + + stand.mean2 <- t(grand.mean) %*% t(rep(1,n.array2)) - ## Normalize the Data ### - message("Adjusting the Data\n") - - bayesdata <- s.data - j <- 1 - for (i in batches){ - bayesdata[,i] <- (bayesdata[,i]-t(batch.design[i,]%*%gamma.star))/(sqrt(delta.star[j,])%*%t(rep(1,n.batches[j]))) # FIXME - j <- j+1 + if(!is.null(design)){ + tmp <- design2 + tmp[,c(1:n.batch)] <- 0 + stand.mean2 <- stand.mean2+t(tmp %*% B.hat) + } + + s.data2 <- (dat2-stand.mean2)/(sqrt(var.pooled) %*% t(rep(1,n.array2))) + + ##Get regression batch effect parameters + batch.design2 <- design2[, 1:n.batch] + + bayesdata2 <- s.data2 + + for (j in 1:n.batch){ + i <- batches2[[j]] + bayesdata2[,i] <- (bayesdata2[,i]-t(batch.design2[i,]%*%gamma.star))/(sqrt(delta.star[j,])%*%t(rep(1,n.batches2[j]))) # FIXME } - bayesdata <- (bayesdata*(sqrt(var.pooled)%*%t(rep(1,n.array))))+stand.mean # FIXME - + bayesdata2 <- (bayesdata2*(sqrt(var.pooled)%*%t(rep(1,n.array2))))+stand.mean2 # FIXME + ## Do not change ref batch at all in reference version if(!is.null(ref.batch)){ - bayesdata[, batches[[ref]]] <- dat[, batches[[ref]]] + bayesdata2[, batches2[[ref]]] <- dat2[, batches2[[ref]]] } - ## put genes with 0 variance in any batch back in data if (length(zero.rows) > 0) { - dat.orig[keep.rows, ] <- bayesdata - bayesdata <- dat.orig + dat.orig2[keep.rows, ] <- bayesdata2 + bayesdata2 <- dat.orig2 } - + return(list(tr_data=bayesdata,te_data=bayesdata2)) + } else { return(bayesdata) + } }