Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow HydratedSourcesRequest to indicate which Sources types are expected #9641

Merged
merged 5 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions src/python/pants/backend/python/rules/importable_python_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from pants.backend.python.rules.inject_init import InitInjectedSnapshot, InjectInitRequest
from pants.backend.python.rules.inject_init import rules as inject_init_rules
from pants.backend.python.target_types import PythonRequirementsFileSources, PythonSources
from pants.backend.python.target_types import PythonSources
from pants.core.target_types import FilesSources, ResourcesSources
from pants.core.util_rules import determine_source_files
from pants.core.util_rules.determine_source_files import AllSourceFilesRequest, SourceFiles
from pants.engine.fs import Snapshot
from pants.engine.rules import RootRule, rule
from pants.engine.selectors import Get
from pants.engine.target import Sources, Target, Targets
from pants.engine.target import Sources, Targets


@dataclass(frozen=True)
Expand All @@ -33,19 +33,11 @@ class ImportablePythonSources:

@rule
async def prepare_python_sources(targets: Targets) -> ImportablePythonSources:
def is_relevant(tgt: Target) -> bool:
# NB: PythonRequirementsFileSources is a subclass of FilesSources. We filter it out so that
# requirements.txt is not included. If the user intended for the file to be included, they
# should use a normal `files()` target rather than `python_requirements()`.
return (
tgt.has_field(PythonSources)
or tgt.has_field(ResourcesSources)
or (tgt.has_field(FilesSources) and not tgt.has_field(PythonRequirementsFileSources))
)

stripped_sources = await Get[SourceFiles](
AllSourceFilesRequest(
(tgt.get(Sources) for tgt in targets if is_relevant(tgt)), strip_source_roots=True
(tgt.get(Sources) for tgt in targets),
valid_sources_types=(PythonSources, ResourcesSources, FilesSources),
strip_source_roots=True,
)
Copy link
Contributor Author

@Eric-Arellano Eric-Arellano Apr 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we add codegen, with this PR:

         AllSourceFilesRequest(
             (tgt.get(Sources) for tgt in targets),
             valid_sources_types=(PythonSources, ResourcesSources, FilesSources),
             codegen_enabled=True,
             strip_source_roots=True,
         )

With the approach in #9634:

        AllSourceFilesRequest(
            (
                tgt.get(Sources)
                for tgt in targets
                if any(
                    tgt.has_field(sources)
                    for sources in (PythonSources, ResourcesSources, FilesSources, CodegenSources)
                )
            ),
            codegen_language=PythonSources,
            strip_source_roots=True,
        )

)
init_injected = await Get[InitInjectedSnapshot](InjectInitRequest(stripped_sources.snapshot))
Expand Down
10 changes: 8 additions & 2 deletions src/python/pants/backend/python/rules/run_setup_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,11 @@ async def get_sources(
) -> SetupPySources:
targets = request.targets
stripped_srcs_list = await MultiGet(
Get[SourceRootStrippedSources](StripSourcesFieldRequest(target.get(Sources)))
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(
target.get(Sources), valid_sources_types=(PythonSources, ResourcesSources)
)
)
for target in targets
)

