Skip to content

Commit

Permalink
Added function for lower bounds (dummy + entropic)
Browse files Browse the repository at this point in the history
  • Loading branch information
pietrocipolla committed Jun 21, 2024
1 parent c347e3a commit c1b2d21
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 1 deletion.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method(plot,gsaot_indices)
S3method(print,gsaot_indices)
export(lower_bound)
export(ot_indices)
export(ot_indices_1d)
export(ot_indices_smap)
Expand Down
14 changes: 13 additions & 1 deletion R/gsaot_indices.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ print.gsaot_indices <- function(x, data = FALSE, ...) {
#'
plot.gsaot_indices <- function(x,
ranking = NULL,
wb_all = FALSE, ...) {
wb_all = FALSE,
dummy = NULL, ...) {
# If ranking is defined, plot only the selected inputs
K <- nrow(x$indices)

Expand All @@ -169,6 +170,13 @@ plot.gsaot_indices <- function(x,
} else
inputs_to_plot <- seq(K)

# Find the threshold to be plotted if dummy is defined
if (!(class(dummy) == "gsaot_indices") & !is.double(dummy) & !is.null(dummy))
stop("`dummy` should be an object of class `gsaot_indices` or a double")

if (!is.null(dummy) & !is.double(dummy))
dummy <- dummy$indices[1]

# If the indices are not from ot_indices_wb, print only the indices
# Otherwise, plot the indices, the advective component and the diffusive one
if (!exists("adv", where = x) | !wb_all) {
Expand Down Expand Up @@ -245,5 +253,9 @@ plot.gsaot_indices <- function(x,
}
}

if (!is.null(dummy))
p <- p +
geom_hline(yintercept = dummy, linetype = 2, color = "red")

return(p)
}
234 changes: 234 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
#' Calculate lower bounds for Optimal Transport sensitivity indices
#'
#' @inheritParams ot_indices
#' @param y An array or a matrix containing the output values.
#' @param bound (default `"dummy"`) A string defining the type of lower bound to
#' compute. Should be `"dummy"` or `"entropic"`. See `details` for more
#' information.
#' @param dummy_optns (default `NULL`) A list containing the options on the
#' distribution of the dummy variable. See `details` for more information.
#'
#' @details The function allows the computation of two different lower bounds.
#' With `bound="dummy"`, the function samples from a distribution defined in
#' `dummy_optns` (by default a standard normal), independent from the output
#' `y` and then computes the indices using the algorithm specified in
#' `solver`. Under the hood, `lower_bound` calls the other available functions
#' in the package:
#' * [ot_indices_1d()] (for `solver="1d`)
#' * [ot_indices_wb()] (for `solver="wasserstein-bures"`)
#' * [ot_indices()] (for `solver %in% c("sinkhorn", "sinkhorn_log", "wasserstein")`)
#' The user can choose the distribution of the dummy variable using the
#' argument `dummy_optns`. `dummy_optns` should be a named list with at least
#' a term called `"distr"` defining the sampling function. The other terms in
#' the list are used as arguments to the sampling function. With
#' `bound="entropic"`, the function computes the lower bound of the entropic
#' indices. In this case, `solver` should be either `"sinkhorn"` or
#' `"sinkhorn_log"`
#'
#' @return An object of class `gsaot_indices` for `bound="dummy"` or a scalar
#' for `bound="entropic"`.
#' @export
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' M <- 25
#'
#' sink_lb <- lower_bound(y, M, bound = "entropic")
#' dummy_lb <- lower_bound(y, M, bound = "dummy")
#'
#' # Custom sampling funtion and network simplex solver
#' dummy_optns <- list(distr = "rgamma", shape = 3)
#' dummy_lb_cust <- lower_bound(y, M, bound = "dummy",
#' dummy_optns = dummy_optns,
#' solver = "wasserstein")
lower_bound <- function(y,
M,
bound = "dummy",
dummy_optns = NULL,
cost = "L2",
discrete_out = FALSE,
solver = "sinkhorn",
solver_optns = NULL,
scaling = TRUE) {
# INPUT CHECKS
# ----------------------------------------------------------------------------
# Check if the output is a numerical
if (!is.numeric(y) | (is.matrix(y) & is.vector(y)))
stop("`y` must be a vector or a matrix of numerical values!")

# Check if the number of partitions is lower than the number of samples
if (nrow(as.matrix(y)) <= M)
stop("The number of partitions should be lower than the number of samples")

# Check if the bound is in the pool
match.arg(bound, c("dummy", "entropic"))

# Sinkhorn is allowed only for sinkhorn or sinkhorn log solver
if (bound == "entropic" & !(solver %in% c("sinkhorn", "sinkhorn_log")))
stop("`sinkhorn` bound is allowed only for `sinkhorn` or `sinkhorn_log` solvers")

# Check if the cost is defined correctly
cost_type <- identical(cost, "L2")
if (!cost_type & !is.function(cost))
stop("`cost` should be \"L2\" or a function")

# Check the logical values
if (!is.logical(discrete_out)) stop("`discrete_out` should be logical")
if (!is.logical(scaling)) stop("`scaling` should be logical")

# Check if the solver is present in the pool
match.arg(solver, c("1d", "wasserstein-bures", "sinkhorn", "sinkhorn_log", "wasserstein"))

# RETURN THE DUMMY INDICES
# ----------------------------------------------------------------------------
if (bound == "dummy") {
dummy_optns <- init_dummy_optns(dummy_optns)

N <- nrow(as.matrix(y))
x <- do.call(dummy_optns[["distr"]], c(n = N, within(dummy_optns, rm("distr"))))
x <- as.data.frame(x)
colnames(x) <- dummy_optns[["distr"]]

dummy_ind <- switch(solver,
"1d" = ot_indices_1d(x, y, M),
"wasserstein-bures" = ot_indices_wb(x, y, M),
"sinkhorn" = ot_indices(x,
y,
M,
cost,
discrete_out,
solver,
solver_optns,
scaling),
"sinkhorn_log" = ot_indices(x,
y,
M,
cost,
discrete_out,
solver,
solver_optns,
scaling),
"wasserstein" = ot_indices(x,
y,
M,
cost,
discrete_out,
solver,
solver_optns,
scaling),
default = NULL
)

return(dummy_ind)
}

# RETURN THE SINKHORN LOWER BOUND
sink_ind <- entropic_lower_bound(y, cost, solver, solver_optns, discrete_out, scaling)

return(sink_ind)
}

# Check that the bound_optns contain the required components, otherwise initialize them
init_dummy_optns <- function(dummy_optns) {
if (is.null(dummy_optns)) {
dummy_optns <- list(distr = "rnorm")
cat("Using `rnorm` to generate the dummy\n")
}

return(dummy_optns)
}

# Compute the sinkhorn lower bound
entropic_lower_bound <- function(y,
cost,
solver,
solver_optns,
discrete_out,
scaling) {
# Initialize useful values
N <- nrow(as.matrix(y))
cost_type <- identical(cost, "L2")

# BUILD AN HISTOGRAM IF REQUESTED (otherwise assume equal weights)
# ----------------------------------------------------------------------------
# Default values for continuous output
a <- rep(1 / N, N)
y_unique <- NULL

if (discrete_out) {
# y_unique is always needed, in order to avoid mismatches between the
# pre-computed cost matrix and the resampled outputs
y_unique <- unique(y)

# Function to build the histogram (maybe we can put this outside to make it
# easier to change?)
a <- apply(y_unique, MARGIN = 1, FUN = function(u) sum(
apply(y, 1, FUN = function(row) all(row == u)))) / N
}

if (identical(a, rep(1 / N, N)) & discrete_out)
warning("The output is continuous, consider using `discrete_out=FALSE`\n")

# BUILD COST MATRIX
# ----------------------------------------------------------------------------
if (discrete_out) {
if (cost_type) {
C <- as.matrix(stats::dist(y_unique, method = "euclidean")) ^ 2
} else {
C <- cost(y_unique)
}
} else {
if (cost_type) {
C <- as.matrix(stats::dist(y, method = "euclidean")) ^ 2
} else {
C <- cost(y)
}
}

# If the scaling is requested, scale the cost matrix by the highest value
scaling_param <- 1
if (scaling) {
scaling_param <- max(C)
C <- C / scaling_param
}

# DEFINE THE OT SOLVER
# ----------------------------------------------------------------------------
solver_fun <- switch (
solver,
"sinkhorn" = sinkhorn,
"sinkhorn_log" = sinkhorn_log,
default = NULL
)

# Check the consistency of the options provided
solver_optns <- check_solver_optns(solver, solver_optns)

# Evaluate the upper bound of the indices
# ----------------------------------------------------------------------------
V <- higher_bound(C) * scaling_param

# COMPUTE THE LOWER BOUND
# ----------------------------------------------------------------------------
ret <- do.call(solver_fun,
c(list(a = a, b = a, costm = C),
solver_optns))

sink_ind <- ret$cost * scaling_param / V

return(sink_ind)
}
111 changes: 111 additions & 0 deletions man/lower_bound.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c1b2d21

Please sign in to comment.