Skip to content

Commit

Permalink
Pre-filter TestTarget @union members (#8368)
Browse files Browse the repository at this point in the history
Allows running tests over globbed targets without erroring on
any non-test targets matched by the glob.

Provides a mechanism that any console_rule can use for similar
effects if needed.
  • Loading branch information
benjyw authored Oct 6, 2019
1 parent db1961c commit 30165bb
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 34 deletions.
2 changes: 1 addition & 1 deletion build-support/githooks/pre-commit
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ echo "* Checking shell scripts via our custom linter"
# fails in pants changed.
if git rev-parse --verify "${MERGE_BASE}" &>/dev/null; then
echo "* Checking imports"
./build-support/bin/isort.sh || die "To fix import sort order, run \`\"$(pwd)/build-support/bin/isort.sh\" -f\`"
./build-support/bin/isort.sh || die "To fix import sort order, run \`build-support/bin/isort.sh -f\`"

# TODO(CMLivingston) Make lint use `-q` option again after addressing proper workunit labeling:
# https://github.com/pantsbuild/pants/issues/6633
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/build_graph/build_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def rules(self):
return list(self._rules)

def union_rules(self):
"""Returns a mapping of registered union base types -> [a list of union member types].
"""Returns a mapping of registered union base types -> [OrderedSet of union member types].
:rtype: OrderedDict
"""
Expand Down
14 changes: 13 additions & 1 deletion src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Iterable
from dataclasses import dataclass
from textwrap import dedent
from typing import Any, Callable, Type, cast
from typing import Any, Callable, Dict, Type, cast

import asttokens
from twitter.common.collections import OrderedSet
Expand Down Expand Up @@ -428,6 +429,17 @@ def __new__(cls, union_base, union_member):
return super().__new__(cls, union_base, union_member)


@dataclass(frozen=True)
class UnionMembership:
union_rules: Dict[type, typing.Iterable[type]]

def is_member(self, union_type, putative_member):
members = self.union_rules.get(union_type)
if members is None:
raise TypeError(f'Not a registered union type: {union_type}')
return type(putative_member) in members


class Rule(ABC):
"""Rules declare how to produce products for the product graph.
Expand Down
7 changes: 6 additions & 1 deletion src/python/pants/init/engine_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from pants.engine.mapper import AddressMapper
from pants.engine.parser import SymbolTable
from pants.engine.platform import create_platform_rules
from pants.engine.rules import RootRule, rule
from pants.engine.rules import RootRule, UnionMembership, rule
from pants.engine.scheduler import Scheduler
from pants.engine.selectors import Params
from pants.init.options_initializer import BuildConfigInitializer, OptionsInitializer
Expand Down Expand Up @@ -357,6 +357,10 @@ def build_configuration_singleton() -> BuildConfiguration:
def symbol_table_singleton() -> SymbolTable:
return symbol_table

@rule
def union_membership_singleton() -> UnionMembership:
return UnionMembership(build_configuration.union_rules())

# Create a Scheduler containing graph and filesystem rules, with no installed goals. The
# LegacyBuildGraph will explicitly request the products it needs.
rules = (
Expand All @@ -365,6 +369,7 @@ def symbol_table_singleton() -> SymbolTable:
glob_match_error_behavior_singleton,
build_configuration_singleton,
symbol_table_singleton,
union_membership_singleton,
] +
create_legacy_graph_tasks() +
create_fs_rules() +
Expand Down
54 changes: 36 additions & 18 deletions src/python/pants/rules/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import logging
from dataclasses import dataclass
from typing import Optional

from pants.base.exiter import PANTS_FAILED_EXIT_CODE, PANTS_SUCCEEDED_EXIT_CODE
from pants.build_graph.address import Address
from pants.build_graph.address import Address, BuildFileAddress
from pants.engine.addressable import BuildFileAddresses
from pants.engine.console import Console
from pants.engine.goal import Goal
from pants.engine.legacy.graph import HydratedTarget
from pants.engine.rules import console_rule, rule
from pants.engine.rules import UnionMembership, console_rule, rule
from pants.engine.selectors import Get
from pants.rules.core.core_test_model import Status, TestResult, TestTarget

Expand All @@ -24,18 +26,27 @@ class Test(Goal):
name = 'test'


@dataclass(frozen=True)
class AddressAndTestResult:
address: BuildFileAddress
test_result: Optional[TestResult] # If None, target was not a test target.


@console_rule
def fast_test(console: Console, addresses: BuildFileAddresses) -> Test:
test_results = yield [Get(TestResult, Address, address.to_address()) for address in addresses]
results = yield [Get(AddressAndTestResult, Address, addr.to_address()) for addr in addresses]
did_any_fail = False
for address, test_result in zip(addresses, test_results):
filtered_results = [(x.address, x.test_result) for x in results if x.test_result is not None]

for address, test_result in filtered_results:
if test_result.status == Status.FAILURE:
did_any_fail = True
if test_result.stdout:
console.write_stdout(
"{} stdout:\n{}\n".format(
address.reference(),
console.red(test_result.stdout) if test_result.status == Status.FAILURE else test_result.stdout
(console.red(test_result.stdout) if test_result.status == Status.FAILURE
else test_result.stdout)
)
)
if test_result.stderr:
Expand All @@ -44,14 +55,16 @@ def fast_test(console: Console, addresses: BuildFileAddresses) -> Test:
console.write_stdout(
"{} stderr:\n{}\n".format(
address.reference(),
console.red(test_result.stderr) if test_result.status == Status.FAILURE else test_result.stderr
(console.red(test_result.stderr) if test_result.status == Status.FAILURE
else test_result.stderr)
)
)

console.write_stdout("\n")

for address, test_result in zip(addresses, test_results):
console.print_stdout('{0:80}.....{1:>10}'.format(address.reference(), test_result.status.value))
for address, test_result in filtered_results:
console.print_stdout('{0:80}.....{1:>10}'.format(
address.reference(), test_result.status.value))

if did_any_fail:
console.print_stderr(console.red('Tests failed'))
Expand All @@ -63,19 +76,24 @@ def fast_test(console: Console, addresses: BuildFileAddresses) -> Test:


@rule
def coordinator_of_tests(target: HydratedTarget) -> TestResult:
def coordinator_of_tests(target: HydratedTarget,
union_membership: UnionMembership) -> AddressAndTestResult:
# TODO(#6004): when streaming to live TTY, rely on V2 UI for this information. When not a
# live TTY, periodically dump heavy hitters to stderr. See
# https://github.com/pantsbuild/pants/issues/6004#issuecomment-492699898.
logger.info("Starting tests: {}".format(target.address.reference()))
# NB: This has the effect of "casting" a TargetAdaptor to a member of the TestTarget union. If the
# TargetAdaptor is not a member of the union, it will fail at runtime with a useful error message.
result = yield Get(TestResult, TestTarget, target.adaptor)
logger.info("Tests {}: {}".format(
"succeeded" if result.status == Status.SUCCESS else "failed",
target.address.reference(),
))
yield result
if union_membership.is_member(TestTarget, target.adaptor):
logger.info("Starting tests: {}".format(target.address.reference()))
# NB: This has the effect of "casting" a TargetAdaptor to a member of the TestTarget union.
# The adaptor will always be a member because of the union membership check above, but if
# it were not it would fail at runtime with a useful error message.
result = yield Get(TestResult, TestTarget, target.adaptor)
logger.info("Tests {}: {}".format(
"succeeded" if result.status == Status.SUCCESS else "failed",
target.address.reference(),
))
else:
result = None # Not a test target.
yield AddressAndTestResult(target.address, result)


def rules():
Expand Down
44 changes: 32 additions & 12 deletions tests/python/pants_test/rules/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
from pants.build_graph.address import Address, BuildFileAddress
from pants.engine.legacy.graph import HydratedTarget
from pants.engine.legacy.structs import PythonTestsAdaptor
from pants.rules.core.test import Status, TestResult, coordinator_of_tests, fast_test
from pants.engine.rules import UnionMembership
from pants.rules.core.core_test_model import TestTarget
from pants.rules.core.test import (
AddressAndTestResult,
Status,
TestResult,
coordinator_of_tests,
fast_test,
)
from pants_test.engine.util import MockConsole, run_rule
from pants_test.test_base import TestBase

Expand All @@ -16,14 +24,16 @@ class TestTest(TestBase):
def single_target_test(self, result, expected_console_output, success=True):
console = MockConsole(use_colors=False)

res = run_rule(fast_test, console, (self.make_build_target_address("some/target"),), {
(TestResult, Address): lambda _: result,
addr = self.make_build_target_address("some/target")
res = run_rule(fast_test, console, (addr,), {
(AddressAndTestResult, Address): lambda _: AddressAndTestResult(addr, result),
})

self.assertEquals(console.stdout.getvalue(), expected_console_output)
self.assertEquals(0 if success else 1, res.exit_code)

def make_build_target_address(self, spec):
@staticmethod
def make_build_target_address(spec):
address = Address.parse(spec)
return BuildFileAddress(
build_file=None,
Expand Down Expand Up @@ -61,14 +71,15 @@ def test_output_mixed(self):

def make_result(target):
if target == target1:
return TestResult(status=Status.SUCCESS, stdout='I passed\n', stderr='')
tr = TestResult(status=Status.SUCCESS, stdout='I passed\n', stderr='')
elif target == target2:
return TestResult(status=Status.FAILURE, stdout='I failed\n', stderr='')
tr = TestResult(status=Status.FAILURE, stdout='I failed\n', stderr='')
else:
raise Exception("Unrecognised target")
return AddressAndTestResult(target, tr)

res = run_rule(fast_test, console, (target1, target2), {
(TestResult, Address): make_result,
(AddressAndTestResult, Address): make_result,
})

self.assertEqual(1, res.exit_code)
Expand Down Expand Up @@ -97,10 +108,19 @@ def test_stderr(self):
)

def test_coordinator_python_test(self):
addr = Address.parse("some/target")
target_adaptor = PythonTestsAdaptor(type_alias='python_tests')
with self.captured_logging(logging.INFO):
result = run_rule(coordinator_of_tests, HydratedTarget(Address.parse("some/target"), target_adaptor, ()), {
(TestResult, PythonTestsAdaptor): lambda _: TestResult(status=Status.FAILURE, stdout='foo', stderr=''),
})

self.assertEqual(result, TestResult(status=Status.FAILURE, stdout='foo', stderr=''))
result = run_rule(
coordinator_of_tests,
HydratedTarget(addr, target_adaptor, ()),
UnionMembership(union_rules={TestTarget: [PythonTestsAdaptor]}),
{
(TestResult, PythonTestsAdaptor):
lambda _: TestResult(status=Status.FAILURE, stdout='foo', stderr=''),
})

self.assertEqual(
result,
AddressAndTestResult(addr, TestResult(status=Status.FAILURE, stdout='foo', stderr=''))
)

0 comments on commit 30165bb

Please sign in to comment.