Skip to content

Commit

Permalink
Added stable solver instead of log-scale one
Browse files Browse the repository at this point in the history
  • Loading branch information
pietrocipolla committed Jan 23, 2025
1 parent 227b3b3 commit 379b5c4
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 57 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ sinkhorn <- function(a, b, costm, numIterations, epsilon, maxErr) {
.Call(`_gsaot_sinkhorn`, a, b, costm, numIterations, epsilon, maxErr)
}

sinkhorn_log <- function(a, b, costm, numIterations, epsilon, maxErr) {
.Call(`_gsaot_sinkhorn_log`, a, b, costm, numIterations, epsilon, maxErr)
sinkhorn_stable <- function(a, b, costm, numIterations, epsilon, maxErr, tau) {
.Call(`_gsaot_sinkhorn_stable`, a, b, costm, numIterations, epsilon, maxErr, tau)
}

12 changes: 8 additions & 4 deletions R/check_solver_optns.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ check_solver_optns <- function(solver, solver_optns) {
"sinkhorn" = list(numIterations = 1e3,
epsilon = 0.01,
maxErr = 1e-9),
"sinkhorn_log" = list(numIterations = 1e3,
"sinkhorn_stable" = list(numIterations = 1e3,
epsilon = 0.01,
maxErr = 1e-9),
maxErr = 1e-9,
tau = 1e4),
"transport" = list(fullreturn = TRUE)
)

return(solver_optns)
}

