Skip to content
This repository has been archived by the owner on Oct 6, 2021. It is now read-only.

Refactor injection site finder #12

Merged
merged 26 commits into from
Feb 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed .DS_Store
Binary file not shown.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,7 @@ venv.bak/
.idea/
.vs/
*.~lock.*


# macOS
*.DS_Store
2 changes: 1 addition & 1 deletion docs/searchindex.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file removed neuro/injection_finder/.DS_Store
Binary file not shown.
129 changes: 88 additions & 41 deletions neuro/injection_finder/extraction.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,54 @@
import os
import numpy as np
from pathlib import Path

from skimage.filters import gaussian as gaussian_filter
from skimage.filters import threshold_otsu
from skimage import measure

from brainio import brainio
from imlib.IO.surfaces import marching_cubes_to_obj
from imlib.image.orient import reorient_image

from registration import get_registered_image
from utils import reorient_image, marching_cubes_to_obj, get_largest_component
from parsers import extraction_parser
from neuro.injection_finder.registration import get_registered_image
from neuro.injection_finder.parsers import extraction_parser

# For logging
import neuro as package_for_log
import logging
from fancylog import fancylog
import fancylog as package


class Extractor:
def __init__(
self,
img_filepath,
registration_folder,
logging,
overwrite=False,
gaussian_kernel=2,
percentile_threshold=99.95,
threshold_type="otsu",
obj_path=None,
overwrite_registration=False,
):

"""
Extractor processes a downsampled.nii image to extract the location of the injection site.
This is done by registering the image to the allen CCF, blurring, thresholding and finally a
marching cube algorithm to extract the surface of the injection site.

:param img_filepath: str, path to .nii file
:param registration_folder: str, path to the registration folder [from cellfinder or amap]
:param logging: instance of fancylog logger
:param overwrite: bool, if False it will avoid overwriting files
:gaussian_kernel: float, size of kernel used for smoothing
:param percentile_threshold: float, in range [0, 1] percentile to use for thresholding
:param threshold_type: str, either ['otsu', 'percentile'], type of threshold used
:param obj_path: path to .obj file destination.
:param overwrite_registration: if false doesn't overwrite the registration step
Extractor processes a downsampled.nii image to extract the location of
the injection site.
This is done by registering the image to the allen CCF, blurring,
thresholding and finally a marching cube algorithm to extract the
surface of the injection site.

:param img_filepath: str, path to .nii file
:param registration_folder: str, path to the registration folder
[from cellfinder or amap]
:param overwrite: bool, if False it will avoid overwriting files
:gaussian_kernel: float, size of kernel used for smoothing
:param percentile_threshold: float, in range [0, 1] percentile to use
for thresholding
:param threshold_type: str, either ['otsu', 'percentile'],
type of threshold used
:param obj_path: path to .obj file destination.
:param overwrite_registration: if false doesn't overwrite the
registration step
"""

# Get arguments
Expand Down Expand Up @@ -72,14 +76,14 @@ def setup(self):
self.img_filepath.split(".")[0] + "_thresholded.nii"
)

# Get path to obj file and check if it existsts
# Get path to obj file and check if it exists
if self.obj_path is None:
self.obj_path = self.img_filepath.split(".")[0] + ".obj"

