Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement recursive handling of import #23

Merged
merged 5 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 149 additions & 7 deletions memestra/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,74 @@
import hashlib
import yaml

serge-sans-paille marked this conversation as resolved.
Show resolved Hide resolved
# not using gast because we only rely on Import and ImportFrom, which are
# portable. Not using gast prevents an extra costly conversion step.
import ast

from memestra.docparse import docparse
from memestra.utils import resolve_module


class DependenciesResolver(ast.NodeVisitor):
'''
Traverse a module an collect statically imported modules
'''


def __init__(self):
self.result = set()

def add_module(self, module_name):
module_path = resolve_module(module_name)
if module_path is not None:
self.result.add(module_path)

def visit_Import(self, node):
for alias in node.names:
self.add_module(alias.name)

def visit_ImportFrom(self, node):
self.add_module(node.module)

# All members below are specialized in order to improve performance:
# It's useless to traverse leaf statements and expression when looking for
# an import.

def visit_stmt(self, node):
pass

visit_Assign = visit_AugAssign = visit_AnnAssign = visit_Expr = visit_stmt
visit_Return = visit_Print = visit_Raise = visit_Assert = visit_stmt
visit_Pass = visit_Break = visit_Continue = visit_Delete = visit_stmt
visit_Global = visit_Nonlocal = visit_Exec = visit_stmt

def visit_body(self, node):
for stmt in node.body:
self.visit(stmt)

visit_FunctionDef = visit_ClassDef = visit_AsyncFunctionDef = visit_body
visit_With = visit_AsyncWith = visit_body

def visit_orelse(self, node):
for stmt in node.body:
self.visit(stmt)
for stmt in node.orelse:
self.visit(stmt)

visit_For = visit_While = visit_If = visit_AsyncFor = visit_orelse

def visit_Try(self, node):
for stmt in node.body:
self.visit(stmt)
for stmt in node.orelse:
self.visit(stmt)
for stmt in node.finalbody:
self.visit(stmt)


class Format(object):

version = 0
version = 1

fields = (('version', lambda: Format.version),
('name', str),
Expand Down Expand Up @@ -64,12 +126,86 @@ def check_deprecated(data):
raise ValueError("deprecated must be a list of string")


class CacheKey(object):
class CacheKeyFactoryBase(object):
def __init__(self, keycls):
self.keycls = keycls
self.created = dict()

def __call__(self, module_path):
if module_path in self.created:
return self.created[module_path]
else:
self.created[module_path] = None # creation in process
key = self.keycls(module_path, self)
self.created[module_path] = key
return key

def get(self, *args):
return self.created.get(*args)


class CacheKeyFactory(CacheKeyFactoryBase):
'''
Factory for non-recursive keys.
Only the content of the module is taken into account
'''

class CacheKey(object):

def __init__(self, module_path, _):
self.name, _ = os.path.splitext(os.path.basename(module_path))
with open(module_path, 'rb') as fd:
module_content = fd.read()
module_hash = hashlib.sha256(module_content).hexdigest()
self.module_hash = module_hash

def __init__(self):
super(CacheKeyFactory, self).__init__(CacheKeyFactory.CacheKey)


class RecursiveCacheKeyFactory(CacheKeyFactoryBase):
'''
Factory for recursive keys.
This take into account the module content, and the content of *all* imported
module. That way, a change in the module hierarchy implies a change in the
key.
'''

class CacheKey(object):

def __init__(self, module_path, factory):
assert module_path not in factory.created or factory.created[module_path] is None

def __init__(self, module_path):
self.name, _ = os.path.splitext(os.path.basename(module_path))
with open(module_path, 'rb') as fd:
self.module_hash = hashlib.sha256(fd.read()).hexdigest()
self.name, _ = os.path.splitext(os.path.basename(module_path))
with open(module_path, 'rb') as fd:
module_content = fd.read()

code = ast.parse(module_content)
dependencies_resolver = DependenciesResolver()
dependencies_resolver.visit(code)

new_deps = []
for dep in dependencies_resolver.result:
if factory.get(dep, 1) is not None:
new_deps.append(dep)

module_hash = hashlib.sha256(module_content).hexdigest()

hashes = [module_hash]

for new_dep in sorted(new_deps):
try:
new_dep_key = factory(new_dep)
# FIXME: this only happens on windows, maybe we could do
# better?
except UnicodeDecodeError:
continue
hashes.append(new_dep_key.module_hash)

self.module_hash = hashlib.sha256("".join(hashes).encode("ascii")).hexdigest()

def __init__(self):
super(RecursiveCacheKeyFactory, self).__init__(RecursiveCacheKeyFactory.CacheKey)


class Cache(object):
Expand Down Expand Up @@ -128,7 +264,11 @@ def run_set(args):
data = {'generator': 'manual',
'deprecated': args.deprecated}
cache = Cache()
key = CacheKey(args.input)
if args.recursive:
key_factory = RecursiveCacheKeyFactory()
else:
key_factory = CacheKeyFactory()
key = key_factory(args.input)
cache[key] = data


Expand Down Expand Up @@ -172,6 +312,8 @@ def run():
type=str, nargs='+',
default='decorator.deprecated',
help='function to flag as deprecated')
parser_set.add_argument('--recursive', action='store_true',
help='compute a dependency-aware cache key')
parser_set.add_argument('input', type=str,
help='module.py to edit')
parser_set.set_defaults(runner=run_set)
Expand Down
120 changes: 80 additions & 40 deletions memestra/memestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,56 @@
import warnings

