Skip to content

Commit 3f05081

Browse files
committed
adding pnp_ula notebook and required functions to sampling_tools
1 parent ba7ff35 commit 3f05081

File tree

9 files changed

+931
-7
lines changed

9 files changed

+931
-7
lines changed

PnP_ULA_notebook.ipynb

Lines changed: 681 additions & 0 deletions
Large diffs are not rendered by default.
10.5 MB
Binary file not shown.
10.5 MB
Binary file not shown.
10.5 MB
Binary file not shown.

SKROCK_notebook.ipynb

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"cells": [
33
{
4+
"attachments": {},
45
"cell_type": "markdown",
56
"id": "5f612dcb",
67
"metadata": {},
@@ -31,6 +32,7 @@
3132
]
3233
},
3334
{
35+
"attachments": {},
3436
"cell_type": "markdown",
3537
"id": "38c5335e",
3638
"metadata": {},
@@ -63,6 +65,7 @@
6365
]
6466
},
6567
{
68+
"attachments": {},
6669
"cell_type": "markdown",
6770
"id": "e8310c03",
6871
"metadata": {},
@@ -87,6 +90,7 @@
8790
]
8891
},
8992
{
93+
"attachments": {},
9094
"cell_type": "markdown",
9195
"id": "57f67eda",
9296
"metadata": {},
@@ -110,6 +114,7 @@
110114
]
111115
},
112116
{
117+
"attachments": {},
113118
"cell_type": "markdown",
114119
"id": "bb1f66c7",
115120
"metadata": {},
@@ -174,6 +179,7 @@
174179
]
175180
},
176181
{
182+
"attachments": {},
177183
"cell_type": "markdown",
178184
"id": "cd29e8a8",
179185
"metadata": {},
@@ -195,6 +201,7 @@
195201
]
196202
},
197203
{
204+
"attachments": {},
198205
"cell_type": "markdown",
199206
"id": "652e4b2d",
200207
"metadata": {},
@@ -213,6 +220,7 @@
213220
]
214221
},
215222
{
223+
"attachments": {},
216224
"cell_type": "markdown",
217225
"id": "05b2eacc",
218226
"metadata": {},
@@ -278,6 +286,7 @@
278286
]
279287
},
280288
{
289+
"attachments": {},
281290
"cell_type": "markdown",
282291
"id": "424ac110",
283292
"metadata": {},
@@ -331,6 +340,7 @@
331340
]
332341
},
333342
{
343+
"attachments": {},
334344
"cell_type": "markdown",
335345
"id": "cae0887a",
336346
"metadata": {},
@@ -352,6 +362,7 @@
352362
]
353363
},
354364
{
365+
"attachments": {},
355366
"cell_type": "markdown",
356367
"id": "11ca71d9",
357368
"metadata": {},
@@ -458,6 +469,7 @@
458469
]
459470
},
460471
{
472+
"attachments": {},
461473
"cell_type": "markdown",
462474
"id": "6316b76d",
463475
"metadata": {},
@@ -490,6 +502,7 @@
490502
]
491503
},
492504
{
505+
"attachments": {},
493506
"cell_type": "markdown",
494507
"id": "94eb2c4d",
495508
"metadata": {},
@@ -516,6 +529,7 @@
516529
]
517530
},
518531
{
532+
"attachments": {},
519533
"cell_type": "markdown",
520534
"id": "72857df8",
521535
"metadata": {},
@@ -538,6 +552,7 @@
538552
]
539553
},
540554
{
555+
"attachments": {},
541556
"cell_type": "markdown",
542557
"id": "38392096",
543558
"metadata": {},
@@ -601,6 +616,7 @@
601616
]
602617
},
603618
{
619+
"attachments": {},
604620
"cell_type": "markdown",
605621
"id": "35f9cffe",
606622
"metadata": {},
@@ -621,6 +637,7 @@
621637
]
622638
},
623639
{
640+
"attachments": {},
624641
"cell_type": "markdown",
625642
"id": "8bf135fe",
626643
"metadata": {},
@@ -641,6 +658,7 @@
641658
]
642659
},
643660
{
661+
"attachments": {},
644662
"cell_type": "markdown",
645663
"id": "5db60f1d",
646664
"metadata": {},
@@ -660,6 +678,7 @@
660678
]
661679
},
662680
{
681+
"attachments": {},
663682
"cell_type": "markdown",
664683
"id": "ae94380a-7cc6-4293-aabb-a65faec537bc",
665684
"metadata": {},
@@ -682,6 +701,7 @@
682701
]
683702
},
684703
{
704+
"attachments": {},
685705
"cell_type": "markdown",
686706
"id": "1e9651db-2976-416e-97e6-cef3ef681fd2",
687707
"metadata": {},

sampling_tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
from .tv_norm import *
99
from .plots import *
1010
from .welford import *
11+
from .load_model import *
12+
from .spectral_normalize_chen import *

sampling_tools/load_model.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
from sampling_tools.spectral_normalize_chen import spectral_norm
5+
6+
class DnCNN(nn.Module):
7+
def __init__(self, channels, num_of_layers=17):
8+
super(DnCNN, self).__init__()
9+
kernel_size = 3
10+
padding = 1
11+
features = 64
12+
layers = []
13+
layers.append(spectral_norm(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)))
14+
layers.append(nn.ReLU(inplace=True))
15+
for _ in range(num_of_layers-2):
16+
layers.append(spectral_norm(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)))
17+
layers.append(nn.BatchNorm2d(features))
18+
layers.append(nn.ReLU(inplace=True))
19+
layers.append(spectral_norm(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False)))
20+
self.dncnn = nn.Sequential(*layers)
21+
def forward(self, x):
22+
out = self.dncnn(x)
23+
return out
24+
25+
# ---- load the model based on the type and sigma (noise level) ----
26+
def load_model(model_type, sigma,device):
27+
28+
path = "Pretrained_models/" + model_type + "_noise" + str(sigma) + ".pth"
29+
30+
net = DnCNN(channels=1, num_of_layers=17)
31+
model = nn.DataParallel(net).cuda(device)
32+
33+
model.load_state_dict(torch.load(path))
34+
model.eval()
35+
36+
return model

sampling_tools/plots.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def plot_trace(x, title="Title", x_label="xlabel"):
5151
plt.xlabel(x_label)
5252

5353

54-
def plots(x,y,post_meanvar,post_meanvar_absfourier, nrmse_values, psnr_values, ssim_values, logPi_trace):
54+
def plots(x,y,post_meanvar,post_meanvar_absfourier, nrmse_values, psnr_values, ssim_values, logPi_trace=None):
5555

5656
post_mean_numpy = post_meanvar.get_mean().detach().cpu().numpy()
5757
post_var_numpy = post_meanvar.get_var().detach().cpu().numpy()
@@ -138,12 +138,12 @@ def plots(x,y,post_meanvar,post_meanvar_absfourier, nrmse_values, psnr_values, s
138138
plt.close()
139139

140140
# --- log pi
141-
plot = plt.figure(figsize = (15,10))
142-
143-
plt.plot(np.arange(len(logPi_trace))[::10],logPi_trace[::10], label = "- $\log \pi$ -")
144-
plt.legend()
145-
plt.show()
146-
plt.close()
141+
if type(logPi_trace) == np.ndarray:
142+
plot = plt.figure(figsize = (15,10))
143+
plt.plot(np.arange(len(logPi_trace))[::10],logPi_trace[::10], label = "- $\log \pi$ -")
144+
plt.legend()
145+
plt.show()
146+
plt.close()
147147

148148

149149
def downsampling_variance(X_chain):

0 commit comments

Comments
 (0)