From 3869fb0b5b8fd10bb8fa56522175cfd9160eb407 Mon Sep 17 00:00:00 2001 From: Max Bladen <60872845+Max-Bladen@users.noreply.github.com> Date: Thu, 17 Nov 2022 11:59:32 +1100 Subject: [PATCH] Fix for Issue #268 (#269) fix: improved nzv feature handling for block contexts, particularly via `auroc()` Filtration applied more consistently via `Check.entry.wrapper.mint.block()` . Additional failsafe added here for zero variance features. `predict()` also now checks to see if filtration has been applied to prevent it applying filtering twice. tests: adjusted new test to ensure it passes --- R/check_entry.R | 43 +++++++++++++++++++++---------------- R/predict.R | 9 +++++++- tests/testthat/test-auroc.R | 5 +++-- 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/R/check_entry.R b/R/check_entry.R index c08bbdb9..2a393796 100644 --- a/R/check_entry.R +++ b/R/check_entry.R @@ -663,30 +663,37 @@ Check.entry.wrapper.mint.block = function(X, nzv.A = lapply(A, nearZeroVar) for(q in 1:length(A)) { - if (length(nzv.A[[q]]$Position) > 0 &&(!DA & q == indY)) - { - names.remove.X = colnames(A[[q]])[nzv.A[[q]]$Position] - A[[q]] = A[[q]][, -nzv.A[[q]]$Position, drop=FALSE] - #if (verbose) - #warning("Zero- or near-zero variance predictors.\n - #Reset predictors matrix to not near-zero variance predictors.\n - # See $nzv for problematic predictors.") - if (ncol(A[[q]]) == 0) - stop(paste0("No more variables in",A[[q]])) - - #need to check that the keepA[[q]] is now not higher than ncol(A[[q]]) - if (any(keepA[[q]] > ncol(A[[q]]))) - { - ind = which(keepA[[q]] > ncol(A[[q]])) - keepA[[q]][ind] = ncol(A[[q]]) - } - } + if (length(nzv.A[[q]]$Position) <= 0) { next } + if (DA && q == indY) { next } + + names.remove.X = colnames(A[[q]])[nzv.A[[q]]$Position] + A[[q]] = A[[q]][, -nzv.A[[q]]$Position, drop=FALSE] + #if (verbose) + #warning("Zero- or near-zero variance predictors.\n + #Reset predictors matrix to not near-zero variance predictors.\n + # See $nzv for problematic predictors.") + if (ncol(A[[q]]) == 0) + stop(paste0("No more variables in",A[[q]])) + + #need to check that the keepA[[q]] is now not higher than ncol(A[[q]]) + if (any(keepA[[q]] > ncol(A[[q]]))) + { + ind = which(keepA[[q]] > ncol(A[[q]])) + keepA[[q]][ind] = ncol(A[[q]]) + } } } else { nzv.A=NULL } + for(q in 1:length(A)) + { + vars <- apply(A[[q]], 2, sd)^2 + if (length(which(vars==0)) >0) { + stop(sprintf("There are features with zero variance in block '%s'. If nearZeroVar() function or 'near.zero.var' parameter hasn't been used, please use it. If you have used one of these, you may need to manually filter out these features.", names(A)[q]), call.=F) + } + } return(list(A=A, ncomp=ncomp, study=study, keepA=keepA, indY=indY, design=design, init=init, nzv.A=nzv.A)) } diff --git a/R/predict.R b/R/predict.R index 1c26b8d5..0dda09b2 100644 --- a/R/predict.R +++ b/R/predict.R @@ -317,7 +317,14 @@ predict.mixo_pls <- # deal with near.zero.var in object, to remove the same variable in newdata as in object$X (already removed in object$X) if(!is.null(object$nzv)) { - newdata = lapply(1:(length(object$nzv)-1),function(x){if(length(object$nzv[[x]]$Position>0)) {newdata[[x]][, -object$nzv[[x]]$Position,drop=FALSE]}else{newdata[[x]]}}) + # for each of the input blocks, checks to see if the nzv features have already been removed + # if not, then these features are removed here + for (x in 1:length(newdata)) { + if (nrow(object$nzv[[x]]$Metrics) == 0) { next } + if (all(!(rownames(object$nzv[[x]]$Metrics) %in% colnames(newdata[[x]])))) { next } + + newdata[[x]] <- newdata[[x]][, -object$nzv[[x]]$Position,drop=FALSE] + } } if(length(newdata)!=length(object$X)) stop("'newdata' must have as many blocks as 'object$X'") diff --git a/tests/testthat/test-auroc.R b/tests/testthat/test-auroc.R index 74981fc2..fed81340 100644 --- a/tests/testthat/test-auroc.R +++ b/tests/testthat/test-auroc.R @@ -34,11 +34,12 @@ test_that("Safely handles zero var (non-zero center) features", { list.keepX <- list(block1=c(15, 15), block2=c(30,30)) + set.seed(9425) X$block1[,1] <- rep(1, 100) model = suppressWarnings(block.splsda(X = X, Y = Y, ncomp = 2, - keepX = list.keepX, design = "full")) + keepX = list.keepX, design = "full", + near.zero.var = T)) - set.seed(9425) auc.splsda = .quiet(auroc(model)) .expect_numerically_close(auc.splsda$block1$comp1[[1]], 0.815)