11"""Deals with nodes which are dependencies or products of a task."""
22import functools
33import inspect
4+ import itertools
45import pathlib
56from abc import ABCMeta
67from abc import abstractmethod
1314from _pytask .exceptions import NodeNotCollectedError
1415from _pytask .exceptions import NodeNotFoundError
1516from _pytask .mark import get_marks_from_obj
16- from _pytask .shared import to_list
17+ from _pytask .shared import find_duplicates
1718
1819
1920def depends_on (objects : Union [Any , Iterable [Any ]]) -> Union [Any , Iterable [Any ]]:
@@ -68,22 +69,24 @@ class PythonFunctionTask(MetaTask):
6869 """pathlib.Path: Path to the file where the task was defined."""
6970 function = attr .ib (type = callable )
7071 """callable: The task function."""
71- depends_on = attr .ib (converter = to_list )
72+ depends_on = attr .ib (factory = dict )
7273 """Optional[List[MetaNode]]: A list of dependencies of task."""
73- produces = attr .ib (converter = to_list )
74+ produces = attr .ib (factory = dict )
7475 """List[MetaNode]: A list of products of task."""
75- markers = attr .ib ()
76+ markers = attr .ib (factory = list )
7677 """Optional[List[Mark]]: A list of markers attached to the task function."""
7778 _report_sections = attr .ib (factory = list )
7879
7980 @classmethod
8081 def from_path_name_function_session (cls , path , name , function , session ):
8182 """Create a task from a path, name, function, and session."""
8283 objects = _extract_nodes_from_function_markers (function , depends_on )
83- dependencies = _collect_nodes (session , path , name , objects )
84+ nodes = _convert_objects_to_node_dictionary (objects , "depends_on" )
85+ dependencies = _collect_nodes (session , path , name , nodes )
8486
8587 objects = _extract_nodes_from_function_markers (function , produces )
86- products = _collect_nodes (session , path , name , objects )
88+ nodes = _convert_objects_to_node_dictionary (objects , "produces" )
89+ products = _collect_nodes (session , path , name , nodes )
8790
8891 markers = [
8992 marker
@@ -118,8 +121,10 @@ def _get_kwargs_from_task_for_function(self):
118121 attribute = getattr (self , name )
119122 kwargs [name ] = (
120123 attribute [0 ].value
121- if len (attribute ) == 1
122- else [node .value for node in attribute ]
124+ if len (attribute ) == 1 and 0 in attribute
125+ else {
126+ node_name : node .value for node_name , node in attribute .items ()
127+ }
123128 )
124129
125130 return kwargs
@@ -169,8 +174,9 @@ def state(self):
169174
170175def _collect_nodes (session , path , name , nodes ):
171176 """Collect nodes for a task."""
172- collect_nodes = []
173- for node in nodes :
177+ collected_nodes = {}
178+
179+ for node_name , node in nodes .items ():
174180 collected_node = session .hook .pytask_collect_node (
175181 session = session , path = path , node = node
176182 )
@@ -180,9 +186,9 @@ def _collect_nodes(session, path, name, nodes):
180186 f"'{ name } ' in '{ path } '."
181187 )
182188 else :
183- collect_nodes . append ( collected_node )
189+ collected_nodes [ node_name ] = collected_node
184190
185- return collect_nodes
191+ return collected_nodes
186192
187193
188194def _extract_nodes_from_function_markers (function , parser ):
@@ -195,4 +201,82 @@ def _extract_nodes_from_function_markers(function, parser):
195201 """
196202 marker_name = parser .__name__
197203 for marker in get_marks_from_obj (function , marker_name ):
198- yield from to_list (parser (* marker .args , ** marker .kwargs ))
204+ parsed = parser (* marker .args , ** marker .kwargs )
205+ yield parsed
206+
207+
208+ def _convert_objects_to_node_dictionary (objects , when ):
209+ list_of_tuples = _convert_objects_to_list_of_tuples (objects )
210+ _check_that_names_are_not_used_multiple_times (list_of_tuples , when )
211+ nodes = _convert_nodes_to_dictionary (list_of_tuples )
212+ return nodes
213+
214+
215+ def _convert_objects_to_list_of_tuples (objects ):
216+ out = []
217+ for obj in objects :
218+ if isinstance (obj , dict ):
219+ obj = obj .items ()
220+
221+ if isinstance (obj , Iterable ) and not isinstance (obj , str ):
222+ for x in obj :
223+ if isinstance (x , Iterable ) and not isinstance (x , str ):
224+ tuple_x = tuple (x )
225+ if len (tuple_x ) in [1 , 2 ]:
226+ out .append (tuple_x )
227+ else :
228+ raise ValueError ("ERROR" )
229+ else :
230+ out .append ((x ,))
231+ else :
232+ out .append ((obj ,))
233+
234+ return out
235+
236+
237+ def _check_that_names_are_not_used_multiple_times (list_of_tuples , when ):
238+ """Check that names of nodes are not assigned multiple times.
239+
240+ Tuples in the list have either one or two elements. The first element in the two
241+ element tuples is the name and cannot occur twice.
242+
243+ Examples
244+ --------
245+ >>> _check_that_names_are_not_used_multiple_times(
246+ ... [("a",), ("a", 1)], "depends_on"
247+ ... )
248+ >>> _check_that_names_are_not_used_multiple_times(
249+ ... [("a", 0), ("a", 1)], "produces"
250+ ... )
251+ Traceback (most recent call last):
252+ ValueError: '@pytask.mark.produces' has nodes with the same name: {'a'}
253+
254+ """
255+ names = [x [0 ] for x in list_of_tuples if len (x ) == 2 ]
256+ duplicated = find_duplicates (names )
257+
258+ if duplicated :
259+ raise ValueError (
260+ f"'@pytask.mark.{ when } ' has nodes with the same name: { duplicated } "
261+ )
262+
263+
264+ def _convert_nodes_to_dictionary (list_of_tuples ):
265+ nodes = {}
266+ counter = itertools .count ()
267+ names = [x [0 ] for x in list_of_tuples if len (x ) == 2 ]
268+
269+ for tuple_ in list_of_tuples :
270+ if len (tuple_ ) == 2 :
271+ node_name , node = tuple_
272+ nodes [node_name ] = node
273+
274+ else :
275+ while True :
276+ node_name = next (counter )
277+ if node_name not in names :
278+ break
279+
280+ nodes [node_name ] = tuple_ [0 ]
281+
282+ return nodes
0 commit comments