Skip to content

Commit

Permalink
Merge pull request #443 from astro-informatics/mm/onnx_parameter_tweaks
Browse files Browse the repository at this point in the history
Mm/onnx parameter tweaks
  • Loading branch information
mmcleod89 authored Nov 22, 2024
2 parents 0c6b22e + 5d267ca commit 749d6b7
Show file tree
Hide file tree
Showing 34 changed files with 301 additions and 276 deletions.
6 changes: 3 additions & 3 deletions cpp/examples/forward_backward/inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}

sopt::t_real constexpr gamma = 18;
sopt::t_real constexpr regulariser_strength = 18;
sopt::t_real const beta = sigma * sigma * 0.5;
SOPT_HIGH_LOG("Creating Foward Backward Functor");
auto fb = sopt::algorithm::ImagingForwardBackward<Scalar>(y)
.itermax(500)
.beta(beta) // stepsize
.step_size(beta) // stepsize
.sigma(sigma) // sigma
.gamma(gamma) // regularisation paramater
.regulariser_strength(regulariser_strength) // regularisation paramater
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true)
Expand Down
10 changes: 5 additions & 5 deletions cpp/examples/forward_backward/inpainting_credible_interval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}

sopt::t_real constexpr gamma = 18;
sopt::t_real constexpr regulariser_strength = 18;
sopt::t_real const beta = sigma * sigma;
SOPT_HIGH_LOG("Creating Foward Backward Functor");
auto fb = sopt::algorithm::ImagingForwardBackward<Scalar>(y)
.itermax(500)
.beta(beta)
.step_size(beta)
.sigma(sigma)
.gamma(gamma)
.regulariser_strength(regulariser_strength)
.relative_variation(5e-4)
.residual_tolerance(0)
.tight_frame(true)
Expand Down Expand Up @@ -133,9 +133,9 @@ int main(int argc, char const **argv) {
constexpr sopt::t_real alpha = 0.99;
const sopt::t_uint grid_pixel_size = image.rows() / 16;
SOPT_HIGH_LOG("Finding credible interval");
const std::function<Scalar(Vector)> objective_function = [gamma, sigma, &y, &sampling,
const std::function<Scalar(Vector)> objective_function = [regulariser_strength, sigma, &y, &sampling,
&psi](const Vector &x) {
return sopt::l1_norm(psi.adjoint() * x) * gamma +
return sopt::l1_norm(psi.adjoint() * x) * regulariser_strength +
0.5 * std::pow(sopt::l2_norm(sampling * x - y), 2) / (sigma * sigma);
};

Expand Down
6 changes: 3 additions & 3 deletions cpp/examples/forward_backward/inpainting_joint_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}

sopt::t_real constexpr gamma = 0;
sopt::t_real constexpr regulariser_strength = 0;
sopt::t_real const beta = sigma * sigma * 0.5;
SOPT_HIGH_LOG("Creating Foward Backward Functor");
auto fb = std::make_shared<sopt::algorithm::ImagingForwardBackward<Scalar>>(y);
fb->itermax(500)
.beta(beta) // stepsize
.step_size(beta) // stepsize
.sigma(sigma) // sigma
.gamma(gamma) // regularisation paramater
.regulariser_strength(regulariser_strength) // regularisation paramater
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true)
Expand Down
8 changes: 4 additions & 4 deletions cpp/examples/forward_backward/l2_inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}
constexpr sopt::t_real x_sigma = 1.;
sopt::t_real constexpr gamma = 1. / (x_sigma * x_sigma * 2);
sopt::t_real const beta = sigma * sigma * 0.5;
sopt::t_real constexpr regulariser_strength = 1. / (x_sigma * x_sigma * 2);
sopt::t_real const step_size = sigma * sigma * 0.5;
SOPT_HIGH_LOG("Creating Foward Backward Functor");
auto const fb = sopt::algorithm::L2ForwardBackward<Scalar>(y)
.itermax(500)
.beta(beta) // stepsize
.step_size(step_size) // stepsize
.sigma(sigma) // sigma
.gamma(gamma) // regularisation paramater
.regulariser_strength(regulariser_strength) // regularisation paramater
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true)
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/primal_dual/inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ int main(int argc, char const **argv) {
sopt::utilities::write_tiff(Matrix::Map(dirty.data(), image.rows(), image.cols()),
"dirty_" + output + ".tiff");
}
sopt::t_real const gamma = (psi.adjoint() * (sampling.adjoint() * y)).real().maxCoeff() * 1e-2;
sopt::t_real const regulariser_strength = (psi.adjoint() * (sampling.adjoint() * y)).real().maxCoeff() * 1e-2;

