Skip to content

Commit 4b464fa

Browse files
authored
Merge pull request #234 from oesteban/enh/4d-resampling
ENH: Improve support of 4D in ``sdcflows.interfaces.bspline.ApplyCoeffsField``
2 parents bd61a37 + 8ca1e79 commit 4b464fa

File tree

6 files changed

+217
-78
lines changed

6 files changed

+217
-78
lines changed

sdcflows/interfaces/bspline.py

Lines changed: 132 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#
2323
"""Filtering of :math:`B_0` field mappings with B-Splines."""
2424
from pathlib import Path
25-
from functools import partial
2625
import numpy as np
2726
import nibabel as nb
2827
from nibabel.affines import apply_affine
@@ -34,12 +33,13 @@
3433
TraitedSpec,
3534
File,
3635
traits,
36+
isdefined,
3737
SimpleInterface,
3838
InputMultiObject,
3939
OutputMultiObject,
4040
)
4141

42-
from sdcflows.transform import grid_bspline_weights as gbsw, B0FieldTransform
42+
from sdcflows.transform import grid_bspline_weights as gbsw
4343

4444

4545
LOW_MEM_BLOCK_SIZE = 1000
@@ -210,14 +210,17 @@ def _run_interface(self, runtime):
210210

211211

212212
class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
213-
in_target = InputMultiObject(
213+
in_data = InputMultiObject(
214214
File(exist=True, mandatory=True, desc="input EPI data to be corrected")
215215
)
216216
in_coeff = InputMultiObject(
217217
File(exists=True),
218218
mandatory=True,
219219
desc="input coefficients, after alignment to the EPI data",
220220
)
221+
in_xfms = InputMultiObject(
222+
File(exists=True), desc="list of head-motion correction matrices"
223+
)
221224
ro_time = InputMultiObject(
222225
traits.Float(mandatory=True, desc="EPI readout time (s).")
223226
)
@@ -230,14 +233,15 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
230233
"k",
231234
"k-",
232235
mandatory=True,
233-
desc="the phase-encoding direction corresponding to in_target",
236+
desc="the phase-encoding direction corresponding to in_data",
234237
)
235238
)
239+
num_threads = traits.Int(nohash=True, desc="number of threads")
236240

237241

238242
class _ApplyCoeffsFieldOutputSpec(TraitedSpec):
239243
out_corrected = OutputMultiObject(File(exists=True))
240-
out_field = File(exists=True)
244+
out_field = OutputMultiObject(File(exists=True))
241245
out_warp = OutputMultiObject(File(exists=True))
242246

243247

@@ -248,38 +252,73 @@ class ApplyCoeffsField(SimpleInterface):
248252
output_spec = _ApplyCoeffsFieldOutputSpec
249253

250254
def _run_interface(self, runtime):
251-
# Prepare output names
252-
filename = partial(fname_presuffix, newpath=runtime.cwd)
253-
254-
self._results["out_field"] = filename(self.inputs.in_coeff[0], suffix="_field")
255-
self._results["out_warp"] = []
256-
self._results["out_corrected"] = []
255+
n = len(self.inputs.in_data)
257256

258-
xfm = B0FieldTransform(
259-
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff]
260-
)
261-
xfm.fit(self.inputs.in_target[0])
262-
xfm.shifts.to_filename(self._results["out_field"])
263-
264-
n_inputs = len(self.inputs.in_target)
265257
ro_time = self.inputs.ro_time
266258
if len(ro_time) == 1:
267-
ro_time = [ro_time[0]] * n_inputs
259+
ro_time *= n
268260