if os.path.isfile(self.obj_path) and not self.overwrite:
self.logging.warning(
"A file exists already at {}. \
Analysis will not run as overwrite is set disabled".format(
"A file exists already at {}."
"Analysis will not run as overwrite is set disabled".format(
self.obj_path
)
)
Expand All @@ -88,7 +92,6 @@ def setup(self):
image = get_registered_image(
self.img_filepath,
self.registration_folder,
self.logging,
overwrite=self.overwrite_registration,
)
return image
Expand All @@ -103,12 +106,12 @@ def extract(self, image, voxel_size=10):

# Gaussian filter
kernel_shape = [self.gaussian_kernel, self.gaussian_kernel, 6]
filtered = gaussian_filter(image, kernel_shape)
image = gaussian_filter(image, kernel_shape)
self.logging.info("Filtering completed")

# Thresholding
if self.threshold_type.lower() == "otsu":
thresh = threshold_otsu(filtered)
thresh = threshold_otsu(image)
self.logging.info(
"Thresholding with {} threshold type".format(
self.threshold_type
Expand All @@ -119,21 +122,20 @@ def extract(self, image, voxel_size=10):
self.threshold_type.lower() == "percentile"
or self.threshold_type.lower() == "perc"
):
thresh = np.percentile(filtered.ravel(), self.percentile_threshold)
thresh = np.percentile(image.ravel(), self.percentile_threshold)
self.logging.info(
"Thresholding with {} threshold type. {}th percentile [{}]".format(
"Thresholding with {} threshold type. "
"{}th percentile [{}]".format(
self.threshold_type, self.percentile_threshold, thresh
)
)
else:
raise valueError(
raise ValueError(
"Unrecognised thresholding type: " + self.threshold_type
)

binary = filtered > thresh
oriented_binary = reorient_image(
binary, invert_axes=[2,], orientation="coronal"
)
binary = image > thresh
binary = keep_n_largest_objects(binary)

# Save thresholded image
if not os.path.isfile(self.thresholded_savepath) or self.overwrite:
Expand All @@ -144,10 +146,14 @@ def extract(self, image, voxel_size=10):
)
brainio.to_nii(binary.astype(np.int16), self.thresholded_savepath)

binary = reorient_image(
binary, invert_axes=[2,], orientation="coronal"
)

# apply marching cubes
self.logging.info("Extracting surface from thresholded image")
verts, faces, normals, values = measure.marching_cubes_lewiner(
oriented_binary, 0, step_size=1
binary, 0, step_size=1
)

# Scale to atlas spacing
Expand All @@ -159,32 +165,73 @@ def extract(self, image, voxel_size=10):
faces = faces + 1
marching_cubes_to_obj((verts, faces, normals, values), self.obj_path)

# Keep only the largest connected component
get_largest_component(self.obj_path)

def keep_n_largest_objects(numpy_array, n=1, connectivity=None):
"""
Given an input binary numpy array, return a "clean" array with only the
n largest connected components remaining

Inspired by stackoverflow.com/questions/47540926

TODO: optimise

:param numpy_array: Binary numpy array
:param n: How many objects to keep
:param connectivity: Labelling connectivity (see skimage.measure.label)
:return: "Clean" numpy array with n largest objects
"""

labels = measure.label(numpy_array, connectivity=connectivity)
assert labels.max() != 0 # assume at least 1 CC
n_largest_objects = get_largest_non_zero_object(labels)
if n > 1:
i = 1
while i < n:
labels[n_largest_objects] = 0
n_largest_objects += get_largest_non_zero_object(labels)
i += 1
return n_largest_objects


def get_largest_non_zero_object(label_image):
"""
In a labelled (each object assigned an int) numpy array. Return the
largest object with a value >= 1.
:param label_image: Output of skimage.measure.label
:return: Boolean numpy array or largest object
"""
return label_image == np.argmax(np.bincount(label_image.flat)[1:]) + 1


def main():
args = extraction_parser().parse_args()

# Get output directory
if args.output_directory is None:
outdir = os.get_cwd()
outdir = os.getcwd()
elif not os.path.isdir(args.output_directory):
raise ValueError("Output directory invalid")
else:
outdir = args.output_directory

if args.obj_path is None:
args.obj_path = Path(args.img_filepath).with_suffix(".obj")
else:
args.obj_path = Path(args.obj_path)

# Start log
log_name = "injection_finder_{}".format(
os.path.split(args.registration_folder)[-1]
fancylog.start_logging(
outdir,
package_for_log,
filename="injection_finder",
verbose=args.debug,
log_to_file=args.save_log,
)
fancylog.start_logging(outdir, package, filename=log_name, verbose=True)

# Start extraction
Extractor(
args.img_filepath,
args.registration_folder,
logging,
overwrite=args.overwrite,
gaussian_kernel=args.gaussian_kernel,
percentile_threshold=args.percentile_threshold,
Expand Down
23 changes: 19 additions & 4 deletions neuro/injection_finder/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def extraction_parser():
"--obj-path",
dest="obj_path",
type=str,
default=False,
help="Path to output .obj file. Optional.",
default=None,
help="Path to output .obj file. Will default to the image directory.",
)

parser.add_argument(
Expand All @@ -57,7 +57,8 @@ def extraction_parser():
dest="percentile_threshold",
type=float,
default=99.995,
help="Float in range [0, 100]. The percentile number of pixel intensity values for tresholding",
help="Float in range [0, 100]. The percentile number of pixel "
"intensity values for tresholding",
)

parser.add_argument(
Expand All @@ -66,7 +67,8 @@ def extraction_parser():
dest="threshold_type",
type=str,
default="otsu",
help="'otsu' or 'percentile'. Determines how the threshold value is computed",
help="'otsu' or 'percentile'. Determines how the threshold "
"value is computed",
)

parser.add_argument(
Expand All @@ -77,4 +79,17 @@ def extraction_parser():
default="False",
help="If false skip running again the registration",
)
parser.add_argument(
"--debug",
dest="debug",
action="store_true",
help="Debug mode. Will increase verbosity of logging and save all "
"intermediate files for diagnosis of software issues.",
)
parser.add_argument(
"--save-log",
dest="save_log",
action="store_true",
help="Save logging to file (in addition to logging to terminal).",
)
return parser
51 changes: 8 additions & 43 deletions neuro/injection_finder/registration.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,25 @@
import os
import logging

from brainio import brainio
from amap.tools import source_files
from amap.config.config import get_binary
from amap.tools.exceptions import RegistrationError
from imlib.register.niftyreg.transform import run_transform


from cellfinder.tools.system import (
safe_execute_command,
SafeExecuteCommandError,
)

PROGRAM_NAME = "reg_resample"
DEFAULT_CONTROL_POINT_FILE = "inverse_control_point_file.nii"
default_atlas_name = "brain_filtered.nii"


def prepare_segmentation_cmd(
program_path,
floating_image_path,
output_file_name,
destination_image_filename,
control_point_file,
):
cmd = "{} -cpp {} -flo {} -ref {} -res {}".format(
program_path,
control_point_file,
floating_image_path,
destination_image_filename,
output_file_name,
)
return cmd


def get_registered_image(nii_path, registration_dir, logging, overwrite=False):
# get binaries
nifty_reg_binaries_folder = source_files.get_niftyreg_binaries()
program_path = get_binary(nifty_reg_binaries_folder, PROGRAM_NAME)

def get_registered_image(nii_path, registration_dir, overwrite=False):
# get file paths
basedir = os.path.split(nii_path)[0]
output_filename = os.path.join(
basedir,
"{}_transformed.nii".format(os.path.split(nii_path)[1].split(".")[0]),
)
if os.path.isfile(output_filename) and not overwrite:
run = False
logging.info("Skipping registration as output file already exists")
else:
run = True

if run:
destination_image = os.path.join(registration_dir, default_atlas_name)
control_point_file = os.path.join(
registration_dir, DEFAULT_CONTROL_POINT_FILE
Expand All @@ -58,19 +28,14 @@ def get_registered_image(nii_path, registration_dir, logging, overwrite=False):
log_file_path = os.path.join(basedir, "registration_log.txt")
error_file_path = os.path.join(basedir, "registration_err.txt")

reg_cmd = prepare_segmentation_cmd(
program_path,
logging.info("Running registration")
run_transform(
nii_path,
output_filename,
destination_image,
control_point_file,
log_file_path,
error_file_path,
)
logging.info("Running registration")
try:
safe_execute_command(reg_cmd, log_file_path, error_file_path)
except SafeExecuteCommandError as err:
raise RegistrationError("Registration failed; {}".format(err))
else:
logging.info("Skipping registration as output file already exists")

return brainio.load_any(output_filename)
Loading