1616 type_to_str ,
1717)
1818from fileformats .core .mixin import WithClassifiers
19- from fileformats .generic import File , Directory
19+ from fileformats .generic import File
20+ from pydra .utils .typing import is_optional
2021
2122
2223logger = logging .getLogger ("nipype2pydra" )
@@ -59,7 +60,7 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> str:
5960 set of non-standard types
6061 output_fields : list[tuple[str, type, dict]]
6162 list of output fields, each field is a tuple of (name, type, metadata)
62-
63+
6364 Returns
6465 -------
6566 converted_code : str
@@ -95,75 +96,65 @@ def unwrap_field_type(t):
9596
9697 nonstd_types = copy (nonstd_types )
9798
98- # def types_to_names(spec_fields):
99- # spec_fields_str = []
100- # for el in spec_fields:
101- # el = list(el)
102- # field_type = el[1]
103- # if inspect.isclass(field_type) and issubclass(
104- # field_type, WithClassifiers
105- # ):
106- # field_type_str = unwrap_field_type(field_type)
107- # else:
108- # field_type_str = str(field_type)
109- # if field_type_str.startswith("<class "):
110- # field_type_str = el[1].__name__
111- # else:
112- # # Alter modules in type string to match those that will be imported
113- # field_type_str = field_type_str.replace("typing", "ty")
114- # field_type_str = re.sub(
115- # r"(\w+\.)+(?<!ty\.)(\w+)", r"\2", field_type_str
116- # )
117- # if field_type_str == "File":
118- # nonstd_types.add(File)
119- # elif field_type_str == "Directory":
120- # nonstd_types.add(Directory)
121- # el[1] = "#" + field_type_str + "#"
122- # spec_fields_str.append(tuple(el))
123- # return spec_fields_str
124-
12599 input_names = [i [0 ] for i in input_fields ]
126100 output_names = [o [0 ] for o in output_fields ]
127- # input_fields_str = types_to_names(spec_fields=input_fields)
128- # input_fields_str = re.sub(
129- # r"'formatter': '(\w+)'", r"'formatter': \1", input_fields_str
130- # )
131- # output_fields_str = types_to_names(spec_fields=output_fields)
132- # output_fields_str = re.sub(
133- # r"'callable': '(\w+)'", r"'callable': \1", output_fields_str
134- # )
135- # functions_str = self.function_callables()
136- # functions_imports, functions_str = functions_str.split("\n\n", 1)
137- # spec_str = functions_str
101+
102+ # Pull out xor fields into task-level xor_sets
103+ xor_sets = set ()
104+ for inpt in input_fields :
105+ if len (inpt ) == 3 :
106+ name , _ , mdata = inpt
107+ else :
108+ name , _ , __ , mdata = inpt
109+ if "xor" in mdata :
110+ xor_sets .add (frozenset (mdata ["xor" ] + [name ]))
138111
139112 input_fields_str = ""
140113 output_fields_str = ""
141- xor_sets = set ()
114+
142115 for inpt in input_fields :
143116 if len (inpt ) == 3 :
144117 name , type_ , mdata = inpt
118+ mdata = copy (mdata ) # Copy to avoid modifying the original
145119 else :
146120 name , type_ , default , mdata = inpt
121+ mdata = copy (mdata ) # Copy to avoid modifying the original
147122 mdata ["default" ] = default
123+ if (
124+ any (name in x for x in xor_sets )
125+ and type_ is not bool
126+ and not is_optional (type_ )
127+ and (inspect .isclass (type_ ) and not issubclass (type_ , ty .Sequence ))
128+ ):
129+ type_ = type_ | None
148130 type_str = type_to_str (type_ , mdata .pop ("mandatory" , True ))
149131 if mdata .pop ("copyfile" , None ):
150132 nonstd_types .add (File )
151133 mdata ["copy_mode" ] = "File.CopyMode.copy"
152- if xor := mdata .pop ("xor" , None ):
153- xor_sets .add (frozenset (xor + [name ]))
134+ mdata .pop ("xor" , None )
154135 args_str = ", " .join (f"{ k } ={ v !r} " for k , v in mdata .items ())
155136 if "path_template" in mdata :
156- output_fields_str = f" { name } : { type_str } = shell.outarg({ args_str } )\n "
137+ output_fields_str = (
138+ f" { name } : { type_str } = shell.outarg({ args_str } )\n "
139+ )
157140 else :
158141 input_fields_str += f" { name } : { type_str } = shell.arg({ args_str } )\n "
159142
143+ callable_fields = set (n for n , _ , __ in self .callable_output_fields )
144+
160145 for outpt in output_fields :
161146 name , type_ , mdata = outpt
162- cllble = mdata .pop ("callable" , None )
147+ cllble = mdata .pop (
148+ "callable" , f"{ name } _callable" if name in callable_fields else None
149+ )
163150 args_str = ", " .join (f"{ k } ={ v !r} " for k , v in mdata .items ())
151+ if args_str :
152+ args_str += ", "
164153 if cllble :
165- args_str += f", callable={ cllble } "
166- output_fields_str += f" { name } : { type_to_str (type_ )} = shell.out({ args_str } )\n "
154+ args_str += f"callable={ cllble } "
155+ output_fields_str += (
156+ f" { name } : { type_to_str (type_ )} = shell.out({ args_str } )\n "
157+ )
167158
168159 spec_str = (
169160 self .init_code
@@ -176,7 +167,9 @@ def unwrap_field_type(t):
176167 spec_str += "@shell.define"
177168 if xor_sets :
178169 spec_str += f"(xor={ [list (x ) for x in xor_sets ]} )"
179- spec_str += f"\n class { self .task_name } (shell.Task['{ self .task_name } .Outputs']):\n "
170+ spec_str += (
171+ f"\n class { self .task_name } (shell.Task['{ self .task_name } .Outputs']):\n "
172+ )
180173 spec_str += ' """\n '
181174 spec_str += self .create_doctests (
182175 input_fields = input_fields , nonstd_types = nonstd_types
@@ -254,10 +247,7 @@ def callable_output_fields(self):
254247 return [
255248 f
256249 for f in super ().output_fields
257- if (
258- "path_template" not in f [- 1 ]
259- and f [0 ] not in INBUILT_NIPYPE_TRAIT_NAMES
260- )
250+ if ("path_template" not in f [- 1 ] and f [0 ] not in INBUILT_NIPYPE_TRAIT_NAMES )
261251 ]
262252
263253 @property
0 commit comments