Skip to content

Commit

Permalink
Fix kwargs in PDHG (#2010)
Browse files Browse the repository at this point in the history
  • Loading branch information
MargaretDuff authored Dec 13, 2024
1 parent 38d7004 commit 554a7fd
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 52 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Fix bug with 'median' and 'mean' methods in Masker averaging over the wrong axes.
- `SPDHG` `gamma` parameter is now applied correctly so that the product of the dual and primal step sizes remains constant as `gamma` varies (#1644)
- Allow MaskGenerator to be run on DataContainers (#2001)
- Fix bug passing `kwargs` to PDHG (#2010)
- Enhancements:
- Removed multiple exits from numba implementation of KullbackLeibler divergence (#1901)
- Updated the `SPDHG` algorithm to take a stochastic `Sampler`(#1644)
Expand Down
120 changes: 70 additions & 50 deletions Wrappers/Python/cil/optimisation/algorithms/PDHG.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,30 @@ class PDHG(Algorithm):
A convex function with a "simple" proximal.
operator : LinearOperator
A Linear Operator.
sigma : positive :obj:`float`, or `np.ndarray`, `DataContainer`, `BlockDataContainer`, optional, default=None
sigma : positive :obj:`float`, or `np.ndarray`, `DataContainer`, `BlockDataContainer`, optional, default is 1.0/norm(K) or 1.0/ (tau*norm(K)**2) if tau is provided
Step size for the dual problem.
tau : positive :obj:`float`, or `np.ndarray`, `DataContainer`, `BlockDataContainer`, optional, default=None
tau : positive :obj:`float`, or `np.ndarray`, `DataContainer`, `BlockDataContainer`, optional, default is 1.0/norm(K) or 1.0/ (sigma*norm(K)**2) if sigma is provided
Step size for the primal problem.
initial : DataContainer, optional, default=None
initial : DataContainer, optional, default is a DataContainer of zeros
Initial point for the PDHG algorithm.
gamma_g : positive :obj:`float`, optional, default=None
Strongly convex constant if the function g is strongly convex. Allows primal acceleration of the PDHG algorithm.
gamma_fconj : positive :obj:`float`, optional, default=None
Strongly convex constant if the convex conjugate of f is strongly convex. Allows dual acceleration of the PDHG algorithm.
**kwargs:
Keyward arguments used from the base class :class:`Algorithm`.
max_iteration : :obj:`int`, optional, default=0
Maximum number of iterations of the PDHG.
update_objective_interval : :obj:`int`, optional, default=1
Evaluates objectives, e.g., primal/dual/primal-dual gap every ``update_objective_interval``.
check_convergence : :obj:`boolean`, default=True
Checks scalar sigma and tau values satisfy convergence criterion
Checks scalar sigma and tau values satisfy convergence criterion and warns if not satisfied. Can be computationally expensive for custom sigma or tau values.
theta : Float between 0 and 1, default 1.0
Relaxation parameter for the over-relaxation of the primal variable.
Example
-------
In our `CIL-Demos <https://github.com/TomographicImaging/CIL-Demos/blob/main/binder/TomographyReconstruction.ipynb>`_ repository\
you can find examples using the PDHG algorithm for different imaging problems, such as Total Variation denoising, Total Generalised Variation inpainting\
In our CIL-Demos repository (https://github.com/TomographicImaging/CIL-Demos) you can find examples using the PDHG algorithm for different imaging problems, such as Total Variation denoising, Total Generalised Variation inpainting
and Total Variation Tomography reconstruction. More examples can also be found in :cite:`Jorgensen_et_al_2021`, :cite:`Papoutsellis_et_al_2021`.
Note
Expand Down Expand Up @@ -212,17 +210,22 @@ class PDHG(Algorithm):
Note
----
The case where both functions are strongly convex is not available at the moment.
"""

.. todo:: Implement acceleration of PDHG when both functions are strongly convex.
def __init__(self, f, g, operator, tau=None, sigma=None, initial=None, gamma_g=None, gamma_fconj=None, **kwargs):
"""Initialisation of the PDHG algorithm"""

self._theta = kwargs.pop('theta', 1.0)
if self._theta > 1 or self._theta < 0:
raise ValueError(
"The relaxation parameter theta must be in the range [0,1], passed theta = {}".format(self.theta))

"""
self._check_convergence = kwargs.pop('check_convergence', True)

def __init__(self, f, g, operator, tau=None, sigma=None, initial=None, gamma_g=None, gamma_fconj=None, **kwargs):
super().__init__(**kwargs)

self._tau = None
self._sigma = None

Expand All @@ -232,22 +235,32 @@ def __init__(self, f, g, operator, tau=None, sigma=None, initial=None, gamma_g=N
self.set_gamma_g(gamma_g)
self.set_gamma_fconj(gamma_fconj)

self.set_up(f=f, g=g, operator=operator, tau=tau, sigma=sigma, initial=initial, **kwargs)
self.set_up(f=f, g=g, operator=operator, tau=tau,
sigma=sigma, initial=initial)

@property
def tau(self):
"""The primal step-size """
return self._tau

@property
def sigma(self):
"""The dual step-size """
return self._sigma

@property
def theta(self):
"""The relaxation parameter for the over-relaxation of the primal variable """
return self._theta

@property
def gamma_g(self):
"""The strongly convex constant for the function g """
return self._gamma_g

@property
def gamma_fconj(self):
"""The strongly convex constant for the convex conjugate of the function f """
return self._gamma_fconj

def set_gamma_g(self, value):
Expand All @@ -258,17 +271,19 @@ def set_gamma_g(self, value):
value : a positive number or None
'''
if self.gamma_fconj is not None and value is not None:
raise ValueError("The adaptive update of the PDHG stepsizes in the case where both functions are strongly convex is not implemented at the moment." +\
"Currently the strongly convex constant of the convex conjugate of the function f has been specified as ", self.gamma_fconj)
raise ValueError("The adaptive update of the PDHG stepsizes in the case where both functions are strongly convex is not implemented at the moment." +
"Currently the strongly convex constant of the convex conjugate of the function f has been specified as ", self.gamma_fconj)

if isinstance (value, Number):
if isinstance(value, Number):
if value <= 0:
raise ValueError("Strongly convex constant is a positive number, {} is passed for the strongly convex function g.".format(value))
raise ValueError(
"Strongly convex constant is a positive number, {} is passed for the strongly convex function g.".format(value))
self._gamma_g = value
elif value is None:
pass
else:
raise ValueError("Positive float is expected for the strongly convex constant of function g, {} is passed".format(value))
raise ValueError(
"Positive float is expected for the strongly convex constant of function g, {} is passed".format(value))

def set_gamma_fconj(self, value):
'''Set the value of the strongly convex constant for the convex conjugate of function `f`
Expand All @@ -278,19 +293,21 @@ def set_gamma_fconj(self, value):
value : a positive number or None
'''
if self.gamma_g is not None and value is not None:
raise ValueError("The adaptive update of the PDHG stepsizes in the case where both functions are strongly convex is not implemented at the moment." +\
"Currently the strongly convex constant of the function g has been specified as ", self.gamma_g)
raise ValueError("The adaptive update of the PDHG stepsizes in the case where both functions are strongly convex is not implemented at the moment." +
"Currently the strongly convex constant of the function g has been specified as ", self.gamma_g)

if isinstance (value, Number):
if isinstance(value, Number):
if value <= 0:
raise ValueError("Strongly convex constant is positive, {} is passed for the strongly convex conjugate function of f.".format(value))
raise ValueError(
"Strongly convex constant is positive, {} is passed for the strongly convex conjugate function of f.".format(value))
self._gamma_fconj = value
elif value is None:
pass
else:
raise ValueError("Positive float is expected for the strongly convex constant of the convex conjugate of function f, {} is passed".format(value))
raise ValueError(
"Positive float is expected for the strongly convex constant of the convex conjugate of function f, {} is passed".format(value))

def set_up(self, f, g, operator, tau=None, sigma=None, initial=None, **kwargs):
def set_up(self, f, g, operator, tau=None, sigma=None, initial=None):
"""Initialisation of the algorithm
Parameters
Expand All @@ -301,14 +318,12 @@ def set_up(self, f, g, operator, tau=None, sigma=None, initial=None, **kwargs):
A convex function with a "simple" proximal.
operator : LinearOperator
A Linear Operator.
sigma : positive :obj:`float`, or `np.ndarray`, `DataContainer`, `BlockDataContainer`, optional, default=None
sigma : positive :obj:`float`, or `np.ndarray`, `DataContainer`, `BlockDataContainer`, optional, default is 1.0/norm(K) or 1.0/ (tau*norm(K)**2) if tau is provided
Step size for the dual problem.
tau : positive :obj:`float`, or `np.ndarray`, `DataContainer`, `BlockDataContainer`, optional, default=None
tau : positive :obj:`float`, or `np.ndarray`, `DataContainer`, `BlockDataContainer`, optional, default is 1.0/norm(K) or 1.0/ (sigma*norm(K)**2) if sigma is provided
Step size for the primal problem.
initial : DataContainer, optional, default=None
Initial point for the PDHG algorithm.
theta : Relaxation parameter, Number, default 1.0
"""
initial : DataContainer, optional, default is a DataContainer of zeros
Initial point for the PDHG algorithm. """
log.info("%s setting up", self.__class__.__name__)
# Triplet (f, g, K)
self.f = f
Expand All @@ -317,7 +332,7 @@ def set_up(self, f, g, operator, tau=None, sigma=None, initial=None, **kwargs):

self.set_step_sizes(sigma=sigma, tau=tau)

if kwargs.get('check_convergence', True):
if self._check_convergence:
self.check_convergence()

if initial is None:
Expand All @@ -330,9 +345,6 @@ def set_up(self, f, g, operator, tau=None, sigma=None, initial=None, **kwargs):
self.y = self.operator.range_geometry().allocate(0)
self.y_tmp = self.operator.range_geometry().allocate(0)

# relaxation parameter, default value is 1.0
self.theta = kwargs.get('theta',1.0)

if self.gamma_g is not None:
warnings.warn("Primal Acceleration of PDHG: The function g is assumed to be strongly convex with positive parameter `gamma_g`. You need to be sure that gamma_g = {} is the correct strongly convex constant for g. ".format(self.gamma_g))

Expand All @@ -357,25 +369,26 @@ def get_output(self):

def update(self):
"""Performs a single iteration of the PDHG algorithm"""
#calculate x-bar and store in self.x_tmp
self.x_old.sapyb((self.theta + 1.0), self.x, -self.theta, out=self.x_tmp)
# calculate x-bar and store in self.x_tmp
self.x_old.sapyb((self.theta + 1.0), self.x, -
self.theta, out=self.x_tmp)

# Gradient ascent for the dual variable
self.operator.direct(self.x_tmp, out=self.y_tmp)

self.y_tmp.sapyb(self.sigma, self.y, 1.0 , out=self.y_tmp)
self.y_tmp.sapyb(self.sigma, self.y, 1.0, out=self.y_tmp)

self.f.proximal_conjugate(self.y_tmp, self.sigma, out=self.y)

# Gradient descent for the primal variable
self.operator.adjoint(self.y, out=self.x_tmp)

self.x_tmp.sapyb(-self.tau, self.x_old, 1.0 , self.x_tmp)
self.x_tmp.sapyb(-self.tau, self.x_old, 1.0, self.x_tmp)

self.g.proximal(self.x_tmp, self.tau, out=self.x)

# update_previous_solution() called after update by base class
#i.e current solution is now in x_old, previous solution is now in x
# i.e current solution is now in x_old, previous solution is now in x

# update the step sizes for special cases
self.update_step_sizes()
Expand All @@ -390,10 +403,12 @@ def check_convergence(self):
"""
if isinstance(self.tau, Number) and isinstance(self.sigma, Number):
if self.sigma * self.tau * self.operator.norm()**2 > 1:
warnings.warn("Convergence criterion of PDHG for scalar step-sizes is not satisfied.")
warnings.warn(
"Convergence criterion of PDHG for scalar step-sizes is not satisfied.")
return False
return True
warnings.warn("Convergence criterion can only be checked for scalar values of tau and sigma.")
warnings.warn(
"Convergence criterion can only be checked for scalar values of tau and sigma.")
return False

def set_step_sizes(self, sigma=None, tau=None):
Expand All @@ -413,16 +428,20 @@ def set_step_sizes(self, sigma=None, tau=None):
if tau is not None:
if isinstance(tau, Number):
if tau <= 0:
raise ValueError("The step-sizes of PDHG must be positive, passed tau = {}".format(tau))
raise ValueError(
"The step-sizes of PDHG must be positive, passed tau = {}".format(tau))
elif tau.shape != self.operator.domain_geometry().shape:
raise ValueError(" The shape of tau = {0} is not the same as the shape of the domain_geometry = {1}".format(tau.shape, self.operator.domain_geometry().shape))
raise ValueError(" The shape of tau = {0} is not the same as the shape of the domain_geometry = {1}".format(
tau.shape, self.operator.domain_geometry().shape))

if sigma is not None:
if isinstance(sigma, Number):
if sigma <= 0:
raise ValueError("The step-sizes of PDHG are positive, passed sigma = {}".format(sigma))
raise ValueError(
"The step-sizes of PDHG are positive, passed sigma = {}".format(sigma))
elif sigma.shape != self.operator.range_geometry().shape:
raise ValueError(" The shape of sigma = {0} is not the same as the shape of the range_geometry = {1}".format(sigma.shape, self.operator.range_geometry().shape))
raise ValueError(" The shape of sigma = {0} is not the same as the shape of the range_geometry = {1}".format(
sigma.shape, self.operator.range_geometry().shape))

# Default sigma and tau step-sizes
if tau is None and sigma is None:
Expand All @@ -438,7 +457,8 @@ def set_step_sizes(self, sigma=None, tau=None):
self._sigma = sigma
self._tau = 1./(self.sigma*self.operator.norm()**2)
else:
raise NotImplementedError("If using arrays for sigma or tau both must arrays must be provided.")
raise NotImplementedError(
"If using arrays for sigma or tau both must arrays must be provided.")

def update_step_sizes(self):
"""
Expand All @@ -447,14 +467,14 @@ def update_step_sizes(self):
"""
# Update sigma and tau based on the strong convexity of G
if self.gamma_g is not None:
self.theta = 1.0/ np.sqrt(1 + 2 * self.gamma_g * self.tau)
self._theta = 1.0 / np.sqrt(1 + 2 * self.gamma_g * self.tau)
self._tau *= self.theta
self._sigma /= self.theta

# Update sigma and tau based on the strong convexity of F
# Following operations are reversed due to symmetry, sigma --> tau, tau -->sigma
if self.gamma_fconj is not None:
self.theta = 1.0 / np.sqrt(1 + 2 * self.gamma_fconj * self.sigma)
self._theta = 1.0 / np.sqrt(1 + 2 * self.gamma_fconj * self.sigma)
self._sigma *= self.theta
self._tau /= self.theta

Expand Down
47 changes: 45 additions & 2 deletions Wrappers/Python/test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
if has_cvxpy:
import cvxpy


import warnings
class TestGD(CCPiTestClass):
def setUp(self):

Expand Down Expand Up @@ -700,6 +700,10 @@ def test_update(self):
beta= ((self.data - self.initial-alpha*(self.data-self.initial)).norm()**2)/4
self.assertNumpyArrayEqual(self.alg.p.as_array(), ((self.data - self.initial-alpha*(self.data-self.initial))+beta*(self.data-self.initial)).as_array())
class TestPDHG(CCPiTestClass):





def test_PDHG_Denoising(self):
# adapted from demo PDHG_TV_Color_Denoising.py in CIL-Demos repository
Expand Down Expand Up @@ -876,13 +880,32 @@ def test_PDHG_step_sizes(self):
with self.assertRaises(AttributeError):
pdhg = PDHG(f=f, g=g, operator=operator,
tau="tau")



# check warning message if condition is not satisfied
sigma = 4
sigma = 4/operator.norm()
tau = 1/3
with self.assertWarnsRegex(UserWarning, "Convergence criterion"):
pdhg = PDHG(f=f, g=g, operator=operator, tau=tau,
sigma=sigma)

# check no warning message if check convergence is false
sigma = 4/operator.norm()
tau = 1/3
with warnings.catch_warnings(record=True) as warnings_log:
pdhg = PDHG(f=f, g=g, operator=operator, tau=tau,
sigma=sigma, check_convergence=False)
self.assertEqual(warnings_log, [])

# check no warning message if condition is satisfied
sigma = 1/operator.norm()
tau = 1/3
with warnings.catch_warnings(record=True) as warnings_log:
pdhg = PDHG(f=f, g=g, operator=operator, tau=tau,
sigma=sigma)
self.assertEqual(warnings_log, [])


def test_PDHG_strongly_convex_gamma_g(self):
ig = ImageGeometry(3, 3)
Expand Down Expand Up @@ -970,7 +993,27 @@ def test_PDHG_strongly_convex_both_fconj_and_g(self):
except ValueError as err:
log.info(str(err))

def test_pdhg_theta(self):
ig = ImageGeometry(3, 3)
data = ig.allocate('random')

f = L2NormSquared(b=data)
g = L2NormSquared()
operator = IdentityOperator(ig)

pdhg = PDHG(f=f, g=g, operator=operator)
self.assertEqual(pdhg.theta, 1.0)

pdhg = PDHG(f=f, g=g, operator=operator,theta=0.5)
self.assertEqual(pdhg.theta, 0.5)

with self.assertRaises(ValueError):
PDHG( f=f, g=g, operator=operator, theta=-0.5)

with self.assertRaises(ValueError):
PDHG( f=f, g=g, operator=operator, theta=5)


class TestSIRT(CCPiTestClass):

def setUp(self):
Expand Down

0 comments on commit 554a7fd

Please sign in to comment.