Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ Change the Graph creation to allow subclassing #2086

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions metaflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def check(obj, warnings=False):
@click.pass_obj
def show(obj):
echo_always("\n%s" % obj.graph.doc)
for _, node in sorted((n.func_lineno, n) for n in obj.graph):
for node_name in obj.graph.sorted_nodes:
node = obj.graph[node_name]
echo_always("\nStep *%s*" % node.name, err=False)
echo_always(node.doc if node.doc else "?", indent=True, err=False)
if node.type != "end":
Expand Down Expand Up @@ -1132,10 +1133,13 @@ def _check(graph, flow, environment, pylint=True, warnings=False, **kwargs):

def print_metaflow_exception(ex):
echo_always(ex.headline, indent=True, nl=False, bold=True)
if ex.line_no is None:
echo_always(":")
else:
echo_always(" on line %d:" % ex.line_no, bold=True)
location = ""
if ex.source_file is not None:
location += " in file %s" % ex.source_file
if ex.line_no is not None:
location += " on line %d" % ex.line_no
location += ":"
echo_always(location, bold=True)
echo_always(ex.message, indent=True, bold=False, padding_bottom=True)


Expand Down
10 changes: 8 additions & 2 deletions metaflow/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,19 @@ def __str__(self):
class MetaflowException(Exception):
headline = "Flow failed"

def __init__(self, msg="", lineno=None):
def __init__(self, msg="", lineno=None, source_file=None):
self.message = msg
self.line_no = lineno
self.source_file = source_file
super(MetaflowException, self).__init__()

def __str__(self):
prefix = "line %d: " % self.line_no if self.line_no else ""
prefix = ""
if self.source_file:
prefix = "%s:" % self.source_file
if self.line_no:
prefix = "line %d:" % self.line_no
prefix = "%s: " % prefix if prefix else ""
return "%s%s" % (prefix, self.message)


Expand Down
49 changes: 23 additions & 26 deletions metaflow/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import ast
import re
import textwrap


from .util import to_pod
Expand Down Expand Up @@ -45,9 +46,12 @@ def deindent_docstring(doc):


class DAGNode(object):
def __init__(self, func_ast, decos, doc):
def __init__(self, func_ast, decos, doc, source_file, lineno):
self.name = func_ast.name
self.func_lineno = func_ast.lineno
self.source_file = source_file
# lineno is the start line of decorators in source_file
# func_ast.lineno is lines from decorators start to def of function
self.func_lineno = lineno + func_ast.lineno - 1
self.decorators = decos
self.doc = deindent_docstring(doc)
self.parallel_step = any(getattr(deco, "IS_PARALLEL", False) for deco in decos)
Expand All @@ -62,7 +66,7 @@ def __init__(self, func_ast, decos, doc):
self.foreach_param = None
self.num_parallel = 0
self.parallel_foreach = False
self._parse(func_ast)
self._parse(func_ast, lineno)

# these attributes are populated by _traverse_graph
self.in_funcs = set()
Expand All @@ -74,7 +78,7 @@ def __init__(self, func_ast, decos, doc):
def _expr_str(self, expr):
return "%s.%s" % (expr.value.id, expr.attr)

def _parse(self, func_ast):
def _parse(self, func_ast, lineno):
self.num_args = len(func_ast.args.args)
tail = func_ast.body[-1]

Expand All @@ -94,7 +98,7 @@ def _parse(self, func_ast):

self.has_tail_next = True
self.invalid_tail_next = True
self.tail_next_lineno = tail.lineno
self.tail_next_lineno = lineno + tail.lineno - 1
self.out_funcs = [e.attr for e in tail.value.args]

