Skip to content

Commit

Permalink
Add useful validation of Gets.
Browse files Browse the repository at this point in the history
  • Loading branch information
stuhood committed Mar 30, 2018
1 parent 8e6fc55 commit 9e22ec5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 20 deletions.
26 changes: 9 additions & 17 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
from abc import abstractproperty
from collections import OrderedDict
from types import TypeType

from twitter.common.collections import OrderedSet

Expand All @@ -30,20 +31,7 @@ def __init__(self):
def visit_Call(self, node):
if not isinstance(node.func, ast.Name) or node.func.id != Get.__name__:
return

# TODO: Validation.
if len(node.args) == 2:
product_type, subject_constructor = node.args
if not isinstance(product_type, ast.Name) or not isinstance(subject_constructor, ast.Call):
raise Exception('TODO: Implement validation of Get shapes.')
self.gets.append((product_type.id, subject_constructor.func.id))
elif len(node.args) == 3:
product_type, subject_type, _ = node.args
if not isinstance(product_type, ast.Name) or not isinstance(subject_type, ast.Name):
raise Exception('TODO: Implement validation of Get shapes.')
self.gets.append((product_type.id, subject_type.id))
else:
raise Exception('Invalid {}: {}'.format(Get.__name__, node.args))
self.gets.append(Get.extract_constraints(node))


def rule(output_type, input_selectors):
Expand All @@ -62,15 +50,19 @@ def wrapper(func):
caller_frame = inspect.stack()[1][0]
module_ast = ast.parse(inspect.getsource(func))

def resolve(name):
return caller_frame.f_globals.get(name) or caller_frame.f_builtins.get(name)
def resolve_type(name):
resolved = caller_frame.f_globals.get(name) or caller_frame.f_builtins.get(name)
if not isinstance(resolved, (TypeType, Exactly)):
raise ValueError('Expected either a `type` constructor or TypeConstraint instance; '
'got: {}'.format(name))
return resolved

gets = []
for node in ast.iter_child_nodes(module_ast):
if isinstance(node, ast.FunctionDef) and node.name == func.__name__:
rule_visitor = _RuleVisitor()
rule_visitor.visit(node)
gets.extend(Get(resolve(p), resolve(s)) for p, s in rule_visitor.gets)
gets.extend(Get(resolve_type(p), resolve_type(s)) for p, s in rule_visitor.gets)

func._rule = TaskRule(output_type, input_selectors, func, input_gets=gets)
return func
Expand Down
36 changes: 33 additions & 3 deletions src/python/pants/engine/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import (absolute_import, division, generators, nested_scopes, print_function,
unicode_literals, with_statement)

import ast
from abc import abstractproperty

import six
Expand Down Expand Up @@ -33,21 +34,50 @@ def constraint_for(type_or_constraint):


class Get(datatype('Get', ['product', 'subject'])):
"""TODO: Experimental synchronous generator API.
"""Experimental synchronous generator API.
May be called equivalently as either:
# verbose form: Get(product_type, subject_type, subject)
# shorthand form: Get(product_type, subject_type(subject))
"""

@staticmethod
def extract_constraints(call_node):
"""Parses a `Get(..)` call in one of its two legal forms to return its type constraints.
:param call_node: An `ast.Call` node representing a call to `Get(..)`.
:return: A tuple of product type id and subject type id.
"""
def render_args():
return ', '.join(a.id for a in call_node.args)

if len(call_node.args) == 2:
product_type, subject_constructor = call_node.args
if not isinstance(product_type, ast.Name) or not isinstance(subject_constructor, ast.Call):
raise Exception('Two arg form of {} expected (product_type, subject_type(subject)), but '
'got: ({})'.format(Get.__name__, render_args()))
return (product_type.id, subject_constructor.func.id)
elif len(call_node.args) == 3:
product_type, subject_type, _ = call_node.args
if not isinstance(product_type, ast.Name) or not isinstance(subject_type, ast.Name):
raise Exception('Three arg form of {} expected (product_type, subject_type, subject), but '
'got: ({})'.format(Get.__name__, render_args()))
return (product_type.id, subject_type.id)
else:
raise Exception('Invalid {}; expected either two or three args, but '
'got: ({})'.format(Get.__name__, render_args()))

def __new__(cls, *args):
if len(args) == 2:
product, subject = args
elif len(args) == 3:
product, _, subject = args
product, subject_type, subject = args
if type(subject) is not subject_type:
raise TypeError('Declared type did not match actual type for {}({}).'.format(
Get.__name__, ', '.join(str(a) for a in args)))
else:
raise Exception('Expected either two or three arguments to {}; got {}.'.format(
Get.__name__, args))
Get.__name__, args))
return super(Get, cls).__new__(cls, product, subject)


Expand Down

0 comments on commit 9e22ec5

Please sign in to comment.