Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mm/onnx parameter tweaks #443

Merged
merged 26 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp/examples/forward_backward/inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}

sopt::t_real constexpr gamma = 18;
sopt::t_real constexpr regulariser_strength = 18;
sopt::t_real const beta = sigma * sigma * 0.5;
SOPT_HIGH_LOG("Creating Foward Backward Functor");
auto fb = sopt::algorithm::ImagingForwardBackward<Scalar>(y)
.itermax(500)
.beta(beta) // stepsize
.step_size(beta) // stepsize
.sigma(sigma) // sigma
.gamma(gamma) // regularisation paramater
.regulariser_strength(regulariser_strength) // regularisation paramater
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true)
Expand Down
10 changes: 5 additions & 5 deletions cpp/examples/forward_backward/inpainting_credible_interval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}

sopt::t_real constexpr gamma = 18;
sopt::t_real constexpr regulariser_strength = 18;
sopt::t_real const beta = sigma * sigma;
SOPT_HIGH_LOG("Creating Foward Backward Functor");
auto fb = sopt::algorithm::ImagingForwardBackward<Scalar>(y)
.itermax(500)
.beta(beta)
.step_size(beta)
.sigma(sigma)
.gamma(gamma)
.regulariser_strength(regulariser_strength)
.relative_variation(5e-4)
.residual_tolerance(0)
.tight_frame(true)
Expand Down Expand Up @@ -133,9 +133,9 @@ int main(int argc, char const **argv) {
constexpr sopt::t_real alpha = 0.99;
const sopt::t_uint grid_pixel_size = image.rows() / 16;
SOPT_HIGH_LOG("Finding credible interval");
const std::function<Scalar(Vector)> objective_function = [gamma, sigma, &y, &sampling,
const std::function<Scalar(Vector)> objective_function = [regulariser_strength, sigma, &y, &sampling,
&psi](const Vector &x) {
return sopt::l1_norm(psi.adjoint() * x) * gamma +
return sopt::l1_norm(psi.adjoint() * x) * regulariser_strength +
0.5 * std::pow(sopt::l2_norm(sampling * x - y), 2) / (sigma * sigma);
};

Expand Down
6 changes: 3 additions & 3 deletions cpp/examples/forward_backward/inpainting_joint_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}

sopt::t_real constexpr gamma = 0;
sopt::t_real constexpr regulariser_strength = 0;
sopt::t_real const beta = sigma * sigma * 0.5;
SOPT_HIGH_LOG("Creating Foward Backward Functor");
auto fb = std::make_shared<sopt::algorithm::ImagingForwardBackward<Scalar>>(y);
fb->itermax(500)
.beta(beta) // stepsize
.step_size(beta) // stepsize
.sigma(sigma) // sigma
.gamma(gamma) // regularisation paramater
.regulariser_strength(regulariser_strength) // regularisation paramater
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true)
Expand Down
8 changes: 4 additions & 4 deletions cpp/examples/forward_backward/l2_inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}
constexpr sopt::t_real x_sigma = 1.;
sopt::t_real constexpr gamma = 1. / (x_sigma * x_sigma * 2);
sopt::t_real const beta = sigma * sigma * 0.5;
sopt::t_real constexpr regulariser_strength = 1. / (x_sigma * x_sigma * 2);
sopt::t_real const step_size = sigma * sigma * 0.5;
SOPT_HIGH_LOG("Creating Foward Backward Functor");
auto const fb = sopt::algorithm::L2ForwardBackward<Scalar>(y)
.itermax(500)
.beta(beta) // stepsize
.step_size(step_size) // stepsize
.sigma(sigma) // sigma
.gamma(gamma) // regularisation paramater
.regulariser_strength(regulariser_strength) // regularisation paramater
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true)
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/primal_dual/inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ int main(int argc, char const **argv) {
sopt::utilities::write_tiff(Matrix::Map(dirty.data(), image.rows(), image.cols()),
"dirty_" + output + ".tiff");
}
sopt::t_real const gamma = (psi.adjoint() * (sampling.adjoint() * y)).real().maxCoeff() * 1e-2;
sopt::t_real const regulariser_strength = (psi.adjoint() * (sampling.adjoint() * y)).real().maxCoeff() * 1e-2;

