Skip to content

Commit

Permalink
incorporate scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
thomvolker committed Aug 23, 2024
1 parent e605a01 commit 56d855c
Show file tree
Hide file tree
Showing 19 changed files with 99 additions and 107 deletions.
60 changes: 37 additions & 23 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,35 @@
check.datatype <- function(data) {
if (is.vector(data)) data <- data.frame(x = data)
else data <- as.data.frame(data)

if (sum(is.na(data)) > 0) {
stop("Missing data can currently not be handled, please solve the missing data problem first.")
}
data
}

check.dataform <- function(nu, de, centers = NULL, nullcenters, newdata = NULL, scale) {

numvars <- which(sapply(nu, is.numeric))
numvars_de <- which(sapply(de, is.numeric))
if (!is.null(centers)) {
numvars_ce <- which(sapply(centers, is.numeric))
alldat <- rbind(nu, de, centers)
ind <- c(rep("nu", nrow(nu)), rep("de", nrow(de)), rep("ce", nrow(centers)))
} else {
alldat <- rbind(nu, de)
ind <- c(rep("nu", nrow(nu)), rep("de", nrow(de)))
}

scale <- match.arg(scale, c("numerator", "denominator", FALSE))
if (!is.null(scale)) {
scale <- match.arg(scale, c("numerator", "denominator"))
}

scaledat <- if (scale == "numerator") {
scaledat <- if (identical(scale, "numerator")) {
nu[, numvars, drop = FALSE]
} else if (scale == "denominator") {
} else if (identical(scale, "denominator")) {
de[, numvars, drop = FALSE]
}

if (scale != FALSE) {
if (!is.null(scale)) {
if (!nullcenters) {
warning("Note that you provided centers while also applying scaling to the variables. The centers are scaled accordingly.")
}
Expand All @@ -51,7 +55,7 @@ check.dataform <- function(nu, de, centers = NULL, nullcenters, newdata = NULL,

if (!is.null(newdata)) {
newdata <- check.datatype(newdata)
if (scale != FALSE) {
if (!is.null(scale)) {
newdata[, numvars] <- scale(newdata[, numvars], center = means, scale = sds) |> as.data.frame()
}
alldat <- rbind(alldat, newdata)
Expand All @@ -76,17 +80,21 @@ check.variables <- function(nu, de, ce = NULL) {
numvars_nu <- which(sapply(nu, is.numeric))
numvars_de <- which(sapply(de, is.numeric))

if (!all(numvars_nu == numvars_de) |
ncol(nu) != ncol(de) |
!all(colnames(nu) == colnames(de))) {
if (
!identical(numvars_nu, numvars_de) |
!identical(ncol(nu), ncol(de)) |
!identical(colnames(nu), colnames(de))
) {
stop("The numerator and denominator data must contain the same variables.")
}
if (!is.null(ce)) {
ce <- check.datatype(ce)
numvars_ce <- which(sapply(ce, is.numeric))
if (!all(numvars_nu == numvars_ce) |
ncol(nu) != ncol(ce) |
!all(colnames(nu) == colnames(ce))) {
if (
!identical(numvars_nu, numvars_ce) |
!identical(ncol(nu), ncol(ce)) |
!identical(colnames(nu), colnames(ce))
) {
stop("The data and centers must contain the same variables.")
}
}
Expand All @@ -110,7 +118,8 @@ check.sigma <- function(nsigma, sigma_quantile, sigma, dist_nu) {
stop("If 'sigma_quantile' is specified, the values must be larger than 0 and smaller than 1.")
} else {
p <- sigma_quantile
sigma <- stats::quantile(dist_nu, p) |> sqrt()
sigma <- stats::quantile(dist_nu, p)
sigma <- sqrt(sigma/2)
}
}
# if both sigma and sigma_quantile are not specified, specify the quantiles linearly, based on nsigma
Expand All @@ -122,10 +131,12 @@ check.sigma <- function(nsigma, sigma_quantile, sigma, dist_nu) {
stop("'nsigma' must be a positive scalar.")
}
else if (nsigma == 1) {
sigma <- stats::median(dist_nu) |> sqrt()
sigma <- stats::median(dist_nu)
sigma <- sqrt(sigma/2)
} else {
p <- seq(0.05, 0.95, length.out = nsigma)
sigma <- stats::quantile(dist_nu, p) |> sqrt()
sigma <- stats::quantile(dist_nu, p)
sigma <- sqrt(sigma/2)
}
}
}
Expand Down Expand Up @@ -182,19 +193,19 @@ check.lambda <- function(nlambda, lambda) {
lambda
}

check.centers <- function(nu, centers, ncenters) {
check.centers <- function(dat, centers, ncenters) {

if (!is.null(centers)) {
centers <- check.datatype(centers)
} else {
if (!is.numeric(ncenters) | length(ncenters) != 1 | ncenters < 1) {
stop("The 'ncenters' parameter must be a positive numeric scalar.")
} else if (ncenters == nrow(nu)) {
centers <- nu
} else if (ncenters > nrow(nu)) {
centers <- nu
} else if (ncenters == nrow(dat)) {
centers <- dat
} else if (ncenters > nrow(dat)) {
centers <- dat
} else {
centers <- nu[sample(nrow(nu), ncenters), , drop = FALSE]
centers <- dat[sample(nrow(dat), ncenters), , drop = FALSE]
}
}
centers
Expand Down Expand Up @@ -424,7 +435,10 @@ check.subspace.spectral <- function(J, cv_ind_de) {

check.newdata <- function(object, newdata) {
if (!is.null(newdata)) {
check.variables(object$df_numerator, newdata)
check.variables(
check.datatype(object$df_numerator),
check.datatype(newdata)
)
newdata <- check.dataform(
check.datatype(object$df_numerator),
check.datatype(object$df_denominator),
Expand Down
6 changes: 3 additions & 3 deletions R/kliep.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @param df_denominator \code{data.frame} with exclusively numeric variables
#' with the denominator samples (must have the same variables as
#' \code{df_denominator})
#' @param scale \code{"numerator"}, \code{"denominator"}, or \code{FALSE},
#' @param scale \code{"numerator"}, \code{"denominator"}, or \code{NULL},
#' indicating whether to standardize each numeric variable according to the
#' numerator means and standard deviations, the denominator means and standard
#' deviations, or apply no standardization at all.
Expand Down Expand Up @@ -45,15 +45,15 @@
kliep <- function(df_numerator, df_denominator, scale = "numerator", nsigma = 10,
sigma_quantile = NULL, sigma = NULL, ncenters = 200,
centers = NULL, cv = TRUE, nfold = 5, epsilon = NULL,
maxit = 2000, progressbar = TRUE) {
maxit = 5000, progressbar = TRUE) {

cl <- match.call()
nu <- check.datatype(df_numerator)
de <- check.datatype(df_denominator)

check.variables(nu, de, centers)

df_centers <- check.centers(nu, centers, ncenters)
df_centers <- check.centers(rbind(nu, de), centers, ncenters)
dat <- check.dataform(nu, de, df_centers, is.null(centers), NULL, scale)

nnu <- nrow(dat$nu)
Expand Down
3 changes: 2 additions & 1 deletion R/kmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @param df_denominator \code{data.frame} with exclusively numeric variables
#' with the denominator samples (must have the same variables as
#' \code{df_denominator})
#' @param scale \code{"numerator"}, \code{"denominator"}, or \code{FALSE},
#' @param scale \code{"numerator"}, \code{"denominator"}, or \code{NULL},
#' indicating whether to standardize each numeric variable according to the
#' numerator means and standard deviations, the denominator means and standard
#' deviations, or apply no standardization at all.
Expand Down Expand Up @@ -84,6 +84,7 @@ kmm <- function(df_numerator, df_denominator, scale = "numerator",
out <- list(rhat_de = rhat_de,
sigma = sigma,
lambda = lambda,
scale = scale,
call = cl)
class(out) <- "kmm"

Expand Down
4 changes: 2 additions & 2 deletions R/lhss.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @param m Scalar indicating the dimensionality of the reduced subspace
#' @param intercept \code{logical} Indicating whether to include an intercept
#' term in the model. Defaults to \code{TRUE}.
#' @param scale \code{"numerator"}, \code{"denominator"}, or \code{FALSE},
#' @param scale \code{"numerator"}, \code{"denominator"}, or \code{NULL},
#' indicating whether to standardize each numeric variable according to the
#' numerator means and standard deviations, the denominator means and standard
#' deviations, or apply no standardization at all.
Expand Down Expand Up @@ -64,7 +64,7 @@ lhss <- function(df_numerator, df_denominator, m = NULL, intercept = TRUE,

check.variables(nu, de, centers)

df_centers <- check.centers(nu, centers, ncenters)
df_centers <- check.centers(rbind(nu, de), centers, ncenters)
dat <- check.dataform(nu, de, df_centers, is.null(centers), NULL, scale)

p <- ncol(dat$nu)
Expand Down
8 changes: 2 additions & 6 deletions R/naive.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
#' @param df_denominator \code{data.frame} with exclusively numeric variables
#' with the denominator samples (must have the same variables as
#' \code{df_denominator})
#' @param scale \code{"numerator"}, \code{"denominator"}, or \code{FALSE},
#' indicating whether to standardize each numeric variable according to the
#' numerator means and standard deviations, the denominator means and standard
#' deviations, or apply no standardization at all.
#' @param n \code{integer} the number of equally spaced points at which the density is
#' estimated. When n > 512, it is rounded up to a power of 2 during the
#' calculations (as fft is used) and the final result is interpolated by
Expand All @@ -34,14 +30,14 @@
#' naive(x, y, bw = 2)
#'
#' @export
naive <- function(df_numerator, df_denominator, scale = "numerator", n = 2L^11, ...) {
naive <- function(df_numerator, df_denominator, n = 2L^11, ...) {
cl <- match.call()

nu <- check.datatype(df_numerator)
de <- check.datatype(df_denominator)

check.variables(nu, de)
dat <- check.dataform(nu, de, nu, TRUE, NULL, scale)
dat <- check.dataform(nu, de, NULL, TRUE, NULL, NULL)

P <- ncol(dat$nu)

Expand Down
19 changes: 9 additions & 10 deletions R/naive_subspace.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
#' with the denominator samples (must have the same variables as
#' \code{df_denominator})
#' @param m The size (in number of features) of the subspace
#' @param scale \code{"numerator"}, \code{"denominator"}, or \code{FALSE},
#' indicating whether to standardize each numeric variable according to the
#' numerator means and standard deviations, the denominator means and standard
#' deviations, or apply no standardization at all.
#' @param n the number of equally spaced points at which the density is to be
#' estimated. When n > 512, it is rounded up to a power of 2 during the
#' calculations (as fft is used) and the final result is interpolated by
Expand Down Expand Up @@ -48,22 +44,24 @@
#' lines(df_new[,1], predict(dr_subspace, df_new), col = "darkorange")
#'
#' @export
naivesubspace <- function(df_numerator, df_denominator, m = NULL,
scale = "numerator", n = 2L^11, ...) {
naivesubspace <- function(df_numerator, df_denominator, m = NULL, n = 2L^11, ...) {
cl <- match.call()
nu <- check.datatype(df_numerator)
de <- check.datatype(df_denominator)

check.variables(nu, de)

dat <- check.dataform(nu, de, nu, TRUE, NULL, scale)
dat <- check.dataform(nu, de, NULL, TRUE, NULL, NULL)

m <- check.subspace(m, ncol(dat$nu))

# first, use svd to compute m-dimensional subspace
V <- svd(dat$de, nu = m, nv = m)$v
de_proj <- dat$de %*% V
nu_proj <- dat$nu %*% V
nu_centered <- scale(dat$nu, scale = FALSE)
center <- attr(nu_centered, "scaled:center")
V <- svd(nu_centered, nu = m, nv = m)$v

nu_proj <- nu_centered %*% V
de_proj <- scale(dat$de, center = center, scale = FALSE) %*% V

# then, perform naive density ratio estimation
d_nu <- lapply(1:m, \(p) density(nu_proj[,p], n = n, ...))
Expand All @@ -74,6 +72,7 @@ naivesubspace <- function(df_numerator, df_denominator, m = NULL,
df_numerator = df_numerator,
df_denominator = df_denominator,
projection_matrix = V,
center = center,
subspace_dim = m,
model_matrices = list(nu = dat$nu, de = dat$de),
density_numerator = d_nu,
Expand Down
Loading

0 comments on commit 56d855c

Please sign in to comment.