From 0a964bd91f0de164589dccf7b92adde46e5f9191 Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Tue, 17 Sep 2024 10:45:11 +0100 Subject: [PATCH 1/9] feat: :sparkles: Change the Graph creation to allow subclassing --- metaflow/graph.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 4727c1bbad6..3d5a4e45162 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -1,6 +1,7 @@ import inspect import ast import re +import textwrap from .util import to_pod @@ -156,18 +157,6 @@ def __str__(self): ) -class StepVisitor(ast.NodeVisitor): - def __init__(self, nodes, flow): - self.nodes = nodes - self.flow = flow - super(StepVisitor, self).__init__() - - def visit_FunctionDef(self, node): - func = getattr(self.flow, node.name) - if hasattr(func, "is_step"): - self.nodes[node.name] = DAGNode(node, func.decorators, func.__doc__) - - class FlowGraph(object): def __init__(self, flow): self.name = flow.__name__ @@ -179,13 +168,14 @@ def __init__(self, flow): self._postprocess() def _create_nodes(self, flow): - module = __import__(flow.__module__) - tree = ast.parse(inspect.getsource(module)).body - root = [n for n in tree if isinstance(n, ast.ClassDef) and n.name == self.name][ - 0 - ] nodes = {} - StepVisitor(nodes, flow).visit(root) + for element in dir(flow): + func = getattr(flow, element) + if hasattr(func, "is_step"): + source_code = textwrap.dedent(inspect.getsource(func)) + function_ast = ast.parse(source_code).body[0] + node = DAGNode(function_ast, func.decorators, func.__doc__) + nodes[element] = node return nodes def _postprocess(self): From f0e60d98039bc6954352c794a6eeb4fb65a32e8f Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Thu, 24 Oct 2024 13:25:06 +0100 Subject: [PATCH 2/9] fix: :bug: Fix the lineno for the function step --- metaflow/graph.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 3d5a4e45162..fde12f7e497 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -46,9 +46,12 @@ def deindent_docstring(doc): class DAGNode(object): - def __init__(self, func_ast, decos, doc): + def __init__(self, func_ast, decos, doc, source_file, lineno): self.name = func_ast.name - self.func_lineno = func_ast.lineno + self.source_file = source_file + # lineno is the start line of decorators in source_file + # func_ast.lineno is lines from decorators start to def of function + self.func_lineno = lineno + func_ast.lineno - 1 self.decorators = decos self.doc = deindent_docstring(doc) self.parallel_step = any(getattr(deco, "IS_PARALLEL", False) for deco in decos) @@ -63,7 +66,7 @@ def __init__(self, func_ast, decos, doc): self.foreach_param = None self.num_parallel = 0 self.parallel_foreach = False - self._parse(func_ast) + self._parse(func_ast, lineno) # these attributes are populated by _traverse_graph self.in_funcs = set() @@ -75,7 +78,7 @@ def __init__(self, func_ast, decos, doc): def _expr_str(self, expr): return "%s.%s" % (expr.value.id, expr.attr) - def _parse(self, func_ast): + def _parse(self, func_ast, lineno): self.num_args = len(func_ast.args.args) tail = func_ast.body[-1] @@ -95,7 +98,7 @@ def _parse(self, func_ast): self.has_tail_next = True self.invalid_tail_next = True - self.tail_next_lineno = tail.lineno + self.tail_next_lineno = lineno + tail.lineno - 1 self.out_funcs = [e.attr for e in tail.value.args] keywords = dict( @@ -172,9 +175,13 @@ def _create_nodes(self, flow): for element in dir(flow): func = getattr(flow, element) if hasattr(func, "is_step"): - source_code = textwrap.dedent(inspect.getsource(func)) + source_file = inspect.getsourcefile(func) + source_lines, lineno = inspect.getsourcelines(func) + source_code = textwrap.dedent("".join(source_lines)) function_ast = ast.parse(source_code).body[0] - node = DAGNode(function_ast, func.decorators, func.__doc__) + node = DAGNode( + function_ast, func.decorators, func.__doc__, source_file, lineno + ) nodes[element] = node return nodes From 95ef1d4954930f0570af5772627e1854553304b4 Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Thu, 24 Oct 2024 14:13:46 +0100 Subject: [PATCH 3/9] fix: :bug: Use sorted nodes in graph --- metaflow/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index 64800b189b7..51c13c2ceca 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -170,7 +170,7 @@ def check(obj, warnings=False): @click.pass_obj def show(obj): echo_always("\n%s" % obj.graph.doc) - for _, node in sorted((n.func_lineno, n) for n in obj.graph): + for node in obj.graph.sorted_nodes: echo_always("\nStep *%s*" % node.name, err=False) echo_always(node.doc if node.doc else "?", indent=True, err=False) if node.type != "end": From 9315e081cce27ac10d37bb7494523e8ac711de10 Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Thu, 24 Oct 2024 14:14:10 +0100 Subject: [PATCH 4/9] docs: :memo: Add source_file to the __str__ --- metaflow/graph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metaflow/graph.py b/metaflow/graph.py index fde12f7e497..34fee84e365 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -136,6 +136,7 @@ def _parse(self, func_ast, lineno): def __str__(self): return """*[{0.name} {0.type} (line {0.func_lineno})]* + source_file={0.source_file} in_funcs={in_funcs} out_funcs={out_funcs} split_parents={parents} From 49b690432ad4f5b80968be39f18df7e78f416fd1 Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Fri, 25 Oct 2024 16:25:28 +0100 Subject: [PATCH 5/9] feat: :sparkles: Show the file when printing exceptions anf fix node variable --- metaflow/cli.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index 51c13c2ceca..832eb3e4f4a 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -170,7 +170,8 @@ def check(obj, warnings=False): @click.pass_obj def show(obj): echo_always("\n%s" % obj.graph.doc) - for node in obj.graph.sorted_nodes: + for node_name in obj.graph.sorted_nodes: + node = obj.graph[node_name] echo_always("\nStep *%s*" % node.name, err=False) echo_always(node.doc if node.doc else "?", indent=True, err=False) if node.type != "end": @@ -1104,10 +1105,13 @@ def _check(graph, flow, environment, pylint=True, warnings=False, **kwargs): def print_metaflow_exception(ex): echo_always(ex.headline, indent=True, nl=False, bold=True) - if ex.line_no is None: - echo_always(":") - else: - echo_always(" on line %d:" % ex.line_no, bold=True) + location = "" + if ex.source_file is not None: + location += " in file %s" % ex.source_file + if ex.line_no is not None: + location += " on line %d" % ex.line_no + location += ":" + echo_always(location, bold=True) echo_always(ex.message, indent=True, bold=False, padding_bottom=True) From 82571bcbe28bd45ac77fce14674d401bcc3ef2cf Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Fri, 25 Oct 2024 16:25:49 +0100 Subject: [PATCH 6/9] feat: :sparkles: Add source_file to the exception base class --- metaflow/exception.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/metaflow/exception.py b/metaflow/exception.py index 4b85a2d1b39..f0f021a2a05 100644 --- a/metaflow/exception.py +++ b/metaflow/exception.py @@ -43,13 +43,19 @@ def __str__(self): class MetaflowException(Exception): headline = "Flow failed" - def __init__(self, msg="", lineno=None): + def __init__(self, msg="", lineno=None, source_file=None): self.message = msg self.line_no = lineno + self.source_file = source_file super(MetaflowException, self).__init__() def __str__(self): - prefix = "line %d: " % self.line_no if self.line_no else "" + prefix = "" + if self.source_file: + prefix = "%s:" % self.source_file + if self.line_no: + prefix = "line %d:" % self.line_no + prefix = "%s: " % prefix if prefix else "" return "%s%s" % (prefix, self.message) From c48ad2853e40230f60b49e8e036d694c2394c057 Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Fri, 25 Oct 2024 16:26:23 +0100 Subject: [PATCH 7/9] feat: :sparkles: Use sorted_nodes in graph __str__ --- metaflow/graph.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 34fee84e365..ed5f87bb52f 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -238,9 +238,7 @@ def __iter__(self): return iter(self.nodes.values()) def __str__(self): - return "\n".join( - str(n) for _, n in sorted((n.func_lineno, n) for n in self.nodes.values()) - ) + return "\n".join(str(self[n]) for n in self.sorted_nodes) def output_dot(self): def edge_specs(): @@ -284,6 +282,7 @@ def node_to_dict(name, node): "name": name, "type": node_to_type(node), "line": node.func_lineno, + "source_file": node.source_file, "doc": node.doc, "decorators": [ { From 3b1c1eb7c0bb33032d02c2fa6eb28361830de41a Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Fri, 25 Oct 2024 16:26:54 +0100 Subject: [PATCH 8/9] feat: :sparkles: Update Linting to include source_file in exceptions raised --- metaflow/lint.py | 55 ++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/metaflow/lint.py b/metaflow/lint.py index 9b96c271506..ecca1605405 100644 --- a/metaflow/lint.py +++ b/metaflow/lint.py @@ -52,7 +52,7 @@ def check_reserved_words(graph): msg = "Step name *%s* is a reserved word. Choose another name for the " "step." for node in graph: if node.name in RESERVED: - raise LintWarn(msg % node.name) + raise LintWarn(msg % node.name, node.func_lineno, node.source_file) @linter.ensure_fundamentals @@ -76,9 +76,9 @@ def check_that_end_is_end(graph): node = graph["end"] if node.has_tail_next or node.invalid_tail_next: - raise LintWarn(msg0, node.tail_next_lineno) + raise LintWarn(msg0, node.tail_next_lineno, node.source_file) if node.num_args > 1: - raise LintWarn(msg1, node.tail_next_lineno) + raise LintWarn(msg1, node.tail_next_lineno, node.source_file) @linter.ensure_fundamentals @@ -90,7 +90,7 @@ def check_step_names(graph): ) for node in graph: if re.search("[^a-z0-9_]", node.name) or node.name[0] == "_": - raise LintWarn(msg.format(node), node.func_lineno) + raise LintWarn(msg.format(node), node.func_lineno, node.source_file) @linter.ensure_fundamentals @@ -108,11 +108,11 @@ def check_num_args(graph): msg2 = "Step *{0.name}* is missing the 'self' argument." for node in graph: if node.num_args > 2: - raise LintWarn(msg0.format(node), node.func_lineno) + raise LintWarn(msg0.format(node), node.func_lineno, node.source_file) elif node.num_args == 2 and node.type != "join": - raise LintWarn(msg1.format(node), node.func_lineno) + raise LintWarn(msg1.format(node), node.func_lineno, node.source_file) elif node.num_args == 0: - raise LintWarn(msg2.format(node), node.func_lineno) + raise LintWarn(msg2.format(node), node.func_lineno, node.source_file) @linter.ensure_static_graph @@ -125,7 +125,7 @@ def check_static_transitions(graph): ) for node in graph: if node.type != "end" and not node.has_tail_next: - raise LintWarn(msg.format(node), node.func_lineno) + raise LintWarn(msg.format(node), node.func_lineno, node.source_file) @linter.ensure_static_graph @@ -138,7 +138,7 @@ def check_valid_transitions(graph): ) for node in graph: if node.type != "end" and node.has_tail_next and node.invalid_tail_next: - raise LintWarn(msg.format(node), node.tail_next_lineno) + raise LintWarn(msg.format(node), node.tail_next_lineno, node.source_file) @linter.ensure_static_graph @@ -151,7 +151,11 @@ def check_unknown_transitions(graph): for node in graph: unknown = [n for n in node.out_funcs if n not in graph] if unknown: - raise LintWarn(msg.format(node, step=unknown[0]), node.tail_next_lineno) + raise LintWarn( + msg.format(node, step=unknown[0]), + node.tail_next_lineno, + node.source_file, + ) @linter.ensure_acyclicity @@ -167,7 +171,9 @@ def check_path(node, seen): for n in node.out_funcs: if n in seen: path = "->".join(seen + [n]) - raise LintWarn(msg.format(path), node.tail_next_lineno) + raise LintWarn( + msg.format(path), node.tail_next_lineno, node.source_file + ) else: check_path(graph[n], seen + [n]) @@ -195,7 +201,7 @@ def traverse(node): orphans = nodeset - seen if orphans: orphan = graph[list(orphans)[0]] - raise LintWarn(msg.format(orphan), orphan.func_lineno) + raise LintWarn(msg.format(orphan), orphan.func_lineno, orphan.source_file) @linter.ensure_static_graph @@ -230,7 +236,9 @@ def traverse(node, split_stack): if split_stack: _, split_roots = split_stack.pop() roots = ", ".join(split_roots) - raise LintWarn(msg0.format(roots=roots)) + raise LintWarn( + msg0.format(roots=roots), node.func_lineno, node.source_file + ) elif node.type == "join": if split_stack: _, split_roots = split_stack[-1] @@ -243,9 +251,10 @@ def traverse(node, split_stack): node, paths=paths, num_roots=len(split_roots), roots=roots ), node.func_lineno, + node.source_file, ) else: - raise LintWarn(msg2.format(node), node.func_lineno) + raise LintWarn(msg2.format(node), node.func_lineno, node.source_file) # check that incoming steps come from the same lineage # (no cross joins) @@ -256,7 +265,7 @@ def parents(n): return tuple(graph[n].split_parents) if not all_equal(map(parents, node.in_funcs)): - raise LintWarn(msg3.format(node), node.func_lineno) + raise LintWarn(msg3.format(node), node.func_lineno, node.source_file) for n in node.out_funcs: traverse(graph[n], new_stack) @@ -276,7 +285,9 @@ def check_empty_foreaches(graph): if node.type == "foreach": joins = [n for n in node.out_funcs if graph[n].type == "join"] if joins: - raise LintWarn(msg.format(node, join=joins[0])) + raise LintWarn( + msg.format(node, join=joins[0]), node.func_lineno, node.source_file + ) @linter.ensure_static_graph @@ -290,7 +301,7 @@ def check_parallel_step_after_next(graph): if node.parallel_foreach and not all( graph[out_node].parallel_step for out_node in node.out_funcs ): - raise LintWarn(msg.format(node)) + raise LintWarn(msg.format(node), node.func_lineno, node.source_file) @linter.ensure_static_graph @@ -303,7 +314,9 @@ def check_join_followed_by_parallel_step(graph): ) for node in graph: if node.parallel_step and not graph[node.out_funcs[0]].type == "join": - raise LintWarn(msg.format(node.out_funcs[0])) + raise LintWarn( + msg.format(node.out_funcs[0]), node.func_lineno, node.source_file + ) @linter.ensure_static_graph @@ -318,7 +331,9 @@ def check_parallel_foreach_calls_parallel_step(graph): for node2 in graph: if node2.out_funcs and node.name in node2.out_funcs: if not node2.parallel_foreach: - raise LintWarn(msg.format(node, node2)) + raise LintWarn( + msg.format(node, node2), node.func_lineno, node.source_file + ) @linter.ensure_non_nested_foreach @@ -331,4 +346,4 @@ def check_nested_foreach(graph): for node in graph: if node.type == "foreach": if any(graph[p].type == "foreach" for p in node.split_parents): - raise LintWarn(msg.format(node)) + raise LintWarn(msg.format(node), node.func_lineno, node.source_file) From bb83d0ca0a41341e1ee10c92c8cb0975370f25c5 Mon Sep 17 00:00:00 2001 From: Romain Date: Tue, 17 Dec 2024 00:33:00 -0800 Subject: [PATCH 9/9] Update metaflow/graph.py Make compatible with configs. --- metaflow/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index ed5f87bb52f..16426006960 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -175,7 +175,7 @@ def _create_nodes(self, flow): nodes = {} for element in dir(flow): func = getattr(flow, element) - if hasattr(func, "is_step"): + if callable(func) and hasattr(func, "is_step"): source_file = inspect.getsourcefile(func) source_lines, lineno = inspect.getsourcelines(func) source_code = textwrap.dedent("".join(source_lines))