Expand Down Expand Up @@ -518,7 +522,9 @@ async def get_ancestor_init_py(
"""
source_roots = source_root_config.get_source_roots()
sources = await Get[SourceFiles](
AllSourceFilesRequest(tgt[PythonSources] for tgt in targets if tgt.has_field(PythonSources))
AllSourceFilesRequest(
(tgt.get(Sources) for tgt in targets), valid_sources_types=(PythonSources,)
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will be necessary once we have codegen. tgt.has_field(PythonSources) will fail when encountering a ProtobufLibrary because it has ProtobufSources, even if that can be generated into PythonSources.

Copy link
Contributor Author

@Eric-Arellano Eric-Arellano Apr 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we add codegen, with this PR:

         AllSourceFilesRequest(
              (tgt.get(Sources) for tgt in targets),
              valid_sources_types=(PythonSources,),
              codegen_enabled=True
         )

If we go with the approach in #9634:

        AllSourceFilesRequest(
            (
                tgt[Sources]
                for tgt in targets
                if tgt.has_field(PythonSources) or tgt.has_field(CodegenSources)
            ),
            codegen_language=PythonSources,
        )

)
# Find the ancestors of all dirs containing .py files, including those dirs themselves.
source_dir_ancestors: Set[Tuple[str, str]] = set() # Items are (src_root, path incl. src_root).
Expand Down
5 changes: 1 addition & 4 deletions src/python/pants/backend/python/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from pants.backend.python.python_artifact import PythonArtifact
from pants.backend.python.subsystems.pytest import PyTest
from pants.core.target_types import FilesSources
from pants.core.util_rules.determine_source_files import SourceFiles
from pants.engine.addresses import Address
from pants.engine.fs import Snapshot
Expand Down Expand Up @@ -414,9 +413,7 @@ class PythonRequirementLibrary(Target):
# -----------------------------------------------------------------------------------------------


# NB: This subclasses FilesSources to ensure that we still properly handle stripping source roots,
# but we still new type so that we can distinguish between normal FilesSources vs. this field.
class PythonRequirementsFileSources(FilesSources):
class PythonRequirementsFileSources(Sources):
pass


Expand Down
41 changes: 33 additions & 8 deletions src/python/pants/core/util_rules/determine_source_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from dataclasses import dataclass
from typing import Iterable, Tuple, Union
from typing import Iterable, Tuple, Type, Union

from pants.base.specs import AddressSpec, OriginSpec
from pants.core.util_rules import strip_source_roots
Expand Down Expand Up @@ -35,28 +35,37 @@ def files(self) -> Tuple[str, ...]:
@dataclass(unsafe_hash=True)
class AllSourceFilesRequest:
sources_fields: Tuple[SourcesField, ...]
strip_source_roots: bool = False
valid_sources_types: Tuple[Type[SourcesField], ...]
strip_source_roots: bool

def __init__(
self, sources_fields: Iterable[SourcesField], *, strip_source_roots: bool = False
self,
sources_fields: Iterable[SourcesField],
*,
valid_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
strip_source_roots: bool = False
) -> None:
self.sources_fields = tuple(sources_fields)
self.valid_sources_types = tuple(valid_sources_types)
self.strip_source_roots = strip_source_roots


@frozen_after_init
@dataclass(unsafe_hash=True)
class SpecifiedSourceFilesRequest:
sources_fields_with_origins: Tuple[Tuple[SourcesField, OriginSpec], ...]
strip_source_roots: bool = False
valid_sources_types: Tuple[Type[SourcesField], ...]
strip_source_roots: bool

def __init__(
self,
sources_fields_with_origins: Iterable[Tuple[SourcesField, OriginSpec]],
*,
valid_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
strip_source_roots: bool = False
) -> None:
self.sources_fields_with_origins = tuple(sources_fields_with_origins)
self.valid_sources_types = tuple(valid_sources_types)
self.strip_source_roots = strip_source_roots


Expand All @@ -81,15 +90,23 @@ async def determine_all_source_files(request: AllSourceFilesRequest) -> SourceFi
"""Merge all `Sources` fields into one Snapshot."""
if request.strip_source_roots:
stripped_snapshots = await MultiGet(
Get[SourceRootStrippedSources](StripSourcesFieldRequest(sources_field))
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(
sources_field, valid_sources_types=request.valid_sources_types
)
)
for sources_field in request.sources_fields
)
digests_to_merge = tuple(
stripped_snapshot.snapshot.directory_digest for stripped_snapshot in stripped_snapshots
)
else:
all_hydrated_sources = await MultiGet(
Get[HydratedSources](HydrateSourcesRequest(sources_field))
Get[HydratedSources](
HydrateSourcesRequest(
sources_field, valid_sources_types=request.valid_sources_types
)
)
for sources_field in request.sources_fields
)
digests_to_merge = tuple(
Expand All @@ -104,7 +121,11 @@ async def determine_specified_source_files(request: SpecifiedSourceFilesRequest)
"""Determine the specified `sources` for targets, possibly finding a subset of the original
`sources` fields if the user supplied file arguments."""
all_hydrated_sources = await MultiGet(
Get[HydratedSources](HydrateSourcesRequest(sources_field_with_origin[0]))
Get[HydratedSources](
HydrateSourcesRequest(
sources_field_with_origin[0], valid_sources_types=request.valid_sources_types
)
)
for sources_field_with_origin in request.sources_fields_with_origins
)

Expand Down Expand Up @@ -133,7 +154,11 @@ async def determine_specified_source_files(request: SpecifiedSourceFilesRequest)
all_sources_fields = (*full_snapshots.keys(), *snapshot_subset_requests.keys())
stripped_snapshots = await MultiGet(
Get[SourceRootStrippedSources](
StripSourcesFieldRequest(sources_field, specified_files_snapshot=snapshot)
StripSourcesFieldRequest(
sources_field,
specified_files_snapshot=snapshot,
valid_sources_types=request.valid_sources_types,
)
)
for sources_field, snapshot in zip(all_sources_fields, all_snapshots)
)
Expand Down
24 changes: 21 additions & 3 deletions src/python/pants/core/util_rules/strip_source_roots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
from dataclasses import dataclass
from pathlib import PurePath
from typing import Optional, cast
from typing import Iterable, Optional, Tuple, Type, cast

from pants.core.target_types import FilesSources
from pants.engine.addresses import Address
Expand All @@ -23,6 +23,7 @@
from pants.engine.target import Sources as SourcesField
from pants.engine.target import rules as target_rules
from pants.source.source_root import NoSourceRootError, SourceRootConfig
from pants.util.meta import frozen_after_init


@dataclass(frozen=True)
Expand All @@ -46,7 +47,8 @@ class StripSnapshotRequest:
representative_path: Optional[str] = None


@dataclass(frozen=True)
@frozen_after_init
@dataclass(unsafe_hash=True)
class StripSourcesFieldRequest:
"""A request to strip source roots for every file in a `Sources` field.

Expand All @@ -56,8 +58,20 @@ class StripSourcesFieldRequest:
"""

sources_field: SourcesField
valid_sources_types: Tuple[Type[SourcesField], ...] = (SourcesField,)
specified_files_snapshot: Optional[Snapshot] = None

def __init__(
self,
sources_field: SourcesField,
*,
valid_sources_types: Iterable[Type[SourcesField]] = (SourcesField,),
specified_files_snapshot: Optional[Snapshot] = None,
) -> None:
self.sources_field = sources_field
self.valid_sources_types = tuple(valid_sources_types)
self.specified_files_snapshot = specified_files_snapshot


@rule
async def strip_source_roots_from_snapshot(
Expand Down Expand Up @@ -129,7 +143,11 @@ async def strip_source_roots_from_sources_field(
if request.specified_files_snapshot is not None:
sources_snapshot = request.specified_files_snapshot
else:
hydrated_sources = await Get[HydratedSources](HydrateSourcesRequest(request.sources_field))
hydrated_sources = await Get[HydratedSources](
HydrateSourcesRequest(
request.sources_field, valid_sources_types=request.valid_sources_types
)
)
sources_snapshot = hydrated_sources.snapshot

if not sources_snapshot.files:
Expand Down
43 changes: 39 additions & 4 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,15 +1343,39 @@ def filespec(self) -> Filespec:
)


@dataclass(frozen=True)
@frozen_after_init
@dataclass(unsafe_hash=True)
class HydrateSourcesRequest:
field: Sources
valid_sources_types: Tuple[Type[Sources], ...]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming wise, the word "valid" would seem to imply that the request will fail if the sources can't be converted to one of these types. "desired" would be a bit verbose. Maybe just output_ ...?


def __init__(
self, field: Sources, *, valid_sources_types: Iterable[Type[Sources]] = (Sources,)
) -> None:
"""Convert raw sources globs into an instance of HydratedSources.

If you only want to convert certain Sources fields, such as only PythonSources, set
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/convert/handle/

You could probably also the fact that you might get subclasses of your requested types, but that the output type will always be an exact match.

`valid_sources_types`. Any invalid sources will return an empty `HydratedSources` instance,
indicated by the attribute `output_type = None`.
"""
self.field = field
self.valid_sources_types = tuple(valid_sources_types)


@dataclass(frozen=True)
class HydratedSources:
"""The result of hydrating a SourcesField.

The `output_type` will indicate which of the `HydrateSourcesRequest.valid_sources_type` the
result corresponds to, e.g. if the result comes from `FilesSources` vs. `PythonSources`. If this
value is None, then the input `Sources` field was not one of the expected types. This property
allows for switching on the result, e.g. handling hydrated files() sources differently than
hydrated Python sources.
"""

snapshot: Snapshot
filespec: Filespec
output_type: Optional[Type[Sources]]

def eager_fileset_with_spec(self, *, address: Address) -> EagerFilesetWithSpec:
return EagerFilesetWithSpec(address.spec_path, self.filespec, self.snapshot)
Expand All @@ -1362,10 +1386,21 @@ async def hydrate_sources(
request: HydrateSourcesRequest, glob_match_error_behavior: GlobMatchErrorBehavior
) -> HydratedSources:
sources_field = request.field
globs = sources_field.sanitized_raw_value

output_type = next(
(
valid_type
for valid_type in request.valid_sources_types
if isinstance(sources_field, valid_type)
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that we will upcast subclasses like PythonBinarySources into PythonSources, which makes it easier for call sites to switch on the resulting output_type.

None,
)
if output_type is None:
return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, output_type=None)

globs = sources_field.sanitized_raw_value
if globs is None:
return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec)
return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, output_type=output_type)

conjunction = (
GlobExpansionConjunction.all_match
Expand All @@ -1387,7 +1422,7 @@ async def hydrate_sources(
)
)
sources_field.validate_snapshot(snapshot)
return HydratedSources(snapshot, sources_field.filespec)
return HydratedSources(snapshot, sources_field.filespec, output_type=output_type)


# TODO: figure out what support looks like for this with the Target API. The expected value is an
Expand Down
25 changes: 24 additions & 1 deletion src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def assert_invalid_type(raw_value: Any) -> None:


# -----------------------------------------------------------------------------------------------
# Test common fields
# Test Sources
# -----------------------------------------------------------------------------------------------


Expand Down Expand Up @@ -791,6 +791,29 @@ def test_normal_hydration(self) -> None:
"exclude": [{"globs": ["src/fortran/**/ignore*", "src/fortran/ignored.f03"]}],
}

def test_output_type(self) -> None:
class SourcesSubclass(Sources):
pass

addr = Address.parse(":lib")
self.create_files("", files=["f1.f95"])

valid_sources = SourcesSubclass(["*"], address=addr)
hydrated_valid_sources = self.request_single_product(
HydratedSources,
HydrateSourcesRequest(valid_sources, valid_sources_types=[SourcesSubclass]),
)
assert hydrated_valid_sources.snapshot.files == ("f1.f95",)
assert hydrated_valid_sources.output_type == SourcesSubclass

invalid_sources = Sources(["*"], address=addr)
hydrated_invalid_sources = self.request_single_product(
HydratedSources,
HydrateSourcesRequest(invalid_sources, valid_sources_types=[SourcesSubclass]),
)
assert hydrated_invalid_sources.snapshot.files == ()
assert hydrated_invalid_sources.output_type is None

def test_unmatched_globs(self) -> None:
self.create_files("", files=["f1.f95"])
sources = Sources(["non_existent.f95"], address=Address.parse(":lib"))
Expand Down