keywords = dict(
Expand Down Expand Up @@ -132,6 +136,7 @@ def _parse(self, func_ast):

def __str__(self):
return """*[{0.name} {0.type} (line {0.func_lineno})]*
source_file={0.source_file}
in_funcs={in_funcs}
out_funcs={out_funcs}
split_parents={parents}
Expand All @@ -156,18 +161,6 @@ def __str__(self):
)


class StepVisitor(ast.NodeVisitor):
def __init__(self, nodes, flow):
self.nodes = nodes
self.flow = flow
super(StepVisitor, self).__init__()

def visit_FunctionDef(self, node):
func = getattr(self.flow, node.name)
if hasattr(func, "is_step"):
self.nodes[node.name] = DAGNode(node, func.decorators, func.__doc__)


class FlowGraph(object):
def __init__(self, flow):
self.name = flow.__name__
Expand All @@ -179,13 +172,18 @@ def __init__(self, flow):
self._postprocess()

def _create_nodes(self, flow):
module = __import__(flow.__module__)
tree = ast.parse(inspect.getsource(module)).body
root = [n for n in tree if isinstance(n, ast.ClassDef) and n.name == self.name][
0
]
nodes = {}
StepVisitor(nodes, flow).visit(root)
for element in dir(flow):
func = getattr(flow, element)
if callable(func) and hasattr(func, "is_step"):
source_file = inspect.getsourcefile(func)
source_lines, lineno = inspect.getsourcelines(func)
source_code = textwrap.dedent("".join(source_lines))
function_ast = ast.parse(source_code).body[0]
node = DAGNode(
function_ast, func.decorators, func.__doc__, source_file, lineno
)
nodes[element] = node
return nodes

def _postprocess(self):
Expand Down Expand Up @@ -240,9 +238,7 @@ def __iter__(self):
return iter(self.nodes.values())

def __str__(self):
return "\n".join(
str(n) for _, n in sorted((n.func_lineno, n) for n in self.nodes.values())
)
return "\n".join(str(self[n]) for n in self.sorted_nodes)

def output_dot(self):
def edge_specs():
Expand Down Expand Up @@ -286,6 +282,7 @@ def node_to_dict(name, node):
"name": name,
"type": node_to_type(node),
"line": node.func_lineno,
"source_file": node.source_file,
"doc": node.doc,
"decorators": [
{
Expand Down
55 changes: 35 additions & 20 deletions metaflow/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def check_reserved_words(graph):
msg = "Step name *%s* is a reserved word. Choose another name for the " "step."
for node in graph:
if node.name in RESERVED:
raise LintWarn(msg % node.name)
raise LintWarn(msg % node.name, node.func_lineno, node.source_file)


@linter.ensure_fundamentals
Expand All @@ -76,9 +76,9 @@ def check_that_end_is_end(graph):
node = graph["end"]

if node.has_tail_next or node.invalid_tail_next:
raise LintWarn(msg0, node.tail_next_lineno)
raise LintWarn(msg0, node.tail_next_lineno, node.source_file)
if node.num_args > 1:
raise LintWarn(msg1, node.tail_next_lineno)
raise LintWarn(msg1, node.tail_next_lineno, node.source_file)


@linter.ensure_fundamentals
Expand All @@ -90,7 +90,7 @@ def check_step_names(graph):
)
for node in graph:
if re.search("[^a-z0-9_]", node.name) or node.name[0] == "_":
raise LintWarn(msg.format(node), node.func_lineno)
raise LintWarn(msg.format(node), node.func_lineno, node.source_file)


@linter.ensure_fundamentals
Expand All @@ -108,11 +108,11 @@ def check_num_args(graph):
msg2 = "Step *{0.name}* is missing the 'self' argument."
for node in graph:
if node.num_args > 2:
raise LintWarn(msg0.format(node), node.func_lineno)
raise LintWarn(msg0.format(node), node.func_lineno, node.source_file)
elif node.num_args == 2 and node.type != "join":
raise LintWarn(msg1.format(node), node.func_lineno)
raise LintWarn(msg1.format(node), node.func_lineno, node.source_file)
elif node.num_args == 0:
raise LintWarn(msg2.format(node), node.func_lineno)
raise LintWarn(msg2.format(node), node.func_lineno, node.source_file)


@linter.ensure_static_graph
Expand All @@ -125,7 +125,7 @@ def check_static_transitions(graph):
)
for node in graph:
if node.type != "end" and not node.has_tail_next:
raise LintWarn(msg.format(node), node.func_lineno)
raise LintWarn(msg.format(node), node.func_lineno, node.source_file)


@linter.ensure_static_graph
Expand All @@ -138,7 +138,7 @@ def check_valid_transitions(graph):
)
for node in graph:
if node.type != "end" and node.has_tail_next and node.invalid_tail_next:
raise LintWarn(msg.format(node), node.tail_next_lineno)
raise LintWarn(msg.format(node), node.tail_next_lineno, node.source_file)


@linter.ensure_static_graph
Expand All @@ -151,7 +151,11 @@ def check_unknown_transitions(graph):
for node in graph:
unknown = [n for n in node.out_funcs if n not in graph]
if unknown:
raise LintWarn(msg.format(node, step=unknown[0]), node.tail_next_lineno)
raise LintWarn(
msg.format(node, step=unknown[0]),
node.tail_next_lineno,
node.source_file,
)


@linter.ensure_acyclicity
Expand All @@ -167,7 +171,9 @@ def check_path(node, seen):
for n in node.out_funcs:
if n in seen:
path = "->".join(seen + [n])
raise LintWarn(msg.format(path), node.tail_next_lineno)
raise LintWarn(
msg.format(path), node.tail_next_lineno, node.source_file
)
else:
check_path(graph[n], seen + [n])

Expand Down Expand Up @@ -195,7 +201,7 @@ def traverse(node):
orphans = nodeset - seen
if orphans:
orphan = graph[list(orphans)[0]]
raise LintWarn(msg.format(orphan), orphan.func_lineno)
raise LintWarn(msg.format(orphan), orphan.func_lineno, orphan.source_file)


@linter.ensure_static_graph
Expand Down Expand Up @@ -230,7 +236,9 @@ def traverse(node, split_stack):
if split_stack:
_, split_roots = split_stack.pop()
roots = ", ".join(split_roots)
raise LintWarn(msg0.format(roots=roots))
raise LintWarn(
msg0.format(roots=roots), node.func_lineno, node.source_file
)
elif node.type == "join":
if split_stack:
_, split_roots = split_stack[-1]
Expand All @@ -243,9 +251,10 @@ def traverse(node, split_stack):
node, paths=paths, num_roots=len(split_roots), roots=roots
),
node.func_lineno,
node.source_file,
)
else:
raise LintWarn(msg2.format(node), node.func_lineno)
raise LintWarn(msg2.format(node), node.func_lineno, node.source_file)

# check that incoming steps come from the same lineage
# (no cross joins)
Expand All @@ -256,7 +265,7 @@ def parents(n):
return tuple(graph[n].split_parents)

if not all_equal(map(parents, node.in_funcs)):
raise LintWarn(msg3.format(node), node.func_lineno)
raise LintWarn(msg3.format(node), node.func_lineno, node.source_file)

for n in node.out_funcs:
traverse(graph[n], new_stack)
Expand All @@ -276,7 +285,9 @@ def check_empty_foreaches(graph):
if node.type == "foreach":
joins = [n for n in node.out_funcs if graph[n].type == "join"]
if joins:
raise LintWarn(msg.format(node, join=joins[0]))
raise LintWarn(
msg.format(node, join=joins[0]), node.func_lineno, node.source_file
)


@linter.ensure_static_graph
Expand All @@ -290,7 +301,7 @@ def check_parallel_step_after_next(graph):
if node.parallel_foreach and not all(
graph[out_node].parallel_step for out_node in node.out_funcs
):
raise LintWarn(msg.format(node))
raise LintWarn(msg.format(node), node.func_lineno, node.source_file)


@linter.ensure_static_graph
Expand All @@ -303,7 +314,9 @@ def check_join_followed_by_parallel_step(graph):
)
for node in graph:
if node.parallel_step and not graph[node.out_funcs[0]].type == "join":
raise LintWarn(msg.format(node.out_funcs[0]))
raise LintWarn(
msg.format(node.out_funcs[0]), node.func_lineno, node.source_file
)


@linter.ensure_static_graph
Expand All @@ -318,7 +331,9 @@ def check_parallel_foreach_calls_parallel_step(graph):
for node2 in graph:
if node2.out_funcs and node.name in node2.out_funcs:
if not node2.parallel_foreach:
raise LintWarn(msg.format(node, node2))
raise LintWarn(
msg.format(node, node2), node.func_lineno, node.source_file
)


@linter.ensure_non_nested_foreach
Expand All @@ -331,4 +346,4 @@ def check_nested_foreach(graph):
for node in graph:
if node.type == "foreach":
if any(graph[p].type == "foreach" for p in node.split_parents):
raise LintWarn(msg.format(node))
raise LintWarn(msg.format(node), node.func_lineno, node.source_file)
Loading