From 0a964bd91f0de164589dccf7b92adde46e5f9191 Mon Sep 17 00:00:00 2001 From: simonwardjones Date: Tue, 17 Sep 2024 10:45:11 +0100 Subject: [PATCH] feat: :sparkles: Change the Graph creation to allow subclassing --- metaflow/graph.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 4727c1bbad6..3d5a4e45162 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -1,6 +1,7 @@ import inspect import ast import re +import textwrap from .util import to_pod @@ -156,18 +157,6 @@ def __str__(self): ) -class StepVisitor(ast.NodeVisitor): - def __init__(self, nodes, flow): - self.nodes = nodes - self.flow = flow - super(StepVisitor, self).__init__() - - def visit_FunctionDef(self, node): - func = getattr(self.flow, node.name) - if hasattr(func, "is_step"): - self.nodes[node.name] = DAGNode(node, func.decorators, func.__doc__) - - class FlowGraph(object): def __init__(self, flow): self.name = flow.__name__ @@ -179,13 +168,14 @@ def __init__(self, flow): self._postprocess() def _create_nodes(self, flow): - module = __import__(flow.__module__) - tree = ast.parse(inspect.getsource(module)).body - root = [n for n in tree if isinstance(n, ast.ClassDef) and n.name == self.name][ - 0 - ] nodes = {} - StepVisitor(nodes, flow).visit(root) + for element in dir(flow): + func = getattr(flow, element) + if hasattr(func, "is_step"): + source_code = textwrap.dedent(inspect.getsource(func)) + function_ast = ast.parse(source_code).body[0] + node = DAGNode(function_ast, func.decorators, func.__doc__) + nodes[element] = node return nodes def _postprocess(self):