Skip to content
This repository has been archived by the owner on Oct 6, 2021. It is now read-only.

Commit

Permalink
Merge pull request #12 from adamltyson/injectionsite
Browse files Browse the repository at this point in the history
Refactor injection site finder
  • Loading branch information
adamltyson authored Feb 18, 2020
2 parents 567c825 + dff47ec commit 58d7838
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 157 deletions.
Binary file removed .DS_Store
Binary file not shown.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,7 @@ venv.bak/
.idea/
.vs/
*.~lock.*


# macOS
*.DS_Store
2 changes: 1 addition & 1 deletion docs/searchindex.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file removed neuro/injection_finder/.DS_Store
Binary file not shown.
129 changes: 88 additions & 41 deletions neuro/injection_finder/extraction.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,54 @@
import os
import numpy as np
from pathlib import Path

from skimage.filters import gaussian as gaussian_filter
from skimage.filters import threshold_otsu
from skimage import measure

from brainio import brainio
from imlib.IO.surfaces import marching_cubes_to_obj
from imlib.image.orient import reorient_image

from registration import get_registered_image
from utils import reorient_image, marching_cubes_to_obj, get_largest_component
from parsers import extraction_parser
from neuro.injection_finder.registration import get_registered_image
from neuro.injection_finder.parsers import extraction_parser

# For logging
import neuro as package_for_log
import logging
from fancylog import fancylog
import fancylog as package


class Extractor:
def __init__(
self,
img_filepath,
registration_folder,
logging,
overwrite=False,
gaussian_kernel=2,
percentile_threshold=99.95,
threshold_type="otsu",
obj_path=None,
overwrite_registration=False,
):

"""
Extractor processes a downsampled.nii image to extract the location of the injection site.
This is done by registering the image to the allen CCF, blurring, thresholding and finally a
marching cube algorithm to extract the surface of the injection site.
:param img_filepath: str, path to .nii file
:param registration_folder: str, path to the registration folder [from cellfinder or amap]
:param logging: instance of fancylog logger
:param overwrite: bool, if False it will avoid overwriting files
:gaussian_kernel: float, size of kernel used for smoothing
:param percentile_threshold: float, in range [0, 1] percentile to use for thresholding
:param threshold_type: str, either ['otsu', 'percentile'], type of threshold used
:param obj_path: path to .obj file destination.
:param overwrite_registration: if false doesn't overwrite the registration step
Extractor processes a downsampled.nii image to extract the location of
the injection site.
This is done by registering the image to the allen CCF, blurring,
thresholding and finally a marching cube algorithm to extract the
surface of the injection site.
:param img_filepath: str, path to .nii file
:param registration_folder: str, path to the registration folder
[from cellfinder or amap]
:param overwrite: bool, if False it will avoid overwriting files
:gaussian_kernel: float, size of kernel used for smoothing
:param percentile_threshold: float, in range [0, 1] percentile to use
for thresholding
:param threshold_type: str, either ['otsu', 'percentile'],
type of threshold used
:param obj_path: path to .obj file destination.
:param overwrite_registration: if false doesn't overwrite the
registration step
"""

# Get arguments
Expand Down Expand Up @@ -72,14 +76,14 @@ def setup(self):
self.img_filepath.split(".")[0] + "_thresholded.nii"
)

# Get path to obj file and check if it existsts
# Get path to obj file and check if it exists
if self.obj_path is None:
self.obj_path = self.img_filepath.split(".")[0] + ".obj"

