diff --git a/docs/changes.rst b/docs/changes.rst index 936cba33..d17bcad0 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -21,6 +21,9 @@ all releases are available on `Anaconda.org - :gh:`39` releases v0.0.9. - :gh:`40` cleans up the capture manager and other parts of pytask. - :gh:`41` shortens the task ids in the error reports for better readability. +- :gh:`42` ensures that lists with one element and dictionaries with only a zero key as + input for ``@pytask.mark.depends_on`` and ``@pytask.mark.produces`` are preserved as a + dictionary inside the function. 0.0.8 - 2020-10-04 diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 3c910462..fbb075cd 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -7,8 +7,10 @@ from abc import abstractmethod from pathlib import Path from typing import Any +from typing import Dict from typing import Iterable from typing import List +from typing import Tuple from typing import Union import attr @@ -82,17 +84,22 @@ class PythonFunctionTask(MetaTask): """List[MetaNode]: A list of products of task.""" markers = attr.ib(factory=list) """Optional[List[Mark]]: A list of markers attached to the task function.""" + keep_dict = attr.ib(factory=dict) _report_sections = attr.ib(factory=list) @classmethod def from_path_name_function_session(cls, path, name, function, session): """Create a task from a path, name, function, and session.""" + keep_dictionary = {} + objects = _extract_nodes_from_function_markers(function, depends_on) - nodes = _convert_objects_to_node_dictionary(objects, "depends_on") + nodes, keep_dict = _convert_objects_to_node_dictionary(objects, "depends_on") + keep_dictionary["depends_on"] = keep_dict dependencies = _collect_nodes(session, path, name, nodes) objects = _extract_nodes_from_function_markers(function, produces) - nodes = _convert_objects_to_node_dictionary(objects, "produces") + nodes, keep_dict = _convert_objects_to_node_dictionary(objects, "produces") + keep_dictionary["produces"] = keep_dict products = _collect_nodes(session, path, name, nodes) markers = [ @@ -109,6 +116,7 @@ def from_path_name_function_session(cls, path, name, function, session): depends_on=dependencies, produces=products, markers=markers, + keep_dict=keep_dictionary, ) def execute(self): @@ -124,15 +132,15 @@ def _get_kwargs_from_task_for_function(self): """Process dependencies and products to pass them as kwargs to the function.""" func_arg_names = set(inspect.signature(self.function).parameters) kwargs = {} - for name in ["depends_on", "produces"]: - if name in func_arg_names: - attribute = getattr(self, name) - kwargs[name] = ( + for arg_name in ["depends_on", "produces"]: + if arg_name in func_arg_names: + attribute = getattr(self, arg_name) + kwargs[arg_name] = ( attribute[0].value - if len(attribute) == 1 and 0 in attribute - else { - node_name: node.value for node_name, node in attribute.items() - } + if len(attribute) == 1 + and 0 in attribute + and not self.keep_dict[arg_name] + else {name: node.value for name, node in attribute.items()} ) return kwargs @@ -208,32 +216,49 @@ def _extract_nodes_from_function_markers(function, parser): def _convert_objects_to_node_dictionary(objects, when): - list_of_tuples = _convert_objects_to_list_of_tuples(objects) + """Convert objects to node dictionary.""" + list_of_tuples, keep_dict = _convert_objects_to_list_of_tuples(objects) _check_that_names_are_not_used_multiple_times(list_of_tuples, when) nodes = _convert_nodes_to_dictionary(list_of_tuples) - return nodes + return nodes, keep_dict def _convert_objects_to_list_of_tuples(objects): + """Convert objects to list of tuples. + + Examples + -------- + _convert_objects_to_list_of_tuples([{0: 0}, [4, (3, 2)], ((1, 4),)) + [(0, 0), (4,), (3, 2), (1, 4)], False + + """ + keep_dict = False + out = [] for obj in objects: if isinstance(obj, dict): obj = obj.items() if isinstance(obj, Iterable) and not isinstance(obj, str): + keep_dict = True for x in obj: if isinstance(x, Iterable) and not isinstance(x, str): tuple_x = tuple(x) if len(tuple_x) in [1, 2]: out.append(tuple_x) else: - raise ValueError("ERROR") + raise ValueError( + f"Element {x} can only have two elements at most." + ) else: out.append((x,)) else: out.append((obj,)) - return out + if len(out) > 1: + keep_dict = False + + return out, keep_dict def _check_that_names_are_not_used_multiple_times(list_of_tuples, when): @@ -263,7 +288,19 @@ def _check_that_names_are_not_used_multiple_times(list_of_tuples, when): ) -def _convert_nodes_to_dictionary(list_of_tuples): +def _convert_nodes_to_dictionary( + list_of_tuples: List[Tuple[str]], +) -> Dict[str, Union[str, Path]]: + """Convert nodes to dictionaries. + + Examples + -------- + >>> _convert_nodes_to_dictionary([(0,), (1,)]) + {0: 0, 1: 1} + >>> _convert_nodes_to_dictionary([(1, 0), (1,)]) + {1: 0, 0: 1} + + """ nodes = {} counter = itertools.count() names = [x[0] for x in list_of_tuples if len(x) == 2] diff --git a/tests/test_execute.py b/tests/test_execute.py index cbc53661..d90b337c 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -169,3 +169,32 @@ def task_dummy(depends_on, produces): result = runner.invoke(cli, [tmp_path.as_posix()]) assert result.exit_code == 0 + + +@pytest.mark.parametrize("input_type", ["list", "dict"]) +def test_preserve_input_for_dependencies_and_products(tmp_path, input_type): + """Input type for dependencies and products is preserved.""" + path = tmp_path.joinpath("in.txt") + input_ = {0: path.as_posix()} if input_type == "dict" else [path.as_posix()] + path.touch() + + path = tmp_path.joinpath("out.txt") + output = {0: path.as_posix()} if input_type == "dict" else [path.as_posix()] + + source = f""" + import pytask + from pathlib import Path + + @pytask.mark.depends_on({input_}) + @pytask.mark.produces({output}) + def task_dummy(depends_on, produces): + for nodes in [depends_on, produces]: + assert isinstance(nodes, dict) + assert len(nodes) == 1 + assert 0 in nodes + produces[0].touch() + """ + tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) + + session = main({"paths": tmp_path}) + assert session.exit_code == 0 diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 39d5373d..99abc857 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -7,6 +7,7 @@ from _pytask.nodes import _check_that_names_are_not_used_multiple_times from _pytask.nodes import _convert_nodes_to_dictionary from _pytask.nodes import _convert_objects_to_list_of_tuples +from _pytask.nodes import _convert_objects_to_node_dictionary from _pytask.nodes import _create_task_name from _pytask.nodes import _extract_nodes_from_function_markers from _pytask.nodes import _find_closest_ancestor @@ -108,21 +109,25 @@ def state(self): @pytest.mark.unit @pytest.mark.parametrize( - ("x", "expected"), + ("x", "expected_lot", "expected_kd"), [ - (["string"], [("string",)]), - (("string",), [("string",)]), - (range(2), [(0,), (1,)]), - ([{"a": 0, "b": 1}], [("a", 0), ("b", 1)]), + (["string"], [("string",)], False), + (("string",), [("string",)], False), + (range(2), [(0,), (1,)], False), + ([{"a": 0, "b": 1}], [("a", 0), ("b", 1)], False), ( ["a", ("b", "c"), {"d": 1, "e": 1}], [("a",), ("b",), ("c",), ("d", 1), ("e", 1)], + False, ), + ([["string"]], [("string",)], True), + ([{0: "string"}], [(0, "string")], True), ], ) -def test_convert_objects_to_list_of_tuples(x, expected): - result = _convert_objects_to_list_of_tuples(x) - assert result == expected +def test_convert_objects_to_list_of_tuples(x, expected_lot, expected_kd): + list_of_tuples, keep_dict = _convert_objects_to_list_of_tuples(x) + assert list_of_tuples == expected_lot + assert keep_dict is expected_kd ERROR = "'@pytask.mark.depends_on' has nodes with the same name:" @@ -253,3 +258,30 @@ def test_shorten_node_name(node, paths, expectation, expected): with expectation: result = shorten_node_name(node, paths) assert result == expected + + +@pytest.mark.integration +@pytest.mark.parametrize("when", ["depends_on", "produces"]) +@pytest.mark.parametrize( + "objects, expectation, expected_dict, expected_kd", + [ + ([0, 1], does_not_raise, {0: 0, 1: 1}, False), + ([{0: 0}, {1: 1}], does_not_raise, {0: 0, 1: 1}, False), + ([{0: 0}], does_not_raise, {0: 0}, True), + ([[0]], does_not_raise, {0: 0}, True), + ([((0, 0),), ((0, 1),)], ValueError, None, None), + ([{0: 0}, {0: 1}], ValueError, None, None), + ], +) +def test_convert_objects_to_node_dictionary( + objects, when, expectation, expected_dict, expected_kd +): + expectation = ( + pytest.raises(expectation, match=f"'@pytask.mark.{when}' has nodes") + if expectation == ValueError + else expectation() + ) + with expectation: + node_dict, keep_dict = _convert_objects_to_node_dictionary(objects, when) + assert node_dict == expected_dict + assert keep_dict is expected_kd