From ee95331051f08fe2a0522637baf2ea75a82e84ae Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Wed, 9 Mar 2022 13:44:44 -0800 Subject: [PATCH 01/12] Adds Java protobuf rules file # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- .../backend/codegen/protobuf/java/register.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 src/python/pants/backend/codegen/protobuf/java/register.py diff --git a/src/python/pants/backend/codegen/protobuf/java/register.py b/src/python/pants/backend/codegen/protobuf/java/register.py new file mode 100644 index 00000000000..3a180b14cc7 --- /dev/null +++ b/src/python/pants/backend/codegen/protobuf/java/register.py @@ -0,0 +1,33 @@ +# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +"""Generate Python sources from Protocol Buffers (Protobufs). + +See https://www.pantsbuild.org/docs/protobuf. +""" + +from pants.backend.codegen import export_codegen_goal +from pants.backend.codegen.protobuf import protobuf_dependency_inference +from pants.backend.codegen.protobuf import tailor as protobuf_tailor +from pants.backend.codegen.protobuf.java.rules import rules as java_rules +from pants.backend.codegen.protobuf.target_types import ( + ProtobufSourcesGeneratorTarget, + ProtobufSourceTarget, +) +from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_rules +from pants.core.util_rules import stripped_source_files + + +def rules(): + return [ + *java_rules(), + *protobuf_dependency_inference.rules(), + *protobuf_tailor.rules(), + *export_codegen_goal.rules(), + *protobuf_target_rules(), + *stripped_source_files.rules(), + ] + + +def target_types(): + return [ProtobufSourcesGeneratorTarget, ProtobufSourceTarget] From 9e866db8d2bc694ab57a66335ba46972bfbb364c Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Wed, 9 Mar 2022 13:46:34 -0800 Subject: [PATCH 02/12] Adds `JVMRequestTypes` helper, and first codegen classification helper # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- src/python/pants/backend/java/bsp/rules.py | 7 ++- src/python/pants/backend/java/goals/check.py | 8 ++-- src/python/pants/backend/scala/bsp/rules.py | 9 ++-- src/python/pants/backend/scala/goals/check.py | 8 ++-- src/python/pants/jvm/classpath.py | 7 ++- src/python/pants/jvm/compile.py | 47 ++++++++++++++++--- src/python/pants/jvm/compile_test.py | 4 +- 7 files changed, 60 insertions(+), 30 deletions(-) diff --git a/src/python/pants/backend/java/bsp/rules.py b/src/python/pants/backend/java/bsp/rules.py index c5ff1a837ee..d08ee38f56d 100644 --- a/src/python/pants/backend/java/bsp/rules.py +++ b/src/python/pants/backend/java/bsp/rules.py @@ -36,7 +36,7 @@ ) from pants.engine.unions import UnionMembership, UnionRule from pants.jvm.bsp.spec import JvmBuildTarget -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry +from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes from pants.jvm.resolve.key import CoursierResolveKey LANGUAGE_ID = "java" @@ -178,8 +178,7 @@ class JavaBSPCompileFieldSet(BSPCompileFieldSet): @rule async def bsp_java_compile_request( - request: JavaBSPCompileFieldSet, - union_membership: UnionMembership, + request: JavaBSPCompileFieldSet, jvm_request_types: JVMRequestTypes ) -> BSPCompileResult: coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address])) assert len(coarsened_targets) == 1 @@ -190,7 +189,7 @@ async def bsp_java_compile_request( FallibleClasspathEntry, ClasspathEntryRequest, ClasspathEntryRequest.for_targets( - union_membership, component=coarsened_target, resolve=resolve + jvm_request_types, component=coarsened_target, resolve=resolve ), ) _logger.info(f"java compile result = {result}") diff --git a/src/python/pants/backend/java/goals/check.py b/src/python/pants/backend/java/goals/check.py index 10d62c8fd89..889adcafd82 100644 --- a/src/python/pants/backend/java/goals/check.py +++ b/src/python/pants/backend/java/goals/check.py @@ -11,8 +11,8 @@ from pants.engine.addresses import Addresses from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets -from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry +from pants.engine.unions import UnionRule +from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes from pants.jvm.resolve.key import CoursierResolveKey from pants.util.logging import LogLevel @@ -27,7 +27,7 @@ class JavacCheckRequest(CheckRequest): @rule(desc="Check javac compilation", level=LogLevel.DEBUG) async def javac_check( request: JavacCheckRequest, - union_membership: UnionMembership, + jvm_request_types: JVMRequestTypes, ) -> CheckResults: coarsened_targets = await Get( CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets) @@ -43,7 +43,7 @@ async def javac_check( Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets(union_membership, component=target, resolve=resolve), + ClasspathEntryRequest.for_targets(jvm_request_types, component=target, resolve=resolve), ) for target, resolve in zip(coarsened_targets, resolves) ) diff --git a/src/python/pants/backend/scala/bsp/rules.py b/src/python/pants/backend/scala/bsp/rules.py index 82d1fb5d408..42917cd4565 100644 --- a/src/python/pants/backend/scala/bsp/rules.py +++ b/src/python/pants/backend/scala/bsp/rules.py @@ -29,7 +29,6 @@ from pants.bsp.util_rules.lifecycle import BSPLanguageSupport from pants.bsp.util_rules.targets import BSPBuildTargets, BSPBuildTargetsRequest from pants.build_graph.address import Address, AddressInput -from pants.core.util_rules.system_binaries import BashBinary, UnzipBinary from pants.engine.addresses import Addresses from pants.engine.fs import EMPTY_DIGEST, AddPrefix, CreateDigest, Digest, DigestEntries from pants.engine.internals.selectors import Get, MultiGet @@ -42,7 +41,7 @@ WrappedTarget, ) from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry +from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes from pants.jvm.resolve.key import CoursierResolveKey from pants.jvm.subsystems import JvmSubsystem from pants.jvm.target_types import JvmResolveField @@ -199,9 +198,7 @@ class ScalaBSPCompileFieldSet(BSPCompileFieldSet): @rule async def bsp_scala_compile_request( request: ScalaBSPCompileFieldSet, - union_membership: UnionMembership, - unzip: UnzipBinary, - bash: BashBinary, + jvm_request_types: JVMRequestTypes, ) -> BSPCompileResult: coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address])) assert len(coarsened_targets) == 1 @@ -212,7 +209,7 @@ async def bsp_scala_compile_request( FallibleClasspathEntry, ClasspathEntryRequest, ClasspathEntryRequest.for_targets( - union_membership, component=coarsened_target, resolve=resolve + jvm_request_types, component=coarsened_target, resolve=resolve ), ) _logger.info(f"scala compile result = {result}") diff --git a/src/python/pants/backend/scala/goals/check.py b/src/python/pants/backend/scala/goals/check.py index 8bb26f72dac..01e842936c9 100644 --- a/src/python/pants/backend/scala/goals/check.py +++ b/src/python/pants/backend/scala/goals/check.py @@ -11,8 +11,8 @@ from pants.engine.addresses import Addresses from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets -from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry +from pants.engine.unions import UnionRule +from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes from pants.jvm.resolve.key import CoursierResolveKey from pants.util.logging import LogLevel @@ -27,7 +27,7 @@ class ScalacCheckRequest(CheckRequest): @rule(desc="Check compilation for Scala", level=LogLevel.DEBUG) async def scalac_check( request: ScalacCheckRequest, - union_membership: UnionMembership, + jvm_request_types: JVMRequestTypes, ) -> CheckResults: coarsened_targets = await Get( CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets) @@ -43,7 +43,7 @@ async def scalac_check( Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets(union_membership, component=target, resolve=resolve), + ClasspathEntryRequest.for_targets(jvm_request_types, component=target, resolve=resolve), ) for target, resolve in zip(coarsened_targets, resolves) ) diff --git a/src/python/pants/jvm/classpath.py b/src/python/pants/jvm/classpath.py index 1de2d9f5f8d..b11c1fe8823 100644 --- a/src/python/pants/jvm/classpath.py +++ b/src/python/pants/jvm/classpath.py @@ -10,8 +10,7 @@ from pants.engine.fs import Digest from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets -from pants.engine.unions import UnionMembership -from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest +from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest, JVMRequestTypes from pants.jvm.resolve.key import CoursierResolveKey logger = logging.getLogger(__name__) @@ -69,7 +68,7 @@ def root_immutable_inputs_args(self, *, prefix: str = "") -> Iterator[str]: @rule async def classpath( coarsened_targets: CoarsenedTargets, - union_membership: UnionMembership, + jvm_request_types: JVMRequestTypes, ) -> Classpath: # Compute a single shared resolve for all of the roots, which will validate that they # are compatible with one another. @@ -81,7 +80,7 @@ async def classpath( ClasspathEntry, ClasspathEntryRequest, ClasspathEntryRequest.for_targets( - union_membership, component=t, resolve=resolve, root=True + jvm_request_types, component=t, resolve=resolve, root=True ), ) for t in coarsened_targets diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index 7fda142d362..210ed0b21eb 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -6,7 +6,7 @@ import logging import os from abc import ABCMeta -from collections import deque +from collections import defaultdict, deque from dataclasses import dataclass from enum import Enum, auto from typing import ClassVar, Iterable, Iterator, Sequence @@ -17,9 +17,10 @@ from pants.engine.internals.selectors import Get, MultiGet from pants.engine.process import FallibleProcessResult from pants.engine.rules import collect_rules, rule -from pants.engine.target import CoarsenedTarget, FieldSet +from pants.engine.target import CoarsenedTarget, Field, FieldSet, GenerateSourcesRequest from pants.engine.unions import UnionMembership, union from pants.jvm.resolve.key import CoursierResolveKey +from pants.util.frozendict import FrozenDict from pants.util.logging import LogLevel from pants.util.meta import frozen_after_init from pants.util.ordered_set import FrozenOrderedSet @@ -47,6 +48,31 @@ class _ClasspathEntryRequestClassification(Enum): INCOMPATIBLE = auto() +@dataclass(frozen=True) +class JVMRequestTypes: + classpath_entry_requests: tuple[type[ClasspathEntryRequest], ...] + code_generator_requests: FrozenDict[type[GenerateSourcesRequest], type[ClasspathEntryRequest]] + + +@rule +def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequestTypes: + cpe_impls = union_membership.get(ClasspathEntryRequest) + b: dict[type[Field], type[ClasspathEntryRequest]] = set() + for impl in cpe_impls: + for field_set in impl.field_sets: + for field in field_set.required_fields: + # Assume only one impl per field (normally sound) + b[field] = impl + + generators: Iterable[type[GenerateSourcesRequest]] = union_membership.get( + GenerateSourcesRequest + ) + + usable_generators = {g.input: (g, b[g.output]) for g in generators if g.output in b} + + return JVMRequestTypes(tuple(cpe_impls), FrozenDict(usable_generators)) + + @union @dataclass(frozen=True) class ClasspathEntryRequest(metaclass=ABCMeta): @@ -77,7 +103,7 @@ class ClasspathEntryRequest(metaclass=ABCMeta): @staticmethod def for_targets( - union_membership: UnionMembership, + jvm_request_types: JVMRequestTypes, component: CoarsenedTarget, resolve: CoursierResolveKey, *, @@ -89,10 +115,19 @@ def for_targets( request types which are marked `root_only`. """ + impls = jvm_request_types.classpath_entry_requests + usable_generators = jvm_request_types.code_generator_requests + + # TODO: filter usable generators by acceptable languages + + for (input, request_type) in usable_generators.items(): + if component.representative.get(input) is not None: + return request_type(component, resolve, None) + compatible = [] partial = [] consume_only = [] - impls = union_membership.get(ClasspathEntryRequest) + impls = jvm_request_types.classpath_entry_requests for impl in impls: classification = ClasspathEntryRequest.classify_impl(impl, component) if classification == _ClasspathEntryRequestClassification.INCOMPATIBLE: @@ -341,7 +376,7 @@ def required_classfiles(fallible_result: FallibleClasspathEntry) -> ClasspathEnt @rule def classpath_dependency_requests( - union_membership: UnionMembership, request: ClasspathDependenciesRequest + jvm_request_types: JVMRequestTypes, request: ClasspathDependenciesRequest ) -> ClasspathEntryRequests: def ignore_because_generated(coarsened_dep: CoarsenedTarget) -> bool: if len(coarsened_dep.members) == 1: @@ -352,7 +387,7 @@ def ignore_because_generated(coarsened_dep: CoarsenedTarget) -> bool: return ClasspathEntryRequests( ClasspathEntryRequest.for_targets( - union_membership, component=coarsened_dep, resolve=request.request.resolve + jvm_request_types, component=coarsened_dep, resolve=request.request.resolve ) for coarsened_dep in request.request.component.dependencies if not request.ignore_generated or not ignore_because_generated(coarsened_dep) diff --git a/src/python/pants/jvm/compile_test.py b/src/python/pants/jvm/compile_test.py index 25902069add..053871fdffb 100644 --- a/src/python/pants/jvm/compile_test.py +++ b/src/python/pants/jvm/compile_test.py @@ -39,13 +39,13 @@ from pants.engine.fs import EMPTY_DIGEST from pants.engine.internals.native_engine import FileDigest from pants.engine.target import CoarsenedTarget, Target, UnexpandedTargets -from pants.engine.unions import UnionMembership from pants.jvm import classpath, jdk_rules, testutil from pants.jvm.classpath import Classpath from pants.jvm.compile import ( ClasspathEntryRequest, ClasspathSourceAmbiguity, ClasspathSourceMissing, + JVMRequestTypes, ) from pants.jvm.goals import lockfile from pants.jvm.resolve.common import ArtifactRequirement, Coordinate, Coordinates @@ -193,7 +193,7 @@ def classify( members: Sequence[type[ClasspathEntryRequest]], ) -> tuple[type[ClasspathEntryRequest], type[ClasspathEntryRequest] | None]: req = ClasspathEntryRequest.for_targets( - UnionMembership({ClasspathEntryRequest: members}), + JVMRequestTypes(tuple(members), ()), CoarsenedTarget(targets, ()), CoursierResolveKey("example", "path", EMPTY_DIGEST), ) From 7a654d16dd1ee7786a13a990ea5518972d40ff46 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Wed, 9 Mar 2022 14:05:07 -0800 Subject: [PATCH 03/12] more progress # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- src/python/pants/jvm/compile.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index 210ed0b21eb..ef1eceebfe2 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -57,7 +57,7 @@ class JVMRequestTypes: @rule def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequestTypes: cpe_impls = union_membership.get(ClasspathEntryRequest) - b: dict[type[Field], type[ClasspathEntryRequest]] = set() + b: dict[type[Field], type[ClasspathEntryRequest]] = {} for impl in cpe_impls: for field_set in impl.field_sets: for field in field_set.required_fields: @@ -68,7 +68,10 @@ def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequest GenerateSourcesRequest ) - usable_generators = {g.input: (g, b[g.output]) for g in generators if g.output in b} + # TODO: Does not currently support multiple code generators per source type + # We'll need to add support for that, once it's possible to disambiguate in + # a build file + usable_generators = {g.input: b[g.output] for g in generators if g.output in b} return JVMRequestTypes(tuple(cpe_impls), FrozenDict(usable_generators)) @@ -121,7 +124,8 @@ def for_targets( # TODO: filter usable generators by acceptable languages for (input, request_type) in usable_generators.items(): - if component.representative.get(input) is not None: + logger.warning(f"{component.representative} {input} { request_type}") + if component.representative.has_field(input): return request_type(component, resolve, None) compatible = [] From 513f000a1fc3a110e5a45c4a997b4a4bd6610ef4 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Wed, 9 Mar 2022 14:51:37 -0800 Subject: [PATCH 04/12] Allow for compilation of Java protobufs # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- .../backend/codegen/protobuf/target_types.py | 8 ++++++- src/python/pants/jvm/compile.py | 21 ++++++++++++------- src/python/pants/jvm/compile_test.py | 3 ++- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/python/pants/backend/codegen/protobuf/target_types.py b/src/python/pants/backend/codegen/protobuf/target_types.py index b4f75e08815..b7d574a4858 100644 --- a/src/python/pants/backend/codegen/protobuf/target_types.py +++ b/src/python/pants/backend/codegen/protobuf/target_types.py @@ -19,6 +19,7 @@ generate_file_based_overrides_field_help_message, ) from pants.engine.unions import UnionRule +from pants.jvm.target_types import JvmJdkField from pants.util.docutil import doc_url from pants.util.logging import LogLevel @@ -60,7 +61,9 @@ class ProtobufSourceTarget(Target): ProtobufDependenciesField, ProtobufSourceField, ProtobufGrpcToggleField, + JvmJdkField, ) + jdk = JvmJdkField help = ( "A single Protobuf file used to generate various languages.\n\n" f"See {doc_url('protobuf')}." @@ -117,7 +120,10 @@ class ProtobufSourcesGeneratorTarget(TargetFilesGenerator): *COMMON_TARGET_FIELDS, ProtobufDependenciesField, ) - moved_fields = (ProtobufGrpcToggleField,) + moved_fields = ( + ProtobufGrpcToggleField, + JvmJdkField, + ) settings_request_cls = GeneratorSettingsRequest help = "Generate a `protobuf_source` target for each file in the `sources` field." diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index ef1eceebfe2..74a8d2ca762 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -6,7 +6,7 @@ import logging import os from abc import ABCMeta -from collections import defaultdict, deque +from collections import deque from dataclasses import dataclass from enum import Enum, auto from typing import ClassVar, Iterable, Iterator, Sequence @@ -17,7 +17,13 @@ from pants.engine.internals.selectors import Get, MultiGet from pants.engine.process import FallibleProcessResult from pants.engine.rules import collect_rules, rule -from pants.engine.target import CoarsenedTarget, Field, FieldSet, GenerateSourcesRequest +from pants.engine.target import ( + CoarsenedTarget, + Field, + FieldSet, + GenerateSourcesRequest, + SourcesField, +) from pants.engine.unions import UnionMembership, union from pants.jvm.resolve.key import CoursierResolveKey from pants.util.frozendict import FrozenDict @@ -51,7 +57,7 @@ class _ClasspathEntryRequestClassification(Enum): @dataclass(frozen=True) class JVMRequestTypes: classpath_entry_requests: tuple[type[ClasspathEntryRequest], ...] - code_generator_requests: FrozenDict[type[GenerateSourcesRequest], type[ClasspathEntryRequest]] + code_generator_requests: FrozenDict[type[SourcesField], type[ClasspathEntryRequest]] @rule @@ -71,9 +77,9 @@ def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequest # TODO: Does not currently support multiple code generators per source type # We'll need to add support for that, once it's possible to disambiguate in # a build file - usable_generators = {g.input: b[g.output] for g in generators if g.output in b} + usable_generators = FrozenDict((g.input, b[g.output]) for g in generators if g.output in b) - return JVMRequestTypes(tuple(cpe_impls), FrozenDict(usable_generators)) + return JVMRequestTypes(tuple(cpe_impls), usable_generators) @union @@ -122,10 +128,9 @@ def for_targets( usable_generators = jvm_request_types.code_generator_requests # TODO: filter usable generators by acceptable languages - + for (input, request_type) in usable_generators.items(): - logger.warning(f"{component.representative} {input} { request_type}") - if component.representative.has_field(input): + if component.representative.has_field(input): return request_type(component, resolve, None) compatible = [] diff --git a/src/python/pants/jvm/compile_test.py b/src/python/pants/jvm/compile_test.py index 053871fdffb..205913485bc 100644 --- a/src/python/pants/jvm/compile_test.py +++ b/src/python/pants/jvm/compile_test.py @@ -66,6 +66,7 @@ ) from pants.jvm.util_rules import rules as util_rules from pants.testutil.rule_runner import PYTHON_BOOTSTRAP_ENV, QueryRule, RuleRunner +from pants.util.frozendict import FrozenDict DEFAULT_LOCKFILE = TestCoursierWrapper( CoursierResolvedLockfile( @@ -193,7 +194,7 @@ def classify( members: Sequence[type[ClasspathEntryRequest]], ) -> tuple[type[ClasspathEntryRequest], type[ClasspathEntryRequest] | None]: req = ClasspathEntryRequest.for_targets( - JVMRequestTypes(tuple(members), ()), + JVMRequestTypes(tuple(members), FrozenDict()), CoarsenedTarget(targets, ()), CoursierResolveKey("example", "path", EMPTY_DIGEST), ) From b0e78f9d9d752753fbe18104de0d320d77678b1f Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 10 Mar 2022 12:38:20 -0800 Subject: [PATCH 05/12] Smallest code review changes # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- .../pants/backend/codegen/protobuf/java/register.py | 2 +- src/python/pants/jvm/compile.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/python/pants/backend/codegen/protobuf/java/register.py b/src/python/pants/backend/codegen/protobuf/java/register.py index 3a180b14cc7..ae8c0eb8e88 100644 --- a/src/python/pants/backend/codegen/protobuf/java/register.py +++ b/src/python/pants/backend/codegen/protobuf/java/register.py @@ -1,7 +1,7 @@ # Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). -"""Generate Python sources from Protocol Buffers (Protobufs). +"""Generate Java sources from Protocol Buffers (Protobufs). See https://www.pantsbuild.org/docs/protobuf. """ diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index 74a8d2ca762..af2d6a34c8c 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -63,12 +63,12 @@ class JVMRequestTypes: @rule def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequestTypes: cpe_impls = union_membership.get(ClasspathEntryRequest) - b: dict[type[Field], type[ClasspathEntryRequest]] = {} + impls_by_source: dict[type[Field], type[ClasspathEntryRequest]] = {} for impl in cpe_impls: for field_set in impl.field_sets: for field in field_set.required_fields: # Assume only one impl per field (normally sound) - b[field] = impl + impls_by_source[field] = impl generators: Iterable[type[GenerateSourcesRequest]] = union_membership.get( GenerateSourcesRequest @@ -77,7 +77,9 @@ def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequest # TODO: Does not currently support multiple code generators per source type # We'll need to add support for that, once it's possible to disambiguate in # a build file - usable_generators = FrozenDict((g.input, b[g.output]) for g in generators if g.output in b) + usable_generators = FrozenDict( + (g.input, impls_by_source[g.output]) for g in generators if g.output in impls_by_source + ) return JVMRequestTypes(tuple(cpe_impls), usable_generators) From f937aaeada92e3d2f53c621645a0979580ae51b7 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 10 Mar 2022 15:44:00 -0800 Subject: [PATCH 06/12] Adds `jvm_jdk` field # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- .../backend/codegen/protobuf/java/rules.py | 13 ++++++++- .../backend/codegen/protobuf/target_types.py | 8 +----- src/python/pants/jvm/compile.py | 27 ++++++++++--------- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/python/pants/backend/codegen/protobuf/java/rules.py b/src/python/pants/backend/codegen/protobuf/java/rules.py index f323a412a51..0c4a4e106e0 100644 --- a/src/python/pants/backend/codegen/protobuf/java/rules.py +++ b/src/python/pants/backend/codegen/protobuf/java/rules.py @@ -3,7 +3,11 @@ from pants.backend.codegen.protobuf.protoc import Protoc -from pants.backend.codegen.protobuf.target_types import ProtobufSourceField +from pants.backend.codegen.protobuf.target_types import ( + ProtobufSourceField, + ProtobufSourcesGeneratorTarget, + ProtobufSourceTarget, +) from pants.backend.java.target_types import JavaSourceField from pants.backend.python.util_rules import pex from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest @@ -28,6 +32,7 @@ TransitiveTargetsRequest, ) from pants.engine.unions import UnionRule +from pants.jvm.target_types import JvmJdkField from pants.source.source_root import SourceRoot, SourceRootRequest from pants.util.logging import LogLevel @@ -116,9 +121,15 @@ async def generate_java_from_protobuf( return GeneratedSources(source_root_restored) +class PrefixedJvmJdkField(JvmJdkField): + alias = "jvm_jdk" + + def rules(): return [ *collect_rules(), *pex.rules(), UnionRule(GenerateSourcesRequest, GenerateJavaFromProtobufRequest), + ProtobufSourceTarget.register_plugin_field(PrefixedJvmJdkField), + ProtobufSourcesGeneratorTarget.register_plugin_field(PrefixedJvmJdkField), ] diff --git a/src/python/pants/backend/codegen/protobuf/target_types.py b/src/python/pants/backend/codegen/protobuf/target_types.py index b7d574a4858..b4f75e08815 100644 --- a/src/python/pants/backend/codegen/protobuf/target_types.py +++ b/src/python/pants/backend/codegen/protobuf/target_types.py @@ -19,7 +19,6 @@ generate_file_based_overrides_field_help_message, ) from pants.engine.unions import UnionRule -from pants.jvm.target_types import JvmJdkField from pants.util.docutil import doc_url from pants.util.logging import LogLevel @@ -61,9 +60,7 @@ class ProtobufSourceTarget(Target): ProtobufDependenciesField, ProtobufSourceField, ProtobufGrpcToggleField, - JvmJdkField, ) - jdk = JvmJdkField help = ( "A single Protobuf file used to generate various languages.\n\n" f"See {doc_url('protobuf')}." @@ -120,10 +117,7 @@ class ProtobufSourcesGeneratorTarget(TargetFilesGenerator): *COMMON_TARGET_FIELDS, ProtobufDependenciesField, ) - moved_fields = ( - ProtobufGrpcToggleField, - JvmJdkField, - ) + moved_fields = (ProtobufGrpcToggleField,) settings_request_cls = GeneratorSettingsRequest help = "Generate a `protobuf_source` target for each file in the `sources` field." diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index af2d6a34c8c..dbe9defdf95 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -6,7 +6,7 @@ import logging import os from abc import ABCMeta -from collections import deque +from collections import defaultdict, deque from dataclasses import dataclass from enum import Enum, auto from typing import ClassVar, Iterable, Iterator, Sequence @@ -57,7 +57,7 @@ class _ClasspathEntryRequestClassification(Enum): @dataclass(frozen=True) class JVMRequestTypes: classpath_entry_requests: tuple[type[ClasspathEntryRequest], ...] - code_generator_requests: FrozenDict[type[SourcesField], type[ClasspathEntryRequest]] + code_generator_requests: FrozenDict[type[SourcesField], tuple[type[ClasspathEntryRequest], ...]] @rule @@ -68,18 +68,20 @@ def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequest for field_set in impl.field_sets: for field in field_set.required_fields: # Assume only one impl per field (normally sound) + # (note that subsequently, we only check for `SourceFields`, so no need to filter) impls_by_source[field] = impl generators: Iterable[type[GenerateSourcesRequest]] = union_membership.get( GenerateSourcesRequest ) - # TODO: Does not currently support multiple code generators per source type - # We'll need to add support for that, once it's possible to disambiguate in - # a build file - usable_generators = FrozenDict( - (g.input, impls_by_source[g.output]) for g in generators if g.output in impls_by_source + usable_generators_: dict[type[SourcesField], list[type[ClasspathEntryRequest]]] = defaultdict( + list ) + for g in generators: + if g.output in impls_by_source: + usable_generators_[g.input].append(impls_by_source[g.output]) + usable_generators = FrozenDict((key, tuple(value)) for key, value in usable_generators_.items()) return JVMRequestTypes(tuple(cpe_impls), usable_generators) @@ -126,12 +128,11 @@ def for_targets( request types which are marked `root_only`. """ - impls = jvm_request_types.classpath_entry_requests - usable_generators = jvm_request_types.code_generator_requests - - # TODO: filter usable generators by acceptable languages - - for (input, request_type) in usable_generators.items(): + for (input, request_types) in jvm_request_types.code_generator_requests.items(): + if len(request_types) > 1: + # TODO: filter usable generators by acceptable languages + pass + request_type = request_types[0] if component.representative.has_field(input): return request_type(component, resolve, None) From 079c77551800ab763448b0c6d538e9efdbd47cce Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 10 Mar 2022 16:05:52 -0800 Subject: [PATCH 07/12] Raise an exception when multiple generators are compatible with a given input codegen type # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- src/python/pants/jvm/compile.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index dbe9defdf95..2a3b72bd308 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -129,12 +129,15 @@ def for_targets( """ for (input, request_types) in jvm_request_types.code_generator_requests.items(): + if not component.representative.has_field(input): + continue if len(request_types) > 1: - # TODO: filter usable generators by acceptable languages - pass + raise ClasspathSourceAmbiguity( + f"More than one code generator ({request_types}) was compatible with the " + f"inputs:\n{component.bullet_list()}" + ) request_type = request_types[0] - if component.representative.has_field(input): - return request_type(component, resolve, None) + return request_type(component, resolve, None) compatible = [] partial = [] From 5580fc468339b2664ce7c59e6f993d0d565611d3 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 10 Mar 2022 16:12:12 -0800 Subject: [PATCH 08/12] Adds scala registration rules # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- .../backend/codegen/protobuf/java/register.py | 2 +- .../codegen/protobuf/scala/register.py | 33 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 src/python/pants/backend/codegen/protobuf/scala/register.py diff --git a/src/python/pants/backend/codegen/protobuf/java/register.py b/src/python/pants/backend/codegen/protobuf/java/register.py index ae8c0eb8e88..ef510a37ec1 100644 --- a/src/python/pants/backend/codegen/protobuf/java/register.py +++ b/src/python/pants/backend/codegen/protobuf/java/register.py @@ -1,4 +1,4 @@ -# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). """Generate Java sources from Protocol Buffers (Protobufs). diff --git a/src/python/pants/backend/codegen/protobuf/scala/register.py b/src/python/pants/backend/codegen/protobuf/scala/register.py new file mode 100644 index 00000000000..dc167171f3d --- /dev/null +++ b/src/python/pants/backend/codegen/protobuf/scala/register.py @@ -0,0 +1,33 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +"""Generate Scala sources from Protocol Buffers (Protobufs). + +See https://www.pantsbuild.org/docs/protobuf. +""" + +from pants.backend.codegen import export_codegen_goal +from pants.backend.codegen.protobuf import protobuf_dependency_inference +from pants.backend.codegen.protobuf import tailor as protobuf_tailor +from pants.backend.codegen.protobuf.scala.rules import rules as scala_rules +from pants.backend.codegen.protobuf.target_types import ( + ProtobufSourcesGeneratorTarget, + ProtobufSourceTarget, +) +from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_rules +from pants.core.util_rules import stripped_source_files + + +def rules(): + return [ + *scala_rules(), + *protobuf_dependency_inference.rules(), + *protobuf_tailor.rules(), + *export_codegen_goal.rules(), + *protobuf_target_rules(), + *stripped_source_files.rules(), + ] + + +def target_types(): + return [ProtobufSourcesGeneratorTarget, ProtobufSourceTarget] From a22866170ee76f7a60e556093362f763d47444a9 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 11 Mar 2022 10:02:48 -0800 Subject: [PATCH 09/12] Centralises JVM construction code into `ClasspathEntryRequestFactory`. Much better. # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- src/python/pants/backend/java/bsp/rules.py | 12 +-- src/python/pants/backend/java/goals/check.py | 10 ++- src/python/pants/backend/scala/bsp/rules.py | 12 +-- src/python/pants/backend/scala/goals/check.py | 10 ++- src/python/pants/jvm/classpath.py | 8 +- src/python/pants/jvm/compile.py | 82 +++++++++---------- src/python/pants/jvm/compile_test.py | 5 +- 7 files changed, 73 insertions(+), 66 deletions(-) diff --git a/src/python/pants/backend/java/bsp/rules.py b/src/python/pants/backend/java/bsp/rules.py index d08ee38f56d..58487d88ca5 100644 --- a/src/python/pants/backend/java/bsp/rules.py +++ b/src/python/pants/backend/java/bsp/rules.py @@ -36,7 +36,11 @@ ) from pants.engine.unions import UnionMembership, UnionRule from pants.jvm.bsp.spec import JvmBuildTarget -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes +from pants.jvm.compile import ( + ClasspathEntryRequest, + ClasspathEntryRequestFactory, + FallibleClasspathEntry, +) from pants.jvm.resolve.key import CoursierResolveKey LANGUAGE_ID = "java" @@ -178,7 +182,7 @@ class JavaBSPCompileFieldSet(BSPCompileFieldSet): @rule async def bsp_java_compile_request( - request: JavaBSPCompileFieldSet, jvm_request_types: JVMRequestTypes + request: JavaBSPCompileFieldSet, classpath_entry_request: ClasspathEntryRequestFactory ) -> BSPCompileResult: coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address])) assert len(coarsened_targets) == 1 @@ -188,9 +192,7 @@ async def bsp_java_compile_request( result = await Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets( - jvm_request_types, component=coarsened_target, resolve=resolve - ), + classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve), ) _logger.info(f"java compile result = {result}") output_digest = EMPTY_DIGEST diff --git a/src/python/pants/backend/java/goals/check.py b/src/python/pants/backend/java/goals/check.py index 889adcafd82..0b7c110fc89 100644 --- a/src/python/pants/backend/java/goals/check.py +++ b/src/python/pants/backend/java/goals/check.py @@ -12,7 +12,11 @@ from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets from pants.engine.unions import UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes +from pants.jvm.compile import ( + ClasspathEntryRequest, + ClasspathEntryRequestFactory, + FallibleClasspathEntry, +) from pants.jvm.resolve.key import CoursierResolveKey from pants.util.logging import LogLevel @@ -27,7 +31,7 @@ class JavacCheckRequest(CheckRequest): @rule(desc="Check javac compilation", level=LogLevel.DEBUG) async def javac_check( request: JavacCheckRequest, - jvm_request_types: JVMRequestTypes, + classpath_entry_request: ClasspathEntryRequestFactory, ) -> CheckResults: coarsened_targets = await Get( CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets) @@ -43,7 +47,7 @@ async def javac_check( Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets(jvm_request_types, component=target, resolve=resolve), + classpath_entry_request.for_targets(component=target, resolve=resolve), ) for target, resolve in zip(coarsened_targets, resolves) ) diff --git a/src/python/pants/backend/scala/bsp/rules.py b/src/python/pants/backend/scala/bsp/rules.py index 42917cd4565..ea93df044c9 100644 --- a/src/python/pants/backend/scala/bsp/rules.py +++ b/src/python/pants/backend/scala/bsp/rules.py @@ -41,7 +41,11 @@ WrappedTarget, ) from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes +from pants.jvm.compile import ( + ClasspathEntryRequest, + ClasspathEntryRequestFactory, + FallibleClasspathEntry, +) from pants.jvm.resolve.key import CoursierResolveKey from pants.jvm.subsystems import JvmSubsystem from pants.jvm.target_types import JvmResolveField @@ -198,7 +202,7 @@ class ScalaBSPCompileFieldSet(BSPCompileFieldSet): @rule async def bsp_scala_compile_request( request: ScalaBSPCompileFieldSet, - jvm_request_types: JVMRequestTypes, + classpath_entry_request: ClasspathEntryRequestFactory, ) -> BSPCompileResult: coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address])) assert len(coarsened_targets) == 1 @@ -208,9 +212,7 @@ async def bsp_scala_compile_request( result = await Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets( - jvm_request_types, component=coarsened_target, resolve=resolve - ), + classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve), ) _logger.info(f"scala compile result = {result}") output_digest = EMPTY_DIGEST diff --git a/src/python/pants/backend/scala/goals/check.py b/src/python/pants/backend/scala/goals/check.py index 01e842936c9..7b011c836d8 100644 --- a/src/python/pants/backend/scala/goals/check.py +++ b/src/python/pants/backend/scala/goals/check.py @@ -12,7 +12,11 @@ from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets from pants.engine.unions import UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry, JVMRequestTypes +from pants.jvm.compile import ( + ClasspathEntryRequest, + ClasspathEntryRequestFactory, + FallibleClasspathEntry, +) from pants.jvm.resolve.key import CoursierResolveKey from pants.util.logging import LogLevel @@ -27,7 +31,7 @@ class ScalacCheckRequest(CheckRequest): @rule(desc="Check compilation for Scala", level=LogLevel.DEBUG) async def scalac_check( request: ScalacCheckRequest, - jvm_request_types: JVMRequestTypes, + classpath_entry_request: ClasspathEntryRequestFactory, ) -> CheckResults: coarsened_targets = await Get( CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets) @@ -43,7 +47,7 @@ async def scalac_check( Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets(jvm_request_types, component=target, resolve=resolve), + classpath_entry_request.for_targets(component=target, resolve=resolve), ) for target, resolve in zip(coarsened_targets, resolves) ) diff --git a/src/python/pants/jvm/classpath.py b/src/python/pants/jvm/classpath.py index b11c1fe8823..2e507e18327 100644 --- a/src/python/pants/jvm/classpath.py +++ b/src/python/pants/jvm/classpath.py @@ -10,7 +10,7 @@ from pants.engine.fs import Digest from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets -from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest, JVMRequestTypes +from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest, ClasspathEntryRequestFactory from pants.jvm.resolve.key import CoursierResolveKey logger = logging.getLogger(__name__) @@ -68,7 +68,7 @@ def root_immutable_inputs_args(self, *, prefix: str = "") -> Iterator[str]: @rule async def classpath( coarsened_targets: CoarsenedTargets, - jvm_request_types: JVMRequestTypes, + classpath_entry_request: ClasspathEntryRequestFactory, ) -> Classpath: # Compute a single shared resolve for all of the roots, which will validate that they # are compatible with one another. @@ -79,9 +79,7 @@ async def classpath( Get( ClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets( - jvm_request_types, component=t, resolve=resolve, root=True - ), + classpath_entry_request.for_targets(component=t, resolve=resolve, root=True), ) for t in coarsened_targets ) diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index 2a3b72bd308..44ae83db820 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -54,38 +54,6 @@ class _ClasspathEntryRequestClassification(Enum): INCOMPATIBLE = auto() -@dataclass(frozen=True) -class JVMRequestTypes: - classpath_entry_requests: tuple[type[ClasspathEntryRequest], ...] - code_generator_requests: FrozenDict[type[SourcesField], tuple[type[ClasspathEntryRequest], ...]] - - -@rule -def calculate_jvm_request_types(union_membership: UnionMembership) -> JVMRequestTypes: - cpe_impls = union_membership.get(ClasspathEntryRequest) - impls_by_source: dict[type[Field], type[ClasspathEntryRequest]] = {} - for impl in cpe_impls: - for field_set in impl.field_sets: - for field in field_set.required_fields: - # Assume only one impl per field (normally sound) - # (note that subsequently, we only check for `SourceFields`, so no need to filter) - impls_by_source[field] = impl - - generators: Iterable[type[GenerateSourcesRequest]] = union_membership.get( - GenerateSourcesRequest - ) - - usable_generators_: dict[type[SourcesField], list[type[ClasspathEntryRequest]]] = defaultdict( - list - ) - for g in generators: - if g.output in impls_by_source: - usable_generators_[g.input].append(impls_by_source[g.output]) - usable_generators = FrozenDict((key, tuple(value)) for key, value in usable_generators_.items()) - - return JVMRequestTypes(tuple(cpe_impls), usable_generators) - - @union @dataclass(frozen=True) class ClasspathEntryRequest(metaclass=ABCMeta): @@ -114,9 +82,14 @@ class ClasspathEntryRequest(metaclass=ABCMeta): # True if this request type is only valid at the root of a compile graph. root_only: ClassVar[bool] = False - @staticmethod + +@dataclass(frozen=True) +class ClasspathEntryRequestFactory: + classpath_entry_requests: tuple[type[ClasspathEntryRequest], ...] + code_generator_requests: FrozenDict[type[SourcesField], tuple[type[ClasspathEntryRequest], ...]] + def for_targets( - jvm_request_types: JVMRequestTypes, + self, component: CoarsenedTarget, resolve: CoursierResolveKey, *, @@ -128,7 +101,7 @@ def for_targets( request types which are marked `root_only`. """ - for (input, request_types) in jvm_request_types.code_generator_requests.items(): + for (input, request_types) in self.code_generator_requests.items(): if not component.representative.has_field(input): continue if len(request_types) > 1: @@ -142,9 +115,9 @@ def for_targets( compatible = [] partial = [] consume_only = [] - impls = jvm_request_types.classpath_entry_requests + impls = self.classpath_entry_requests for impl in impls: - classification = ClasspathEntryRequest.classify_impl(impl, component) + classification = self.classify_impl(impl, component) if classification == _ClasspathEntryRequestClassification.INCOMPATIBLE: continue elif classification == _ClasspathEntryRequestClassification.COMPATIBLE: @@ -184,9 +157,8 @@ def for_targets( f"combination of inputs:\n{component.bullet_list()}" ) - @staticmethod def classify_impl( - impl: type[ClasspathEntryRequest], component: CoarsenedTarget + self, impl: type[ClasspathEntryRequest], component: CoarsenedTarget ) -> _ClasspathEntryRequestClassification: targets = component.members compatible = sum(1 for t in targets for fs in impl.field_sets if fs.is_applicable(t)) @@ -202,6 +174,32 @@ def classify_impl( return _ClasspathEntryRequestClassification.PARTIAL +@rule +def calculate_jvm_request_types(union_membership: UnionMembership) -> ClasspathEntryRequestFactory: + cpe_impls = union_membership.get(ClasspathEntryRequest) + impls_by_source: dict[type[Field], type[ClasspathEntryRequest]] = {} + for impl in cpe_impls: + for field_set in impl.field_sets: + for field in field_set.required_fields: + # Assume only one impl per field (normally sound) + # (note that subsequently, we only check for `SourceFields`, so no need to filter) + impls_by_source[field] = impl + + generators: Iterable[type[GenerateSourcesRequest]] = union_membership.get( + GenerateSourcesRequest + ) + + usable_generators_: dict[type[SourcesField], list[type[ClasspathEntryRequest]]] = defaultdict( + list + ) + for g in generators: + if g.output in impls_by_source: + usable_generators_[g.input].append(impls_by_source[g.output]) + usable_generators = FrozenDict((key, tuple(value)) for key, value in usable_generators_.items()) + + return ClasspathEntryRequestFactory(tuple(cpe_impls), usable_generators) + + @frozen_after_init @dataclass(unsafe_hash=True) class ClasspathEntry: @@ -391,7 +389,7 @@ def required_classfiles(fallible_result: FallibleClasspathEntry) -> ClasspathEnt @rule def classpath_dependency_requests( - jvm_request_types: JVMRequestTypes, request: ClasspathDependenciesRequest + classpath_entry_request: ClasspathEntryRequestFactory, request: ClasspathDependenciesRequest ) -> ClasspathEntryRequests: def ignore_because_generated(coarsened_dep: CoarsenedTarget) -> bool: if len(coarsened_dep.members) == 1: @@ -401,8 +399,8 @@ def ignore_because_generated(coarsened_dep: CoarsenedTarget) -> bool: return us.spec_path == them.spec_path and us.target_name == them.target_name return ClasspathEntryRequests( - ClasspathEntryRequest.for_targets( - jvm_request_types, component=coarsened_dep, resolve=request.request.resolve + classpath_entry_request.for_targets( + component=coarsened_dep, resolve=request.request.resolve ) for coarsened_dep in request.request.component.dependencies if not request.ignore_generated or not ignore_because_generated(coarsened_dep) diff --git a/src/python/pants/jvm/compile_test.py b/src/python/pants/jvm/compile_test.py index 205913485bc..7b66b3b70c0 100644 --- a/src/python/pants/jvm/compile_test.py +++ b/src/python/pants/jvm/compile_test.py @@ -43,9 +43,9 @@ from pants.jvm.classpath import Classpath from pants.jvm.compile import ( ClasspathEntryRequest, + ClasspathEntryRequestFactory, ClasspathSourceAmbiguity, ClasspathSourceMissing, - JVMRequestTypes, ) from pants.jvm.goals import lockfile from pants.jvm.resolve.common import ArtifactRequirement, Coordinate, Coordinates @@ -193,8 +193,7 @@ def classify( targets: Sequence[Target], members: Sequence[type[ClasspathEntryRequest]], ) -> tuple[type[ClasspathEntryRequest], type[ClasspathEntryRequest] | None]: - req = ClasspathEntryRequest.for_targets( - JVMRequestTypes(tuple(members), FrozenDict()), + req = ClasspathEntryRequestFactory(tuple(members), FrozenDict()).for_targets( CoarsenedTarget(targets, ()), CoursierResolveKey("example", "path", EMPTY_DIGEST), ) From 39890bd4cc4638b2da1e4174d7580e8e24b869d1 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 11 Mar 2022 11:30:22 -0800 Subject: [PATCH 10/12] Use `classify_impl` to handle classifying JVM code generators # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- src/python/pants/jvm/compile.py | 41 ++++++++++++++------------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index 44ae83db820..ddf238523da 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -85,8 +85,8 @@ class ClasspathEntryRequest(metaclass=ABCMeta): @dataclass(frozen=True) class ClasspathEntryRequestFactory: - classpath_entry_requests: tuple[type[ClasspathEntryRequest], ...] - code_generator_requests: FrozenDict[type[SourcesField], tuple[type[ClasspathEntryRequest], ...]] + impls: tuple[type[ClasspathEntryRequest], ...] + generator_sources: FrozenDict[type[ClasspathEntryRequest], frozenset[type[SourcesField]]] def for_targets( self, @@ -101,21 +101,10 @@ def for_targets( request types which are marked `root_only`. """ - for (input, request_types) in self.code_generator_requests.items(): - if not component.representative.has_field(input): - continue - if len(request_types) > 1: - raise ClasspathSourceAmbiguity( - f"More than one code generator ({request_types}) was compatible with the " - f"inputs:\n{component.bullet_list()}" - ) - request_type = request_types[0] - return request_type(component, resolve, None) - compatible = [] partial = [] consume_only = [] - impls = self.classpath_entry_requests + impls = self.impls for impl in impls: classification = self.classify_impl(impl, component) if classification == _ClasspathEntryRequestClassification.INCOMPATIBLE: @@ -161,7 +150,12 @@ def classify_impl( self, impl: type[ClasspathEntryRequest], component: CoarsenedTarget ) -> _ClasspathEntryRequestClassification: targets = component.members - compatible = sum(1 for t in targets for fs in impl.field_sets if fs.is_applicable(t)) + generator_sources = self.generator_sources.get(impl) or frozenset() + + compatible_direct = sum(1 for t in targets for fs in impl.field_sets if fs.is_applicable(t)) + compatible_generated = sum(1 for t in targets for g in generator_sources if t.has_field(g)) + + compatible = compatible_direct + compatible_generated if compatible == 0: return _ClasspathEntryRequestClassification.INCOMPATIBLE if compatible == len(targets): @@ -177,6 +171,7 @@ def classify_impl( @rule def calculate_jvm_request_types(union_membership: UnionMembership) -> ClasspathEntryRequestFactory: cpe_impls = union_membership.get(ClasspathEntryRequest) + impls_by_source: dict[type[Field], type[ClasspathEntryRequest]] = {} for impl in cpe_impls: for field_set in impl.field_sets: @@ -185,19 +180,17 @@ def calculate_jvm_request_types(union_membership: UnionMembership) -> ClasspathE # (note that subsequently, we only check for `SourceFields`, so no need to filter) impls_by_source[field] = impl - generators: Iterable[type[GenerateSourcesRequest]] = union_membership.get( - GenerateSourcesRequest - ) - - usable_generators_: dict[type[SourcesField], list[type[ClasspathEntryRequest]]] = defaultdict( + # Classify code generator sources by their CPE impl + sources_by_impl_: dict[type[ClasspathEntryRequest], list[type[SourcesField]]] = defaultdict( list ) - for g in generators: + + for g in union_membership.get(GenerateSourcesRequest): if g.output in impls_by_source: - usable_generators_[g.input].append(impls_by_source[g.output]) - usable_generators = FrozenDict((key, tuple(value)) for key, value in usable_generators_.items()) + sources_by_impl_[impls_by_source[g.output]].append(g.input) + sources_by_impl = FrozenDict((key, frozenset(value)) for key, value in sources_by_impl_.items()) - return ClasspathEntryRequestFactory(tuple(cpe_impls), usable_generators) + return ClasspathEntryRequestFactory(tuple(cpe_impls), sources_by_impl) @frozen_after_init From c77afd11e4a362c4ed5b2713385e266a811f9344 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 11 Mar 2022 12:29:33 -0800 Subject: [PATCH 11/12] Adds test for classifying protobuf sources # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- src/python/pants/jvm/compile_test.py | 78 +++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/src/python/pants/jvm/compile_test.py b/src/python/pants/jvm/compile_test.py index 7b66b3b70c0..3c80b2ae774 100644 --- a/src/python/pants/jvm/compile_test.py +++ b/src/python/pants/jvm/compile_test.py @@ -18,6 +18,10 @@ import chevron import pytest +from pants.backend.codegen.protobuf.java.rules import GenerateJavaFromProtobufRequest +from pants.backend.codegen.protobuf.java.rules import rules as protobuf_rules +from pants.backend.codegen.protobuf.target_types import ProtobufSourceField, ProtobufSourceTarget +from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_types_rules from pants.backend.java.compile.javac import CompileJavaSourceRequest from pants.backend.java.compile.javac import rules as javac_rules from pants.backend.java.dependency_inference.rules import rules as java_dep_inf_rules @@ -33,12 +37,20 @@ from pants.backend.scala.target_types import ScalaSourcesGeneratorTarget from pants.backend.scala.target_types import rules as scala_target_types_rules from pants.build_graph.address import Address -from pants.core.util_rules import config_files, source_files +from pants.core.util_rules import config_files, source_files, stripped_source_files from pants.core.util_rules.external_tool import rules as external_tool_rules from pants.engine.addresses import Addresses from pants.engine.fs import EMPTY_DIGEST from pants.engine.internals.native_engine import FileDigest -from pants.engine.target import CoarsenedTarget, Target, UnexpandedTargets +from pants.engine.target import ( + CoarsenedTarget, + GeneratedSources, + HydratedSources, + HydrateSourcesRequest, + SourcesField, + Target, + UnexpandedTargets, +) from pants.jvm import classpath, jdk_rules, testutil from pants.jvm.classpath import Classpath from pants.jvm.compile import ( @@ -126,11 +138,21 @@ def rule_runner() -> RuleRunner: *java_target_types_rules(), *util_rules(), *testutil.rules(), + *protobuf_rules(), + *stripped_source_files.rules(), + *protobuf_target_types_rules(), QueryRule(Classpath, (Addresses,)), QueryRule(RenderedClasspath, (Addresses,)), QueryRule(UnexpandedTargets, (Addresses,)), + QueryRule(HydratedSources, [HydrateSourcesRequest]), + QueryRule(GeneratedSources, [GenerateJavaFromProtobufRequest]), + ], + target_types=[ + JavaSourcesGeneratorTarget, + JvmArtifactTarget, + ProtobufSourceTarget, + ScalaSourcesGeneratorTarget, ], - target_types=[ScalaSourcesGeneratorTarget, JavaSourcesGeneratorTarget, JvmArtifactTarget], ) rule_runner.set_options(args=[], env_inherit=PYTHON_BOOTSTRAP_ENV) return rule_runner @@ -183,6 +205,22 @@ def main(args: Array[String]): Unit = { ) +def proto_source() -> str: + return dedent( + """\ + syntax = "proto3"; + + package dir1; + + message Person { + string name = 1; + int32 id = 2; + string email = 3; + } + """ + ) + + class CompileMockSourceRequest(ClasspathEntryRequest): field_sets = (JavaFieldSet, JavaGeneratorFieldSet) @@ -192,8 +230,12 @@ def test_request_classification(rule_runner: RuleRunner) -> None: def classify( targets: Sequence[Target], members: Sequence[type[ClasspathEntryRequest]], + generators: FrozenDict[type[ClasspathEntryRequest], frozenset[type[SourcesField]]], ) -> tuple[type[ClasspathEntryRequest], type[ClasspathEntryRequest] | None]: - req = ClasspathEntryRequestFactory(tuple(members), FrozenDict()).for_targets( + + factory = ClasspathEntryRequestFactory(tuple(members), generators) + + req = factory.for_targets( CoarsenedTarget(targets, ()), CoursierResolveKey("example", "path", EMPTY_DIGEST), ) @@ -206,13 +248,15 @@ def classify( scala_sources(name='scala') java_sources(name='java') jvm_artifact(name='jvm_artifact', group='ex', artifact='ex', version='0.0.0') + protobuf_source(name='proto', source="f.proto") {DEFAULT_SCALA_LIBRARY_TARGET} """ ), + "f.proto": proto_source(), "3rdparty/jvm/default.lock": DEFAULT_LOCKFILE, } ) - scala, java, jvm_artifact = rule_runner.request( + scala, java, jvm_artifact, proto = rule_runner.request( UnexpandedTargets, [ Addresses( @@ -220,33 +264,41 @@ def classify( Address("", target_name="scala"), Address("", target_name="java"), Address("", target_name="jvm_artifact"), + Address("", target_name="proto"), ] ) ], ) all_members = [CompileJavaSourceRequest, CompileScalaSourceRequest, CoursierFetchRequest] + generators = FrozenDict( + { + CompileJavaSourceRequest: frozenset([cast(type[SourcesField], ProtobufSourceField)]), + CompileScalaSourceRequest: frozenset(), + } + ) # Fully compatible. - assert (CompileJavaSourceRequest, None) == classify([java], all_members) - assert (CompileScalaSourceRequest, None) == classify([scala], all_members) - assert (CoursierFetchRequest, None) == classify([jvm_artifact], all_members) + assert (CompileJavaSourceRequest, None) == classify([java], all_members, generators) + assert (CompileScalaSourceRequest, None) == classify([scala], all_members, generators) + assert (CoursierFetchRequest, None) == classify([jvm_artifact], all_members, generators) + assert (CompileJavaSourceRequest, None) == classify([proto], all_members, generators) # Partially compatible. assert (CompileJavaSourceRequest, CompileScalaSourceRequest) == classify( - [java, scala], all_members + [java, scala], all_members, generators ) with pytest.raises(ClasspathSourceMissing): - classify([java, jvm_artifact], all_members) + classify([java, jvm_artifact], all_members, generators) # None compatible. with pytest.raises(ClasspathSourceMissing): - classify([java], []) + classify([java], [], generators) with pytest.raises(ClasspathSourceMissing): - classify([scala, java, jvm_artifact], all_members) + classify([scala, java, jvm_artifact], all_members, generators) # Too many compatible. with pytest.raises(ClasspathSourceAmbiguity): - classify([java], [CompileJavaSourceRequest, CompileMockSourceRequest]) + classify([java], [CompileJavaSourceRequest, CompileMockSourceRequest], generators) @maybe_skip_jdk_test From 66aef0a977abd02b61374f709c329a245bfd771b Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 11 Mar 2022 13:06:04 -0800 Subject: [PATCH 12/12] APPEASEMENT # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] --- src/python/pants/jvm/compile_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/pants/jvm/compile_test.py b/src/python/pants/jvm/compile_test.py index 3c80b2ae774..66b9b9b3296 100644 --- a/src/python/pants/jvm/compile_test.py +++ b/src/python/pants/jvm/compile_test.py @@ -13,7 +13,7 @@ import textwrap from textwrap import dedent -from typing import Sequence, cast +from typing import Sequence, Type, cast import chevron import pytest @@ -272,7 +272,7 @@ def classify( all_members = [CompileJavaSourceRequest, CompileScalaSourceRequest, CoursierFetchRequest] generators = FrozenDict( { - CompileJavaSourceRequest: frozenset([cast(type[SourcesField], ProtobufSourceField)]), + CompileJavaSourceRequest: frozenset([cast(Type[SourcesField], ProtobufSourceField)]), CompileScalaSourceRequest: frozenset(), } )