Skip to content

Commit

Permalink
Initial commit: task graph key interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
zgornel committed Jun 5, 2018
1 parent 040e590 commit 17609d3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
27 changes: 27 additions & 0 deletions graphchain/funcutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import lz4.frame
from joblib import hash as joblib_hash
from joblib.func_inspect import get_func_code as joblib_getsource
from typing import Union
from dask.core import get_dependencies

from .errors import (GraphchainCompressionMismatch, GraphchainPicklingError,
InvalidPersistencyOption)
Expand Down Expand Up @@ -327,3 +329,28 @@ def recursive_hash(coll, prev_hash=None):
recursive_hash(val, prev_hash)

return prev_hash


def get_bottom_tasks(dsk: dict,
task: Union[str, list, dict]) -> Union[str, list, dict]:
"""
Function that iteratively replaces any task graph keys present in
an input variable `task` with the lowest level keys in the task graph
`dsk` (i.e. the ones pointing to the actual values). This allows
"""
if isinstance(task, str) and task in dsk.keys():
if not get_dependencies(dsk, task):
return task
else:
task = get_bottom_tasks(dsk, dsk[task])
elif isinstance(task, list):
for idx in range(len(task)):
task[idx] = get_bottom_tasks(dsk, task[idx])
elif isinstance(task, dict):
for key in task:
task[key] = get_bottom_tasks(dsk, task[key])
else:
# Non-key of collection, return value as is
pass

return task
13 changes: 10 additions & 3 deletions graphchain/graphchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
"""
import warnings
from collections import deque

import dask
from dask.core import get_dependencies

from .funcutils import (analyze_hash_miss, get_hash, get_storage,
load_hashchain, wrap_to_load, wrap_to_store,
write_hashchain)
write_hashchain, get_bottom_tasks)


def optimize(dsk,
Expand Down Expand Up @@ -80,7 +79,7 @@ def optimize(dsk,
keyhashmaps = {} # key:hash mapping
newdsk = dsk.copy() # output task graph
hashes_to_store = set() # list of hashes that correspond # noqa
# to keys whose output will be stored # noqa
# to keys whose output will be stored # noqa
while work:
key = work.popleft()
deps = dependencies[key]
Expand All @@ -95,9 +94,17 @@ def optimize(dsk,

# Account for different task types: i.e. functions/constants
if isinstance(task, tuple):
# function call node
fno = task[0]
fnargs = task[1:]
elif isinstance(task, str) or isinstance(task, list) or \
isinstance(task, dict):
# graph key node
def identity(x): return x
fno = identity
fnargs = [get_bottom_tasks(dsk, task)]
else:
# constant value
fno = task
fnargs = []

Expand Down

0 comments on commit 17609d3

Please sign in to comment.