# If options are provided, check correctness for sinkhorn solvers
if (solver == "sinkhorn" || solver == "sinkhorn_log") {
stopifnot(all(names(solver_optns) %in% c("numIterations", "epsilon", "maxErr")))
if (solver == "sinkhorn" || solver == "sinkhorn_stable") {
stopifnot(all(names(solver_optns) %in% c("numIterations", "epsilon", "maxErr", "tau")))

if (!exists("numIterations", solver_optns)) {
solver_optns[["numIterations"]] <- 1e3
Expand All @@ -27,6 +28,9 @@ check_solver_optns <- function(solver, solver_optns) {
if (!exists("maxErr", solver_optns)) {
solver_optns[["maxErr"]] <- 1e-9
}
if (!exists("tau", solver_optns) && solver == "sinkhorn_stable") {
solver_optns[["tau"]] <- 1e5
}

return(solver_optns)
}
Expand Down
4 changes: 2 additions & 2 deletions R/ot_indices.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ ot_indices <- function(x,
if (!is.logical(scaling)) stop("`scaling` should be logical")

# Check if the solver is present in the pool
match.arg(solver, c("sinkhorn", "sinkhorn_log", "transport"))
match.arg(solver, c("sinkhorn", "sinkhorn_stable", "transport"))

# Check that bootstrapping is correctly set
if ((!boot & !is.null(R)) | (boot & is.null(R))) {
Expand Down Expand Up @@ -270,7 +270,7 @@ ot_indices <- function(x,
solver_fun <- switch (
solver,
"sinkhorn" = sinkhorn,
"sinkhorn_log" = sinkhorn_log,
"sinkhorn_stable" = sinkhorn_stable,
"transport" = transport::transport,
default = NULL
)
Expand Down
2 changes: 1 addition & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ entropic_lower_bound <- function(y,
solver_fun <- switch (
solver,
"sinkhorn" = sinkhorn,
"sinkhorn_log" = sinkhorn_log,
"sinkhorn_stable" = sinkhorn_stable,
default = NULL
)

Expand Down
Binary file modified man/figures/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 6 additions & 5 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// sinkhorn_log
List sinkhorn_log(Eigen::VectorXd a, Eigen::VectorXd b, Eigen::MatrixXd costm, int numIterations, double epsilon, double maxErr);
RcppExport SEXP _gsaot_sinkhorn_log(SEXP aSEXP, SEXP bSEXP, SEXP costmSEXP, SEXP numIterationsSEXP, SEXP epsilonSEXP, SEXP maxErrSEXP) {
// sinkhorn_stable
List sinkhorn_stable(Eigen::VectorXd a, Eigen::VectorXd b, Eigen::MatrixXd costm, int numIterations, double epsilon, double maxErr, double tau);
RcppExport SEXP _gsaot_sinkhorn_stable(SEXP aSEXP, SEXP bSEXP, SEXP costmSEXP, SEXP numIterationsSEXP, SEXP epsilonSEXP, SEXP maxErrSEXP, SEXP tauSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -39,14 +39,15 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< int >::type numIterations(numIterationsSEXP);
Rcpp::traits::input_parameter< double >::type epsilon(epsilonSEXP);
Rcpp::traits::input_parameter< double >::type maxErr(maxErrSEXP);
rcpp_result_gen = Rcpp::wrap(sinkhorn_log(a, b, costm, numIterations, epsilon, maxErr));
Rcpp::traits::input_parameter< double >::type tau(tauSEXP);
rcpp_result_gen = Rcpp::wrap(sinkhorn_stable(a, b, costm, numIterations, epsilon, maxErr, tau));
return rcpp_result_gen;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_gsaot_sinkhorn", (DL_FUNC) &_gsaot_sinkhorn, 6},
{"_gsaot_sinkhorn_log", (DL_FUNC) &_gsaot_sinkhorn_log, 6},
{"_gsaot_sinkhorn_stable", (DL_FUNC) &_gsaot_sinkhorn_stable, 7},
{NULL, NULL, 0}
};

Expand Down
19 changes: 10 additions & 9 deletions src/optimal_transport_sinkhorn.cpp → src/sinkhorn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ List sinkhorn_cpp(Eigen::VectorXd a,
// Eigen::MatrixXd U(u.asDiagonal());
// Eigen::MatrixXd V(v.asDiagonal());
Eigen::MatrixXd P((K.array().colwise() * u.array()).rowwise() * v.array().transpose());
// Rcout << P << std::endl;
// Rcout << P.maxCoeff() << std::endl;

// Wasserstein distance
double W22_prime = (P.transpose() * costMatrix).trace();
Expand All @@ -91,11 +91,12 @@ List sinkhorn(Eigen::VectorXd a,
return sinkhorn_cpp(a, b, costm, numIterations, epsilon, maxErr);
}

// /***R
// n <- 100
// m <- 50
// a <- rep(1 / n, n)
// b <- rep(1 / m, m)
// C <- as.matrix(dist(rnorm(100)))[, 1:50]
// ret <- sinkhorn(a, b, C, 1e3, 0.1, 1e-3)
// */
// # /***R
// # set.seed(1)
// # n <- 100
// # m <- 100
// # a <- rep(1 / n, n)
// # b <- rep(1 / m, m)
// # C <- as.matrix(dist(rnorm(100)))#[, 1:50]
// # ret <- sinkhorn(a, b, C, 1e3, 0.1, 1e-3)
// # */
101 changes: 67 additions & 34 deletions src/optimal_transport_sinkhorn_log.cpp → src/sinkhorn_stable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,60 +19,90 @@ Eigen::VectorXd logsumexp(const Eigen::MatrixXd x,
return Eigen::VectorXd::Zero(x.rows());
}

// Define the kernel matrix computation
Eigen::MatrixXd getK(Eigen::MatrixXd costMatrix,
Eigen::VectorXd alpha,
Eigen::VectorXd beta,
double epsilon) {
// Remove marginals
costMatrix.colwise() -= alpha;
costMatrix.rowwise() -= beta.transpose();

// Scale and exponentiate
costMatrix = - costMatrix / epsilon;
costMatrix = costMatrix.array().exp();

return costMatrix;
}

// Define the Sinkhorn algorithm function
List sinkhorn_log_cpp(Eigen::VectorXd a,
List sinkhorn_stable_cpp(Eigen::VectorXd a,
Eigen::VectorXd b,
Eigen::MatrixXd costMatrix,
int numIterations,
double epsilon,
double maxErr) {
double maxErr,
double tau,
double trunc_thresh) {
const int numRows = costMatrix.rows();
const int numCols = costMatrix.cols();

// Initialize the K matrix
Eigen::MatrixXd K(-costMatrix / epsilon);

// Initialize all the dual variables to vectors of 1
Eigen::VectorXd u(numRows);
Eigen::VectorXd v(numCols);
// Initialize the overloaded potentials to zeros
Eigen::VectorXd alpha = Eigen::VectorXd::Zero(numRows);
Eigen::VectorXd beta = Eigen::VectorXd::Zero(numCols);

u.setZero();
v.setZero();
// Initialize the K matrix
Eigen::MatrixXd K = getK(costMatrix, alpha, beta, epsilon);

// Build potential matrices
Eigen::MatrixXd U(u.array().exp().matrix().asDiagonal());
Eigen::MatrixXd V(v.array().exp().matrix().asDiagonal());
// Initialize all the extra dual variables to vectors of 1
Eigen::VectorXd u_tilde = Eigen::VectorXd::Ones(numRows);
Eigen::VectorXd v_tilde = Eigen::VectorXd::Ones(numCols);

// Initialize the marginal weights and the other useful variables
Eigen::VectorXd estimated_marginal;
Eigen::VectorXd loga = a.array().log();
Eigen::VectorXd logb = b.array().log();

// Initialize algorithms values
int iter = 1;
double err = std::numeric_limits<double>::infinity();

// External loop for convergence
while (iter == 1 || (iter <= numIterations &&
err > maxErr && err < std::numeric_limits<double>::infinity())) {
// First step: updates
// Compute v updates
v = logb - logsumexp(K.colwise() + u, 1);
// Rcout << "v " << v << std::endl;
v_tilde = K.transpose() * u_tilde;
v_tilde = b.array() * v_tilde.array().inverse();
// Rcout << "v " << v_tilde << std::endl;

// Compute u updates
u = loga - logsumexp(K.transpose().colwise() + v, 1);
// Rcout << "u " << u << std::endl;
u_tilde = K * v_tilde;
u_tilde = a.array() * u_tilde.array().inverse();
// Rcout << "u " << u_tilde << std::endl;

// Second step: absorption iteration
if (u_tilde.cwiseAbs().maxCoeff() > tau ||
v_tilde.cwiseAbs().maxCoeff() > tau ||
iter == numIterations) {
// Absorb the big values
alpha = alpha.array() + epsilon * u_tilde.array().log();
// Rcout << "alpha " << alpha << std::endl;
beta = beta.array() + epsilon * v_tilde.array().log();
// Rcout << "beta " << beta << std::endl;

// Reset the potentials
u_tilde = Eigen::VectorXd::Ones(numRows);
v_tilde = Eigen::VectorXd::Ones(numCols);

// Update the cost matrix
K = getK(costMatrix, alpha, beta, epsilon);
// Rcout << "K " << K << std::endl;
}

// Error control
if (iter % 10 == 0 || iter == 1) {
// Update potential matrices
U = u.array().exp().matrix().asDiagonal();
V = v.array().exp().matrix().asDiagonal();
// Rcout << "U " << U << std::endl;
// Rcout << "V " << V << std::endl;

// Update error
estimated_marginal = (V * K.transpose().array().exp().matrix() * U).array().rowwise().sum();
estimated_marginal = v_tilde.array() * (K.transpose() * u_tilde).array();
// Rcout << "marginal " << estimated_marginal << std::endl;
err = (estimated_marginal - b).cwiseAbs().sum();
// Rcout << "err " << err << std::endl;
}

// Update iteration
Expand All @@ -92,7 +122,8 @@ List sinkhorn_log_cpp(Eigen::VectorXd a,
// double W22 = f.dot(a) + g.dot(b);

// Optimal coupling
Eigen::MatrixXd P(U * K.array().exp().matrix() * V);
Eigen::MatrixXd P((K.array().colwise() * u_tilde.array()).rowwise() * v_tilde.array().transpose());
// Rcout << "P " << P.maxCoeff() << std::endl;

// Wasserstein distance
double W22_prime = (P.transpose() * costMatrix).trace();
Expand All @@ -111,20 +142,22 @@ List sinkhorn_log_cpp(Eigen::VectorXd a,
// Expose the Sinkhorn function to R
// [[Rcpp::depends(RcppEigen)]]
// [[Rcpp::export]]
List sinkhorn_log(Eigen::VectorXd a,
List sinkhorn_stable(Eigen::VectorXd a,
Eigen::VectorXd b,
Eigen::MatrixXd costm,
int numIterations,
double epsilon,
double maxErr) {
return sinkhorn_log_cpp(a, b, costm, numIterations, epsilon, maxErr);
double maxErr,
double tau) {
return sinkhorn_stable_cpp(a, b, costm, numIterations, epsilon, maxErr, tau);
}

// # /***R
// # /***R
// # set.seed(1)
// # n <- 100
// # m <- 100
// # a <- rep(1 / n, n)
// # b <- rep(1 / m, m)
// # C <- as.matrix(dist(rnorm(100)))#[, 1:50]
// # sinkhorn_log(a, b, C, 1e5, 0.1, 1e-3)
// # sinkhorn_log(a, b, C, 1e5, 0.0000001, 1e-3, 1e5)
// # */

0 comments on commit 379b5c4

Please sign in to comment.