-
Notifications
You must be signed in to change notification settings - Fork 106
/
Copy pathproximal_gradient_denoising.py
66 lines (45 loc) · 1.82 KB
/
proximal_gradient_denoising.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""L1-regularized denoising using the proximal gradient solvers.
Solves the optimization problem
min_x || x - g ||_1 + lam || grad(x) ||_2^2
Where ``grad`` is the spatial gradient operator and ``g`` is given noisy data.
The proximal gradient solvers are also known as ISTA and FISTA.
"""
import odl
# --- Set up problem definition --- #
# Define function space: discretized functions on the rectangle
# [-20, 20]^2 with 300 samples per dimension.
space = odl.uniform_discr(
min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300])
# Create phantom
data = odl.phantom.shepp_logan(space, modified=True)
data = odl.phantom.salt_pepper_noise(data)
# Create gradient operator
grad = odl.Gradient(space)
# --- Set up the inverse problem --- #
# Create data discrepancy by translating the l1 norm
l1_norm = odl.solvers.L1Norm(space)
data_discrepancy = l1_norm.translated(data)
# l2-squared norm of gradient
regularizer = 0.05 * odl.solvers.L2NormSquared(grad.range) * grad
# --- Select solver parameters and solve using proximal gradient --- #
# Select step-size that guarantees convergence.
gamma = 0.01
# Optionally pass callback to the solver to display intermediate results
callback = (odl.solvers.CallbackPrintIteration() &
odl.solvers.CallbackShow())
# Run the algorithm (ISTA)
x = space.zero()
odl.solvers.proximal_gradient(
x, f=data_discrepancy, g=regularizer, niter=200, gamma=gamma,
callback=callback)
# Compare to accelerated version (FISTA) which is much faster
callback.reset()
x_acc = space.zero()
odl.solvers.accelerated_proximal_gradient(
x_acc, f=data_discrepancy, g=regularizer, niter=50, gamma=gamma,
callback=callback)
# Display images
data.show(title='Data')
x.show(title='L1 Regularized Reconstruction')
x_acc.show(title='L1 Regularized Reconstruction (Accelerated)',
force_show=True)