Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 49 additions & 28 deletions smriprep/interfaces/templateflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,34 +108,9 @@ def _run_interface(self, runtime):
if isdefined(self.inputs.cohort):
specs['cohort'] = self.inputs.cohort

name = self.inputs.template.strip(':').split(':', 1)
if len(name) > 1:
specs.update(
{
k: v
for modifier in name[1].split(':')
for k, v in [tuple(modifier.split('-'))]
if k not in specs
}
)

if specs['resolution'] and not isinstance(specs['resolution'], list):
specs['resolution'] = [specs['resolution']]

available_resolutions = tf.TF_LAYOUT.get_resolutions(template=name[0])
if specs['resolution'] and not set(specs['resolution']) & set(available_resolutions):
fallback_res = available_resolutions[0] if available_resolutions else None
LOGGER.warning(
f"Template {name[0]} does not have resolution(s): {specs['resolution']}."
f"Falling back to resolution: {fallback_res}."
)
specs['resolution'] = fallback_res

self._results['t1w_file'] = tf.get(name[0], desc=None, suffix='T1w', **specs)

self._results['brain_mask'] = tf.get(
name[0], desc='brain', suffix='mask', **specs
) or tf.get(name[0], label='brain', suffix='mask', **specs)
files = fetch_template_files(self.inputs.template, specs)
self._results['t1w_file'] = files['t1w']
self._results['brain_mask'] = files['mask']
return runtime


Expand Down Expand Up @@ -186,3 +161,49 @@ def _run_interface(self, runtime):
descsplit = desc.split('-')
self._results['spec'][descsplit[0]] = descsplit[1]
return runtime


def fetch_template_files(
template: str,
specs: dict | None = None,
sloppy: bool = False,
) -> dict:
if specs is None:
specs = {}

name = template.strip(':').split(':', 1)
if len(name) > 1:
specs.update(
{
k: v
for modifier in name[1].split(':')
for k, v in [tuple(modifier.split('-'))]
if k not in specs
}
)

if res := specs.pop('res', None):
if res != 'native':
specs['resolution'] = res

if not specs.get('resolution'):
specs['resolution'] = 2 if sloppy else 1

if specs.get('resolution') and not isinstance(specs['resolution'], list):
specs['resolution'] = [specs['resolution']]

available_resolutions = tf.TF_LAYOUT.get_resolutions(template=name[0])
if specs.get('resolution') and not set(specs['resolution']) & set(available_resolutions):
fallback_res = available_resolutions[0] if available_resolutions else None
LOGGER.warning(
f"Template {name[0]} does not have resolution(s): {specs['resolution']}."
f"Falling back to resolution: {fallback_res}."
)
specs['resolution'] = fallback_res

files = {}
files['t1w'] = tf.get(name[0], desc=None, suffix='T1w', **specs)
files['mask'] = tf.get(name[0], desc='brain', suffix='mask', **specs) or tf.get(
name[0], label='brain', suffix='mask', **specs
)
return files
3 changes: 2 additions & 1 deletion smriprep/workflows/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def init_anat_preproc_wf(
omp_nthreads=omp_nthreads,
skull_strip_fixed_seed=skull_strip_fixed_seed,
)
template_iterator_wf = init_template_iterator_wf(spaces=spaces)
template_iterator_wf = init_template_iterator_wf(spaces=spaces, sloppy=sloppy)
ds_std_volumes_wf = init_ds_anat_volumes_wf(
bids_root=bids_root,
output_dir=output_dir,
Expand Down Expand Up @@ -725,6 +725,7 @@ def init_anat_fit_wf(
spaces=spaces,
freesurfer=freesurfer,
output_dir=output_dir,
sloppy=sloppy,
)
# fmt:off
workflow.connect([
Expand Down
30 changes: 16 additions & 14 deletions smriprep/workflows/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
from niworkflows.interfaces.utility import KeySelect

from ..interfaces import DerivativesDataSink
from ..interfaces.templateflow import TemplateFlowSelect
from ..interfaces.templateflow import TemplateFlowSelect, fetch_template_files

if ty.TYPE_CHECKING:
from niworkflows.utils.spaces import SpatialReferences

BIDS_TISSUE_ORDER = ('GM', 'WM', 'CSF')


def init_anat_reports_wf(*, spaces, freesurfer, output_dir, name='anat_reports_wf'):
def init_anat_reports_wf(*, spaces, freesurfer, output_dir, sloppy=False, name='anat_reports_wf'):
"""
Set up a battery of datasinks to store reports in the right location.

Expand Down Expand Up @@ -131,7 +134,7 @@ def init_anat_reports_wf(*, spaces, freesurfer, output_dir, name='anat_reports_w
# fmt:on

if spaces._cached is not None and spaces.cached.references:
template_iterator_wf = init_template_iterator_wf(spaces=spaces)
template_iterator_wf = init_template_iterator_wf(spaces=spaces, sloppy=sloppy)
t1w_std = pe.Node(
ApplyTransforms(
dimension=3,
Expand Down Expand Up @@ -1112,7 +1115,9 @@ def init_anat_second_derivatives_wf(
return workflow


def init_template_iterator_wf(*, spaces, name='template_iterator_wf'):
def init_template_iterator_wf(
*, spaces: 'SpatialReferences', sloppy: bool = False, name='template_iterator_wf'
):
"""Prepare the necessary components to resample an image to a template space

This produces a workflow with an unjoined iterable named "spacesource".
Expand All @@ -1122,6 +1127,9 @@ def init_template_iterator_wf(*, spaces, name='template_iterator_wf'):

The fields in `outputnode` can be used as if they come from a single template.
"""
for template in spaces.get_spaces(nonstandard=False, dim=(3,)):
fetch_template_files(template, specs=None, sloppy=sloppy)

workflow = pe.Workflow(name=name)

inputnode = pe.Node(
Expand Down Expand Up @@ -1159,9 +1167,7 @@ def init_template_iterator_wf(*, spaces, name='template_iterator_wf'):
name='select_xfm',
run_without_submitting=True,
)
select_tpl = pe.Node(
TemplateFlowSelect(resolution=1), name='select_tpl', run_without_submitting=True
)
select_tpl = pe.Node(TemplateFlowSelect(), name='select_tpl', run_without_submitting=True)

# fmt:off
workflow.connect([
Expand All @@ -1177,7 +1183,7 @@ def init_template_iterator_wf(*, spaces, name='template_iterator_wf'):
(spacesource, select_tpl, [
('space', 'template'),
('cohort', 'cohort'),
(('resolution', _no_native), 'resolution'),
(('resolution', _no_native, sloppy), 'resolution'),
]),
(spacesource, outputnode, [
('space', 'space'),
Expand Down Expand Up @@ -1243,10 +1249,6 @@ def _pick_cohort(in_template):
return [_pick_cohort(v) for v in in_template]


def _fmt(in_template):
return in_template.replace(':', '_')


def _empty_report(in_file=None):
from pathlib import Path

Expand All @@ -1268,11 +1270,11 @@ def _is_native(value):
return value == 'native'


def _no_native(value):
def _no_native(value, sloppy=False):
try:
return int(value)
except (TypeError, ValueError):
return 1
return 2 if sloppy else 1


def _drop_path(in_path):
Expand Down