Skip to content

Commit

Permalink
[FIX] Make recon workflows work on multi-session data (#639)
Browse files Browse the repository at this point in the history
* Move interchange to interfaces

* refactor the recon workflows again

* name workflow consistently
  • Loading branch information
mattcieslak authored Sep 29, 2023
1 parent 1844872 commit a5754c6
Show file tree
Hide file tree
Showing 15 changed files with 166 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
if overlapping_names:
raise Exception("Someone has added overlapping outputs between the anatomical "
"and dwi inputs: " + " ".join(overlapping_names))

recon_workflow_input_fields = qsiprep_output_names + \
recon_workflow_anatomical_input_fields
default_input_set = set(recon_workflow_input_fields)
Expand All @@ -53,10 +53,34 @@ class ReconWorkflowInputs(SimpleInterface):
output_spec = _ReconWorkflowInputsOutputSpec

def _run_interface(self, runtime):
inputs = self.inputs.get()
for name in recon_workflow_input_fields:
self._results[name] = self.inputs.get(name)
self._results[name] = inputs.get(name)
return runtime

for name in recon_workflow_input_fields:
_ReconWorkflowInputsInputSpec.add_class_trait(name, traits.Any)
_ReconWorkflowInputsOutputSpec.add_class_trait(name, traits.Any)
_ReconWorkflowInputsOutputSpec.add_class_trait(name, traits.Any)


class _ReconAnatomicalDataInputSpec(BaseInterfaceInputSpec):
pass


class _ReconAnatomicalDataOutputSpec(TraitedSpec):
pass


class ReconAnatomicalData(SimpleInterface):
input_spec = _ReconAnatomicalDataInputSpec
output_spec = _ReconAnatomicalDataOutputSpec

def _run_interface(self, runtime):
inputs = self.inputs.get()
for name in anatomical_workflow_outputs:
self._results[name] = inputs.get(name)
return runtime

for name in anatomical_workflow_outputs:
_ReconAnatomicalDataInputSpec.add_class_trait(name, traits.Any)
_ReconAnatomicalDataOutputSpec.add_class_trait(name, traits.Any)
49 changes: 26 additions & 23 deletions qsiprep/interfaces/mrtrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,16 @@ class GenerateMasked5ttInputSpec(Generate5ttInputSpec):
mandatory=True,
desc='tissue segmentation algorithm')
in_file = traits.Either(
File(exists=True),
File(exists=True),
traits.Directory(exists=True),
argstr='%s',
mandatory=True,
position=1,
desc='input T1w image or FreeSurfer directory')
out_file = File(
argstr='%s',
genfile=True,
position=2,
argstr='%s',
genfile=True,
position=2,
desc='output image')
mask = File(exists=True, argstr='-mask %s')
amygdala_hipppocampi_subcortical_gm = traits.Bool(
Expand All @@ -232,7 +232,7 @@ class GenerateMasked5ttInputSpec(Generate5ttInputSpec):
thalami_method = traits.Enum(
"nuclei",
"first",
"aseg",
"aseg",
argstr="-thalami %s")
hippocampi_method = traits.Enum(
"subfields",
Expand Down Expand Up @@ -880,10 +880,10 @@ def _run_interface(self, runtime):
merge_weights = pe.Node(niu.Merge(num_nodes),
name='merge_weights')
merge_exemplars = pe.Node(niu.Merge(3), name='merge_exemplars')
compress_exemplars = pe.Node(CompressConnectome2Tck(),
compress_exemplars = pe.Node(CompressConnectome2Tck(),
name='compress_exemplars')
outputnode = pe.Node(
niu.IdentityInterface(fields=['matfiles', 'tckfiles', 'weights']),
niu.IdentityInterface(fields=['matfiles', 'tckfiles', 'weights']),
name='outputnode')
workflow.connect(merge_mats, 'out', outputnode, 'matfiles')
workflow.connect(merge_tcks, 'out', outputnode, 'tckfiles')
Expand Down Expand Up @@ -935,7 +935,7 @@ def _run_interface(self, runtime):
workflow.connect(c2t_nodes[-1], 'exemplar_weights',
merge_weights, 'in%d' % in_num)
in_num += 1

# Get the exemplar tcks and weights
workflow.connect([
(merge_tcks, merge_exemplars, [('out', "in1")]),
Expand All @@ -960,7 +960,7 @@ def _run_interface(self, runtime):
wf_result = workflow.run(**plugin_settings)
else:
wf_result = workflow.run()

# Merge the connectivity matrices into a single file
merge_node, = [node for node in list(wf_result.nodes) if node.name.endswith('merge_mats')]
merged_connectivity_file = op.join(cwd, "combined_connectivity.mat")
Expand Down Expand Up @@ -1061,30 +1061,30 @@ class _CompressConnectome2TckOutputSpec(TraitedSpec):
class CompressConnectome2Tck(SimpleInterface):
input_spec = _CompressConnectome2TckInputSpec
output_spec = _CompressConnectome2TckOutputSpec

def _run_interface(self, runtime):
out_zip = op.join(runtime.cwd, self.inputs.out_zip)
zipfh = zipfile.ZipFile(out_zip, "w")
# Get the matrix csvs and add them to the zip
csvfiles = [fname for fname in self.inputs.files if fname.endswith(".csv")
csvfiles = [fname for fname in self.inputs.files if fname.endswith(".csv")
and not fname.endswith("weights.csv")]
for csvfile in csvfiles:
zipfh.write(csvfile, arcname=_rename_connectome(csvfile, suffix='connectome.csv'),
compresslevel=8, compress_type=zipfile.ZIP_DEFLATED)

# Get the sift weights if they exist
weightfiles = [fname for fname in self.inputs.files if fname.endswith("weights.csv")]
for weightfile in weightfiles:
zipfh.write(weightfile, arcname=_rename_connectome(weightfile, suffix='_weights.csv'),
compresslevel=8, compress_type=zipfile.ZIP_DEFLATED)

# Get the tck files
tckfiles = [fname for fname in self.inputs.files if fname.endswith(".tck")
or fname.endswith(".tck.gz")]
for tckfile in tckfiles:
zipfh.write(tckfile, arcname=_rename_connectome(tckfile, suffix='_exemplars.tck'),
zipfh.write(tckfile, arcname=_rename_connectome(tckfile, suffix='_exemplars.tck'),
compresslevel=8, compress_type=zipfile.ZIP_DEFLATED)

zipfh.close()
self._results["out_zip"] = out_zip
return runtime
Expand All @@ -1101,7 +1101,10 @@ def _rename_connectome(connectome_csv, suffix="_connectome.csv"):
"""
parts = connectome_csv.split(os.sep)
conn_name = parts[-2]
image_name, = [part for part in parts if part.startswith("sub_") and part.endswith("recon_wf")]
try:
image_name, = [part for part in parts if part.startswith("sub_") and part.endswith("recon_wf")]
except Exception as ex:
raise Exception(f"unable to detect image name from these parts {parts}")
image_name = image_name[:-len("_recon_wf")]
return "connectome2tck/" +_rebids(image_name) + "_" + conn_name + suffix

Expand Down Expand Up @@ -1303,9 +1306,9 @@ class _ITKTransformConvertInputSpec(CommandLineInputSpec):
mandatory=True,
position=0)
operation = traits.Enum(
"itk_import",
default="itk_import",
usedefault=True,
"itk_import",
default="itk_import",
usedefault=True,
posision=1,
argstr="%s")
out_transform = traits.File(
Expand All @@ -1328,13 +1331,13 @@ class ITKTransformConvert(CommandLine):

class _TransformHeaderInputSpec(CommandLineInputSpec):
transform_file = traits.File(
exists=True,
position=0,
exists=True,
position=0,
mandatory=True,
argstr="-linear %s")
in_image = traits.File(
exists=True,
mandatory=True,
exists=True,
mandatory=True,
position=1,
argstr="%s")
out_image = traits.File(
Expand Down
6 changes: 3 additions & 3 deletions qsiprep/workflows/recon/amico.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import nipype.pipeline.engine as pe
from nipype.interfaces import afni, utility as niu
from qsiprep.interfaces.bids import ReconDerivativesDataSink
from .interchange import recon_workflow_input_fields
from ...interfaces.interchange import recon_workflow_input_fields
from ...engine import Workflow
from ...interfaces.amico import NODDI
from ...interfaces.reports import CLIReconPeaksReport
Expand Down Expand Up @@ -57,7 +57,7 @@ def init_amico_noddi_fit_wf(omp_nthreads, available_anatomical_data,
: """
noddi_fit = pe.Node(
NODDI(**params),
NODDI(**params),
name="recon_noddi",
n_procs=omp_nthreads)
desc += """\
Expand Down Expand Up @@ -91,7 +91,7 @@ def init_amico_noddi_fit_wf(omp_nthreads, available_anatomical_data,
(convert_to_fibgz, outputnode, [('fibgz_file', 'fibgz')])])
if plot_reports:
plot_peaks = pe.Node(
CLIReconPeaksReport(),
CLIReconPeaksReport(),
name='plot_peaks',
n_procs=omp_nthreads)
ds_report_peaks = pe.Node(
Expand Down
2 changes: 1 addition & 1 deletion qsiprep/workflows/recon/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...interfaces.freesurfer import find_fs_path
from ...interfaces.gradients import ExtractB0s
from ...interfaces.nilearn import MaskB0Series
from .interchange import (qsiprep_anatomical_ingressed_fields,
from ...interfaces.interchange import (qsiprep_anatomical_ingressed_fields,
FS_FILES_TO_REGISTER, anatomical_workflow_outputs, recon_workflow_input_fields)
from qsiprep.interfaces.utils import GetConnectivityAtlases

Expand Down
87 changes: 72 additions & 15 deletions qsiprep/workflows/recon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,23 @@
from glob import glob
from copy import deepcopy
from nipype import __version__ as nipype_ver
import nipype.pipeline.engine as pe
from nipype.utils.filemanip import split_filename
from nilearn import __version__ as nilearn_ver
from dipy import __version__ as dipy_ver
from pkg_resources import resource_filename as pkgrf
from ...engine import Workflow
from ...utils.sloppy_recon import make_sloppy
from ...__about__ import __version__

from ...interfaces.bids import QsiReconIngress
import logging
import json
from bids.layout import BIDSLayout
from .build_workflow import init_dwi_recon_workflow
from .anatomical import init_recon_anatomical_wf
from .interchange import anatomical_workflow_outputs
from .anatomical import init_recon_anatomical_wf, init_dwi_recon_anatomical_workflow
from ...interfaces.interchange import (anatomical_workflow_outputs, recon_workflow_anatomical_input_fields,
ReconWorkflowInputs,
qsiprep_output_names, recon_workflow_input_fields)

LOGGER = logging.getLogger('nipype.workflow')

Expand Down Expand Up @@ -211,7 +215,10 @@ def init_single_subject_wf(
extras_to_make=spec.get('anatomical', []),
freesurfer_dir=freesurfer_input,
name='anat_ingress_wf')


# Connect the anatomical-only inputs. NOTE this is not to the inputnode!
LOGGER.info("Anatomical (T1w) available for recon: %s", available_anatomical_data)

# Fill-in datasinks and reportlet datasinks for the anatomical workflow
for _node in anat_ingress_wf.list_node_names():
node_suffix = _node.split('.')[-1]
Expand All @@ -221,31 +228,81 @@ def init_single_subject_wf(
anat_ingress_wf.get_node(_node).inputs.source_file = \
"anat/sub-{}_desc-preproc_T1w.nii.gz".format(subject_id)

# Connect the anatomical-only inputs. NOTE this is not to the inputnode!
LOGGER.info("Anatomical (T1w) available for recon: %s", available_anatomical_data)
to_connect = [('outputnode.' + name, 'qsirecon_anat_wf.inputnode.' + name)
for name in anatomical_workflow_outputs]
# Get the anatomical data (masks, atlases, etc)
atlas_names = spec.get('atlases', [])

# create a processing pipeline for the dwis in each session
dwi_recon_wfs = {}
dwi_individual_anatomical_wfs = {}
recon_full_inputs = {}
dwi_ingress_nodes = {}
for dwi_file in dwi_files:
wf_name = _get_wf_name(dwi_file)

# Get the preprocessed DWI and all the related preprocessed images
dwi_ingress_nodes[dwi_file] = pe.Node(
QsiReconIngress(dwi_file=dwi_file),
name=wf_name + "_ingressed_dwi_data")

# Create scan-specific anatomical data (mask, atlas configs, odf ROIs for reports)
dwi_individual_anatomical_wfs[dwi_file], dwi_available_anatomical_data = \
init_dwi_recon_anatomical_workflow(
atlas_names=atlas_names,
omp_nthreads=omp_nthreads,
infant_mode=False,
prefer_dwi_mask=False,
sloppy=sloppy,
b0_threshold=b0_threshold,
freesurfer_dir=freesurfer_input,
extras_to_make=spec.get('anatomical', []),
name=wf_name + "_anat_wf",
**available_anatomical_data)

# This node holds all the inputs that will go to the recon workflow.
# It is the definitive place to check what the input files are
recon_full_inputs[dwi_file] = pe.Node(ReconWorkflowInputs(), name=wf_name + "_recon_inputs")

# This is the actual recon workflow for this dwi file
dwi_recon_wfs[dwi_file] = init_dwi_recon_workflow(
dwi_file=dwi_file,
available_anatomical_data=available_anatomical_data,
available_anatomical_data=dwi_available_anatomical_data,
workflow_spec=spec,
sloppy=sloppy,
prefer_dwi_mask=False,
infant_mode=False,
b0_threshold=b0_threshold,
name=wf_name + "_recon_wf",
reportlets_dir=reportlets_dir,
output_dir=output_dir,
omp_nthreads=omp_nthreads,
skip_odf_plots=skip_odf_plots)
workflow.connect([(anat_ingress_wf, dwi_recon_wfs[dwi_file], to_connect)])

# Connect the collected diffusion data (gradients, etc) to the inputnode
workflow.connect([

# The dwi data
(dwi_ingress_nodes[dwi_file], recon_full_inputs[dwi_file], [
(trait, trait) for trait in qsiprep_output_names]),

# subject anatomical data to dwi
(anat_ingress_wf, dwi_individual_anatomical_wfs[dwi_file],
[("outputnode."+trait, "inputnode."+trait) for trait in anatomical_workflow_outputs]),
(dwi_ingress_nodes[dwi_file], dwi_individual_anatomical_wfs[dwi_file],
[(trait, "inputnode." + trait) for trait in qsiprep_output_names]),

# subject dwi-specific anatomical to recon inputs
(dwi_individual_anatomical_wfs[dwi_file], recon_full_inputs[dwi_file], [
("outputnode." + trait, trait) for trait in recon_workflow_anatomical_input_fields]),

# recon inputs to recon workflow
(recon_full_inputs[dwi_file], dwi_recon_wfs[dwi_file],
[(trait, "inputnode." + trait) for trait in recon_workflow_input_fields])
])

return workflow


def _get_wf_name(dwi_file):
basedir, fname, ext = split_filename(dwi_file)
tokens = fname.split("_")
return "_".join(tokens[:-1]).replace("-", "_")


def _load_recon_spec(spec_name, sloppy=False):
prepackaged_dir = pkgrf("qsiprep", "data/pipelines")
prepackaged = [op.split(fname)[1][:-5] for fname in glob(prepackaged_dir+"/*.json")]
Expand Down
Loading

0 comments on commit a5754c6

Please sign in to comment.