Skip to content

Commit

Permalink
Merge pull request #1067 from 36000/mahal_cov_woes
Browse files Browse the repository at this point in the history
  • Loading branch information
arokem authored Dec 7, 2023
2 parents 3a58f67 + 87abebf commit 0c1c7ba
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 21 deletions.
8 changes: 4 additions & 4 deletions AFQ/_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,13 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False,
# This should come back as a 3D covariance matrix with the spatial
# variance covariance of this node across the different streamlines,
# reorganized as an upper diagonal matrix for expected Mahalanobis
v_inv = np.triu(np.cov(sls[:, i, :].T, ddof=0))
cov = np.cov(sls[:, i, :].T, ddof=0)

# calculate Mahalanobis for node in every fiber
if np.linalg.matrix_rank(v_inv) == n_dim:
v_inv = np.linalg.inv(v_inv)
if np.any(cov > 0):
ci = np.linalg.inv(cov)

dist = (diff[:, i, :] @ v_inv) * diff[:, i, :]
dist = (diff[:, i, :] @ ci) * diff[:, i, :]
weights[:, i] = np.sqrt(np.sum(dist, axis=1))

# In the special case where all the streamlines have the exact same
Expand Down
19 changes: 3 additions & 16 deletions AFQ/tests/test_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,10 @@ def test_mahal_fix():
]
sls_array = np.asarray(sls).astype(float)
results = np.asarray([
[0. , 0. , 0. , 0.727923, 1.091414],
[0. , 0. , 0. , 0.687989, 0.358011],
[0. , 0. , 0. , 1.414214, 1.347267]])
[0. , 0. , 0. , 1.185854, 2.14735],
[0. , 0. , 0. , 1.185854, 1.556795],
[0. , 0. , 0. , 1.274755, 2.23296]])
npt.assert_array_almost_equal(
gaussian_weights_fast(
sls_array, n_points=5,
return_mahalnobis=True, stat=np.mean), results)

sls = Streamlines(sls)
dipy_res = gaussian_weights(
sls, n_points=5, return_mahalnobis=True, stat=np.mean)
sls = np.asarray(set_number_of_points(sls, 5))
our_res = gaussian_weights_fast(
sls, n_points=5, return_mahalnobis=True, stat=np.mean)

# note the current dipy version
# handles 0 variance differently than this implementation
npt.assert_array_almost_equal(
dipy_res[our_res!=0],
our_res[our_res!=0])
2 changes: 1 addition & 1 deletion examples/tutorial_examples/plot_001_afq_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
bundle_counts = pd.read_csv(myafq.export("sl_counts")["01"], index_col=[0])
for ind in bundle_counts.index:
if ind == "Total Recognized":
threshold = 1500
threshold = 1000
else:
threshold = 10
if bundle_counts["n_streamlines"][ind] < threshold:
Expand Down

0 comments on commit 0c1c7ba

Please sign in to comment.