SOPT_HIGH_LOG("Creating primal-dual Functor");
auto const pd = sopt::algorithm::ImagingPrimalDual<Scalar>(y)
.itermax(500)
.gamma(gamma)
.regulariser_strength(regulariser_strength)
.tau(0.5)
.l2ball_proximal_epsilon(epsilon)
.Psi(psi)
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/primal_dual/tv_inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ int main(int argc, char const **argv) {
"dirty_" + output + ".tiff");
}
const Vector grad = psi.adjoint() * (sampling.adjoint() * y);
const sopt::t_real gamma = (grad.segment(0, image.size()).array().square() +
const sopt::t_real regulariser_strength = (grad.segment(0, image.size()).array().square() +
grad.segment(image.size(), image.size()).array().square())
.sqrt()
.real()
Expand All @@ -97,7 +97,7 @@ int main(int argc, char const **argv) {
SOPT_HIGH_LOG("Creating primal-dual Functor");
auto const pd = sopt::algorithm::TVPrimalDual<Scalar>(y)
.itermax(2000)
.gamma(gamma)
.regulariser_strength(regulariser_strength)
.tau(0.5 / (1. + 1.))
.l2ball_proximal_epsilon(epsilon)
.Psi(psi)
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/proximal_admm/euclidian_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ int main(int, char const **) {

// Creates the resulting proximal
// In practice g_0 and g_1 are any functions with the signature
// void(t_Vector &output, t_Vector::Scalar gamma, t_Vector const &input)
// void(t_Vector &output, t_Vector::Scalar regulariser_strength, t_Vector const &input)
// They are the proximal of ||x - x_0|| and ||x - x_1||
auto prox_g0 = sopt::proximal::translate(sopt::proximal::EuclidianNorm(), -target0);
auto prox_g1 = sopt::proximal::translate(sopt::proximal::EuclidianNorm(), -target1);

auto padmm = sopt::algorithm::ProximalADMM<t_Scalar>(prox_g0, prox_g1, t_Vector::Zero(N))
.itermax(5000)
.is_converged(sopt::RelativeVariation<t_Scalar>(1e-12))
.gamma(0.01)
.regulariser_strength(0.01)
// Phi == -1, so that we can minimize f(x) + g(x), as per problem definition in
// padmm.
.Phi(-t_Matrix::Identity(N, N));
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/proximal_admm/inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ int main(int argc, char const **argv) {
SOPT_HIGH_LOG("Creating proximal-ADMM Functor");
auto const padmm = sopt::algorithm::ImagingProximalADMM<Scalar>(y)
.itermax(500)
.gamma(1e-1)
.regulariser_strength(1e-1)
.relative_variation(5e-4)
.l2ball_proximal_epsilon(epsilon)
.tight_frame(false)
Expand All @@ -95,7 +95,7 @@ int main(int argc, char const **argv) {
.l1_proximal_real_constraint(true)
.residual_convergence(epsilon * 1.001)
.lagrange_update_scale(0.9)
.nu(1e0)
.sq_op_norm(1e0)
.Psi(psi)
.Phi(sampling);

Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/proximal_admm/reweighted.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ int main(int argc, char const **argv) {
SOPT_MEDIUM_LOG("Creating proximal-ADMM Functor");
auto const padmm = sopt::algorithm::ImagingProximalADMM<Scalar>(y)
.itermax(500)
.gamma(1e-1)
.regulariser_strength(1e-1)
.relative_variation(5e-4)
.l2ball_proximal_epsilon(epsilon)
.tight_frame(false)
Expand All @@ -97,7 +97,7 @@ int main(int argc, char const **argv) {
.l1_proximal_real_constraint(true)
.residual_convergence(epsilon * 1.001)
.lagrange_update_scale(0.9)
.nu(1e0)
.sq_op_norm(1e0)
.Psi(psi)
.Phi(sampling);

Expand Down
2 changes: 1 addition & 1 deletion cpp/sopt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set(headers
imaging_padmm.h logging.h
forward_backward.h imaging_forward_backward.h
non_differentiable_func.h l1_non_diff_function.h real_indicator.h joint_map.h
differentiable_func.h
differentiable_func.h l2_differentiable_func.h
imaging_primal_dual.h primal_dual.h
maths.h proximal.h relative_variation.h sdmm.h
wavelets.h conjugate_gradient.h l1_proximal.h padmm.h proximal_expression.h
Expand Down
10 changes: 10 additions & 0 deletions cpp/sopt/differentiable_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ template <typename SCALAR> class DifferentiableFunc
// Calculate the function directly
virtual Real function(t_Vector const &image, t_Vector const &y, t_LinearTransform const &Phi) = 0;

// Get appropriate gradient step-size for FISTA algorithms
Real get_step_size() const
{
return step_size;
}

protected:

Real step_size;

// Transforms input image to a different basis.
// Return linear_transform_identity() if transform not necessary.
//virtual const t_LinearTransform &Phi() const = 0;
Expand Down
30 changes: 15 additions & 15 deletions cpp/sopt/forward_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class ForwardBackward {
ForwardBackward(t_Gradient const &f_gradient, t_Proximal const &g_proximal,
Eigen::MatrixBase<DERIVED> const &target)
: itermax_(std::numeric_limits<t_uint>::max()),
gamma_(1e-8),
beta_(1),
nu_(1),
regulariser_strength_(1e-8),
step_size_(1),
sq_op_norm_(1),
is_converged_(),
fista_(true),
Phi_(linear_transform_identity<Scalar>()),
Expand All @@ -97,11 +97,11 @@ class ForwardBackward {
//! Maximum number of iterations
SOPT_MACRO(itermax, t_uint);
//! γ parameter
SOPT_MACRO(gamma, Real);
SOPT_MACRO(regulariser_strength, Real);
//! β parameter
SOPT_MACRO(beta, Real);
SOPT_MACRO(step_size, Real);
//! ν parameter
SOPT_MACRO(nu, Real);
SOPT_MACRO(sq_op_norm, Real);
//! flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learned g_proximal.
SOPT_MACRO(fista, bool);
//! \brief A function verifying convergence
Expand All @@ -117,8 +117,8 @@ class ForwardBackward {
//! \brief Simplifies calling the gradient function
void f_gradient(t_Vector &out, t_Vector const &x, t_Vector const &res, t_LinearTransform const &Phi) const { f_gradient()(out, x, res, Phi); }
//! \brief Simplifies calling the proximal function
void g_proximal(t_Vector &out, Real gamma, t_Vector const &x) const {
g_proximal()(out, gamma, x);
void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
g_proximal()(out, regulariser_strength, x);
}

//! Convergence function that takes only the output as argument
Expand Down Expand Up @@ -194,7 +194,7 @@ class ForwardBackward {
//! - x = Φ^T y / ν
//! - residuals = Φ x - y
std::tuple<t_Vector, t_Vector> initial_guess() const {
return ForwardBackward<SCALAR>::initial_guess(target(), Phi(), nu());
return ForwardBackward<SCALAR>::initial_guess(target(), Phi(), sq_op_norm());
}

//! \brief Computes initial guess for x and the residual using the targets
Expand All @@ -204,9 +204,9 @@ class ForwardBackward {
//!
//! This function simplifies creating overloads for operator() in FB wrappers.
static std::tuple<t_Vector, t_Vector> initial_guess(t_Vector const &target,
t_LinearTransform const &phi, Real nu) {
t_LinearTransform const &phi, Real sq_op_norm) {
std::tuple<t_Vector, t_Vector> guess;
std::get<0>(guess) = static_cast<t_Vector>(phi.adjoint() * target) / nu;
std::get<0>(guess) = static_cast<t_Vector>(phi.adjoint() * target) / sq_op_norm;
std::get<1>(guess) = phi * std::get<0>(guess) - target;
return guess;
}
Expand Down Expand Up @@ -255,9 +255,9 @@ template <typename SCALAR>
void ForwardBackward<SCALAR>::iteration_step(t_Vector &image, t_Vector &residual, t_Vector &auxilliary_image,
t_Vector &gradient_current, const t_real FISTA_step) const {
t_Vector prev_image = image;
f_gradient(gradient_current, auxilliary_image, residual, Phi()); // takes residual and calculates the grad = 1/sig^2 residual
t_Vector auxilliary_with_step = auxilliary_image - beta() / nu() * gradient_current; // step to new image using gradient
const Real weight = gamma() * beta();
f_gradient(gradient_current, auxilliary_image, residual, Phi()); // assigns gradient_current (non normalised)
t_Vector auxilliary_with_step = auxilliary_image - step_size() / sq_op_norm() * gradient_current; // step to new image using gradient
const Real weight = regulariser_strength() * step_size();
g_proximal(image, weight, auxilliary_with_step); // apply proximal operator to new image
auxilliary_image = image + FISTA_step * (image - prev_image); // update auxilliary vector with FISTA acceleration step
residual = (Phi() * auxilliary_image) - target(); // updates the residual for the NEXT iteration (new image).
Expand Down Expand Up @@ -287,7 +287,7 @@ typename ForwardBackward<SCALAR>::Diagnostic ForwardBackward<SCALAR>::operator()
Real theta_new = 1.0;
Real FISTA_step = 0.0;
for (; (not converged) && (niters < itermax()); ++niters) {
SOPT_LOW_LOG(" - [FB] Iteration {}/{}", niters, itermax());
SOPT_MEDIUM_LOG(" - [FB] Iteration {}/{}", niters, itermax());
if (fista()) {
theta_new = (1 + std::sqrt(1 + 4 * theta * theta)) / 2.;
FISTA_step = (theta - 1) / (theta_new);
Expand Down
40 changes: 23 additions & 17 deletions cpp/sopt/imaging_forward_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ class ImagingForwardBackward {
residual_convergence_(nullptr),
objective_convergence_(nullptr),
itermax_(std::numeric_limits<t_uint>::max()),
gamma_(1e-8),
beta_(1),
regulariser_strength_(1e-8),
step_size_(1),
sigma_(1),
nu_(1),
sq_op_norm_(1),
fista_(true),
is_converged_(),
Phi_(linear_transform_identity<Scalar>()),
Expand Down Expand Up @@ -112,13 +112,13 @@ class ImagingForwardBackward {
//! Maximum number of iterations
SOPT_MACRO(itermax, t_uint);
//! γ parameter
SOPT_MACRO(gamma, Real);
SOPT_MACRO(regulariser_strength, Real);
//! γ parameter
SOPT_MACRO(beta, Real);
SOPT_MACRO(step_size, Real);
//! γ parameter
SOPT_MACRO(sigma, Real);
//! ν parameter
SOPT_MACRO(nu, Real);
SOPT_MACRO(sq_op_norm, Real);
//! flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learned g_function.
SOPT_MACRO(fista, bool);
//! A function verifying convergence
Expand Down Expand Up @@ -181,7 +181,7 @@ class ImagingForwardBackward {
//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
Diagnostic operator()(t_Vector &out) const {
return operator()(out, ForwardBackward<SCALAR>::initial_guess(target(), Phi(), nu()));
return operator()(out, ForwardBackward<SCALAR>::initial_guess(target(), Phi(), sq_op_norm()));
}
//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
Expand Down Expand Up @@ -214,7 +214,7 @@ class ImagingForwardBackward {
DiagnosticAndResult operator()() {
DiagnosticAndResult result;
static_cast<Diagnostic &>(result) = operator()(
result.x, ForwardBackward<SCALAR>::initial_guess(target(), Phi(), nu()));
result.x, ForwardBackward<SCALAR>::initial_guess(target(), Phi(), sq_op_norm()));
return result;
}
//! Makes it simple to chain different calls to FB
Expand Down Expand Up @@ -292,13 +292,19 @@ typename ImagingForwardBackward<SCALAR>::Diagnostic ImagingForwardBackward<SCALA
Diagnostic result;
auto const g_proximal = g_function_->proximal_operator();
t_Gradient f_gradient;
if(f_function_) f_gradient = f_function_->gradient();
Real gradient_step_size;
if(f_function_)
{
f_gradient = f_function_->gradient();
gradient_step_size = f_function_->get_step_size();
}
if(!f_gradient)
{
SOPT_HIGH_LOG("Gradient function has not been set; using default (gaussian likelihood) gradient. (To set a custom gradient set_gradient() must be called before the algorithm is run.)");
SOPT_MEDIUM_LOG("Gradient function has not been set; using default (gaussian likelihood) gradient. (To set a custom gradient set_gradient() must be called before the algorithm is run.)");
f_gradient = [this](t_Vector &output, t_Vector const &x, t_Vector const &residual, t_LinearTransform const &Phi) {
output = Phi.adjoint() * (residual / (this->sigma() * this->sigma()));
};
gradient_step_size = sigma()*sigma();
}
ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
"Objective function");
Expand All @@ -309,9 +315,9 @@ typename ImagingForwardBackward<SCALAR>::Diagnostic ImagingForwardBackward<SCALA
};
auto const fb = ForwardBackward<SCALAR>(f_gradient, g_proximal, target())
.itermax(itermax())
.beta(beta())
.gamma(gamma())
.nu(nu())
.step_size(gradient_step_size)
.regulariser_strength(regulariser_strength())
.sq_op_norm(sq_op_norm())
.fista(fista())
.Phi(Phi())
.is_converged(convergence);
Expand All @@ -336,8 +342,8 @@ bool ImagingForwardBackward<SCALAR>::objective_convergence(ScalarRelativeVariati
t_Vector const &residual) const {
if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
if (scalvar.relative_tolerance() <= 0e0) return true;
auto const current = ((gamma() > 0) ? g_function_->function(x)
* gamma() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma());
auto const current = ((regulariser_strength() > 0) ? g_function_->function(x)
* regulariser_strength() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma());
return scalvar(current);
}

Expand All @@ -350,8 +356,8 @@ bool ImagingForwardBackward<SCALAR>::objective_convergence(mpi::Communicator con
if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
if (scalvar.relative_tolerance() <= 0e0) return true;
auto const current = obj_comm.all_sum_all<t_real>(
((gamma() > 0) ? g_function_->function(x)
* gamma() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma_ * sigma_));
((regulariser_strength() > 0) ? g_function_->function(x)
* regulariser_strength() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma_ * sigma_));
return scalvar(current);
}
#endif
Expand Down
Loading

0 comments on commit 749d6b7

Please sign in to comment.