2222#
2323"""Filtering of :math:`B_0` field mappings with B-Splines."""
2424from pathlib import Path
25- from functools import partial
2625import numpy as np
2726import nibabel as nb
2827from nibabel .affines import apply_affine
4140)
4241
4342from sdcflows .transform import grid_bspline_weights as gbsw
44- from sdcflows .utils .misc import defaultlist
4543
4644
4745LOW_MEM_BLOCK_SIZE = 1000
@@ -238,6 +236,7 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
238236 desc = "the phase-encoding direction corresponding to in_data" ,
239237 )
240238 )
239+ num_threads = traits .Int (nohash = True , desc = "number of threads" )
241240
242241
243242class _ApplyCoeffsFieldOutputSpec (TraitedSpec ):
@@ -253,13 +252,20 @@ class ApplyCoeffsField(SimpleInterface):
253252 output_spec = _ApplyCoeffsFieldOutputSpec
254253
255254 def _run_interface (self , runtime ):
256- # Load head-motion correction matrices
257- ro_time = defaultlist (self .inputs .ro_time )
258- pe_dir = defaultlist (self .inputs .pe_dir )
255+ n = len (self .inputs .in_data )
256+
257+ ro_time = self .inputs .ro_time
258+ if len (ro_time ) == 1 :
259+ ro_time *= n
260+
261+ pe_dir = self .inputs .pe_dir
262+ if len (pe_dir ) == 1 :
263+ pe_dir *= n
259264
260265 unwarp = None
261- hmc_mats = defaultlist ( [None ])
266+ hmc_mats = [None ] * n
262267 if isdefined (self .inputs .in_xfms ):
268+ # Split ITK matrices in separate files if they come collated
263269 hmc_mats = (
264270 list (_split_itk_file (self .inputs .in_xfms [0 ]))
265271 if len (self .inputs .in_xfms ) == 1
@@ -268,22 +274,41 @@ def _run_interface(self, runtime):
268274 else :
269275 from sdcflows .transform import B0FieldTransform
270276
277+ # Pre-cached interpolator object
271278 unwarp = B0FieldTransform (
272279 coeffs = [nb .load (cname ) for cname in self .inputs .in_coeff ],
273280 )
274281
275- outputs = [
276- _b0_resampler (
277- fname ,
278- self .inputs .in_coeff ,
279- pe_dir [i ],
280- ro_time [i ],
281- hmc_mats [i ],
282- unwarp ,
283- runtime .cwd ,
284- )
285- for i , fname in enumerate (self .inputs .in_data )
286- ]
282+ if not isdefined (self .inputs .num_threads ) or self .inputs .num_threads < 2 :
283+ # Linear execution (1 core)
284+ outputs = [
285+ _b0_resampler (
286+ fname ,
287+ self .inputs .in_coeff ,
288+ pe_dir [i ],
289+ ro_time [i ],
290+ hmc_mats [i ],
291+ unwarp , # if no HMC matrices, interpolator can be shared
292+ runtime .cwd ,
293+ )
294+ for i , fname in enumerate (self .inputs .in_data )
295+ ]
296+ else :
297+ # Embarrasingly parallel execution
298+ from concurrent .futures import ProcessPoolExecutor
299+
300+ outputs = [None ] * len (self .inputs .in_data )
301+ with ProcessPoolExecutor (max_workers = self .inputs .num_threads ) as ex :
302+ outputs = ex .map (
303+ _b0_resampler ,
304+ self .inputs .in_data ,
305+ [self .inputs .in_coeff ] * n ,
306+ pe_dir ,
307+ ro_time ,
308+ hmc_mats ,
309+ [None ] * n , # force a new interpolator for each process
310+ [runtime .cwd ] * n ,
311+ )
287312
288313 (
289314 self ._results ["out_corrected" ],
@@ -292,7 +317,7 @@ def _run_interface(self, runtime):
292317 ) = zip (* outputs )
293318
294319 out_fields = set (self ._results ["out_field" ]) - set ([None ])
295- if len () == 1 :
320+ if len (out_fields ) == 1 :
296321 self ._results ["out_field" ] = out_fields .pop ()
297322
298323 return runtime
@@ -482,9 +507,12 @@ def _chunks(inlist, chunksize):
482507
483508
484509def _b0_resampler (data , coeffs , pe , ro , hmc_xfm = None , unwarp = None , newpath = None ):
510+ """Outsource the resampler into a separate callable function to allow parallelization."""
511+ from functools import partial
512+
485513 # Prepare output names
486514 filename = partial (fname_presuffix , newpath = newpath )
487- retval = tuple ( [filename (data , suffix = s ) for s in ("_unwarped" , "_xfm" , "_field" )])
515+ retval = [filename (data , suffix = s ) for s in ("_unwarped" , "_xfm" , "_field" )]
488516
489517 if unwarp is None :
490518 from sdcflows .transform import B0FieldTransform
0 commit comments