1212logger = logging .getLogger ("nipype2pydra" )
1313
1414
15+ def type2str (type_ ):
16+ """Convert a type to a string representation."""
17+ if isinstance (type_ , str ):
18+ return type_
19+ if type_ is ty .Any :
20+ return "ty.Any"
21+ elif hasattr (type_ , "__name__" ):
22+ return type_ .__name__
23+ elif hasattr (type_ , "__qualname__" ):
24+ return type_ .__qualname__
25+ else :
26+ return str (type_ ).replace ("typing." , "ty." )
27+
28+
1529@attrs .define (slots = False )
1630class PythonInterfaceConverter (BaseInterfaceConverter ):
1731
@@ -136,8 +150,8 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[
136150
137151 assert method_body , "Neither `run_interface` and `list_outputs` are defined"
138152
139- spec_str = "@python.define"
140- spec_str += f"class { self .task_name } (python.Task['{ self .task_name } .Outputs']):"
153+ spec_str = "@python.define\n "
154+ spec_str += f"class { self .task_name } (python.Task['{ self .task_name } .Outputs']):\n "
141155 spec_str += ' """\n '
142156 spec_str += self .create_doctests (
143157 input_fields = input_fields , nonstd_types = nonstd_types
@@ -147,32 +161,33 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[
147161 for inpt in input_fields :
148162 if len (inpt ) == 4 :
149163 name , type_ , default , _ = inpt
150- spec_str += f" { name } : { type_ } = { default } \n "
164+ spec_str += f" { name } : { type2str ( type_ ) } = { default } \n "
151165 else :
152166 name , type_ , _ = inpt
153- spec_str += f" { name } : { type_ } \n "
167+ spec_str += f" { name } : { type2str ( type_ ) } \n "
154168
155- spec_str += " @staticmethod\n "
169+ spec_str += "\n \n class Outputs(python.Outputs):\n "
170+ for outpt in output_fields :
171+ name , type_ , _ = outpt
172+ spec_str += f" { name } : { type2str (type_ )} \n "
173+
174+ spec_str += "\n @staticmethod\n "
156175 spec_str += (
157176 " def function("
158- + ", " .join (f"{ n } : { t } " for n , t , _ in input_fields )
177+ + ", " .join (f"{ i [ 0 ] } : { type2str ( i [ 1 ]) } " for i in input_fields )
159178 + ")"
160179 )
161- output_types = [o [1 ] for o in output_fields ]
180+ output_types = [type2str ( o [1 ]) for o in output_fields ]
162181 if any (t is not ty .Any for t in output_types ):
163- spec_str += "-> "
182+ spec_str += " -> "
164183 if len (output_types ) > 1 :
165184 spec_str += "tuples[" + ", " .join (output_types ) + "]"
166185 else :
167186 spec_str += output_types [0 ]
168187 spec_str += ":\n "
169- spec_str += method_body + "\n "
170- spec_str += "\n return {}" .format (", " .join (output_names ))
188+ spec_str += " " + method_body . replace ( " \n " , " \n " ) + "\n "
189+ spec_str += "\n return {}" .format (", " .join (output_names ))
171190
172- spec_str += " class Outputs(python.Outputs):"
173- for outpt in output_fields :
174- name , type_ , _ = outpt
175- spec_str += f" { name } : { type_ } \n "
176191
177192 for m in sorted (self .used .methods , key = attrgetter ("__name__" )):
178193 if m .__name__ not in self .included_methods :
0 commit comments