Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for codegen targets to be used directly by JVM compiler requests #14751

Merged
merged 12 commits into from
Mar 11, 2022
33 changes: 33 additions & 0 deletions src/python/pants/backend/codegen/protobuf/java/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

"""Generate Java sources from Protocol Buffers (Protobufs).

See https://www.pantsbuild.org/docs/protobuf.
"""

from pants.backend.codegen import export_codegen_goal
from pants.backend.codegen.protobuf import protobuf_dependency_inference
from pants.backend.codegen.protobuf import tailor as protobuf_tailor
from pants.backend.codegen.protobuf.java.rules import rules as java_rules
from pants.backend.codegen.protobuf.target_types import (
ProtobufSourcesGeneratorTarget,
ProtobufSourceTarget,
)
from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_rules
from pants.core.util_rules import stripped_source_files


def rules():
return [
*java_rules(),
*protobuf_dependency_inference.rules(),
*protobuf_tailor.rules(),
*export_codegen_goal.rules(),
*protobuf_target_rules(),
*stripped_source_files.rules(),
]


def target_types():
return [ProtobufSourcesGeneratorTarget, ProtobufSourceTarget]
13 changes: 12 additions & 1 deletion src/python/pants/backend/codegen/protobuf/java/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@


from pants.backend.codegen.protobuf.protoc import Protoc
from pants.backend.codegen.protobuf.target_types import ProtobufSourceField
from pants.backend.codegen.protobuf.target_types import (
ProtobufSourceField,
ProtobufSourcesGeneratorTarget,
ProtobufSourceTarget,
)
from pants.backend.java.target_types import JavaSourceField
from pants.backend.python.util_rules import pex
from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest
Expand All @@ -28,6 +32,7 @@
TransitiveTargetsRequest,
)
from pants.engine.unions import UnionRule
from pants.jvm.target_types import JvmJdkField
from pants.source.source_root import SourceRoot, SourceRootRequest
from pants.util.logging import LogLevel

Expand Down Expand Up @@ -116,9 +121,15 @@ async def generate_java_from_protobuf(
return GeneratedSources(source_root_restored)


class PrefixedJvmJdkField(JvmJdkField):
alias = "jvm_jdk"


def rules():
return [
*collect_rules(),
*pex.rules(),
UnionRule(GenerateSourcesRequest, GenerateJavaFromProtobufRequest),
ProtobufSourceTarget.register_plugin_field(PrefixedJvmJdkField),
ProtobufSourcesGeneratorTarget.register_plugin_field(PrefixedJvmJdkField),
]
33 changes: 33 additions & 0 deletions src/python/pants/backend/codegen/protobuf/scala/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

"""Generate Scala sources from Protocol Buffers (Protobufs).

See https://www.pantsbuild.org/docs/protobuf.
"""

from pants.backend.codegen import export_codegen_goal
from pants.backend.codegen.protobuf import protobuf_dependency_inference
from pants.backend.codegen.protobuf import tailor as protobuf_tailor
from pants.backend.codegen.protobuf.scala.rules import rules as scala_rules
from pants.backend.codegen.protobuf.target_types import (
ProtobufSourcesGeneratorTarget,
ProtobufSourceTarget,
)
from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_rules
from pants.core.util_rules import stripped_source_files


def rules():
return [
*scala_rules(),
*protobuf_dependency_inference.rules(),
*protobuf_tailor.rules(),
*export_codegen_goal.rules(),
*protobuf_target_rules(),
*stripped_source_files.rules(),
]


def target_types():
return [ProtobufSourcesGeneratorTarget, ProtobufSourceTarget]
7 changes: 3 additions & 4 deletions src/python/pants/backend/java/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.bsp.spec import JvmBuildTarget
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes
from pants.jvm.resolve.key import CoursierResolveKey

LANGUAGE_ID = "java"
Expand Down Expand Up @@ -178,8 +178,7 @@ class JavaBSPCompileFieldSet(BSPCompileFieldSet):

@rule
async def bsp_java_compile_request(
request: JavaBSPCompileFieldSet,
union_membership: UnionMembership,
request: JavaBSPCompileFieldSet, jvm_request_types: JVMRequestTypes
) -> BSPCompileResult:
coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address]))
assert len(coarsened_targets) == 1
Expand All @@ -190,7 +189,7 @@ async def bsp_java_compile_request(
FallibleClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(
union_membership, component=coarsened_target, resolve=resolve
jvm_request_types, component=coarsened_target, resolve=resolve
),
)
_logger.info(f"java compile result = {result}")
Expand Down
8 changes: 4 additions & 4 deletions src/python/pants/backend/java/goals/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from pants.engine.addresses import Addresses
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import CoarsenedTargets
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry
from pants.engine.unions import UnionRule
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes
from pants.jvm.resolve.key import CoursierResolveKey
from pants.util.logging import LogLevel

