-
Notifications
You must be signed in to change notification settings - Fork 17
/
guass.py
43 lines (35 loc) · 1.44 KB
/
guass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch
def gkern(kernlen=21, nsig=3):
"""Returns a 2D Gaussian kernel array."""
import scipy.stats as st
interval = (2*nsig+1.)/(kernlen)
x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
kern1d = np.diff(st.norm.cdf(x))
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
kernel = kernel_raw/kernel_raw.sum()
return kernel
class GaussianBlur(nn.Module):
def __init__(self, kernel):
super(GaussianBlur, self).__init__()
self.kernel_size = len(kernel)
print('kernel size is {0}.'.format(self.kernel_size))
assert self.kernel_size % 2 == 1, 'kernel size must be odd.'
self.kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
self.weight = nn.Parameter(data=self.kernel, requires_grad=False)
def forward(self, x):
x1 = x[:,0,:,:].unsqueeze_(1)
x2 = x[:,1,:,:].unsqueeze_(1)
x3 = x[:,2,:,:].unsqueeze_(1)
padding = self.kernel_size // 2
x1 = F.conv2d(x1, self.weight, padding=padding)
x2 = F.conv2d(x2, self.weight, padding=padding)
x3 = F.conv2d(x3, self.weight, padding=padding)
x = torch.cat([x1, x2, x3], dim=1)
return x
def get_gaussian_blur(kernel_size, device):
kernel = gkern(kernel_size, 2).astype(np.float32)
gaussian_blur = GaussianBlur(kernel)
return gaussian_blur.to(device)