Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional fixes and improvements #22

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions graphchain/funcutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import lz4.frame
from joblib import hash as joblib_hash
from joblib.func_inspect import get_func_code as joblib_getsource
from typing import Union
from dask.core import get_dependencies

from .errors import GraphchainCompressionMismatch
from .logger import add_logger, mute_dependency_loggers
Expand Down Expand Up @@ -276,3 +278,28 @@ def recursive_hash(coll, prev_hash=None):
recursive_hash(val, prev_hash)

return prev_hash


def get_bottom_tasks(dsk: dict,
task: Union[str, list, dict]) -> Union[str, list, dict]:
"""
Function that iteratively replaces any task graph keys present in
an input variable `task` with the lowest level keys in the task graph
`dsk` (i.e. the ones pointing to the actual values). This allows
"""
if isinstance(task, str) and task in dsk.keys():
if not get_dependencies(dsk, task):
return task
else:
task = get_bottom_tasks(dsk, dsk[task])
elif isinstance(task, list):
for idx in range(len(task)):
task[idx] = get_bottom_tasks(dsk, task[idx])
elif isinstance(task, dict):
for key in task:
task[key] = get_bottom_tasks(dsk, task[key])
else:
# Non-key of collection, return value as is
pass

return task
13 changes: 10 additions & 3 deletions graphchain/graphchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
"""
import warnings
from collections import deque

import dask
from dask.core import get_dependencies, toposort

from .funcutils import (analyze_hash_miss, get_hash, get_storage,
load_hashchain, wrap_to_load, wrap_to_store,
write_hashchain)
write_hashchain, get_bottom_tasks)


def optimize(dsk,
Expand Down Expand Up @@ -75,7 +74,7 @@ def optimize(dsk,
keyhashmaps = {} # key:hash mapping
newdsk = dsk.copy() # output task graph
hashes_to_store = set() # list of hashes that correspond # noqa
# to keys whose output will be stored # noqa
# to keys whose output will be stored # noqa
while work:
key = work.popleft()
deps = dependencies[key]
Expand All @@ -90,9 +89,17 @@ def optimize(dsk,

# Account for different task types: i.e. functions/constants
if isinstance(task, tuple):
# function call node
fno = task[0]
fnargs = task[1:]
elif isinstance(task, str) or isinstance(task, list) or \
isinstance(task, dict):
# graph key node
def identity(x): return x
fno = identity
fnargs = [get_bottom_tasks(dsk, task)]
else:
# constant value
fno = task
fnargs = []

Expand Down