Skip to content

Commit

Permalink
Update get_pad_and_patch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KaijieMo1 authored Nov 2, 2020
1 parent fd40e9c commit be87ae7
Showing 1 changed file with 137 additions and 24 deletions.
161 changes: 137 additions & 24 deletions med_io/get_pad_and_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def get_fixed_patches_index(config, max_fix_img_size, patch_size, overlap_rate=0
it cannot iterate the specific (individual and different) image size of each single image.
Thus a fix grid of patches is created before creating pipeline.
Note:Image and label must have the same size!
:param config: type dict: config parameter
:param max_fix_img_size: type list of int: size of unpatched image,
the length must be greater than or equal to the length of :param: patch_size
:param patch_size: type list of int: patch size images
Expand All @@ -75,22 +74,29 @@ def get_fixed_patches_index(config, max_fix_img_size, patch_size, overlap_rate=0
if start is None: start = np.array([0] * dim)
assert (len(start) == len(overlap_rate) == dim)
patch_size = [tf.math.minimum(max_fix_img_size[i], patch_size[i]) for i in range(dim)]
end1 = [max_fix_img_size[i] - patch_size[i] for i in range(dim)] # end points int list

end1 = [max_fix_img_size[i] - patch_size[i] for i in range(dim)] # stop int list

if end is not None:
for i in range(dim):
if end[i] > end1[i]: end[i] = end1[i]
else:
end = end1
if not config['patch_probability_distribution']['use']:
# Patching with tiling method
step = patch_size - np.round(overlap_rate * patch_size)
step = patch_size - np.round(overlap_rate * np.array(patch_size))
for st in step:
if st <= 0: raise ValueError('step of patches must greater than 0.')

# Sampling patch index with start, end, step
slice_ = (*[slice(start[i], end[i] + step[i] - 1, step[i]) for i in range(dim)],)
index_list = np.array(np.mgrid[slice_].reshape(dim, -1).T, dtype=np.int)

indices_max_bound = [max_fix_img_size[i] - patch_size[i] for i in range(dim)]
for j, index in enumerate(index_list):
# Limiting the patching indices
index_list[j] = [max(min(index[i], indices_max_bound[i]), 0)
for i in range(dim)]
else:
# patching with probability method
index_list = [[0] * dim]
Expand All @@ -107,6 +113,8 @@ def get_fixed_patches_index(config, max_fix_img_size, patch_size, overlap_rate=0
sigma = config['patch_probability_distribution']['normal']['sigma']
else:
sigma = end - start # default std value
print(start, end, mu, sigma)

# Still some problems here, Tensorflow doesn't support type NPY_INT
lst = [
scipy.stats.truncnorm.rvs((start[i] - mu[i]) / sigma, (end[i] - mu[i]) / sigma, loc=mu[i],
Expand All @@ -118,9 +126,10 @@ def get_fixed_patches_index(config, max_fix_img_size, patch_size, overlap_rate=0
# Patching sampling with truncated uniform distribution
lst = [np.random.uniform(start[i], end[i], size=N)[:, 0] for i in range(dim)] # [:, 0]
index_list = np.stack(lst, axis=-1).astype(np.int32)
shuffle=config['index_shuffle'] # set config['index_shuffle'] =True can greatly enhance the quality of the results.
shuffle = config['index_shuffle']
if shuffle: np.random.shuffle(index_list)
if max_patch_num: index_list = index_list[:max_patch_num]

return index_list


Expand Down Expand Up @@ -193,6 +202,8 @@ def unpatch_predict_image(data_patches, indice_list, patch_size, unpatch_data_si
"""
# Data_patches list
dim = len(patch_size)
# add end

data_patch_size = np.array(data_patches[0]).shape

assert (len(data_patches) == len(indice_list))
Expand All @@ -210,6 +221,8 @@ def unpatch_predict_image(data_patches, indice_list, patch_size, unpatch_data_si
weight_patch = np.ones((*patch_size,) + (data_patch_size[-1],))
else:
weight_patch = np.ones((*output_patch_size,) + (data_patch_size[-1],))
print('weight_patch.shape', weight_patch.shape)

for data_patch, index in zip(data_patches, indice_list):

# Indexing using function slice for variable dim, Indexing last channel by slice(None, None),equivalent to [:]
Expand Down Expand Up @@ -241,45 +254,42 @@ def unpatch_predict_image(data_patches, indice_list, patch_size, unpatch_data_si
return unpatch_img


def prediction_prob_and_decision(config, patch_prob_img, indice_list):
def prediction_prob(config, patch_prob_img, indice_list):
"""
Specially for prediction the results by network body part identification
:param config: type dict: config parameter
:param patch_prob_img: type list:
:param indice_list: type list, position of corresonded :param patch_prob_img
:return: prob_map: type ndarray: predicted probability map
:return: decision_map: type ndarray: predicted decision map
:param config:
:param patch_prob_img: size(patch num1146, class6)
:param indice_list:
:return:
"""

n_classes = config['body_identification_n_classes']
# patch_prob_img size(len of indice_list, n_classes)

patch_shape = config['patch_size']
# Body identification patch size [1, X,Y]
# Initialize the map
# patch_prob_img[0] represents total num of patches.
# Initialize
#
patch_shape = config['patch_size'] # Body identification patch size [1, X,Y]
# patch_prob_img[0] is total num of patches.,=
patch_prob_maps = np.zeros([len(indice_list), patch_shape[1], patch_shape[2], n_classes])
patch_decision_maps = np.zeros([len(indice_list), patch_shape[1], patch_shape[2], n_classes])
print('patch_prob_maps,line258', patch_prob_maps.shape)

for i, pos in enumerate(indice_list):
# one hot map
# patch_decision_maps +1 at the correspondent pixels
# one hot matrix
patch_decision_maps[i, :, :, np.argmax(patch_prob_img[i, :])] += 1

# patch_prob_maps +predict results at the correspondent pixels
for class_ in range(n_classes):
patch_prob_maps[i, :, :, class_] += patch_prob_img[i, class_]

# u
# sio.savemat('t.mat',{'d':patch_decision_maps,'p':patch_prob_maps})
prob_map = unpatch_predict_image(patch_prob_maps, indice_list, patch_shape, set_zero_by_threshold=False)
decision_map = unpatch_predict_image(patch_decision_maps, indice_list, patch_shape, set_zero_by_threshold=False)

return prob_map, decision_map


def get_patches_data(data_size, patch_size, data_img, data_label, index_list, random_rate=0.3,
slice_channel_img=None, slice_channel_label=None, output_patch_size=None, random_shift_patch=True,
squeeze_channel=False, squeeze_dimension=None):
def get_patches_data_(data_size, patch_size, data_img, data_label, index_list, random_rate=0.3,
slice_channel_img=None, slice_channel_label=None, output_patch_size=None, random_shift_patch=True,
squeeze_channel=False, squeeze_dimension=None):
"""
Get patches from unpatched image and correspondent label by the list of patch positions.
:param data_size: type ndarray: data size of :param: data_img and :param data_label
Expand Down Expand Up @@ -307,6 +317,9 @@ def get_patches_data(data_size, patch_size, data_img, data_label, index_list, ra

for j, index in enumerate(index_list):

if index_list[2] > image_shape[2] - patch_size[2]:
continue

# Limiting the patching indices
index_list[j] = [max(min(index[i], indices_max_bound[i]), 0)
for i in range(dim)]
Expand All @@ -322,7 +335,6 @@ def get_patches_data(data_size, patch_size, data_img, data_label, index_list, ra

# indexing using function slice for variable dim,indexing last channel by slice(None, None),equivalent to [:]
# Get patch image data

patch_img_collection = [
data_img[(*[slice(index[i], index[i] + patch_size[i]) for i in range(dim)]
+ [slice(None, None)],)]
Expand Down Expand Up @@ -364,8 +376,109 @@ def get_patches_data(data_size, patch_size, data_img, data_label, index_list, ra

if squeeze_dimension is not None:
patch_img_collection = [img[..., 0, :] for img in patch_img_collection]

if slice_channel_label is not None:
# Select the label channel for patching

patch_label_collection = [tf.stack([label[..., i] for i in slice_channel_label], axis=-1) for label in
patch_label_collection]

if squeeze_dimension is not None:
patch_label_collection = [label[..., 0, :] for label in patch_label_collection]

return patch_img_collection, patch_label_collection, index_list


def get_patches_data(data_size, patch_size, data_img, data_label, index_list, random_rate=0.3,
slice_channel_img=None, slice_channel_label=None, output_patch_size=None, random_shift_patch=True,
squeeze_channel=False, squeeze_dimension=None, images_shape=None):
"""
Get patches from unpatched image and correspondent label by the list of patch positions.
:param data_size: type ndarray: data size of :param: data_img and :param data_label
:param patch_size: type list of int: patch size images
:param data_img: type ndarray: unpatched image data with channel,
if 3D image, then its shape is [height,width,depth,channel].
:param data_label: type ndarray: unpatch label data with channel,
if 3D image, then its shape is [height,width,depth,channel].
:param index_list: type list of list of integers: list position of each patch
:param slice_channel_img: type list of int: channel indice chosen for model inputs,
if :param squeeze_channel is true, the img dimension remains same, else reduce 1.
:param slice_channel_label: type list of int: channel indice chosen for model outputs
:param output_patch_size: type list of int: model output size
:param random_rate: type float,rate of random shift of position from :param index_list. random_rate=0 if no shift.
:param random_shift_patch: type bool, True if the patches are randomly shift for data augmentation.
:param squeeze_channel: type bool, True if select image channel. else all channel will be as input if :param slice_channel_img is False.
:return: patch_img_collection: type list of ndarray with the shape :param patch_size: list of patches images.
:return: patch_label_collection type list of ndarray with the shape :param patch_size: list of patches labels.
:return: index_list: type list of int. Position of the patch.
"""
dim = len(patch_size)
indices_max_bound = [data_size[i] - patch_size[i] for i in range(dim)]

for j, index in enumerate(index_list):

# Limiting the patching indices
index_list[j] = [max(min(index[i], indices_max_bound[i]), 0)
for i in range(dim)]

if random_shift_patch:
# Shift patches indices for data augmentation
new_index = [
index[i] + random.randint(int(-patch_size[i] * random_rate / 2), int(patch_size[i] * random_rate / 2))
for i in range(dim)]
index_list[j] = [new_index[i] if (indices_max_bound[i] >= new_index[i] >= 0)
else max(min(index[i], indices_max_bound[i]), 0)
for i in range(dim)]

# indexing using function slice for variable dim,indexing last channel by slice(None, None),equivalent to [:]
# Get patch image data

patch_img_collection = [data_img[(*[slice(index[i], index[i]) for i in range(dim)]
+ [slice(None, None)],)]
for index in index_list]

patch_label_collection = None

if output_patch_size is not None:
# If input label shape>=output label shape -> Enlarge label patch
for j in range(dim): assert patch_size[j] >= output_patch_size[j]
diff = (np.array(patch_size) - np.array(output_patch_size)) // 2

# Get label data with size= output_patch_size, keep the centre with same as image patch.
if data_label is not None:
patch_label_collection = [
data_label[(*[slice(index[i] + diff[i], index[i] + diff[i] + output_patch_size[i]) for i in range(dim)]
+ [slice(None, None)],)]
for index in index_list]
else:
# If input label shape==output label shape
if data_label is not None:
patch_label_collection = [
data_label[(*[slice(index[i], index[i] + patch_size[i]) for i in range(dim)]
+ [slice(None, None)],)]
for index in index_list]

# Select channels for input images and labels by the yaml file
if slice_channel_img is not None:
if not squeeze_channel:

# Select the image channel for patching
patch_img_collection = [tf.stack([img[..., i] for i in slice_channel_img], axis=-1) for img in
patch_img_collection]
else:

# Reduce one dimension (especially for network Body Identification)
patch_img_collection = [img[..., 0] for img in
patch_img_collection]

if squeeze_dimension is not None:
patch_img_collection = [img[..., 0, :] for img in patch_img_collection]

if slice_channel_label is not None:
# Select the label channel for patching

patch_label_collection = [tf.stack([label[..., i] for i in slice_channel_label], axis=-1) for label in
patch_label_collection]

Expand Down

0 comments on commit be87ae7

Please sign in to comment.