diff --git a/nipype/interfaces/utility/tests/test_wrappers.py b/nipype/interfaces/utility/tests/test_wrappers.py index d9f1da255a..3384a5865c 100644 --- a/nipype/interfaces/utility/tests/test_wrappers.py +++ b/nipype/interfaces/utility/tests/test_wrappers.py @@ -8,6 +8,12 @@ from nipype.interfaces import utility import nipype.pipeline.engine as pe +concat_sort = """\ +def concat_sort(in_arrays): + import numpy as np + all_vals = np.concatenate([arr.flatten() for arr in in_arrays]) + return np.sort(all_vals) +""" def test_function(tmpdir): os.chdir(str(tmpdir)) @@ -24,9 +30,15 @@ def gen_random_array(size): def increment_array(in_array): return in_array + 1 - f2 = pe.MapNode(utility.Function(input_names=['in_array'], output_names=['out_array'], function=increment_array), name='increment_array', iterfield=['in_array']) + f2 = pe.MapNode(utility.Function(function=increment_array), name='increment_array', iterfield=['in_array']) wf.connect(f1, 'random_array', f2, 'in_array') + + f3 = pe.Node( + utility.Function(function=concat_sort), + name="concat_sort") + + wf.connect(f2, 'out', f3, 'in_arrays') wf.run() diff --git a/nipype/interfaces/utility/wrappers.py b/nipype/interfaces/utility/wrappers.py index b8e78a56e6..30b4e10e8f 100644 --- a/nipype/interfaces/utility/wrappers.py +++ b/nipype/interfaces/utility/wrappers.py @@ -58,18 +58,19 @@ class Function(IOBase): input_spec = FunctionInputSpec output_spec = DynamicTraitedSpec - def __init__(self, input_names, output_names, function=None, imports=None, - **inputs): + def __init__(self, input_names=None, output_names='out', function=None, + imports=None, **inputs): """ Parameters ---------- - input_names: single str or list + input_names: single str or list or None names corresponding to function inputs + if ``None``, derive input names from function argument names output_names: single str or list - names corresponding to function outputs. - has to match the number of outputs + names corresponding to function outputs (default: 'out'). + if list of length > 1, has to match the number of outputs function : callable callable python object. must be able to execute in an isolated namespace (possibly in concert with the ``imports`` @@ -88,10 +89,18 @@ def __init__(self, input_names, output_names, function=None, imports=None, raise Exception('Interface Function does not accept ' 'function objects defined interactively ' 'in a python session') + else: + if input_names is None: + fninfo = function.__code__ elif isinstance(function, (str, bytes)): self.inputs.function_str = function + if input_names is None: + fninfo = create_function_from_source( + function, imports).__code__ else: raise Exception('Unknown type of function') + if input_names is None: + input_names = fninfo.co_varnames[:fninfo.co_argcount] self.inputs.on_trait_change(self._set_function_string, 'function_str') self._input_names = filename_to_list(input_names) @@ -106,10 +115,18 @@ def _set_function_string(self, obj, name, old, new): if name == 'function_str': if hasattr(new, '__call__'): function_source = getsource(new) + fninfo = new.__code__ elif isinstance(new, (str, bytes)): function_source = new + fninfo = create_function_from_source( + new, self.imports).__code__ self.inputs.trait_set(trait_change_notify=False, **{'%s' % name: function_source}) + # Update input traits + input_names = fninfo.co_varnames[:fninfo.co_argcount] + new_names = set(input_names) - set(self._input_names) + add_traits(self.inputs, list(new_names)) + self._input_names.extend(new_names) def _add_output_traits(self, base): undefined_traits = {}