269261
pe_dir = self.inputs.pe_dir
270262
if len(pe_dir) == 1:
271-
pe_dir = [pe_dir[0]] * n_inputs
263+
pe_dir *= n
264+
265+
unwarp = None
266+
hmc_mats = [None] * n
267+
if isdefined(self.inputs.in_xfms):
268+
# Split ITK matrices in separate files if they come collated
269+
hmc_mats = (
270+
list(_split_itk_file(self.inputs.in_xfms[0]))
271+
if len(self.inputs.in_xfms) == 1
272+
else self.inputs.in_xfms
273+
)
274+
else:
275+
from sdcflows.transform import B0FieldTransform
276+
277+
# Pre-cached interpolator object
278+
unwarp = B0FieldTransform(
279+
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff],
280+
)
281+
282+
if not isdefined(self.inputs.num_threads) or self.inputs.num_threads < 2:
283+
# Linear execution (1 core)
284+
outputs = [
285+
_b0_resampler(
286+
fname,
287+
self.inputs.in_coeff,
288+
pe_dir[i],
289+
ro_time[i],
290+
hmc_mats[i],
291+
unwarp, # if no HMC matrices, interpolator can be shared
292+
runtime.cwd,
293+
)
294+
for i, fname in enumerate(self.inputs.in_data)
295+
]
296+
else:
297+
# Embarrasingly parallel execution
298+
from concurrent.futures import ProcessPoolExecutor
299+
300+
outputs = [None] * len(self.inputs.in_data)
301+
with ProcessPoolExecutor(max_workers=self.inputs.num_threads) as ex:
302+
outputs = ex.map(
303+
_b0_resampler,
304+
self.inputs.in_data,
305+
[self.inputs.in_coeff] * n,
306+
pe_dir,
307+
ro_time,
308+
hmc_mats,
309+
[None] * n, # force a new interpolator for each process
310+
[runtime.cwd] * n,
311+
)
272312

273-
for fname, pe, ro in zip(self.inputs.in_target, pe_dir, ro_time):
274-
# Generate warpfield
275-
warp_name = filename(fname, suffix="_xfm")
276-
xfm.to_displacements(ro_time=ro, pe_dir=pe).to_filename(warp_name)
277-
self._results["out_warp"].append(warp_name)
313+
(
314+
self._results["out_corrected"],
315+
self._results["out_warp"],
316+
self._results["out_field"],
317+
) = zip(*outputs)
278318

279-
# Generate resampled
280-
out_name = filename(fname, suffix="_unwarped")
281-
xfm.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name)
282-
self._results["out_corrected"].append(out_name)
319+
out_fields = set(self._results["out_field"]) - set([None])
320+
if len(out_fields) == 1:
321+
self._results["out_field"] = out_fields.pop()
283322

284323
return runtime
285324

@@ -303,11 +342,21 @@ class TransformCoefficients(SimpleInterface):
303342
output_spec = _TransformCoefficientsOutputSpec
304343

305344
def _run_interface(self, runtime):
306-
self._results["out_coeff"] = _move_coeff(
307-
self.inputs.in_coeff,
308-
self.inputs.fmap_ref,
309-
self.inputs.transform,
310-
)
345+
from sdcflows.transform import _move_coeff
346+
347+
self._results["out_coeff"] = []
348+
349+
for level in self.inputs.in_coeff:
350+
movednii = _move_coeff(
351+
level,
352+
self.inputs.fmap_ref,
353+
self.inputs.transform,
354+
)
355+
out_file = fname_presuffix(
356+
level, suffix="_space-target", newpath=runtime.cwd
357+
)
358+
movednii.to_filename(out_file)
359+
self._results["out_coeff"].append(out_file)
311360
return runtime
312361

313362

@@ -408,30 +457,6 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
408457
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)
409458

410459