from collections import defaultdict
from itertools import chain
from memestra.caching import Cache, CacheKey, Format
from memestra.caching import Cache, CacheKeyFactory, RecursiveCacheKeyFactory
from memestra.caching import Format
from memestra.utils import resolve_module

_defs = ast.AsyncFunctionDef, ast.ClassDef, ast.FunctionDef


# FIXME: this only handles module name not subpackages
def resolve_module(module_name, importer_path=()):
module_path = module_name + ".py"
bases = sys.path
if importer_path:
bases = chain(os.path.abspath(
os.path.dirname(importer_path)), sys.path)
for base in bases:
fullpath = os.path.join(base, module_path)
if os.path.exists(fullpath):
return fullpath
return


class SilentDefUseChains(beniget.DefUseChains):

def unbound_identifier(self, name, node):
pass


# FIXME: this is not recursive, but should be
class ImportResolver(ast.NodeVisitor):
def __init__(self, decorator, file_path=None):

def __init__(self, decorator, file_path=None, recursive=False, parent=None):
'''
Create an ImportResolver that finds deprecated identifiers.

A deprecated identifier is an identifier which is decorated
by `decorator', or which uses a deprecated identifier.

if `recursive' is greater than 0, it considers identifiers
from imported module, with that depth in the import tree.

`parent' is used internally to handle imports.
'''
self.deprecated = None
self.decorator = tuple(decorator)
self.cache = Cache()
self.file_path = file_path
self.recursive = recursive
if parent:
self.cache = parent.cache
self.visited = parent.visited
self.key_factory = parent.key_factory
else:
self.cache = Cache()
self.visited = set()
if recursive:
self.key_factory = RecursiveCacheKeyFactory()
else:
self.key_factory = CacheKeyFactory()

def load_deprecated_from_module(self, module_name):
module_path = resolve_module(module_name, self.file_path)

if module_path is None:
return None

module_key = CacheKey(module_path)
module_key = self.key_factory(module_path)

if module_key in self.cache:
data = self.cache[module_key]
Expand All @@ -60,19 +69,50 @@ def load_deprecated_from_module(self, module_name):
return []

with open(module_path) as fd:
module = ast.parse(fd.read())
try:
module = ast.parse(fd.read())
except UnicodeDecodeError:
return []
duc = SilentDefUseChains()
duc.visit(module)
anc = beniget.Ancestors()
anc.visit(module)

