Skip to content

Commit

Permalink
Merge pull request #973 from 36000/curvature_updates
Browse files Browse the repository at this point in the history
[ENH] Calculate new curvature metric manually
  • Loading branch information
36000 authored Jun 20, 2023
2 parents f867811 + 4f55c51 commit 71cd73a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 27 deletions.
100 changes: 74 additions & 26 deletions AFQ/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
from AFQ.api.bundle_dict import BundleDict
from AFQ.definitions.mapping import ConformedFnirtMapping

from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.discrete_curves import DiscreteCurves

__all__ = ["Segmentation", "clean_bundle", "clean_by_endpoints"]


Expand Down Expand Up @@ -563,14 +560,16 @@ def segment_afq(self, clean_params={}, tg=None):
self.img_affine))

if "curvature" in bundle_def:
ref_curve = load_tractogram(
ref_sl = load_tractogram(
bundle_def["curvature"]["path"], "same",
bbox_valid_check=False)
moved_ref_curve = self.move_streamlines(
ref_curve, "subject")
moved_ref_curve.to_vox()
moved_ref_curve = np.asarray(
moved_ref_curve.streamlines[0])
moved_ref_sl = self.move_streamlines(
ref_sl, "subject")
moved_ref_sl.to_vox()
moved_ref_sl = moved_ref_sl.streamlines[0]
moved_ref_curve = sl_curve(
moved_ref_sl,
len(moved_ref_sl))

b_sls = _SlsBeingRecognized(
tg.streamlines, self.logger,
Expand Down Expand Up @@ -708,25 +707,24 @@ def segment_afq(self, clean_params={}, tg=None):
min_dist_coords[sl_idx] = np.min(sl_dist)

if len(sl_dist) > 1:
roi_dist1 = np.argmin(sl_dist[0], 0)[0]
roi_dist2 = np.argmin(sl_dist[
len(sl_dist) - 1], 0)[0]
roi_dists[sl_idx, :len(sl_dist)] = [
np.argmin(dist, 0)[0]
for dist in sl_dist]
first_roi_idx = roi_dists[sl_idx, 0]
last_roi_idx = roi_dists[
sl_idx, len(sl_dist) - 1]
# Only accept SLs that, when cut, are meaningful
if (len(sl_dist) < 2) or abs(
roi_dists[sl_idx, 0] - roi_dists[
sl_idx, len(sl_dist) - 1]) > 1:
first_roi_idx - last_roi_idx) > 1:
# Flip sl if it is close to second ROI
# before its close to the first ROI
if flip_using_include:
this_flips = roi_dist1 > roi_dist2
to_flip[sl_idx] = this_flips
if this_flips:
to_flip[sl_idx] =\
first_roi_idx > last_roi_idx
if to_flip[sl_idx]:
roi_dists[sl_idx, :len(sl_dist)] =\
np.flip(
roi_dists[sl_idx, :len(sl_dist)])
np.flip(roi_dists[
sl_idx, :len(sl_dist)])
accept_idx[sl_idx] = 1
else:
accept_idx[sl_idx] = 1
Expand All @@ -749,14 +747,15 @@ def segment_afq(self, clean_params={}, tg=None):
# a curve in orientation and shape but not scale
if b_sls and "curvature" in bundle_def:
accept_idx = b_sls.initiate_selection("curvature")
ref_curve_threshold = bundle_def["curvature"].get("thresh", 5)
curves_r3 = DiscreteCurves(ambient_manifold=Euclidean(dim=3))
ref_curve_threshold = np.radians(bundle_def["curvature"].get(
"thresh", 10))
cut = bundle_def["curvature"].get("cut", False)
for idx, sl in enumerate(b_sls.get_selected_sls(cut=cut)):
sl = dps.set_number_of_points(
sl, moved_ref_curve.shape[0])
dist = curves_r3.square_root_velocity_metric.dist(
moved_ref_curve, sl)
if b_sls.oriented_yet\
and b_sls.sls_flipped[idx]:
sl = sl[::-1]
this_sl_curve = sl_curve(sl, len(moved_ref_sl))
dist = sl_curve_dist(this_sl_curve, moved_ref_curve)
if dist <= ref_curve_threshold:
accept_idx[idx] = 1
b_sls.select(accept_idx, "curvature", cut=cut)
Expand Down Expand Up @@ -1067,6 +1066,55 @@ def segment_reco(self, tg=None):
return fiber_groups


def sl_curve(sl, n_points):
"""
Calculate the direction of the displacement between
each point along a streamline
Parameters
----------
sl : 2d array-like
Streamline to calcualte displacements for.
n_points : int
Number of points to resample the streamline to
Returns
-------
2d array of shape (len(sl)-1, 3) with displacements
between each point in sl normalized to 1.
"""
# Resample to a standardized number of points
resampled_sl = dps.set_number_of_points(
sl,
n_points)

# displacement at each point
resampled_sl_diff = np.diff(resampled_sl, axis=0)

# normalize this displacement
resampled_sl_diff = resampled_sl_diff / np.linalg.norm(
resampled_sl_diff, axis=1)[:, None]

return resampled_sl_diff


def sl_curve_dist(curve1, curve2):
"""
Calculate the mean angle using the directions of displacement
between two streamlines
Parameters
----------
curve1, curve2 : 2d array-like
Two curves calculated from sl_curve.
Returns
-------
The mean angle between each curve across all steps, in radians
"""
return np.mean(np.arccos(np.sum(curve1 * curve2, axis=1)))


def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3,
length_threshold=4, min_sl=20, stat='mean',
return_idx=False):
Expand Down Expand Up @@ -1298,7 +1346,7 @@ def clean_by_orientation(streamlines, primary_axis, tol=None):
# endpoint diff is between first and last
endpoint_diff[ii, :] = np.abs(sl[0, :] - sl[-1, :])
# axis diff is difference between the nodes, along
axis_diff[ii, :] = np.sum(np.abs(sl[0:-1, :] - sl[1:, :]), axis=0)
axis_diff[ii, :] = np.sum(np.abs(np.diff(sl, axis=0)), axis=0)

orientation_along = np.argmax(axis_diff, axis=1)
along_accepted_idx = orientation_along == primary_axis
Expand Down
13 changes: 13 additions & 0 deletions AFQ/tests/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ def test_segment_keep_space():
npt.assert_equal(tg.space, orig_space)


def test_segment_sl_curve():
sl_disp_0 = seg.sl_curve(streamlines[4], 4)
npt.assert_array_almost_equal(
sl_disp_0,
[[-0.236384, -0.763855, 0.60054 ],
[ 0.232594, -0.867859, -0.439 ],
[ 0.175343, 0.001082, -0.984507]])

sl_disp_1 = seg.sl_curve(streamlines[2], 4)
mean_angle_diff = seg.sl_curve_dist(sl_disp_0, sl_disp_1)
npt.assert_almost_equal(mean_angle_diff, 1.701458)


def test_segment_clip_edges():
sls = tg.streamlines
idx = np.arange(len(tg.streamlines))
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ install_requires =
pimms
joblib>=0.16.0
dask>=1.1
geomstats>=2.0.0,<=2.4.2
# AWS integration packages
boto3>=1.14.0
s3fs~=0.4.2
Expand Down

0 comments on commit 71cd73a

Please sign in to comment.