Skip to content

Commit

Permalink
Allow Sources subclasses to specify the # of expected files (#9466)
Browse files Browse the repository at this point in the history
Twice now, we've had `Sources` fields that require a specific number of files. So, this PR formalizes a mechanism to do this.

Error message #1:

> ERROR: The 'sources' field in target build-support/bin:check_banned_imports must have 3 files, but it had 1 file.

Error message #2: 

> ERROR: The 'sources' field in target build-support/bin:check_banned_imports must have 2 or 3 files, but it had 1 file.

Error message #3: 

> ERROR: The 'sources' field in target build-support/bin:check_banned_imports must have a number of files in the range `range(20, 20000, 10)`, but it had 1 file.


[ci skip-rust-tests]  # No Rust changes made.
[ci skip-jvm-tests]  # No JVM changes made.
  • Loading branch information
Eric-Arellano authored Apr 5, 2020
1 parent 1f0e13c commit 3e0b781
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 117 deletions.
174 changes: 91 additions & 83 deletions src/python/pants/backend/python/rules/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from pants.backend.python.subsystems.pytest import PyTest
from pants.build_graph.address import Address
from pants.engine.fs import Snapshot
from pants.engine.objects import union
from pants.engine.target import (
COMMON_TARGET_FIELDS,
BoolField,
Expand All @@ -23,32 +21,15 @@
from pants.python.python_setup import PythonSetup
from pants.rules.core.determine_source_files import SourceFiles

# -----------------------------------------------------------------------------------------------
# Common fields
# -----------------------------------------------------------------------------------------------


@union
class PythonSources(Sources):
expected_file_extensions = (".py",)


class PythonLibrarySources(PythonSources):
default = ("*.py", "!test_*.py", "!*_test.py", "!conftest.py")


class PythonTestsSources(PythonSources):
default = ("test_*.py", "*_test.py", "conftest.py")


class PythonBinarySources(PythonSources):
def validate_snapshot(self, snapshot: Snapshot) -> None:
super().validate_snapshot(snapshot)
if len(snapshot.files) not in [0, 1]:
raise InvalidFieldException(
f"The {repr(self.alias)} field in target {self.address} must only have 0 or 1 "
f"files because it is a binary target, but it has {len(snapshot.files)} sources: "
f"{sorted(snapshot.files)}.\n\nTo use any additional files, put them in a "
"`python_library` and then add that `python_library` as a `dependency`."
)


class Compatibility(StringOrStringSequenceField):
"""A string for Python interpreter constraints on this target.
Expand All @@ -67,67 +48,16 @@ def value_or_global_default(self, python_setup: PythonSetup) -> Tuple[str, ...]:
return python_setup.compatibility_or_constraints(self.value)


class Coverage(StringOrStringSequenceField):
"""The module(s) whose coverage should be generated, e.g. `['pants.util']`."""

alias = "coverage"

def determine_packages_to_cover(
self, *, specified_source_files: SourceFiles
) -> Tuple[str, ...]:
"""Either return the specified `coverage` field value or, if not defined, attempt to
generate packages with a heuristic that tests have the same package name as their source
code.
This heuristic about package names works when either the tests live in the same folder as
their source code, or there is a parallel file structure with the same top-level package
names, e.g. `src/python/project` and `tests/python/project` (but not
`tests/python/test_project`).
"""
if self.value is not None:
return self.value
return tuple(
sorted(
{
# Turn file paths into package names.
os.path.dirname(source_file).replace(os.sep, ".")
for source_file in specified_source_files.snapshot.files
}
)
)


class Timeout(IntField):
"""A timeout (in seconds) which covers the total runtime of all tests in this target.
COMMON_PYTHON_FIELDS = (*COMMON_TARGET_FIELDS, Dependencies, Compatibility, ProvidesField)

This only applies if `--pytest-timeouts` is set to True.
"""

alias = "timeout"
# -----------------------------------------------------------------------------------------------
# `python_binary` target
# -----------------------------------------------------------------------------------------------

@classmethod
def compute_value(cls, raw_value: Optional[int], *, address: Address) -> Optional[int]:
value = super().compute_value(raw_value, address=address)
if value is not None and value < 1:
raise InvalidFieldException(
f"The value for the `timeout` field in target {address} must be > 0, but was "
f"{value}."
)
return value

def calculate_from_global_options(self, pytest: PyTest) -> Optional[int]:
"""Determine the timeout (in seconds) after applying global `pytest` options."""
if not pytest.timeouts_enabled:
return None
if self.value is None:
if pytest.timeout_default is None:
return None
result = pytest.timeout_default
else:
result = self.value
if pytest.timeout_maximum is not None:
return min(result, pytest.timeout_maximum)
return result
class PythonBinarySources(PythonSources):
expected_num_files = range(0, 2)


class PythonEntryPoint(StringField):
Expand Down Expand Up @@ -236,9 +166,6 @@ class PexEmitWarnings(BoolField):
default = True


COMMON_PYTHON_FIELDS = (*COMMON_TARGET_FIELDS, Dependencies, Compatibility, ProvidesField)


class PythonBinary(Target):
"""A Python target that can be converted into an executable Pex file.
Expand All @@ -264,11 +191,92 @@ class PythonBinary(Target):
)


