Skip to content

Commit 8ca1e79

Browse files
committed
enh: parallelization
1 parent 33aa519 commit 8ca1e79

File tree

3 files changed

+56
-54
lines changed

3 files changed

+56
-54
lines changed

sdcflows/interfaces/bspline.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#
2323
"""Filtering of :math:`B_0` field mappings with B-Splines."""
2424
from pathlib import Path
25-
from functools import partial
2625
import numpy as np
2726
import nibabel as nb
2827
from nibabel.affines import apply_affine
@@ -41,7 +40,6 @@
4140
)
4241

4342
from sdcflows.transform import grid_bspline_weights as gbsw
44-
from sdcflows.utils.misc import defaultlist
4543

4644

4745
LOW_MEM_BLOCK_SIZE = 1000
@@ -238,6 +236,7 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
238236
desc="the phase-encoding direction corresponding to in_data",
239237
)
240238
)
239+
num_threads = traits.Int(nohash=True, desc="number of threads")
241240

242241

243242
class _ApplyCoeffsFieldOutputSpec(TraitedSpec):
@@ -253,13 +252,20 @@ class ApplyCoeffsField(SimpleInterface):
253252
output_spec = _ApplyCoeffsFieldOutputSpec
254253

255254
def _run_interface(self, runtime):
256-
# Load head-motion correction matrices
257-
ro_time = defaultlist(self.inputs.ro_time)
258-
pe_dir = defaultlist(self.inputs.pe_dir)
255+
n = len(self.inputs.in_data)
256+
257+
ro_time = self.inputs.ro_time
258+
if len(ro_time) == 1:
259+
ro_time *= n
260+
261+
pe_dir = self.inputs.pe_dir
262+
if len(pe_dir) == 1:
263+
pe_dir *= n
259264

260265
unwarp = None
261-
hmc_mats = defaultlist([None])
266+
hmc_mats = [None] * n
262267
if isdefined(self.inputs.in_xfms):
268+
# Split ITK matrices in separate files if they come collated
263269
hmc_mats = (
264270
list(_split_itk_file(self.inputs.in_xfms[0]))
265271
if len(self.inputs.in_xfms) == 1
@@ -268,22 +274,41 @@ def _run_interface(self, runtime):
268274
else:
269275
from sdcflows.transform import B0FieldTransform
270276

277+
# Pre-cached interpolator object
271278
unwarp = B0FieldTransform(
272279
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff],
273280
)
274281

275-
outputs = [
276-
_b0_resampler(
277-
fname,
278-
self.inputs.in_coeff,
279-
pe_dir[i],
280-
ro_time[i],
281-
hmc_mats[i],
282-
unwarp,
283-
runtime.cwd,
284-
)
285-
for i, fname in enumerate(self.inputs.in_data)
286-
]
282+
if not isdefined(self.inputs.num_threads) or self.inputs.num_threads < 2:
283+
# Linear execution (1 core)
284+
outputs = [
285+
_b0_resampler(
286+
fname,
287+
self.inputs.in_coeff,
288+
pe_dir[i],
289+
ro_time[i],
290+
hmc_mats[i],
291+
unwarp, # if no HMC matrices, interpolator can be shared
292+
runtime.cwd,
293+
)
294+
for i, fname in enumerate(self.inputs.in_data)
295+
]
296+
else:
297+
# Embarrasingly parallel execution
298+
from concurrent.futures import ProcessPoolExecutor
299+
300+
outputs = [None] * len(self.inputs.in_data)
301+
with ProcessPoolExecutor(max_workers=self.inputs.num_threads) as ex:
302+
outputs = ex.map(
303+
_b0_resampler,
304+
self.inputs.in_data,
305+
[self.inputs.in_coeff] * n,
306+
pe_dir,
307+
ro_time,
308+
hmc_mats,
309+
[None] * n, # force a new interpolator for each process
310+
[runtime.cwd] * n,
311+
)
287312

288313
(
289314
self._results["out_corrected"],
@@ -292,7 +317,7 @@ def _run_interface(self, runtime):
292317
) = zip(*outputs)
293318

294319
out_fields = set(self._results["out_field"]) - set([None])
295-
if len() == 1:
320+
if len(out_fields) == 1:
296321
self._results["out_field"] = out_fields.pop()
297322

298323
return runtime
@@ -482,9 +507,12 @@ def _chunks(inlist, chunksize):
482507

483508

484509
def _b0_resampler(data, coeffs, pe, ro, hmc_xfm=None, unwarp=None, newpath=None):
510+
"""Outsource the resampler into a separate callable function to allow parallelization."""
511+
from functools import partial
512+
485513
# Prepare output names
486514
filename = partial(fname_presuffix, newpath=newpath)
487-
retval = tuple([filename(data, suffix=s) for s in ("_unwarped", "_xfm", "_field")])
515+
retval = [filename(data, suffix=s) for s in ("_unwarped", "_xfm", "_field")]
488516

489517
if unwarp is None:
490518
from sdcflows.transform import B0FieldTransform

sdcflows/utils/misc.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -67,35 +67,3 @@ def get_free_mem():
6767
return round(virtual_memory().free, 1)
6868
except Exception:
6969
return None
70-
71-
72-
class defaultlist(list):
73-
"""
74-
A sort of default dict for lists.
75-
76-
Examples
77-
--------
78-
>>> defaultlist(range(3))
79-
[0, 1, 2]
80-
81-
>>> defaultlist(["abc"])[100]
82-
'abc'
83-
84-
>>> defaultlist(range(3))[1]
85-
1
86-
87-
>>> l = defaultlist(reversed(range(3)))
88-
>>> l[0]
89-
2
90-
91-
>>> _ = l.pop(0)
92-
>>> _ = l.pop(0)
93-
>>> l[4]
94-
0
95-
96-
"""
97-
98-
def __getitem__(self, i):
99-
if len(self) == 1:
100-
i = 0
101-
return super().__getitem__(i)

sdcflows/workflows/apply/correction.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,20 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"):
9090
)
9191
outputnode = pe.Node(
9292
niu.IdentityInterface(
93-
fields=["fieldmap", "fieldwarp", "corrected", "corrected_ref", "corrected_mask"]
93+
fields=[
94+
"fieldmap",
95+
"fieldwarp",
96+
"corrected",
97+
"corrected_ref",
98+
"corrected_mask",
99+
]
94100
),
95101
name="outputnode",
96102
)
97103

98104
rotime = pe.Node(GetReadoutTime(), name="rotime")
99105
rotime.interface._always_run = debug
100-
resample = pe.Node(ApplyCoeffsField(), name="resample")
106+
resample = pe.Node(ApplyCoeffsField(num_threads=omp_nthreads), name="resample")
101107
merge = pe.Node(MergeSeries(), name="merge")
102108
average = pe.Node(RobustAverage(mc_method=None), name="average")
103109

0 commit comments

Comments
 (0)