2222#
2323"""Filtering of :math:`B_0` field mappings with B-Splines."""
2424from pathlib import Path
25- from functools import partial
2625import numpy as np
2726import nibabel as nb
2827from nibabel .affines import apply_affine
3433 TraitedSpec ,
3534 File ,
3635 traits ,
36+ isdefined ,
3737 SimpleInterface ,
3838 InputMultiObject ,
3939 OutputMultiObject ,
4040)
4141
42- from sdcflows .transform import grid_bspline_weights as gbsw , B0FieldTransform
42+ from sdcflows .transform import grid_bspline_weights as gbsw
4343
4444
4545LOW_MEM_BLOCK_SIZE = 1000
@@ -210,14 +210,17 @@ def _run_interface(self, runtime):
210210
211211
212212class _ApplyCoeffsFieldInputSpec (BaseInterfaceInputSpec ):
213- in_target = InputMultiObject (
213+ in_data = InputMultiObject (
214214 File (exist = True , mandatory = True , desc = "input EPI data to be corrected" )
215215 )
216216 in_coeff = InputMultiObject (
217217 File (exists = True ),
218218 mandatory = True ,
219219 desc = "input coefficients, after alignment to the EPI data" ,
220220 )
221+ in_xfms = InputMultiObject (
222+ File (exists = True ), desc = "list of head-motion correction matrices"
223+ )
221224 ro_time = InputMultiObject (
222225 traits .Float (mandatory = True , desc = "EPI readout time (s)." )
223226 )
@@ -230,14 +233,15 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
230233 "k" ,
231234 "k-" ,
232235 mandatory = True ,
233- desc = "the phase-encoding direction corresponding to in_target " ,
236+ desc = "the phase-encoding direction corresponding to in_data " ,
234237 )
235238 )
239+ num_threads = traits .Int (nohash = True , desc = "number of threads" )
236240
237241
238242class _ApplyCoeffsFieldOutputSpec (TraitedSpec ):
239243 out_corrected = OutputMultiObject (File (exists = True ))
240- out_field = File (exists = True )
244+ out_field = OutputMultiObject ( File (exists = True ) )
241245 out_warp = OutputMultiObject (File (exists = True ))
242246
243247
@@ -248,38 +252,73 @@ class ApplyCoeffsField(SimpleInterface):
248252 output_spec = _ApplyCoeffsFieldOutputSpec
249253
250254 def _run_interface (self , runtime ):
251- # Prepare output names
252- filename = partial (fname_presuffix , newpath = runtime .cwd )
253-
254- self ._results ["out_field" ] = filename (self .inputs .in_coeff [0 ], suffix = "_field" )
255- self ._results ["out_warp" ] = []
256- self ._results ["out_corrected" ] = []
255+ n = len (self .inputs .in_data )
257256
258- xfm = B0FieldTransform (
259- coeffs = [nb .load (cname ) for cname in self .inputs .in_coeff ]
260- )
261- xfm .fit (self .inputs .in_target [0 ])
262- xfm .shifts .to_filename (self ._results ["out_field" ])
263-
264- n_inputs = len (self .inputs .in_target )
265257 ro_time = self .inputs .ro_time
266258 if len (ro_time ) == 1 :
267- ro_time = [ ro_time [ 0 ]] * n_inputs
259+ ro_time *= n
268260
269261 pe_dir = self .inputs .pe_dir
270262 if len (pe_dir ) == 1 :
271- pe_dir = [pe_dir [0 ]] * n_inputs
263+ pe_dir *= n
264+
265+ unwarp = None
266+ hmc_mats = [None ] * n
267+ if isdefined (self .inputs .in_xfms ):
268+ # Split ITK matrices in separate files if they come collated
269+ hmc_mats = (
270+ list (_split_itk_file (self .inputs .in_xfms [0 ]))
271+ if len (self .inputs .in_xfms ) == 1
272+ else self .inputs .in_xfms
273+ )
274+ else :
275+ from sdcflows .transform import B0FieldTransform
276+
277+ # Pre-cached interpolator object
278+ unwarp = B0FieldTransform (
279+ coeffs = [nb .load (cname ) for cname in self .inputs .in_coeff ],
280+ )
281+
282+ if not isdefined (self .inputs .num_threads ) or self .inputs .num_threads < 2 :
283+ # Linear execution (1 core)
284+ outputs = [
285+ _b0_resampler (
286+ fname ,
287+ self .inputs .in_coeff ,
288+ pe_dir [i ],
289+ ro_time [i ],
290+ hmc_mats [i ],
291+ unwarp , # if no HMC matrices, interpolator can be shared
292+ runtime .cwd ,
293+ )
294+ for i , fname in enumerate (self .inputs .in_data )
295+ ]
296+ else :
297+ # Embarrasingly parallel execution
298+ from concurrent .futures import ProcessPoolExecutor
299+
300+ outputs = [None ] * len (self .inputs .in_data )
301+ with ProcessPoolExecutor (max_workers = self .inputs .num_threads ) as ex :
302+ outputs = ex .map (
303+ _b0_resampler ,
304+ self .inputs .in_data ,
305+ [self .inputs .in_coeff ] * n ,
306+ pe_dir ,
307+ ro_time ,
308+ hmc_mats ,
309+ [None ] * n , # force a new interpolator for each process
310+ [runtime .cwd ] * n ,
311+ )
272312
273- for fname , pe , ro in zip ( self . inputs . in_target , pe_dir , ro_time ):
274- # Generate warpfield
275- warp_name = filename ( fname , suffix = "_xfm" )
276- xfm . to_displacements ( ro_time = ro , pe_dir = pe ). to_filename ( warp_name )
277- self . _results [ "out_warp" ]. append ( warp_name )
313+ (
314+ self . _results [ "out_corrected" ],
315+ self . _results [ "out_warp" ],
316+ self . _results [ "out_field" ],
317+ ) = zip ( * outputs )
278318
279- # Generate resampled
280- out_name = filename (fname , suffix = "_unwarped" )
281- xfm .apply (nb .load (fname ), ro_time = ro , pe_dir = pe ).to_filename (out_name )
282- self ._results ["out_corrected" ].append (out_name )
319+ out_fields = set (self ._results ["out_field" ]) - set ([None ])
320+ if len (out_fields ) == 1 :
321+ self ._results ["out_field" ] = out_fields .pop ()
283322
284323 return runtime
285324
@@ -303,11 +342,21 @@ class TransformCoefficients(SimpleInterface):
303342 output_spec = _TransformCoefficientsOutputSpec
304343
305344 def _run_interface (self , runtime ):
306- self ._results ["out_coeff" ] = _move_coeff (
307- self .inputs .in_coeff ,
308- self .inputs .fmap_ref ,
309- self .inputs .transform ,
310- )
345+ from sdcflows .transform import _move_coeff
346+
347+ self ._results ["out_coeff" ] = []
348+
349+ for level in self .inputs .in_coeff :
350+ movednii = _move_coeff (
351+ level ,
352+ self .inputs .fmap_ref ,
353+ self .inputs .transform ,
354+ )
355+ out_file = fname_presuffix (
356+ level , suffix = "_space-target" , newpath = runtime .cwd
357+ )
358+ movednii .to_filename (out_file )
359+ self ._results ["out_coeff" ].append (out_file )
311360 return runtime
312361
313362
@@ -408,30 +457,6 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
408457 return img .__class__ (np .zeros (bs_shape , dtype = "float32" ), bs_affine )
409458
410459
411- def _move_coeff (in_coeff , fmap_ref , transform ):
412- """Read in a rigid transform from ANTs, and update the coefficients field affine."""
413- from pathlib import Path
414- import nibabel as nb
415- import nitransforms as nt
416-
417- if isinstance (in_coeff , str ):
418- in_coeff = [in_coeff ]
419-
420- xfm = nt .linear .Affine (
421- nt .io .itk .ITKLinearTransform .from_filename (transform ).to_ras (),
422- reference = fmap_ref ,
423- )
424-
425- out = []
426- for i , c in enumerate (in_coeff ):
427- out .append (str (Path (f"moved_coeff_{ i :03d} .nii.gz" ).absolute ()))
428- img = nb .load (c )
429- newaff = xfm .matrix @ img .affine
430- img .__class__ (img .dataobj , newaff , img .header ).to_filename (out [- 1 ])
431-
432- return out
433-
434-
435460def _fix_topup_fieldcoeff (in_coeff , fmap_ref , refpe_reversed = False , out_file = None ):
436461 """Read in a coefficients file generated by TOPUP and fix x-form headers."""
437462 from pathlib import Path
@@ -463,3 +488,52 @@ def _fix_topup_fieldcoeff(in_coeff, fmap_ref, refpe_reversed=False, out_file=Non
463488
464489 coeffnii .__class__ (coeffnii .dataobj , newaff , header ).to_filename (out_file )
465490 return out_file
491+
492+
493+ def _split_itk_file (in_file ):
494+ from pathlib import Path
495+
496+ lines = Path (in_file ).read_text ().splitlines ()
497+ header = lines .pop (0 )
498+
499+ def _chunks (inlist , chunksize ):
500+ for i in range (0 , len (inlist ), chunksize ):
501+ yield "\n " .join ([header ] + inlist [i : i + chunksize ])
502+
503+ for i , xfm in enumerate (_chunks (lines , 4 )):
504+ p = Path (f"{ i :05} " )
505+ p .write_text (xfm )
506+ yield str (p )
507+
508+
509+ def _b0_resampler (data , coeffs , pe , ro , hmc_xfm = None , unwarp = None , newpath = None ):
510+ """Outsource the resampler into a separate callable function to allow parallelization."""
511+ from functools import partial
512+
513+ # Prepare output names
514+ filename = partial (fname_presuffix , newpath = newpath )
515+ retval = [filename (data , suffix = s ) for s in ("_unwarped" , "_xfm" , "_field" )]
516+
517+ if unwarp is None :
518+ from sdcflows .transform import B0FieldTransform
519+
520+ # Create a new unwarp object
521+ unwarp = B0FieldTransform (
522+ coeffs = [nb .load (cname ) for cname in coeffs ],
523+ )
524+
525+ if hmc_xfm is not None :
526+ from nitransforms .linear import Affine
527+ from nitransforms .io .itk import ITKLinearTransform as XFMLoader
528+
529+ unwarp .xfm = Affine (XFMLoader .from_filename (hmc_xfm ).to_ras ())
530+
531+ if unwarp .fit (data ):
532+ unwarp .shifts .to_filename (retval [2 ])
533+ else :
534+ retval [2 ] = None
535+
536+ unwarp .apply (nb .load (data ), ro_time = ro , pe_dir = pe ).to_filename (retval [0 ])
537+ unwarp .to_displacements (ro_time = ro , pe_dir = pe ).to_filename (retval [1 ])
538+
539+ return retval
0 commit comments