From 3e0b78181eee374db345fcf2adb83136cb550d47 Mon Sep 17 00:00:00 2001 From: Eric Arellano <14852634+Eric-Arellano@users.noreply.github.com> Date: Sat, 4 Apr 2020 20:35:54 -0700 Subject: [PATCH] Allow `Sources` subclasses to specify the # of expected files (#9466) Twice now, we've had `Sources` fields that require a specific number of files. So, this PR formalizes a mechanism to do this. Error message #1: > ERROR: The 'sources' field in target build-support/bin:check_banned_imports must have 3 files, but it had 1 file. Error message #2: > ERROR: The 'sources' field in target build-support/bin:check_banned_imports must have 2 or 3 files, but it had 1 file. Error message #3: > ERROR: The 'sources' field in target build-support/bin:check_banned_imports must have a number of files in the range `range(20, 20000, 10)`, but it had 1 file. [ci skip-rust-tests] # No Rust changes made. [ci skip-jvm-tests] # No JVM changes made. --- .../pants/backend/python/rules/targets.py | 174 +++++++++--------- .../backend/python/rules/targets_test.py | 32 +--- src/python/pants/engine/target.py | 32 +++- src/python/pants/engine/target_test.py | 33 +++- 4 files changed, 154 insertions(+), 117 deletions(-) diff --git a/src/python/pants/backend/python/rules/targets.py b/src/python/pants/backend/python/rules/targets.py index 6936dc4a99d..5c69e2b53ff 100644 --- a/src/python/pants/backend/python/rules/targets.py +++ b/src/python/pants/backend/python/rules/targets.py @@ -6,8 +6,6 @@ from pants.backend.python.subsystems.pytest import PyTest from pants.build_graph.address import Address -from pants.engine.fs import Snapshot -from pants.engine.objects import union from pants.engine.target import ( COMMON_TARGET_FIELDS, BoolField, @@ -23,32 +21,15 @@ from pants.python.python_setup import PythonSetup from pants.rules.core.determine_source_files import SourceFiles +# ----------------------------------------------------------------------------------------------- +# Common fields +# ----------------------------------------------------------------------------------------------- + -@union class PythonSources(Sources): expected_file_extensions = (".py",) -class PythonLibrarySources(PythonSources): - default = ("*.py", "!test_*.py", "!*_test.py", "!conftest.py") - - -class PythonTestsSources(PythonSources): - default = ("test_*.py", "*_test.py", "conftest.py") - - -class PythonBinarySources(PythonSources): - def validate_snapshot(self, snapshot: Snapshot) -> None: - super().validate_snapshot(snapshot) - if len(snapshot.files) not in [0, 1]: - raise InvalidFieldException( - f"The {repr(self.alias)} field in target {self.address} must only have 0 or 1 " - f"files because it is a binary target, but it has {len(snapshot.files)} sources: " - f"{sorted(snapshot.files)}.\n\nTo use any additional files, put them in a " - "`python_library` and then add that `python_library` as a `dependency`." - ) - - class Compatibility(StringOrStringSequenceField): """A string for Python interpreter constraints on this target. @@ -67,67 +48,16 @@ def value_or_global_default(self, python_setup: PythonSetup) -> Tuple[str, ...]: return python_setup.compatibility_or_constraints(self.value) -class Coverage(StringOrStringSequenceField): - """The module(s) whose coverage should be generated, e.g. `['pants.util']`.""" - - alias = "coverage" - - def determine_packages_to_cover( - self, *, specified_source_files: SourceFiles - ) -> Tuple[str, ...]: - """Either return the specified `coverage` field value or, if not defined, attempt to - generate packages with a heuristic that tests have the same package name as their source - code. - - This heuristic about package names works when either the tests live in the same folder as - their source code, or there is a parallel file structure with the same top-level package - names, e.g. `src/python/project` and `tests/python/project` (but not - `tests/python/test_project`). - """ - if self.value is not None: - return self.value - return tuple( - sorted( - { - # Turn file paths into package names. - os.path.dirname(source_file).replace(os.sep, ".") - for source_file in specified_source_files.snapshot.files - } - ) - ) - - -class Timeout(IntField): - """A timeout (in seconds) which covers the total runtime of all tests in this target. +COMMON_PYTHON_FIELDS = (*COMMON_TARGET_FIELDS, Dependencies, Compatibility, ProvidesField) - This only applies if `--pytest-timeouts` is set to True. - """ - alias = "timeout" +# ----------------------------------------------------------------------------------------------- +# `python_binary` target +# ----------------------------------------------------------------------------------------------- - @classmethod - def compute_value(cls, raw_value: Optional[int], *, address: Address) -> Optional[int]: - value = super().compute_value(raw_value, address=address) - if value is not None and value < 1: - raise InvalidFieldException( - f"The value for the `timeout` field in target {address} must be > 0, but was " - f"{value}." - ) - return value - def calculate_from_global_options(self, pytest: PyTest) -> Optional[int]: - """Determine the timeout (in seconds) after applying global `pytest` options.""" - if not pytest.timeouts_enabled: - return None - if self.value is None: - if pytest.timeout_default is None: - return None - result = pytest.timeout_default - else: - result = self.value - if pytest.timeout_maximum is not None: - return min(result, pytest.timeout_maximum) - return result +class PythonBinarySources(PythonSources): + expected_num_files = range(0, 2) class PythonEntryPoint(StringField): @@ -236,9 +166,6 @@ class PexEmitWarnings(BoolField): default = True -COMMON_PYTHON_FIELDS = (*COMMON_TARGET_FIELDS, Dependencies, Compatibility, ProvidesField) - - class PythonBinary(Target): """A Python target that can be converted into an executable Pex file. @@ -264,11 +191,92 @@ class PythonBinary(Target): ) +# ----------------------------------------------------------------------------------------------- +# `python_library` target +# ----------------------------------------------------------------------------------------------- + + +class PythonLibrarySources(PythonSources): + default = ("*.py", "!test_*.py", "!*_test.py", "!conftest.py") + + class PythonLibrary(Target): alias = "python_library" core_fields = (*COMMON_PYTHON_FIELDS, PythonLibrarySources) +# ----------------------------------------------------------------------------------------------- +# `python_tests` target +# ----------------------------------------------------------------------------------------------- + + +class PythonTestsSources(PythonSources): + default = ("test_*.py", "*_test.py", "conftest.py") + + +class Coverage(StringOrStringSequenceField): + """The module(s) whose coverage should be generated, e.g. `['pants.util']`.""" + + alias = "coverage" + + def determine_packages_to_cover( + self, *, specified_source_files: SourceFiles + ) -> Tuple[str, ...]: + """Either return the specified `coverage` field value or, if not defined, attempt to + generate packages with a heuristic that tests have the same package name as their source + code. + + This heuristic about package names works when either the tests live in the same folder as + their source code, or there is a parallel file structure with the same top-level package + names, e.g. `src/python/project` and `tests/python/project` (but not + `tests/python/test_project`). + """ + if self.value is not None: + return self.value + return tuple( + sorted( + { + # Turn file paths into package names. + os.path.dirname(source_file).replace(os.sep, ".") + for source_file in specified_source_files.snapshot.files + } + ) + ) + + +class Timeout(IntField): + """A timeout (in seconds) which covers the total runtime of all tests in this target. + + This only applies if `--pytest-timeouts` is set to True. + """ + + alias = "timeout" + + @classmethod + def compute_value(cls, raw_value: Optional[int], *, address: Address) -> Optional[int]: + value = super().compute_value(raw_value, address=address) + if value is not None and value < 1: + raise InvalidFieldException( + f"The value for the `timeout` field in target {address} must be > 0, but was " + f"{value}." + ) + return value + + def calculate_from_global_options(self, pytest: PyTest) -> Optional[int]: + """Determine the timeout (in seconds) after applying global `pytest` options.""" + if not pytest.timeouts_enabled: + return None + if self.value is None: + if pytest.timeout_default is None: + return None + result = pytest.timeout_default + else: + result = self.value + if pytest.timeout_maximum is not None: + return min(result, pytest.timeout_maximum) + return result + + class PythonTests(Target): """Python tests (either Pytest-style or unittest style).""" diff --git a/src/python/pants/backend/python/rules/targets_test.py b/src/python/pants/backend/python/rules/targets_test.py index 5de93a0c89d..5316dde20c8 100644 --- a/src/python/pants/backend/python/rules/targets_test.py +++ b/src/python/pants/backend/python/rules/targets_test.py @@ -5,13 +5,10 @@ import pytest -from pants.backend.python.rules.targets import PythonBinarySources, Timeout +from pants.backend.python.rules.targets import Timeout from pants.backend.python.subsystems.pytest import PyTest from pants.build_graph.address import Address -from pants.engine.rules import RootRule -from pants.engine.scheduler import ExecutionError -from pants.engine.target import HydratedSources, HydrateSourcesRequest, InvalidFieldException -from pants.engine.target import rules as target_rules +from pants.engine.target import InvalidFieldException from pants.testutil.subsystem.util import global_subsystem_instance from pants.testutil.test_base import TestBase @@ -65,28 +62,3 @@ def test_no_field_timeout_and_default_greater_than_max(self) -> None: def test_timeouts_disabled(self) -> None: self.assert_timeout_calculated(field_value=10, timeouts_enabled=False, expected=None) - - -class TestPythonSources(TestBase): - @classmethod - def rules(cls): - return [*target_rules(), RootRule(HydrateSourcesRequest)] - - def test_python_binary_sources_validation(self) -> None: - self.create_files(path="", files=["f1.py", "f2.py"]) - address = Address.parse(":binary") - - zero_sources = PythonBinarySources(None, address=address) - assert ( - self.request_single_product(HydratedSources, zero_sources.request).snapshot.files == () - ) - - one_source = PythonBinarySources(["f1.py"], address=address) - assert self.request_single_product(HydratedSources, one_source.request).snapshot.files == ( - "f1.py", - ) - - multiple_sources = PythonBinarySources(["f1.py", "f2.py"], address=address) - with pytest.raises(ExecutionError) as exc: - self.request_single_product(HydratedSources, multiple_sources.request) - assert "has 2 sources" in str(exc.value) diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index 789344a5ebd..a8caf42b858 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -27,6 +27,7 @@ from pants.util.frozendict import FrozenDict from pants.util.meta import frozen_after_init from pants.util.ordered_set import FrozenOrderedSet +from pants.util.strutil import pluralize # ----------------------------------------------------------------------------------------------- # Core Field abstractions @@ -880,6 +881,7 @@ class Sources(AsyncField): sanitized_raw_value: Optional[Tuple[str, ...]] default: ClassVar[Optional[Tuple[str, ...]]] = None expected_file_extensions: ClassVar[Optional[Tuple[str, ...]]] = None + expected_num_files: ClassVar[Optional[Union[int, range]]] = None @classmethod def sanitize_raw_value( @@ -901,13 +903,16 @@ def sanitize_raw_value( return tuple(sorted(value_or_default)) def validate_snapshot(self, snapshot: Snapshot) -> None: - """Perform any additional validation on the resulting snapshot, e.g. ensuring that there are - only a certain number of resolved files. + """Perform any additional validation on the resulting snapshot, e.g. ensuring that certain + banned files are not used. To enforce that the resulting files end in certain extensions, such as `.py` or `.java`, set the class property `expected_file_extensions`. + + To enforce that there are only a certain number of resulting files, such as binary targets + checking for only 0-1 sources, set the class property `expected_num_files`. """ - if self.expected_file_extensions: + if self.expected_file_extensions is not None: bad_files = [ fp for fp in snapshot.files @@ -923,6 +928,27 @@ def validate_snapshot(self, snapshot: Snapshot) -> None: f"The {repr(self.alias)} field in target {self.address} must only contain " f"files that end in {expected}, but it had these files: {sorted(bad_files)}." ) + if self.expected_num_files is not None: + num_files = len(snapshot.files) + is_bad_num_files = ( + num_files not in self.expected_num_files + if isinstance(self.expected_num_files, range) + else num_files != self.expected_num_files + ) + if is_bad_num_files: + if isinstance(self.expected_num_files, range): + if len(self.expected_num_files) == 2: + expected_str = ( + " or ".join(str(n) for n in self.expected_num_files) + " files" + ) + else: + expected_str = f"a number of files in the range `{self.expected_num_files}`" + else: + expected_str = pluralize(self.expected_num_files, "file") + raise InvalidFieldException( + f"The {repr(self.alias)} field in target {self.address} must have " + f"{expected_str}, but it had {pluralize(num_files, 'file')}." + ) @final @property diff --git a/src/python/pants/engine/target_test.py b/src/python/pants/engine/target_test.py index 8d4adacdbaa..97e0aa1014d 100644 --- a/src/python/pants/engine/target_test.py +++ b/src/python/pants/engine/target_test.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple +from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type import pytest from typing_extensions import final @@ -590,3 +590,34 @@ class ExpectedExtensionsSources(Sources): assert self.request_single_product( HydratedSources, valid_sources.request ).snapshot.files == ("src/fortran/s.f95",) + + def test_expected_num_files(self) -> None: + class ExpectedNumber(Sources): + expected_num_files = 2 + + class ExpectedRange(Sources): + # We allow for 1 or 3 files + expected_num_files = range(1, 4, 2) + + self.create_files("", files=["f1.txt", "f2.txt", "f3.txt", "f4.txt"]) + + def hydrate(sources_cls: Type[Sources], sources: Iterable[str]) -> HydratedSources: + return self.request_single_product( + HydratedSources, sources_cls(sources, address=Address.parse(":example")).request + ) + + with pytest.raises(ExecutionError) as exc: + hydrate(ExpectedNumber, []) + assert "must have 2 files" in str(exc.value) + with pytest.raises(ExecutionError) as exc: + hydrate(ExpectedRange, ["f1.txt", "f2.txt"]) + assert "must have 1 or 3 files" in str(exc.value) + + # Also check that we support valid # files. + assert hydrate(ExpectedNumber, ["f1.txt", "f2.txt"]).snapshot.files == ("f1.txt", "f2.txt") + assert hydrate(ExpectedRange, ["f1.txt"]).snapshot.files == ("f1.txt",) + assert hydrate(ExpectedRange, ["f1.txt", "f2.txt", "f3.txt"]).snapshot.files == ( + "f1.txt", + "f2.txt", + "f3.txt", + )