if os.path.isfile(self.obj_path) and not self.overwrite:
self.logging.warning(
"A file exists already at {}. \
Analysis will not run as overwrite is set disabled".format(
"A file exists already at {}."
"Analysis will not run as overwrite is set disabled".format(
self.obj_path
)
)
Expand All @@ -88,7 +92,6 @@ def setup(self):
image = get_registered_image(
self.img_filepath,
self.registration_folder,
self.logging,
overwrite=self.overwrite_registration,
)
return image
Expand All @@ -103,12 +106,12 @@ def extract(self, image, voxel_size=10):

# Gaussian filter
kernel_shape = [self.gaussian_kernel, self.gaussian_kernel, 6]
filtered = gaussian_filter(image, kernel_shape)
image = gaussian_filter(image, kernel_shape)
self.logging.info("Filtering completed")

# Thresholding
if self.threshold_type.lower() == "otsu":
thresh = threshold_otsu(filtered)
thresh = threshold_otsu(image)
self.logging.info(
"Thresholding with {} threshold type".format(
self.threshold_type
Expand All @@ -119,21 +122,20 @@ def extract(self, image, voxel_size=10):
self.threshold_type.lower() == "percentile"
or self.threshold_type.lower() == "perc"
):
thresh = np.percentile(filtered.ravel(), self.percentile_threshold)
thresh = np.percentile(image.ravel(), self.percentile_threshold)
self.logging.info(
"Thresholding with {} threshold type. {}th percentile [{}]".format(
"Thresholding with {} threshold type. "
"{}th percentile [{}]".format(
self.threshold_type, self.percentile_threshold, thresh
)
)
else:
raise valueError(
raise ValueError(
"Unrecognised thresholding type: " + self.threshold_type
)

binary = filtered > thresh
oriented_binary = reorient_image(
binary, invert_axes=[2,], orientation="coronal"
)
binary = image > thresh
binary = keep_n_largest_objects(binary)

# Save thresholded image
if not os.path.isfile(self.thresholded_savepath) or self.overwrite:
Expand All @@ -144,10 +146,14 @@ def extract(self, image, voxel_size=10):
)
brainio.to_nii(binary.astype(np.int16), self.thresholded_savepath)

binary = reorient_image(
binary, invert_axes=[2,], orientation="coronal"
)

# apply marching cubes
self.logging.info("Extracting surface from thresholded image")
verts, faces, normals, values = measure.marching_cubes_lewiner(
oriented_binary, 0, step_size=1
binary, 0, step_size=1
)

# Scale to atlas spacing
Expand All @@ -159,32 +165,73 @@ def extract(self, image, voxel_size=10):
faces = faces + 1
marching_cubes_to_obj((verts, faces, normals, values), self.obj_path)

# Keep only the largest connected component
get_largest_component(self.obj_path)

def keep_n_largest_objects(numpy_array, n=1, connectivity=None):
"""
Given an input binary numpy array, return a "clean" array with only the
n largest connected components remaining
Inspired by stackoverflow.com/questions/47540926
TODO: optimise
:param numpy_array: Binary numpy array
:param n: How many objects to keep
:param connectivity: Labelling connectivity (see skimage.measure.label)
:return: "Clean" numpy array with n largest objects
"""

labels = measure.label(numpy_array, connectivity=connectivity)
assert labels.max() != 0 # assume at least 1 CC
n_largest_objects = get_largest_non_zero_object(labels)
if n > 1:
i = 1
while i < n:
labels[n_largest_objects] = 0
n_largest_objects += get_largest_non_zero_object(labels)
i += 1
return n_largest_objects


def get_largest_non_zero_object(label_image):
"""
In a labelled (each object assigned an int) numpy array. Return the
largest object with a value >= 1.
:param label_image: Output of skimage.measure.label
:return: Boolean numpy array or largest object
"""
return label_image == np.argmax(np.bincount(label_image.flat)[1:]) + 1


def main():
args = extraction_parser().parse_args()

# Get output directory
if args.output_directory is None:
outdir = os.get_cwd()
outdir = os.getcwd()
elif not os.path.isdir(args.output_directory):
raise ValueError("Output directory invalid")
else:
outdir = args.output_directory

if args.obj_path is None:
args.obj_path = Path(args.img_filepath).with_suffix(".obj")
else:
args.obj_path = Path(args.obj_path)

# Start log
log_name = "injection_finder_{}".format(
os.path.split(args.registration_folder)[-1]
fancylog.start_logging(
outdir,
package_for_log,
filename="injection_finder",
verbose=args.debug,
log_to_file=args.save_log,
)
fancylog.start_logging(outdir, package, filename=log_name, verbose=True)

# Start extraction
Extractor(
args.img_filepath,
args.registration_folder,
logging,
overwrite=args.overwrite,
gaussian_kernel=args.gaussian_kernel,
percentile_threshold=args.percentile_threshold,
Expand Down
23 changes: 19 additions & 4 deletions neuro/injection_finder/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def extraction_parser():
"--obj-path",
dest="obj_path",
type=str,
default=False,
help="Path to output .obj file. Optional.",
default=None,
help="Path to output .obj file. Will default to the image directory.",
)

parser.add_argument(
Expand All @@ -57,7 +57,8 @@ def extraction_parser():
dest="percentile_threshold",
type=float,
default=99.995,
help="Float in range [0, 100]. The percentile number of pixel intensity values for tresholding",
help="Float in range [0, 100]. The percentile number of pixel "
"intensity values for tresholding",
)

parser.add_argument(
Expand All @@ -66,7 +67,8 @@ def extraction_parser():
dest="threshold_type",
type=str,
default="otsu",
help="'otsu' or 'percentile'. Determines how the threshold value is computed",
help="'otsu' or 'percentile'. Determines how the threshold "
"value is computed",
)

parser.add_argument(
Expand All @@ -77,4 +79,17 @@ def extraction_parser():
default="False",
help="If false skip running again the registration",
)
parser.add_argument(
"--debug",
dest="debug",
action="store_true",
help="Debug mode. Will increase verbosity of logging and save all "
"intermediate files for diagnosis of software issues.",
)
parser.add_argument(
"--save-log",
dest="save_log",
action="store_true",
help="Save logging to file (in addition to logging to terminal).",
)
return parser
51 changes: 8 additions & 43 deletions neuro/injection_finder/registration.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,25 @@
import os
import logging

from brainio import brainio
from amap.tools import source_files
from amap.config.config import get_binary
from amap.tools.exceptions import RegistrationError
from imlib.register.niftyreg.transform import run_transform


from cellfinder.tools.system import (
safe_execute_command,
SafeExecuteCommandError,
)

PROGRAM_NAME = "reg_resample"
DEFAULT_CONTROL_POINT_FILE = "inverse_control_point_file.nii"
default_atlas_name = "brain_filtered.nii"


def prepare_segmentation_cmd(
program_path,
floating_image_path,
output_file_name,
destination_image_filename,
control_point_file,
):
cmd = "{} -cpp {} -flo {} -ref {} -res {}".format(
program_path,
control_point_file,
floating_image_path,
destination_image_filename,
output_file_name,
)
return cmd


def get_registered_image(nii_path, registration_dir, logging, overwrite=False):
# get binaries
nifty_reg_binaries_folder = source_files.get_niftyreg_binaries()
program_path = get_binary(nifty_reg_binaries_folder, PROGRAM_NAME)

def get_registered_image(nii_path, registration_dir, overwrite=False):
# get file paths
basedir = os.path.split(nii_path)[0]
output_filename = os.path.join(
basedir,
"{}_transformed.nii".format(os.path.split(nii_path)[1].split(".")[0]),
)
if os.path.isfile(output_filename) and not overwrite:
run = False
logging.info("Skipping registration as output file already exists")
else:
run = True

if run:
destination_image = os.path.join(registration_dir, default_atlas_name)
control_point_file = os.path.join(
registration_dir, DEFAULT_CONTROL_POINT_FILE
Expand All @@ -58,19 +28,14 @@ def get_registered_image(nii_path, registration_dir, logging, overwrite=False):
log_file_path = os.path.join(basedir, "registration_log.txt")
error_file_path = os.path.join(basedir, "registration_err.txt")

reg_cmd = prepare_segmentation_cmd(
program_path,
logging.info("Running registration")
run_transform(
nii_path,
output_filename,
destination_image,
control_point_file,
log_file_path,
error_file_path,
)
logging.info("Running registration")
try:
safe_execute_command(reg_cmd, log_file_path, error_file_path)
except SafeExecuteCommandError as err:
raise RegistrationError("Registration failed; {}".format(err))
else:
logging.info("Skipping registration as output file already exists")

return brainio.load_any(output_filename)
Loading

0 comments on commit 58d7838

Please sign in to comment.