# Collect deprecated functions
if self.recursive and module_path not in self.visited:
self.visited.add(module_path)
resolver = ImportResolver(self.decorator,
self.file_path,
self.recursive,
parent=self)
resolver.visit(module)
deprecated_imports = [d for _, _, d in
resolver.get_deprecated_users(duc, anc)]
else:
deprecated_imports = []
deprecated = self.collect_deprecated(module, duc, anc)
deprecated.update(deprecated_imports)
dl = {d.name for d in deprecated}
serge-sans-paille marked this conversation as resolved.
Show resolved Hide resolved
dl = {d.name for d in deprecated}
data = {'generator': 'memestra',
'deprecated': sorted(dl)}
self.cache[module_key] = data
return dl

def get_deprecated_users(self, defuse, ancestors):
deprecated_uses = []
for deprecated_node in self.deprecated:
for user in defuse.chains[deprecated_node].users():
user_ancestors = [n
for n in ancestors.parents(user.node)
if isinstance(n, _defs)]
if any(f in self.deprecated for f in user_ancestors):
continue
deprecated_uses.append((deprecated_node, user,
user_ancestors[-1] if user_ancestors
else user.node))
return deprecated_uses

def visit_Import(self, node):
for alias in node.names:
deprecated = self.load_deprecated_from_module(alias.name)
Expand Down Expand Up @@ -173,45 +213,41 @@ def prettyname(node):
return repr(node)


def memestra(file_descriptor, decorator, file_path=None):
def memestra(file_descriptor, decorator, file_path=None, recursive=False):
'''
Parse `file_descriptor` and returns a list of
(function, filename, line, colno) tuples. Each elements
represents a code location where a deprecated function is used.
A deprecated function is a function flagged by `decorator`, where
`decorator` is a tuple representing an import path,
e.g. (module, attribute)

If `recursive` is set to `True`, deprecated use are
checked recursively throughout the *whole* module import tree. Otherwise,
only one level of import is checked.
'''

assert not isinstance(decorator, str) and \
len(decorator) > 1, "decorator is at least (module, attribute)"

module = ast.parse(file_descriptor.read())

# Collect deprecated functions
resolver = ImportResolver(decorator, file_path)
resolver = ImportResolver(decorator, file_path, recursive)
resolver.visit(module)

ancestors = resolver.ancestors
duc = resolver.def_use_chains

# Find their users
deprecate_uses = []
for deprecated_node in resolver.deprecated:
for user in duc.chains[deprecated_node].users():
user_ancestors = (n
for n in ancestors.parents(user.node)
if isinstance(n, _defs))
if any(f in resolver.deprecated for f in user_ancestors):
continue

deprecate_uses.append((prettyname(deprecated_node),
getattr(file_descriptor, 'name', '<>'),
user.node.lineno,
user.node.col_offset))
formated_deprecated = []
for deprecated_node, user, _ in resolver.get_deprecated_users(duc, ancestors):
formated_deprecated.append((prettyname(deprecated_node),
getattr(file_descriptor, 'name', '<>'),
user.node.lineno,
user.node.col_offset))

deprecate_uses.sort()
return deprecate_uses
formated_deprecated.sort()
return formated_deprecated


def run():
Expand All @@ -223,6 +259,9 @@ def run():
parser.add_argument('--decorator', dest='decorator',
default='decorator.deprecated',
help='Path to the decorator to check')
parser.add_argument('--recursive', dest='recursive',
action='store_true',
help='Traverse the whole module hierarchy')
parser.add_argument('input', type=argparse.FileType('r'),
help='file to scan')

Expand All @@ -236,7 +275,8 @@ def run():

deprecate_uses = dispatcher[extension](args.input,
args.decorator.split('.'),
args.input.name)
args.input.name,
args.recursive)

for fname, fd, lineno, colno in deprecate_uses:
print("{} used at {}:{}:{}".format(fname, fd, lineno, colno + 1))
Expand Down
Loading