Expand All @@ -27,7 +27,7 @@ class JavacCheckRequest(CheckRequest):
@rule(desc="Check javac compilation", level=LogLevel.DEBUG)
async def javac_check(
request: JavacCheckRequest,
union_membership: UnionMembership,
jvm_request_types: JVMRequestTypes,
) -> CheckResults:
coarsened_targets = await Get(
CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets)
Expand All @@ -43,7 +43,7 @@ async def javac_check(
Get(
FallibleClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(union_membership, component=target, resolve=resolve),
ClasspathEntryRequest.for_targets(jvm_request_types, component=target, resolve=resolve),
)
for target, resolve in zip(coarsened_targets, resolves)
)
Expand Down
9 changes: 3 additions & 6 deletions src/python/pants/backend/scala/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pants.bsp.util_rules.lifecycle import BSPLanguageSupport
from pants.bsp.util_rules.targets import BSPBuildTargets, BSPBuildTargetsRequest
from pants.build_graph.address import Address, AddressInput
from pants.core.util_rules.system_binaries import BashBinary, UnzipBinary
from pants.engine.addresses import Addresses
from pants.engine.fs import EMPTY_DIGEST, AddPrefix, CreateDigest, Digest, DigestEntries
from pants.engine.internals.selectors import Get, MultiGet
Expand All @@ -42,7 +41,7 @@
WrappedTarget,
)
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes
from pants.jvm.resolve.key import CoursierResolveKey
from pants.jvm.subsystems import JvmSubsystem
from pants.jvm.target_types import JvmResolveField
Expand Down Expand Up @@ -199,9 +198,7 @@ class ScalaBSPCompileFieldSet(BSPCompileFieldSet):
@rule
async def bsp_scala_compile_request(
request: ScalaBSPCompileFieldSet,
union_membership: UnionMembership,
unzip: UnzipBinary,
bash: BashBinary,
jvm_request_types: JVMRequestTypes,
) -> BSPCompileResult:
coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address]))
assert len(coarsened_targets) == 1
Expand All @@ -212,7 +209,7 @@ async def bsp_scala_compile_request(
FallibleClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(
union_membership, component=coarsened_target, resolve=resolve
jvm_request_types, component=coarsened_target, resolve=resolve
),
)
_logger.info(f"scala compile result = {result}")
Expand Down
8 changes: 4 additions & 4 deletions src/python/pants/backend/scala/goals/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from pants.engine.addresses import Addresses
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import CoarsenedTargets
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry
from pants.engine.unions import UnionRule
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes
from pants.jvm.resolve.key import CoursierResolveKey
from pants.util.logging import LogLevel

Expand All @@ -27,7 +27,7 @@ class ScalacCheckRequest(CheckRequest):
@rule(desc="Check compilation for Scala", level=LogLevel.DEBUG)
async def scalac_check(
request: ScalacCheckRequest,
union_membership: UnionMembership,
jvm_request_types: JVMRequestTypes,
) -> CheckResults:
coarsened_targets = await Get(
CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets)
Expand All @@ -43,7 +43,7 @@ async def scalac_check(
Get(
FallibleClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(union_membership, component=target, resolve=resolve),
ClasspathEntryRequest.for_targets(jvm_request_types, component=target, resolve=resolve),
)
for target, resolve in zip(coarsened_targets, resolves)
)
Expand Down
7 changes: 3 additions & 4 deletions src/python/pants/jvm/classpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from pants.engine.fs import Digest
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import CoarsenedTargets
from pants.engine.unions import UnionMembership
from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest
from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest, JVMRequestTypes
from pants.jvm.resolve.key import CoursierResolveKey

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,7 +68,7 @@ def root_immutable_inputs_args(self, *, prefix: str = "") -> Iterator[str]:
@rule
async def classpath(
coarsened_targets: CoarsenedTargets,
union_membership: UnionMembership,
jvm_request_types: JVMRequestTypes,
) -> Classpath:
# Compute a single shared resolve for all of the roots, which will validate that they
# are compatible with one another.
Expand All @@ -81,7 +80,7 @@ async def classpath(
ClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(
union_membership, component=t, resolve=resolve, root=True
jvm_request_types, component=t, resolve=resolve, root=True
),
)
for t in coarsened_targets
Expand Down
62 changes: 56 additions & 6 deletions src/python/pants/jvm/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import os
from abc import ABCMeta
from collections import deque
from collections import defaultdict, deque
from dataclasses import dataclass
from enum import Enum, auto
from typing import ClassVar, Iterable, Iterator, Sequence
Expand All @@ -17,9 +17,16 @@
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.process import FallibleProcessResult
from pants.engine.rules import collect_rules, rule
from pants.engine.target import CoarsenedTarget, FieldSet
from pants.engine.target import (
CoarsenedTarget,
Field,
FieldSet,
GenerateSourcesRequest,
SourcesField,
)
from pants.engine.unions import UnionMembership, union
from pants.jvm.resolve.key import CoursierResolveKey
from pants.util.frozendict import FrozenDict
from pants.util.logging import LogLevel
from pants.util.meta import frozen_after_init
from pants.util.ordered_set import FrozenOrderedSet
Expand Down Expand Up @@ -47,6 +54,38 @@ class _ClasspathEntryRequestClassification(Enum):
INCOMPATIBLE = auto()


