Skip to content

Commit

Permalink
Merge pull request #445 from astro-informatics/mm_add_updater_fb
Browse files Browse the repository at this point in the history
Add stochastic update functionality to FB algorithm
  • Loading branch information
20DM authored Jan 23, 2025
2 parents 9fe6e32 + cd4abee commit ccb6157
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 50 deletions.
78 changes: 53 additions & 25 deletions cpp/sopt/forward_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "sopt/logging.h"
#include "sopt/types.h"

#include "sopt/gradient_utils.h"

namespace sopt::algorithm {

/*! \brief Forward Backward Splitting
Expand Down Expand Up @@ -41,6 +43,7 @@ class ForwardBackward {
//! Type of the gradient
// The first argument is the output vector, the others are inputs
using t_Gradient = std::function<void(t_Vector &gradient, const t_Vector &image, const t_Vector &residual, const t_LinearTransform& Phi)>;
using t_randomUpdater = std::function<std::shared_ptr<IterationState<t_Vector>>()>;

//! Values indicating how the algorithm ran
struct Diagnostic {
Expand All @@ -65,19 +68,20 @@ class ForwardBackward {
//! Setups ForwardBackward
//! \param[in] f_function: the differentiable function \f$f\f$ with a gradient
//! \param[in] g_function: the non-differentiable function \f$g\f$ with a proximal operator
template <typename DERIVED>
ForwardBackward(t_Gradient const &f_gradient, t_Proximal const &g_proximal,
Eigen::MatrixBase<DERIVED> const &target)
t_Vector const &target)
: itermax_(std::numeric_limits<t_uint>::max()),
regulariser_strength_(1e-8),
step_size_(1),
sq_op_norm_(1),
is_converged_(),
fista_(true),
Phi_(linear_transform_identity<Scalar>()),
f_gradient_(f_gradient),
g_proximal_(g_proximal),
target_(target) {}
g_proximal_(g_proximal)
{
std::shared_ptr<t_LinearTransform> Id = std::make_shared<t_LinearTransform>(linear_transform_identity<Scalar>());
problem_state = std::make_shared<IterationState<t_Vector>>(target, Id);
}
virtual ~ForwardBackward() {}

// Macro helps define properties that can be initialized as in
Expand Down Expand Up @@ -107,12 +111,30 @@ class ForwardBackward {
//! \brief A function verifying convergence
//! \details It takes as input two arguments: the current solution x and the current residual.
SOPT_MACRO(is_converged, t_IsConverged);
//! Measurement operator
SOPT_MACRO(Phi, t_LinearTransform);
//! First proximal
SOPT_MACRO(f_gradient, t_Gradient);
//! Second proximal
SOPT_MACRO(g_proximal, t_Proximal);

//! Measurement operator
t_LinearTransform const &Phi() const { return problem_state->Phi(); }
ForwardBackward<SCALAR> &Phi(t_LinearTransform const &new_phi) {
problem_state->Phi(new_phi);
return *this;
}

ForwardBackward<SCALAR> &random_updater(t_randomUpdater &rU)
{
random_updater_ = rU;
return *this;
}

ForwardBackward<SCALAR> &set_problem_state(std::shared_ptr<IterationState<t_Vector>> pS)
{
problem_state = pS;
return *this;
}

#undef SOPT_MACRO
//! \brief Simplifies calling the gradient function
void f_gradient(t_Vector &out, t_Vector const &x, t_Vector const &res, t_LinearTransform const &Phi) const { f_gradient()(out, x, res, Phi); }
Expand All @@ -127,11 +149,10 @@ class ForwardBackward {
}

//! Vector of target measurements
t_Vector const &target() const { return target_; }
t_Vector const &target() const { return problem_state->target(); }
//! Sets the vector of target measurements
template <typename DERIVED>
ForwardBackward<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
target_ = target;
ForwardBackward<Scalar> &target(t_Vector const &target) {
problem_state->target(target);
return *this;
}

Expand All @@ -142,50 +163,50 @@ class ForwardBackward {

//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
Diagnostic operator()(t_Vector &out) const { return operator()(out, initial_guess()); }
Diagnostic operator()(t_Vector &out) { return operator()(out, initial_guess()); }
//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
//! \param[in] guess: initial guess
Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) {
return operator()(out, std::get<0>(guess), std::get<1>(guess));
}
//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
//! \param[in] guess: initial guess
Diagnostic operator()(t_Vector &out,
std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
std::tuple<t_Vector const &, t_Vector const &> const &guess) {
return operator()(out, std::get<0>(guess), std::get<1>(guess));
}
//! \brief Calls Forward Backward
//! \param[in] guess: initial guess
DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) {
return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
}
//! \brief Calls Forward Backward
//! \param[in] guess: initial guess
DiagnosticAndResult operator()(
std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
std::tuple<t_Vector const &, t_Vector const &> const &guess) {
DiagnosticAndResult result;
static_cast<Diagnostic &>(result) = operator()(result.x, guess);
return result;
}
//! \brief Calls Forward Backward
//! \param[in] guess: initial guess
DiagnosticAndResult operator()() const {
DiagnosticAndResult operator()() {
DiagnosticAndResult result;
static_cast<Diagnostic &>(result) = operator()(result.x, initial_guess());
return result;
}
//! Makes it simple to chain different calls to FB
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const {
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) {
DiagnosticAndResult result = warmstart;
static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
return result;
}
//! Set Φ and Φ^† using arguments that sopt::linear_transform understands
template <typename... ARGS>
typename std::enable_if<sizeof...(ARGS) >= 1, ForwardBackward &>::type Phi(ARGS &&... args) {
Phi_ = linear_transform(std::forward<ARGS>(args)...);
problem_state->Phi(linear_transform(std::forward<ARGS>(args)...));
return *this;
}

Expand Down Expand Up @@ -213,7 +234,7 @@ class ForwardBackward {

protected:
void iteration_step(t_Vector &out, t_Vector &residual, t_Vector &p, t_Vector &z,
const t_real lambda) const;
const t_real lambda);

//! Checks input makes sense
void sanity_check(t_Vector const &x_guess, t_Vector const &res_guess) const {
Expand All @@ -231,10 +252,11 @@ class ForwardBackward {
//! \param[out] out: Output vector x
//! \param[in] guess: initial guess
//! \param[in] residuals: initial residuals
Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;
Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res);

//! Vector of measurements
t_Vector target_;
//! problem state (shared with Imaging Forward Backward)
std::shared_ptr<IterationState<t_Vector>> problem_state;
t_randomUpdater random_updater_;
};

/**
Expand All @@ -253,19 +275,25 @@ class ForwardBackward {
*/
template <typename SCALAR>
void ForwardBackward<SCALAR>::iteration_step(t_Vector &image, t_Vector &residual, t_Vector &auxilliary_image,
t_Vector &gradient_current, const t_real FISTA_step) const {
t_Vector &gradient_current, const t_real FISTA_step) {
t_Vector prev_image = image;
f_gradient(gradient_current, auxilliary_image, residual, Phi()); // assigns gradient_current (non normalised)
t_Vector auxilliary_with_step = auxilliary_image - step_size() / sq_op_norm() * gradient_current; // step to new image using gradient
const Real weight = regulariser_strength() * step_size();
g_proximal(image, weight, auxilliary_with_step); // apply proximal operator to new image
auxilliary_image = image + FISTA_step * (image - prev_image); // update auxilliary vector with FISTA acceleration step

// set up next iteration
if(random_updater_)
{
problem_state = random_updater_();
}
residual = (Phi() * auxilliary_image) - target(); // updates the residual for the NEXT iteration (new image).
}

template <typename SCALAR>
typename ForwardBackward<SCALAR>::Diagnostic ForwardBackward<SCALAR>::operator()(
t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) const {
t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) {
SOPT_HIGH_LOG("Performing Forward Backward Splitting");
if (fista()) {
SOPT_HIGH_LOG("Using FISTA algorithm");
Expand Down
18 changes: 13 additions & 5 deletions cpp/sopt/gradient_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,30 @@ namespace sopt {
template <typename T>
class IterationState {
public:
IterationState() = delete;
IterationState(const T& target)
{
_Phi = std::make_shared<sopt::LinearTransform<T>>(linear_transform_identity<T>());
}

IterationState(const T& target,
std::shared_ptr<sopt::LinearTransform<T>> phi)
std::shared_ptr<sopt::LinearTransform<T>> Phi)
: _target(target) {
_phi = phi;
_Phi = Phi;
}

const T& target() const { return _target; }

const sopt::LinearTransform<T>& phi() const { return *_phi; }
const sopt::LinearTransform<T>& Phi() const { return *_Phi; }

void Phi(const sopt::LinearTransform<T> &new_phi)
{
_Phi = std::make_shared<sopt::LinearTransform<T>>(new_phi);
}

private:
const T _target;

std::shared_ptr<sopt::LinearTransform<T>> _phi;
std::shared_ptr<sopt::LinearTransform<T>> _Phi;
};

} // namespace sopt
Expand Down
Loading

0 comments on commit ccb6157

Please sign in to comment.