-
Notifications
You must be signed in to change notification settings - Fork 0
/
patcher.py
97 lines (89 loc) · 2.56 KB
/
patcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# -*-coding:utf-8 -*-
'''
Created on 17/10/2022
@author: Carlos
'''
import numpy as np
from utils import Util
def Patcher(data, mask_data, patch_size, stride=1, p2v=0):
'''Create patches from image
Parameters:
- data: numpy array to be patched
- mask_data: binary mask
- patch_size: patch size
- stride: sliding window
- p2v: 1 for patch2vox training
Return:
- Created patches
- Patches' shape
- Original image shape
- Padding values
- Binary mask patched
'''
util=Util('None', 'None')
if data.shape[0]!=data.shape[1] or data.shape[0]%patch_size!=0 or data.shape[2]%patch_size!=0: data = util.padding(data, patch_size, p2v); mask_data=util.padding(mask_data, patch_size, p2v)
else: util.pad=((0,0),(0,0))
l=data.shape[0]
w=data.shape[1]
h=data.shape[2]
ldiv = wdiv = hdiv = patch_size
patched_bval=[]; mask_patched=[]
hpre=0
hpos=patch_size
while hpos<=h:
lpre=0
lpos=patch_size
while lpos<=l:
wpre=0
wpos=patch_size
while wpos<=w:
patch = data[lpre:lpos,wpre:wpos,hpre:hpos,:]
mask_patch = mask_data[lpre:lpos,wpre:wpos,hpre:hpos,:]
patched_bval.append(patch)
mask_patched.append(mask_patch)
wpre+=stride
wpos+=stride
lpre+=stride
lpos+=stride
hpre+=stride
hpos+=stride
patch_data=np.array(patched_bval)
mask_data=np.array(mask_patched)
return patch_data, (ldiv,wdiv,hdiv), (data.shape[0],data.shape[1],data.shape[2]), util.pad, mask_data
def reconstruct(patched_data, divs, maxs, stride=1):
'''Reconstruct image from patches
Parameters:
- patched_data: list of patches
- divs: patches' shape
- maxs: original image limits
- stride: sliding window
Return:
- Unpatched image as numpy array
'''
ldiv=divs[0]
wdiv=divs[1]
hdiv=divs[2]
lpre = wpre = hpre = 0
lpos=ldiv
wpos=wdiv
hpos=hdiv
res=np.zeros(maxs+(patched_data.shape[1],))
for patch in patched_data:
patch=np.transpose(patch.cpu().detach().numpy(),(2,3,1,0))
res[lpre:lpos,wpre:wpos,hpre:hpos,:]=patch
if wpos<maxs[1]:
wpre+=stride
wpos+=stride
elif lpos<maxs[0]:
wpre=0
wpos=wdiv
lpre+=stride
lpos+=stride
else:
wpre=0
lpre=0
lpos=ldiv
wpos=wdiv
hpre+=stride
hpos+=stride
return res