-
Notifications
You must be signed in to change notification settings - Fork 0
/
radon.py
128 lines (92 loc) · 4.1 KB
/
radon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from functools import partial
import numpy as np
from interpolations import nearest_neighbor_interpolate, bilinear_interpolate
import multiprocessing as mp
def rotate_theta(theta: float) -> np.ndarray:
"""generate rotation matrix for the angle theta
returns: (2, 2) 2x2 rotation matrix
"""
return np.array([
[np.cos(theta), np.sin(theta)],
[-np.sin(theta), np.cos(theta)],
], dtype=np.float64)
def rotate_image(image: np.ndarray, angles, method: str = 'nearest'):
"""Rotates an image by the given angles using a chosen interpolation method.
Arguments:
image (np.ndarray): an NxN grayscale image to rotate
angles (np.ndarray): a list of rotation angles to rotate the image with
method (str): interpolation method ('nearest', 'bilinear')
Returns:
np.ndarray: (len(angles)xNxN) array of the rotated images
"""
if method == 'nearest':
interpolate = nearest_neighbor_interpolate
elif method == 'bilinear':
interpolate = bilinear_interpolate
else:
raise ValueError('method must be either "nearest" or "bilinear"')
height, width = image.shape
if not height == width:
raise ValueError('image must be square')
cx, cy = width / 2, height / 2
y, x = np.meshgrid(np.arange(height), np.arange(width))
x = x - cx
y = y - cy
angles_rad = np.deg2rad(angles)
rotated_images = np.zeros((len(angles), height, width), dtype=image.dtype)
for i, angle in enumerate(angles_rad):
new_coords = np.matmul(rotate_theta(angle), np.vstack((x.flatten(), y.flatten())))
new_x = new_coords[0].reshape(height, width) + cx
new_y = new_coords[1].reshape(height, width) + cy
new_x = np.clip(new_x, 0, width - 1)
new_y = np.clip(new_y, 0, height - 1)
rotated_images[i] = interpolate(image, new_x, new_y)
return rotated_images
def rotate_image_mp(image: np.ndarray, angles, method: str = 'nearest'):
"""Rotates an image by the given angles using a chosen interpolation method.
Arguments:
image (np.ndarray): an NxN grayscale image to rotate
angles (np.ndarray): a list of rotation angles to rotate the image with
method (str): interpolation method ('nearest', 'bilinear')
Returns:
np.ndarray: (len(angles)xNxN) array of the rotated images
"""
if method == 'nearest':
interpolate = nearest_neighbor_interpolate
elif method == 'bilinear':
interpolate = bilinear_interpolate
else:
raise ValueError('method must be either "nearest" or "bilinear"')
angles_rad = np.deg2rad(angles)
# rotated_images = np.zeros((len(angles), height, width), dtype=image.dtype)
with mp.Pool(processes=mp.cpu_count()) as pool:
rotated_images = pool.map_async(partial(pool_rotate, image, interpolate), angles_rad)
rotated_images.wait()
return np.array(rotated_images.get())
def pool_rotate(image, interpolate, angle):
height, width = image.shape
cx, cy = height / 2, width / 2
y, x = np.meshgrid(np.arange(height), np.arange(width))
x = x - cx
y = y - cy
new_coords = np.matmul(rotate_theta(angle), np.vstack((x.flatten(), y.flatten())))
new_x = new_coords[0].reshape(height, width) + cx
new_y = new_coords[1].reshape(height, width) + cy
new_x = np.clip(new_x, 0, width - 1)
new_y = np.clip(new_y, 0, height - 1)
return interpolate(image, new_x, new_y)
def radon(image: np.ndarray,
angles: np.ndarray = np.arange(180),
method: str = 'nearest',
mp: bool = True) -> np.ndarray:
"""Apply a radon transform to an image.
Arguments:
image (np.ndarray): an NxN grayscale image to radon transform
angles (np.ndarray): a list of rotation angles to use in the radon transform
method (str): interpolation method to use in rotation for discrete image ('nearest', 'bilinear')
Returns:
np.ndarray: The sinogram result of the radon transform.
"""
if mp:
return rotate_image_mp(image, angles, method).sum(axis=-1).T
return rotate_image(image, angles, method).sum(axis=-1).T