Skip to content

Commit

Permalink
Merge pull request #863 from 36000/output_more_scalars
Browse files Browse the repository at this point in the history
ENH: add lower triangular scalars from DTI (could be useful for ML)?
  • Loading branch information
arokem authored Aug 3, 2022
2 parents 8c73324 + e54a391 commit 0ad4dcb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
24 changes: 21 additions & 3 deletions AFQ/tasks/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,24 @@ def dti_fa(dti_tf):
return dti_tf.fa


@pimms.calc("dti_lt0", "dti_lt1", "dti_lt2", "dti_lt3", "dti_lt4", "dti_lt5")
def dti_lt(dti_tf, dwi_affine):
"""
Image of first element in the DTI tensor,
Image of second element in the DTI tensor,
Image of third element in the DTI tensor,
Image of fourth element in the DTI tensor,
Image of fifth element in the DTI tensor,
Image of sixth element in the DTI tensor
"""
dti_lt_dict = {}
for ii in range(6):
dti_lt_dict[f"dti_lt{ii}"] = nib.Nifti1Image(
dti_tf.lower_triangular()[..., ii],
dwi_affine)
return dti_lt_dict


@pimms.calc("dti_cfa")
@as_file(suffix='_model-DTI_desc-DEC_FA.nii.gz')
@as_fit_deriv('DTI')
Expand Down Expand Up @@ -605,9 +623,9 @@ def get_data_plan(kwargs):
data_tasks = with_name([
get_data_gtab, b0, b0_mask, brain_mask,
dti_fit, dki_fit, anisotropic_power_map,
dti_fa, dti_cfa, dti_pdd, dti_md, dki_fa, dki_md, dki_awf, dki_mk,
dti_ga, dti_rd, dti_ad, dki_ga, dki_rd, dki_ad, dki_rk, dki_ak,
dti_params, dki_params, csd_params, get_bundle_dict])
dti_fa, dti_lt, dti_cfa, dti_pdd, dti_md, dki_fa, dki_md, dki_awf,
dki_mk, dti_ga, dti_rd, dti_ad, dki_ga, dki_rd, dki_ad, dki_rk,
dki_ak, dti_params, dki_params, csd_params, get_bundle_dict])

if "scalars" not in kwargs:
kwargs["scalars"] = ["dti_fa", "dti_md"]
Expand Down
9 changes: 7 additions & 2 deletions AFQ/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,16 @@ def tract_profiles(clean_bundles, data_imap,
this_sl = seg_sft.get_bundle(bundle_name).streamlines
if len(this_sl) == 0:
continue
if profile_weights == "gauss":
# calculate only once per bundle
bundle_profile_weights = gaussian_weights(this_sl)
for ii, (scalar, scalar_file) in enumerate(scalar_dict.items()):
scalar_data = nib.load(scalar_file).get_fdata()
if isinstance(scalar_file, str):
scalar_file = nib.load(scalar_file)
scalar_data = scalar_file.get_fdata()
if isinstance(profile_weights, str):
if profile_weights == "gauss":
this_prof_weights = gaussian_weights(this_sl)
this_prof_weights = bundle_profile_weights
elif profile_weights == "median":
# weights bundle to only return the mean
def _median_weight(bundle):
Expand Down
5 changes: 4 additions & 1 deletion AFQ/tasks/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

def _viz_prepare_vol(vol, xform, mapping, scalar_dict):
if vol in scalar_dict.keys():
vol = nib.load(scalar_dict[vol]).get_fdata()
vol = scalar_dict[vol]
if isinstance(vol, str):
vol = nib.load(vol)
vol = vol.get_fdata()
if isinstance(vol, str):
vol = nib.load(vol).get_fdata()
if xform:
Expand Down
6 changes: 4 additions & 2 deletions AFQ/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ def test_AFQ_data_waypoint():
"dti_FA",
"dti_MD",
"dti_GA",
"dti_lt2",
ImageFile(path=t1_path_other),
TemplateImage(t1_path)],
robust_tensor_fitting=True,
Expand Down Expand Up @@ -772,7 +773,7 @@ def test_AFQ_data_waypoint():
tract_profiles = pd.read_csv(tract_profile_fname)

assert tract_profiles.select_dtypes(include=[np.number]).sum().sum() != 0
assert tract_profiles.shape == (500, 8)
assert tract_profiles.shape == (500, 9)

myafq.export("indiv_bundles_figures")
assert op.exists(op.join(
Expand Down Expand Up @@ -821,6 +822,7 @@ def test_AFQ_data_waypoint():
"dti_fa",
"dti_md",
"dti_ga",
"dti_lt2",
f"ImageFile('{t1_path_other}')",
f"TemplateImage('{t1_path}')"]),
VIZ=dict(
Expand All @@ -843,7 +845,7 @@ def test_AFQ_data_waypoint():
# The tract profiles should already exist from the CLI Run:
from_file = pd.read_csv(tract_profile_fname)

assert from_file.shape == (500, 8)
assert from_file.shape == (500, 9)
assert_series_equal(tract_profiles['dti_fa'], from_file['dti_fa'])

# Make sure the CLI did indeed generate these:
Expand Down

0 comments on commit 0ad4dcb

Please sign in to comment.