Skip to content

Commit

Permalink
Improve recursive caching performance
Browse files Browse the repository at this point in the history
Do not recompute the same key twice, using a key factory.
Speedup dependency computation using native ast module and short-circuiting.
  • Loading branch information
serge-sans-paille committed Jun 9, 2020
1 parent dfb22e7 commit 83dd596
Showing 1 changed file with 44 additions and 10 deletions.
54 changes: 44 additions & 10 deletions memestra/caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import hashlib
import yaml
import gast as ast

# not using gast because we only rely on Import and ImportFrom, which are
# portable. Not using gast prevents an extra costly conversion step.
import ast

from memestra.docparse import docparse
from memestra.utils import resolve_module
Expand All @@ -24,6 +27,37 @@ def visit_Import(self, node):
def visit_ImportFrom(self, node):
self.add_module(node.module)

def visit_stmt(self, node):
pass

visit_Assign = visit_AugAssign = visit_AnnAssign = visit_Expr = visit_stmt
visit_Return = visit_Print = visit_Raise = visit_Assert = visit_stmt
visit_Pass = visit_Break = visit_Continue = visit_Delete = visit_stmt
visit_Global = visit_Nonlocal = visit_Exec = visit_stmt

def visit_body(self, node):
for stmt in node.body:
self.visit(stmt)

visit_FunctionDef = visit_ClassDef = visit_AsyncFunctionDef = visit_body
visit_With = visit_AsyncWith = visit_body

def visit_orelse(self, node):
for stmt in node.body:
self.visit(stmt)
for stmt in node.orelse:
self.visit(stmt)

visit_For = visit_While = visit_If = visit_AsyncFor = visit_orelse

def visit_Try(self, node):
for stmt in node.body:
self.visit(stmt)
for stmt in node.orelse:
self.visit(stmt)
for stmt in node.finalbody:
self.visit(stmt)


class Format(object):

Expand Down Expand Up @@ -98,8 +132,8 @@ def __call__(self, module_path):
self.created[module_path] = key
return key

def __iter__(self):
return iter(self.created.keys())
def get(self, *args):
return self.created.get(*args)


class CacheKeyFactory(CacheKeyFactoryBase):
Expand Down Expand Up @@ -132,20 +166,20 @@ def __init__(self, module_path, factory):
dependencies_resolver = DependenciesResolver()
dependencies_resolver.visit(code)

new_deps = dependencies_resolver.result.difference(factory)
new_deps = []
for dep in dependencies_resolver.result:
if factory.get(dep, 1) is not None:
new_deps.append(dep)

module_hash = hashlib.sha256(module_content).hexdigest()

def bytes_xor(left, right):
return str(chr(ord(l) ^ ord(r))
for l, r in zip(left, right))
hashes = [module_hash]

for new_dep in sorted(new_deps):
new_dep_key = factory(new_dep)
module_hash = bytes_xor(module_hash,
new_dep_key.module_hash)
hashes.append(new_dep_key.module_hash)

self.module_hash = module_hash
self.module_hash = hashlib.sha256("".join(hashes).encode("ascii")).hexdigest()

def __init__(self):
super(RecursiveCacheKeyFactory, self).__init__(RecursiveCacheKeyFactory.CacheKey)
Expand Down

0 comments on commit 83dd596

Please sign in to comment.