# -----------------------------------------------------------------------------------------------
# `python_library` target
# -----------------------------------------------------------------------------------------------


class PythonLibrarySources(PythonSources):
default = ("*.py", "!test_*.py", "!*_test.py", "!conftest.py")


class PythonLibrary(Target):
alias = "python_library"
core_fields = (*COMMON_PYTHON_FIELDS, PythonLibrarySources)


# -----------------------------------------------------------------------------------------------
# `python_tests` target
# -----------------------------------------------------------------------------------------------


class PythonTestsSources(PythonSources):
default = ("test_*.py", "*_test.py", "conftest.py")


class Coverage(StringOrStringSequenceField):
"""The module(s) whose coverage should be generated, e.g. `['pants.util']`."""

alias = "coverage"

def determine_packages_to_cover(
self, *, specified_source_files: SourceFiles
) -> Tuple[str, ...]:
"""Either return the specified `coverage` field value or, if not defined, attempt to
generate packages with a heuristic that tests have the same package name as their source
code.
This heuristic about package names works when either the tests live in the same folder as
their source code, or there is a parallel file structure with the same top-level package
names, e.g. `src/python/project` and `tests/python/project` (but not
`tests/python/test_project`).
"""
if self.value is not None:
return self.value
return tuple(
sorted(
{
# Turn file paths into package names.
os.path.dirname(source_file).replace(os.sep, ".")
for source_file in specified_source_files.snapshot.files
}
)
)


class Timeout(IntField):
"""A timeout (in seconds) which covers the total runtime of all tests in this target.
This only applies if `--pytest-timeouts` is set to True.
"""

alias = "timeout"

@classmethod
def compute_value(cls, raw_value: Optional[int], *, address: Address) -> Optional[int]:
value = super().compute_value(raw_value, address=address)
if value is not None and value < 1:
raise InvalidFieldException(
f"The value for the `timeout` field in target {address} must be > 0, but was "
f"{value}."
)
return value

def calculate_from_global_options(self, pytest: PyTest) -> Optional[int]:
"""Determine the timeout (in seconds) after applying global `pytest` options."""
if not pytest.timeouts_enabled:
return None
if self.value is None:
if pytest.timeout_default is None:
return None
result = pytest.timeout_default
else:
result = self.value
if pytest.timeout_maximum is not None:
return min(result, pytest.timeout_maximum)
return result


class PythonTests(Target):
"""Python tests (either Pytest-style or unittest style)."""

Expand Down
32 changes: 2 additions & 30 deletions src/python/pants/backend/python/rules/targets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@

import pytest

from pants.backend.python.rules.targets import PythonBinarySources, Timeout
from pants.backend.python.rules.targets import Timeout
from pants.backend.python.subsystems.pytest import PyTest
from pants.build_graph.address import Address
from pants.engine.rules import RootRule
from pants.engine.scheduler import ExecutionError
from pants.engine.target import HydratedSources, HydrateSourcesRequest, InvalidFieldException
from pants.engine.target import rules as target_rules
from pants.engine.target import InvalidFieldException
from pants.testutil.subsystem.util import global_subsystem_instance
from pants.testutil.test_base import TestBase

Expand Down Expand Up @@ -65,28 +62,3 @@ def test_no_field_timeout_and_default_greater_than_max(self) -> None:

def test_timeouts_disabled(self) -> None:
self.assert_timeout_calculated(field_value=10, timeouts_enabled=False, expected=None)


class TestPythonSources(TestBase):
@classmethod
def rules(cls):
return [*target_rules(), RootRule(HydrateSourcesRequest)]

def test_python_binary_sources_validation(self) -> None:
self.create_files(path="", files=["f1.py", "f2.py"])
address = Address.parse(":binary")

zero_sources = PythonBinarySources(None, address=address)
assert (
self.request_single_product(HydratedSources, zero_sources.request).snapshot.files == ()
)