@dataclass(frozen=True)
class JVMRequestTypes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: a more meaningful name would be good.

classpath_entry_requests: tuple[type[ClasspathEntryRequest], ...]
code_generator_requests: FrozenDict[type[SourcesField], tuple[type[ClasspathEntryRequest], ...]]


@rule
def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequestTypes:
cpe_impls = union_membership.get(ClasspathEntryRequest)
impls_by_source: dict[type[Field], type[ClasspathEntryRequest]] = {}
for impl in cpe_impls:
for field_set in impl.field_sets:
for field in field_set.required_fields:
# Assume only one impl per field (normally sound)
# (note that subsequently, we only check for `SourceFields`, so no need to filter)
impls_by_source[field] = impl

generators: Iterable[type[GenerateSourcesRequest]] = union_membership.get(
GenerateSourcesRequest
)

usable_generators_: dict[type[SourcesField], list[type[ClasspathEntryRequest]]] = defaultdict(
list
)
for g in generators:
if g.output in impls_by_source:
usable_generators_[g.input].append(impls_by_source[g.output])
usable_generators = FrozenDict((key, tuple(value)) for key, value in usable_generators_.items())

return JVMRequestTypes(tuple(cpe_impls), usable_generators)


@union
@dataclass(frozen=True)
class ClasspathEntryRequest(metaclass=ABCMeta):
Expand Down Expand Up @@ -77,7 +116,7 @@ class ClasspathEntryRequest(metaclass=ABCMeta):

@staticmethod
def for_targets(
union_membership: UnionMembership,
jvm_request_types: JVMRequestTypes,
component: CoarsenedTarget,
resolve: CoursierResolveKey,
*,
Expand All @@ -89,10 +128,21 @@ def for_targets(
request types which are marked `root_only`.
"""

for (input, request_types) in jvm_request_types.code_generator_requests.items():
if not component.representative.has_field(input):
continue
if len(request_types) > 1:
raise ClasspathSourceAmbiguity(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at the moment, this will result in a late failure, and only if a user attempts to use a source file with conflicting implementations as a dependency. I could re-write to fail fast, but the code will eventually need to fail here once #14041 is in place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still not clear to me why this matching isn't happening inside of ClasspathEntryRequest.classify_impl. If it did, both 1) only matching the representative, 2) late erroring... would be handled by the existing code below, right? Because you'd match multiple impls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that all checks out. I'll make that refactor before landing this.

f"More than one code generator ({request_types}) was compatible with the "
f"inputs:\n{component.bullet_list()}"
)
request_type = request_types[0]
return request_type(component, resolve, None)

compatible = []
partial = []
consume_only = []
impls = union_membership.get(ClasspathEntryRequest)
impls = jvm_request_types.classpath_entry_requests
for impl in impls:
classification = ClasspathEntryRequest.classify_impl(impl, component)
if classification == _ClasspathEntryRequestClassification.INCOMPATIBLE:
Expand Down Expand Up @@ -341,7 +391,7 @@ def required_classfiles(fallible_result: FallibleClasspathEntry) -> ClasspathEnt

@rule
def classpath_dependency_requests(
union_membership: UnionMembership, request: ClasspathDependenciesRequest
jvm_request_types: JVMRequestTypes, request: ClasspathDependenciesRequest
) -> ClasspathEntryRequests:
def ignore_because_generated(coarsened_dep: CoarsenedTarget) -> bool:
if len(coarsened_dep.members) == 1:
Expand All @@ -352,7 +402,7 @@ def ignore_because_generated(coarsened_dep: CoarsenedTarget) -> bool:

return ClasspathEntryRequests(
ClasspathEntryRequest.for_targets(
union_membership, component=coarsened_dep, resolve=request.request.resolve
jvm_request_types, component=coarsened_dep, resolve=request.request.resolve
)
for coarsened_dep in request.request.component.dependencies
if not request.ignore_generated or not ignore_because_generated(coarsened_dep)
Expand Down
Loading