Skip to content

Commit

Permalink
fixed bug and updated docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
sfarrens committed Feb 10, 2020
1 parent e1923ee commit 3062ec7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
21 changes: 18 additions & 3 deletions modopt/opt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,33 @@ class SetUp(Observable):
"""Algorithm Set-Up
This class contains methods for checking the set-up of an optimisation
algotithm and produces warnings if they do not comply
algotithm and produces warnings if they do not comply.
Parameters
----------
metric_call_period : int, optional
Metric call period (default is ``5``)
metrics : dict, optional
Metrics to be used (default is ``{}``)
verbose : bool, optional
Option for verbose output (default is ``False``)
progress : bool, optional
Option to display progress bar (default is ``True``)
step_size : int, optional
Generic step size parameter to override default algorithm
parameter name (`e.g.` `step_size` will override the value set for
`beta_param` in `ForwardBackward`)
"""

def __init__(self, metric_call_period=5, metrics={}, verbose=False,
progress=True, **dummy_kwargs):
progress=True, step_size=None, **dummy_kwargs):

self.converge = False
self.verbose = verbose
self.progress = progress
self.metrics = metrics
self.step_size = None
self.step_size = step_size
self._op_parents = ('GradParent', 'ProximityParent', 'LinearParent',
'costObj')

Expand Down
5 changes: 4 additions & 1 deletion modopt/tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def setUp(self):
prox_list=[prox_inst,
prox_dual_inst],
cost=cost_inst,
step_size=1)
step_size=2)
self.condat1 = algorithms.Condat(self.data1, self.data2,
grad=grad_inst,
prox=prox_inst,
Expand Down Expand Up @@ -174,6 +174,9 @@ def test_gen_forward_backward(self):
npt.assert_array_equal(self.gfb3.x_final, self.data1,
err_msg='Incorrect GenForwardBackward result.')

npt.assert_equal(self.gfb3.step_size, 2,
err_msg='Incorrect step size.')

npt.assert_raises(TypeError, algorithms.GenForwardBackward,
self.data1, self.dummy, [self.dummy], weights=1)

Expand Down

0 comments on commit 3062ec7

Please sign in to comment.