|
19 | 19 | from _pytask.console import render_to_string |
20 | 20 | from _pytask.exceptions import ResolvingDependenciesError |
21 | 21 | from _pytask.mark import select_by_after_keyword |
| 22 | +from _pytask.mark import select_tasks_by_marks_and_expressions |
22 | 23 | from _pytask.node_protocols import PNode |
23 | 24 | from _pytask.node_protocols import PTask |
24 | 25 | from _pytask.nodes import PythonNode |
25 | | -from _pytask.pluginmanager import hookimpl |
26 | 26 | from _pytask.reports import DagReport |
27 | 27 | from _pytask.shared import reduce_names_of_multiple_nodes |
28 | 28 | from _pytask.tree_util import tree_map |
|
33 | 33 | from _pytask.session import Session |
34 | 34 |
|
35 | 35 |
|
36 | | -@hookimpl |
37 | | -def pytask_dag(session: Session) -> bool | None: |
| 36 | +__all__ = ["create_dag"] |
| 37 | + |
| 38 | + |
| 39 | +def create_dag(session: Session) -> nx.DiGraph: |
38 | 40 | """Create a directed acyclic graph (DAG) for the workflow.""" |
39 | 41 | try: |
40 | | - session.dag = session.hook.pytask_dag_create_dag( |
41 | | - session=session, tasks=session.tasks |
42 | | - ) |
43 | | - session.hook.pytask_dag_modify_dag(session=session, dag=session.dag) |
| 42 | + dag = _create_dag(tasks=session.tasks) |
| 43 | + _check_if_dag_has_cycles(dag) |
| 44 | + _check_if_tasks_have_the_same_products(dag, session.config["paths"]) |
| 45 | + _modify_dag(session=session, dag=dag) |
| 46 | + select_tasks_by_marks_and_expressions(session=session, dag=dag) |
44 | 47 |
|
45 | 48 | except Exception: # noqa: BLE001 |
46 | 49 | report = DagReport.from_exception(sys.exc_info()) |
47 | | - session.hook.pytask_dag_log(session=session, report=report) |
| 50 | + _log_dag(report=report) |
48 | 51 | session.dag_report = report |
49 | 52 |
|
50 | 53 | raise ResolvingDependenciesError from None |
51 | | - |
52 | | - else: |
53 | | - return True |
| 54 | + return dag |
54 | 55 |
|
55 | 56 |
|
56 | | -@hookimpl |
57 | | -def pytask_dag_create_dag(session: Session, tasks: list[PTask]) -> nx.DiGraph: |
| 57 | +def _create_dag(tasks: list[PTask]) -> nx.DiGraph: |
58 | 58 | """Create the DAG from tasks, dependencies and products.""" |
59 | 59 |
|
60 | 60 | def _add_dependency(dag: nx.DiGraph, task: PTask, node: PNode) -> None: |
@@ -90,15 +90,10 @@ def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None: |
90 | 90 | else None, |
91 | 91 | task.depends_on, |
92 | 92 | ) |
93 | | - |
94 | | - _check_if_dag_has_cycles(dag) |
95 | | - _check_if_tasks_have_the_same_products(dag, session.config["paths"]) |
96 | | - |
97 | 93 | return dag |
98 | 94 |
|
99 | 95 |
|
100 | | -@hookimpl |
101 | | -def pytask_dag_modify_dag(session: Session, dag: nx.DiGraph) -> None: |
| 96 | +def _modify_dag(session: Session, dag: nx.DiGraph) -> None: |
102 | 97 | """Create dependencies between tasks when using ``@task(after=...)``.""" |
103 | 98 | temporary_id_to_task = { |
104 | 99 | task.attributes["collection_id"]: task |
@@ -194,8 +189,7 @@ def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) - |
194 | 189 | raise ResolvingDependenciesError(msg) |
195 | 190 |
|
196 | 191 |
|
197 | | -@hookimpl |
198 | | -def pytask_dag_log(report: DagReport) -> None: |
| 192 | +def _log_dag(report: DagReport) -> None: |
199 | 193 | """Log errors which happened while resolving dependencies.""" |
200 | 194 | console.print() |
201 | 195 | console.rule( |
|
0 commit comments