From 964f081805c2817831ef5c82afc3212079f272c0 Mon Sep 17 00:00:00 2001 From: Corneliu Cofaru Date: Fri, 1 Jun 2018 18:05:15 +0200 Subject: [PATCH] Initial commit: task graph key interpolation --- graphchain/funcutils.py | 27 +++++++++++++++++++++++++++ graphchain/graphchain.py | 13 ++++++++++--- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/graphchain/funcutils.py b/graphchain/funcutils.py index e5424ed..1ce2243 100644 --- a/graphchain/funcutils.py +++ b/graphchain/funcutils.py @@ -8,6 +8,8 @@ import lz4.frame from joblib import hash as joblib_hash from joblib.func_inspect import get_func_code as joblib_getsource +from typing import Union +from dask.core import get_dependencies from .errors import GraphchainCompressionMismatch from .logger import add_logger, mute_dependency_loggers @@ -276,3 +278,28 @@ def recursive_hash(coll, prev_hash=None): recursive_hash(val, prev_hash) return prev_hash + + +def get_bottom_tasks(dsk: dict, + task: Union[str, list, dict]) -> Union[str, list, dict]: + """ + Function that iteratively replaces any task graph keys present in + an input variable `task` with the lowest level keys in the task graph + `dsk` (i.e. the ones pointing to the actual values). This allows + """ + if isinstance(task, str) and task in dsk.keys(): + if not get_dependencies(dsk, task): + return task + else: + task = get_bottom_tasks(dsk, dsk[task]) + elif isinstance(task, list): + for idx in range(len(task)): + task[idx] = get_bottom_tasks(dsk, task[idx]) + elif isinstance(task, dict): + for key in task: + task[key] = get_bottom_tasks(dsk, task[key]) + else: + # Non-key of collection, return value as is + pass + + return task diff --git a/graphchain/graphchain.py b/graphchain/graphchain.py index e48337c..3b623c9 100644 --- a/graphchain/graphchain.py +++ b/graphchain/graphchain.py @@ -23,13 +23,12 @@ """ import warnings from collections import deque - import dask from dask.core import get_dependencies, toposort from .funcutils import (analyze_hash_miss, get_hash, get_storage, load_hashchain, wrap_to_load, wrap_to_store, - write_hashchain) + write_hashchain, get_bottom_tasks) def optimize(dsk, @@ -75,7 +74,7 @@ def optimize(dsk, keyhashmaps = {} # key:hash mapping newdsk = dsk.copy() # output task graph hashes_to_store = set() # list of hashes that correspond # noqa - # to keys whose output will be stored # noqa + # to keys whose output will be stored # noqa while work: key = work.popleft() deps = dependencies[key] @@ -90,9 +89,17 @@ def optimize(dsk, # Account for different task types: i.e. functions/constants if isinstance(task, tuple): + # function call node fno = task[0] fnargs = task[1:] + elif isinstance(task, str) or isinstance(task, list) or \ + isinstance(task, dict): + # graph key node + def identity(x): return x + fno = identity + fnargs = [get_bottom_tasks(dsk, task)] else: + # constant value fno = task fnargs = []