one_source = PythonBinarySources(["f1.py"], address=address)
assert self.request_single_product(HydratedSources, one_source.request).snapshot.files == (
"f1.py",
)

multiple_sources = PythonBinarySources(["f1.py", "f2.py"], address=address)
with pytest.raises(ExecutionError) as exc:
self.request_single_product(HydratedSources, multiple_sources.request)
assert "has 2 sources" in str(exc.value)
32 changes: 29 additions & 3 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pants.util.frozendict import FrozenDict
from pants.util.meta import frozen_after_init
from pants.util.ordered_set import FrozenOrderedSet
from pants.util.strutil import pluralize

# -----------------------------------------------------------------------------------------------
# Core Field abstractions
Expand Down Expand Up @@ -880,6 +881,7 @@ class Sources(AsyncField):
sanitized_raw_value: Optional[Tuple[str, ...]]
default: ClassVar[Optional[Tuple[str, ...]]] = None
expected_file_extensions: ClassVar[Optional[Tuple[str, ...]]] = None
expected_num_files: ClassVar[Optional[Union[int, range]]] = None

@classmethod
def sanitize_raw_value(
Expand All @@ -901,13 +903,16 @@ def sanitize_raw_value(
return tuple(sorted(value_or_default))

def validate_snapshot(self, snapshot: Snapshot) -> None:
"""Perform any additional validation on the resulting snapshot, e.g. ensuring that there are
only a certain number of resolved files.
"""Perform any additional validation on the resulting snapshot, e.g. ensuring that certain
banned files are not used.
To enforce that the resulting files end in certain extensions, such as `.py` or `.java`, set
the class property `expected_file_extensions`.
To enforce that there are only a certain number of resulting files, such as binary targets
checking for only 0-1 sources, set the class property `expected_num_files`.
"""
if self.expected_file_extensions:
if self.expected_file_extensions is not None:
bad_files = [
fp
for fp in snapshot.files
Expand All @@ -923,6 +928,27 @@ def validate_snapshot(self, snapshot: Snapshot) -> None:
f"The {repr(self.alias)} field in target {self.address} must only contain "
f"files that end in {expected}, but it had these files: {sorted(bad_files)}."
)
if self.expected_num_files is not None:
num_files = len(snapshot.files)
is_bad_num_files = (
num_files not in self.expected_num_files
if isinstance(self.expected_num_files, range)
else num_files != self.expected_num_files
)
if is_bad_num_files:
if isinstance(self.expected_num_files, range):
if len(self.expected_num_files) == 2:
expected_str = (
" or ".join(str(n) for n in self.expected_num_files) + " files"
)
else:
expected_str = f"a number of files in the range `{self.expected_num_files}`"
else:
expected_str = pluralize(self.expected_num_files, "file")
raise InvalidFieldException(
f"The {repr(self.alias)} field in target {self.address} must have "
f"{expected_str}, but it had {pluralize(num_files, 'file')}."
)

@final
@property
Expand Down
33 changes: 32 additions & 1 deletion src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import PurePath
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type

import pytest
from typing_extensions import final
Expand Down Expand Up @@ -590,3 +590,34 @@ class ExpectedExtensionsSources(Sources):
assert self.request_single_product(
HydratedSources, valid_sources.request
).snapshot.files == ("src/fortran/s.f95",)

def test_expected_num_files(self) -> None:
class ExpectedNumber(Sources):
expected_num_files = 2

class ExpectedRange(Sources):
# We allow for 1 or 3 files
expected_num_files = range(1, 4, 2)

self.create_files("", files=["f1.txt", "f2.txt", "f3.txt", "f4.txt"])

def hydrate(sources_cls: Type[Sources], sources: Iterable[str]) -> HydratedSources:
return self.request_single_product(
HydratedSources, sources_cls(sources, address=Address.parse(":example")).request
)

with pytest.raises(ExecutionError) as exc:
hydrate(ExpectedNumber, [])
assert "must have 2 files" in str(exc.value)
with pytest.raises(ExecutionError) as exc:
hydrate(ExpectedRange, ["f1.txt", "f2.txt"])
assert "must have 1 or 3 files" in str(exc.value)

# Also check that we support valid # files.
assert hydrate(ExpectedNumber, ["f1.txt", "f2.txt"]).snapshot.files == ("f1.txt", "f2.txt")
assert hydrate(ExpectedRange, ["f1.txt"]).snapshot.files == ("f1.txt",)
assert hydrate(ExpectedRange, ["f1.txt", "f2.txt", "f3.txt"]).snapshot.files == (
"f1.txt",
"f2.txt",
"f3.txt",
)

0 comments on commit 3e0b781

Please sign in to comment.