Skip to content

Commit eca3f42

Browse files
committed
ENH: Add a shell data property to DWI data class
Add a shell data property to `DWI` data class that returns a list of pairs consisting of the estimated b-value and the associated DWI data.
1 parent 7237ecf commit eca3f42

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

src/nifreeze/data/dmri.py

+39
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,45 @@ def to_nifti(self, filename: Path | str, insert_b0: bool = False) -> None:
241241
np.savetxt(bvecs_file, self.gradients[:3, ...].T, fmt="%.6f")
242242
np.savetxt(bvals_file, self.gradients[:3, ...], fmt="%.6f")
243243

244+
def shells(
245+
self,
246+
num_bins: int = DEFAULT_NUM_BINS,
247+
multishell_nonempty_bin_count_thr: int = DEFAULT_MULTISHELL_BIN_COUNT_THR,
248+
bval_cap: int = DEFAULT_HIGHB_THRESHOLD,
249+
) -> list:
250+
"""Get the shell data according to the b-value groups.
251+
252+
Bin the shell data according to the b-value groups found by `~find_shelling_scheme`.
253+
254+
Parameters
255+
----------
256+
num_bins : :obj:`int`, optional
257+
Number of bins.
258+
multishell_nonempty_bin_count_thr : :obj:`int`, optional
259+
Bin count to consider a multi-shell scheme.
260+
bval_cap : :obj:`int`, optional
261+
Maximum b-value to be considered in a multi-shell scheme.
262+
263+
Returns
264+
-------
265+
:obj:`list`
266+
Tuples of binned b-values and corresponding shell data.
267+
"""
268+
269+
_, bval_groups, bval_estimated = find_shelling_scheme(
270+
self.gradients[-1, ...],
271+
num_bins=num_bins,
272+
multishell_nonempty_bin_count_thr=multishell_nonempty_bin_count_thr,
273+
bval_cap=bval_cap,
274+
)
275+
indices = [
276+
np.hstack(np.where(np.isin(self.gradients[-1, ...], bvals))) for bvals in bval_groups
277+
]
278+
return [
279+
(bval_estimated[idx], self.dataobj[indices, ...])
280+
for idx, indices in enumerate(indices)
281+
]
282+
244283

245284
def load(
246285
filename: Path | str,

test/test_data_dmri.py

+22
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,28 @@ def test_equality_operator(tmp_path):
182182
assert round_trip_dwi_obj == dwi_obj
183183

184184

185+
def test_shells(datadir):
186+
dwi_h5 = load(datadir / "dwi.h5")
187+
num_bins = 3
188+
189+
_, expected_bval_groups, expected_bval_est = find_shelling_scheme(
190+
dwi_h5.gradients[-1, ...], num_bins=num_bins
191+
)
192+
193+
indices = [
194+
np.hstack(np.where(np.isin(dwi_h5.gradients[-1, ...], bvals)))
195+
for bvals in expected_bval_groups
196+
]
197+
expected_shell_data = [dwi_h5.dataobj[indices, ...] for indices in indices]
198+
199+
shell_data = dwi_h5.shells(num_bins=num_bins)
200+
obtained_bval_est, obtained_shell_data = zip(*shell_data, strict=True)
201+
202+
assert len(shell_data) == num_bins
203+
assert list(obtained_bval_est) == expected_bval_est
204+
assert np.allclose(obtained_shell_data, expected_shell_data)
205+
206+
185207
@pytest.mark.parametrize(
186208
("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"),
187209
[

0 commit comments

Comments
 (0)