Skip to content

Commit

Permalink
feat: ✨ Change the Graph creation to allow subclassing
Browse files Browse the repository at this point in the history
  • Loading branch information
simonwardjones committed Oct 8, 2024
1 parent 294671a commit 0a964bd
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions metaflow/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import ast
import re
import textwrap


from .util import to_pod
Expand Down Expand Up @@ -156,18 +157,6 @@ def __str__(self):
)


class StepVisitor(ast.NodeVisitor):
def __init__(self, nodes, flow):
self.nodes = nodes
self.flow = flow
super(StepVisitor, self).__init__()

def visit_FunctionDef(self, node):
func = getattr(self.flow, node.name)
if hasattr(func, "is_step"):
self.nodes[node.name] = DAGNode(node, func.decorators, func.__doc__)


class FlowGraph(object):
def __init__(self, flow):
self.name = flow.__name__
Expand All @@ -179,13 +168,14 @@ def __init__(self, flow):
self._postprocess()

def _create_nodes(self, flow):
module = __import__(flow.__module__)
tree = ast.parse(inspect.getsource(module)).body
root = [n for n in tree if isinstance(n, ast.ClassDef) and n.name == self.name][
0
]
nodes = {}
StepVisitor(nodes, flow).visit(root)
for element in dir(flow):
func = getattr(flow, element)
if hasattr(func, "is_step"):
source_code = textwrap.dedent(inspect.getsource(func))
function_ast = ast.parse(source_code).body[0]
node = DAGNode(function_ast, func.decorators, func.__doc__)
nodes[element] = node
return nodes

def _postprocess(self):
Expand Down

0 comments on commit 0a964bd

Please sign in to comment.