411-
def _move_coeff(in_coeff, fmap_ref, transform):
412-
"""Read in a rigid transform from ANTs, and update the coefficients field affine."""
413-
from pathlib import Path
414-
import nibabel as nb
415-
import nitransforms as nt
416-
417-
if isinstance(in_coeff, str):
418-
in_coeff = [in_coeff]
419-
420-
xfm = nt.linear.Affine(
421-
nt.io.itk.ITKLinearTransform.from_filename(transform).to_ras(),
422-
reference=fmap_ref,
423-
)
424-
425-
out = []
426-
for i, c in enumerate(in_coeff):
427-
out.append(str(Path(f"moved_coeff_{i:03d}.nii.gz").absolute()))
428-
img = nb.load(c)
429-
newaff = xfm.matrix @ img.affine
430-
img.__class__(img.dataobj, newaff, img.header).to_filename(out[-1])
431-
432-
return out
433-
434-
435460
def _fix_topup_fieldcoeff(in_coeff, fmap_ref, refpe_reversed=False, out_file=None):
436461
"""Read in a coefficients file generated by TOPUP and fix x-form headers."""
437462
from pathlib import Path
@@ -463,3 +488,52 @@ def _fix_topup_fieldcoeff(in_coeff, fmap_ref, refpe_reversed=False, out_file=Non
463488

464489
coeffnii.__class__(coeffnii.dataobj, newaff, header).to_filename(out_file)
465490
return out_file
491+
492+
493+
def _split_itk_file(in_file):
494+
from pathlib import Path
495+
496+
lines = Path(in_file).read_text().splitlines()
497+
header = lines.pop(0)
498+
499+
def _chunks(inlist, chunksize):
500+
for i in range(0, len(inlist), chunksize):
501+
yield "\n".join([header] + inlist[i : i + chunksize])
502+
503+
for i, xfm in enumerate(_chunks(lines, 4)):
504+
p = Path(f"{i:05}")
505+
p.write_text(xfm)
506+
yield str(p)
507+
508+
509+
def _b0_resampler(data, coeffs, pe, ro, hmc_xfm=None, unwarp=None, newpath=None):
510+
"""Outsource the resampler into a separate callable function to allow parallelization."""
511+
from functools import partial
512+
513+
# Prepare output names
514+
filename = partial(fname_presuffix, newpath=newpath)
515+
retval = [filename(data, suffix=s) for s in ("_unwarped", "_xfm", "_field")]
516+
517+
if unwarp is None:
518+
from sdcflows.transform import B0FieldTransform
519+
520+
# Create a new unwarp object
521+
unwarp = B0FieldTransform(
522+
coeffs=[nb.load(cname) for cname in coeffs],
523+
)
524+
525+
if hmc_xfm is not None:
526+
from nitransforms.linear import Affine
527+
from nitransforms.io.itk import ITKLinearTransform as XFMLoader
528+
529+
unwarp.xfm = Affine(XFMLoader.from_filename(hmc_xfm).to_ras())
530+
531+
if unwarp.fit(data):
532+
unwarp.shifts.to_filename(retval[2])
533+
else:
534+
retval[2] = None
535+
536+
unwarp.apply(nb.load(data), ro_time=ro, pe_dir=pe).to_filename(retval[0])
537+
unwarp.to_displacements(ro_time=ro, pe_dir=pe).to_filename(retval[1])
538+
539+
return retval

sdcflows/interfaces/tests/test_bspline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_bsplines(tmp_path, testnum):
6363
os.chdir(tmp_path)
6464
# Check that we can interpolate the coefficients on a target
6565
test1 = ApplyCoeffsField(
66-
in_target=str(tmp_path / "target.nii.gz"),
66+
in_data=str(tmp_path / "target.nii.gz"),
6767
in_coeff=str(tmp_path / "coeffs.nii.gz"),
6868
pe_dir="j-",
6969
ro_time=1.0,
@@ -114,7 +114,7 @@ def test_topup_coeffs_interpolation(tmpdir, testdata_dir):
114114
"""Check that our interpolation is not far away from TOPUP's."""
115115
tmpdir.chdir()
116116
result = ApplyCoeffsField(
117-
in_target=[str(testdata_dir / "epi.nii.gz")] * 2,
117+
in_data=[str(testdata_dir / "epi.nii.gz")] * 2,
118118
in_coeff=str(testdata_dir / "topup-coeff-fixed.nii.gz"),
119119
pe_dir="j-",
120120
ro_time=1.0,

0 commit comments

Comments
 (0)