SOPT_HIGH_LOG("Creating primal-dual Functor");
auto const pd = sopt::algorithm::ImagingPrimalDual<Scalar>(y)
.itermax(500)
.gamma(gamma)
.regulariser_strength(regulariser_strength)
.tau(0.5)
.l2ball_proximal_epsilon(epsilon)
.Psi(psi)
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/primal_dual/tv_inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}
const Vector grad = psi.adjoint() * (sampling.adjoint() * y);
const sopt::t_real gamma = (grad.segment(0, image.size()).array().square() +
const sopt::t_real regulariser_strength = (grad.segment(0, image.size()).array().square() +
grad.segment(image.size(), image.size()).array().square())
.sqrt()
.real()
Expand All @@ -97,7 +97,7 @@ int main(int argc, char const **argv) {
SOPT_HIGH_LOG("Creating primal-dual Functor");
auto const pd = sopt::algorithm::TVPrimalDual<Scalar>(y)
.itermax(2000)
.gamma(gamma)
.regulariser_strength(regulariser_strength)
.tau(0.5 / (1. + 1.))
.l2ball_proximal_epsilon(epsilon)
.Psi(psi)
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/proximal_admm/euclidian_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ int main(int, char const **) {

// Creates the resulting proximal
// In practice g_0 and g_1 are any functions with the signature
// void(t_Vector &output, t_Vector::Scalar gamma, t_Vector const &input)
// void(t_Vector &output, t_Vector::Scalar regulariser_strength, t_Vector const &input)
// They are the proximal of ||x - x_0|| and ||x - x_1||
auto prox_g0 = sopt::proximal::translate(sopt::proximal::EuclidianNorm(), -target0);
auto prox_g1 = sopt::proximal::translate(sopt::proximal::EuclidianNorm(), -target1);

auto padmm = sopt::algorithm::ProximalADMM<t_Scalar>(prox_g0, prox_g1, t_Vector::Zero(N))
.itermax(5000)
.is_converged(sopt::RelativeVariation<t_Scalar>(1e-12))
.gamma(0.01)
.regulariser_strength(0.01)
// Phi == -1, so that we can minimize f(x) + g(x), as per problem definition in
// padmm.
.Phi(-t_Matrix::Identity(N, N));
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/proximal_admm/inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ int main(int argc, char const **argv) {
SOPT_HIGH_LOG("Creating proximal-ADMM Functor");
auto const padmm = sopt::algorithm::ImagingProximalADMM<Scalar>(y)
.itermax(500)
.gamma(1e-1)
.regulariser_strength(1e-1)
.relative_variation(5e-4)
.l2ball_proximal_epsilon(epsilon)
.tight_frame(false)
Expand All @@ -95,7 +95,7 @@ int main(int argc, char const **argv) {
.l1_proximal_real_constraint(true)
.residual_convergence(epsilon * 1.001)
.lagrange_update_scale(0.9)
.nu(1e0)
.sq_op_norm(1e0)
.Psi(psi)
.Phi(sampling);

Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/proximal_admm/reweighted.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ int main(int argc, char const **argv) {
SOPT_MEDIUM_LOG("Creating proximal-ADMM Functor");
auto const padmm = sopt::algorithm::ImagingProximalADMM<Scalar>(y)
.itermax(500)
.gamma(1e-1)
.regulariser_strength(1e-1)
.relative_variation(5e-4)
.l2ball_proximal_epsilon(epsilon)
.tight_frame(false)
Expand All @@ -97,7 +97,7 @@ int main(int argc, char const **argv) {
.l1_proximal_real_constraint(true)
.residual_convergence(epsilon * 1.001)
.lagrange_update_scale(0.9)
.nu(1e0)
.sq_op_norm(1e0)
.Psi(psi)
.Phi(sampling);

Expand Down
2 changes: 1 addition & 1 deletion cpp/sopt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set(headers
imaging_padmm.h logging.h
forward_backward.h imaging_forward_backward.h
non_differentiable_func.h l1_non_diff_function.h real_indicator.h joint_map.h
differentiable_func.h
differentiable_func.h l2_differentiable_func.h
imaging_primal_dual.h primal_dual.h
maths.h proximal.h relative_variation.h sdmm.h
wavelets.h conjugate_gradient.h l1_proximal.h padmm.h proximal_expression.h
Expand Down
10 changes: 10 additions & 0 deletions cpp/sopt/differentiable_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ template <typename SCALAR> class DifferentiableFunc
// Calculate the function directly
virtual Real function(t_Vector const &image, t_Vector const &y, t_LinearTransform const &Phi) = 0;

// Get appropriate gradient step-size for FISTA algorithms
Real get_step_size() const
{
return step_size;
}

protected:

Real step_size;

// Transforms input image to a different basis.
// Return linear_transform_identity() if transform not necessary.
//virtual const t_LinearTransform &Phi() const = 0;
Expand Down
30 changes: 15 additions & 15 deletions cpp/sopt/forward_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class ForwardBackward {
ForwardBackward(t_Gradient const &f_gradient, t_Proximal const &g_proximal,
Eigen::MatrixBase<DERIVED> const &target)
: itermax_(std::numeric_limits<t_uint>::max()),
gamma_(1e-8),
beta_(1),
nu_(1),
regulariser_strength_(1e-8),
step_size_(1),
sq_op_norm_(1),
is_converged_(),
fista_(true),
Phi_(linear_transform_identity<Scalar>()),
Expand All @@ -97,11 +97,11 @@ class ForwardBackward {
//! Maximum number of iterations
SOPT_MACRO(itermax, t_uint);
//! γ parameter
SOPT_MACRO(gamma, Real);
SOPT_MACRO(regulariser_strength, Real);
//! β parameter
SOPT_MACRO(beta, Real);
SOPT_MACRO(step_size, Real);
//! ν parameter
SOPT_MACRO(nu, Real);
SOPT_MACRO(sq_op_norm, Real);
//! flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learned g_proximal.
SOPT_MACRO(fista, bool);
//! \brief A function verifying convergence
Expand All @@ -117,8 +117,8 @@ class ForwardBackward {
//! \brief Simplifies calling the gradient function
void f_gradient(t_Vector &out, t_Vector const &x, t_Vector const &res, t_LinearTransform const &Phi) const { f_gradient()(out, x, res, Phi); }
//! \brief Simplifies calling the proximal function
void g_proximal(t_Vector &out, Real gamma, t_Vector const &x) const {
g_proximal()(out, gamma, x);
void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
g_proximal()(out, regulariser_strength, x);
}

//! Convergence function that takes only the output as argument
Expand Down Expand Up @@ -194,7 +194,7 @@ class ForwardBackward {
//! - x = Φ^T y / ν
//! - residuals = Φ x - y
std::tuple<t_Vector, t_Vector> initial_guess() const {
return ForwardBackward<SCALAR>::initial_guess(target(), Phi(), nu());
return ForwardBackward<SCALAR>::initial_guess(target(), Phi(), sq_op_norm());
}

//! \brief Computes initial guess for x and the residual using the targets
Expand All @@ -204,9 +204,9 @@ class ForwardBackward {
//!
//! This function simplifies creating overloads for operator() in FB wrappers.
static std::tuple<t_Vector, t_Vector> initial_guess(t_Vector const &target,
t_LinearTransform const &phi, Real nu) {
t_LinearTransform const &phi, Real sq_op_norm) {
std::tuple<t_Vector, t_Vector> guess;
std::get<0>(guess) = static_cast<t_Vector>(phi.adjoint() * target) / nu;
std::get<0>(guess) = static_cast<t_Vector>(phi.adjoint() * target) / sq_op_norm;
std::get<1>(guess) = phi * std::get<0>(guess) - target;
return guess;
}
Expand Down Expand Up @@ -255,9 +255,9 @@ template <typename SCALAR>
void ForwardBackward<SCALAR>::iteration_step(t_Vector &image, t_Vector &residual, t_Vector &auxilliary_image,
t_Vector &gradient_current, const t_real FISTA_step) const {
t_Vector prev_image = image;
f_gradient(gradient_current, auxilliary_image, residual, Phi()); // takes residual and calculates the grad = 1/sig^2 residual
t_Vector auxilliary_with_step = auxilliary_image - beta() / nu() * gradient_current; // step to new image using gradient
const Real weight = gamma() * beta();
f_gradient(gradient_current, auxilliary_image, residual, Phi()); // assigns gradient_current (non normalised)
t_Vector auxilliary_with_step = auxilliary_image - step_size() / sq_op_norm() * gradient_current; // step to new image using gradient
const Real weight = regulariser_strength() * step_size();
g_proximal(image, weight, auxilliary_with_step); // apply proximal operator to new image
auxilliary_image = image + FISTA_step * (image - prev_image); // update auxilliary vector with FISTA acceleration step
residual = (Phi() * auxilliary_image) - target(); // updates the residual for the NEXT iteration (new image).
Expand Down Expand Up @@ -287,7 +287,7 @@ typename ForwardBackward<SCALAR>::Diagnostic ForwardBackward<SCALAR>::operator()
Real theta_new = 1.0;
Real FISTA_step = 0.0;
for (; (not converged) && (niters < itermax()); ++niters) {
SOPT_LOW_LOG(" - [FB] Iteration {}/{}", niters, itermax());
SOPT_MEDIUM_LOG(" - [FB] Iteration {}/{}", niters, itermax());
if (fista()) {
theta_new = (1 + std::sqrt(1 + 4 * theta * theta)) / 2.;
FISTA_step = (theta - 1) / (theta_new);
Expand Down
40 changes: 23 additions & 17 deletions cpp/sopt/imaging_forward_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ class ImagingForwardBackward {
residual_convergence_(nullptr),
objective_convergence_(nullptr),
itermax_(std::numeric_limits<t_uint>::max()),
gamma_(1e-8),
beta_(1),
regulariser_strength_(1e-8),
step_size_(1),
sigma_(1),
nu_(1),
sq_op_norm_(1),
fista_(true),
is_converged_(),
Phi_(linear_transform_identity<Scalar>()),
Expand Down Expand Up @@ -112,13 +112,13 @@ class ImagingForwardBackward {
//! Maximum number of iterations
SOPT_MACRO(itermax, t_uint);
//! γ parameter
SOPT_MACRO(gamma, Real);
SOPT_MACRO(regulariser_strength, Real);
//! γ parameter
SOPT_MACRO(beta, Real);
SOPT_MACRO(step_size, Real);
//! γ parameter
SOPT_MACRO(sigma, Real);
//! ν parameter
SOPT_MACRO(nu, Real);
SOPT_MACRO(sq_op_norm, Real);
//! flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learned g_function.
SOPT_MACRO(fista, bool);
//! A function verifying convergence
Expand Down Expand Up @@ -181,7 +181,7 @@ class ImagingForwardBackward {
//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
Diagnostic operator()(t_Vector &out) const {
return operator()(out, ForwardBackward<SCALAR>::initial_guess(target(), Phi(), nu()));
return operator()(out, ForwardBackward<SCALAR>::initial_guess(target(), Phi(), sq_op_norm()));
}
//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
Expand Down Expand Up @@ -214,7 +214,7 @@ class ImagingForwardBackward {
DiagnosticAndResult operator()() {
DiagnosticAndResult result;
static_cast<Diagnostic &>(result) = operator()(
result.x, ForwardBackward<SCALAR>::initial_guess(target(), Phi(), nu()));
result.x, ForwardBackward<SCALAR>::initial_guess(target(), Phi(), sq_op_norm()));
return result;
}
//! Makes it simple to chain different calls to FB
Expand Down Expand Up @@ -292,13 +292,19 @@ typename ImagingForwardBackward<SCALAR>::Diagnostic ImagingForwardBackward<SCALA
Diagnostic result;
auto const g_proximal = g_function_->proximal_operator();
t_Gradient f_gradient;
if(f_function_) f_gradient = f_function_->gradient();
Real gradient_step_size;
if(f_function_)
{
f_gradient = f_function_->gradient();
gradient_step_size = f_function_->get_step_size();
}
if(!f_gradient)
{
SOPT_HIGH_LOG("Gradient function has not been set; using default (gaussian likelihood) gradient. (To set a custom gradient set_gradient() must be called before the algorithm is run.)");
SOPT_MEDIUM_LOG("Gradient function has not been set; using default (gaussian likelihood) gradient. (To set a custom gradient set_gradient() must be called before the algorithm is run.)");
f_gradient = [this](t_Vector &output, t_Vector const &x, t_Vector const &residual, t_LinearTransform const &Phi) {
output = Phi.adjoint() * (residual / (this->sigma() * this->sigma()));
};
gradient_step_size = sigma()*sigma();
}
ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
"Objective function");
Expand All @@ -309,9 +315,9 @@ typename ImagingForwardBackward<SCALAR>::Diagnostic ImagingForwardBackward<SCALA
};
auto const fb = ForwardBackward<SCALAR>(f_gradient, g_proximal, target())
.itermax(itermax())
.beta(beta())
.gamma(gamma())
.nu(nu())
.step_size(gradient_step_size)
.regulariser_strength(regulariser_strength())
.sq_op_norm(sq_op_norm())
.fista(fista())
.Phi(Phi())
.is_converged(convergence);
Expand All @@ -336,8 +342,8 @@ bool ImagingForwardBackward<SCALAR>::objective_convergence(ScalarRelativeVariati
t_Vector const &residual) const {
if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
if (scalvar.relative_tolerance() <= 0e0) return true;
auto const current = ((gamma() > 0) ? g_function_->function(x)
* gamma() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma());
auto const current = ((regulariser_strength() > 0) ? g_function_->function(x)
* regulariser_strength() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma());
return scalvar(current);
}

Expand All @@ -350,8 +356,8 @@ bool ImagingForwardBackward<SCALAR>::objective_convergence(mpi::Communicator con
if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
if (scalvar.relative_tolerance() <= 0e0) return true;
auto const current = obj_comm.all_sum_all<t_real>(
((gamma() > 0) ? g_function_->function(x)
* gamma() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma_ * sigma_));
((regulariser_strength() > 0) ? g_function_->function(x)
* regulariser_strength() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma_ * sigma_));
return scalvar(current);
}
#endif
Expand Down
Loading
Loading