diff --git a/src/python/pants/backend/java/bsp/rules.py b/src/python/pants/backend/java/bsp/rules.py index 5d9b4a10070..fbe8779f4d4 100644 --- a/src/python/pants/backend/java/bsp/rules.py +++ b/src/python/pants/backend/java/bsp/rules.py @@ -6,12 +6,12 @@ from dataclasses import dataclass from pants.backend.java.bsp.spec import JavacOptionsItem, JavacOptionsParams, JavacOptionsResult -from pants.backend.java.target_types import JavaSourceField +from pants.backend.java.target_types import JavaFieldSet, JavaSourceField from pants.base.build_root import BuildRoot from pants.base.specs import AddressSpecs from pants.bsp.protocol import BSPHandlerMapping from pants.bsp.spec.base import BuildTargetIdentifier, StatusCode -from pants.bsp.util_rules.compile import BSPCompileFieldSet, BSPCompileResult +from pants.bsp.util_rules.compile import BSPCompileRequest, BSPCompileResult from pants.bsp.util_rules.lifecycle import BSPLanguageSupport from pants.bsp.util_rules.targets import ( BSPBuildTargets, @@ -20,7 +20,7 @@ ) from pants.engine.addresses import Addresses from pants.engine.fs import CreateDigest, DigestEntries -from pants.engine.internals.native_engine import EMPTY_DIGEST, AddPrefix, Digest +from pants.engine.internals.native_engine import EMPTY_DIGEST, Digest from pants.engine.internals.selectors import Get, MultiGet from pants.engine.rules import collect_rules, rule from pants.engine.target import CoarsenedTargets, FieldSet, Targets @@ -132,39 +132,52 @@ async def bsp_javac_options_request(request: JavacOptionsParams) -> JavacOptions @dataclass(frozen=True) -class JavaBSPCompileFieldSet(BSPCompileFieldSet): - required_fields = (JavaSourceField,) - source: JavaSourceField +class JavaBSPCompileRequest(BSPCompileRequest): + field_set_type = JavaFieldSet @rule async def bsp_java_compile_request( - request: JavaBSPCompileFieldSet, classpath_entry_request: ClasspathEntryRequestFactory + request: JavaBSPCompileRequest, classpath_entry_request: ClasspathEntryRequestFactory ) -> BSPCompileResult: - coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address])) - assert len(coarsened_targets) == 1 - coarsened_target = coarsened_targets[0] - resolve = await Get(CoursierResolveKey, CoarsenedTargets([coarsened_target])) - - result = await Get( - FallibleClasspathEntry, - ClasspathEntryRequest, - classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve), + coarsened_targets = await Get( + CoarsenedTargets, Addresses([fs.address for fs in request.field_sets]) ) - _logger.info(f"java compile result = {result}") - output_digest = EMPTY_DIGEST - if result.exit_code == 0 and result.output: - entries = await Get(DigestEntries, Digest, result.output.digest) - new_entires = [ - dataclasses.replace(entry, path=os.path.basename(entry.path)) for entry in entries - ] - flat_digest = await Get(Digest, CreateDigest(new_entires)) - output_digest = await Get( - Digest, AddPrefix(flat_digest, f"jvm/resolves/{resolve.name}/lib") + resolve = await Get(CoursierResolveKey, CoarsenedTargets, coarsened_targets) + + results = await MultiGet( + Get( + FallibleClasspathEntry, + ClasspathEntryRequest, + classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve), ) + for coarsened_target in coarsened_targets + ) + + status = StatusCode.OK + if any(r.exit_code != 0 for r in results): + status = StatusCode.ERROR + + output_digest = EMPTY_DIGEST + if status == StatusCode.OK: + output_entries = [] + for result in results: + if not result.output: + continue + entries = await Get(DigestEntries, Digest, result.output.digest) + output_entries.extend( + [ + dataclasses.replace( + entry, + path=f"jvm/resolves/{resolve.name}/lib/{os.path.basename(entry.path)}", + ) + for entry in entries + ] + ) + output_digest = await Get(Digest, CreateDigest(output_entries)) return BSPCompileResult( - status=StatusCode.ERROR if result.exit_code != 0 else StatusCode.OK, + status=status, output_digest=output_digest, ) @@ -175,5 +188,5 @@ def rules(): UnionRule(BSPLanguageSupport, JavaBSPLanguageSupport), UnionRule(BSPBuildTargetsMetadataRequest, JavaBSPBuildTargetsMetadataRequest), UnionRule(BSPHandlerMapping, JavacOptionsHandlerMapping), - UnionRule(BSPCompileFieldSet, JavaBSPCompileFieldSet), + UnionRule(BSPCompileRequest, JavaBSPCompileRequest), ) diff --git a/src/python/pants/backend/scala/bsp/rules.py b/src/python/pants/backend/scala/bsp/rules.py index f91e0990046..db141514e47 100644 --- a/src/python/pants/backend/scala/bsp/rules.py +++ b/src/python/pants/backend/scala/bsp/rules.py @@ -19,13 +19,13 @@ ScalaTestClassesResult, ) from pants.backend.scala.subsystems.scala import ScalaSubsystem -from pants.backend.scala.target_types import ScalaSourceField +from pants.backend.scala.target_types import ScalaFieldSet, ScalaSourceField from pants.base.build_root import BuildRoot from pants.base.specs import AddressSpecs from pants.bsp.protocol import BSPHandlerMapping from pants.bsp.spec.base import BuildTarget, BuildTargetIdentifier, StatusCode from pants.bsp.spec.targets import DependencyModule -from pants.bsp.util_rules.compile import BSPCompileFieldSet, BSPCompileResult +from pants.bsp.util_rules.compile import BSPCompileRequest, BSPCompileResult from pants.bsp.util_rules.lifecycle import BSPLanguageSupport from pants.bsp.util_rules.targets import ( BSPBuildTargetInternal, @@ -405,40 +405,53 @@ async def scala_bsp_dependency_modules( @dataclass(frozen=True) -class ScalaBSPCompileFieldSet(BSPCompileFieldSet): - required_fields = (ScalaSourceField,) - source: ScalaSourceField +class ScalaBSPCompileRequest(BSPCompileRequest): + field_set_type = ScalaFieldSet @rule async def bsp_scala_compile_request( - request: ScalaBSPCompileFieldSet, + request: ScalaBSPCompileRequest, classpath_entry_request: ClasspathEntryRequestFactory, ) -> BSPCompileResult: - coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address])) - assert len(coarsened_targets) == 1 - coarsened_target = coarsened_targets[0] - resolve = await Get(CoursierResolveKey, CoarsenedTargets([coarsened_target])) - - result = await Get( - FallibleClasspathEntry, - ClasspathEntryRequest, - classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve), + coarsened_targets = await Get( + CoarsenedTargets, Addresses([fs.address for fs in request.field_sets]) ) - _logger.info(f"scala compile result = {result}") - output_digest = EMPTY_DIGEST - if result.exit_code == 0 and result.output: - entries = await Get(DigestEntries, Digest, result.output.digest) - new_entires = [ - dataclasses.replace(entry, path=os.path.basename(entry.path)) for entry in entries - ] - flat_digest = await Get(Digest, CreateDigest(new_entires)) - output_digest = await Get( - Digest, AddPrefix(flat_digest, f"jvm/resolves/{resolve.name}/lib") + resolve = await Get(CoursierResolveKey, CoarsenedTargets, coarsened_targets) + + results = await MultiGet( + Get( + FallibleClasspathEntry, + ClasspathEntryRequest, + classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve), ) + for coarsened_target in coarsened_targets + ) + + status = StatusCode.OK + if any(r.exit_code != 0 for r in results): + status = StatusCode.ERROR + + output_digest = EMPTY_DIGEST + if status == StatusCode.OK: + output_entries = [] + for result in results: + if not result.output: + continue + entries = await Get(DigestEntries, Digest, result.output.digest) + output_entries.extend( + [ + dataclasses.replace( + entry, + path=f"jvm/resolves/{resolve.name}/lib/{os.path.basename(entry.path)}", + ) + for entry in entries + ] + ) + output_digest = await Get(Digest, CreateDigest(output_entries)) return BSPCompileResult( - status=StatusCode.ERROR if result.exit_code != 0 else StatusCode.OK, + status=status, output_digest=output_digest, ) @@ -451,6 +464,6 @@ def rules(): UnionRule(BSPHandlerMapping, ScalacOptionsHandlerMapping), UnionRule(BSPHandlerMapping, ScalaMainClassesHandlerMapping), UnionRule(BSPHandlerMapping, ScalaTestClassesHandlerMapping), - UnionRule(BSPCompileFieldSet, ScalaBSPCompileFieldSet), + UnionRule(BSPCompileRequest, ScalaBSPCompileRequest), UnionRule(BSPDependencyModulesRequest, ScalaBSPDependencyModulesRequest), ) diff --git a/src/python/pants/bsp/util_rules/compile.py b/src/python/pants/bsp/util_rules/compile.py index 7895221cd73..6323569ef5f 100644 --- a/src/python/pants/bsp/util_rules/compile.py +++ b/src/python/pants/bsp/util_rules/compile.py @@ -1,30 +1,43 @@ # Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +from __future__ import annotations + import logging import time import uuid +from collections import defaultdict from dataclasses import dataclass +from typing import ClassVar, Generic, Type, TypeVar +from pants.base.specs import AddressSpecs from pants.bsp.context import BSPContext from pants.bsp.protocol import BSPHandlerMapping -from pants.bsp.spec.base import StatusCode, TaskId +from pants.bsp.spec.base import BuildTargetIdentifier, StatusCode, TaskId from pants.bsp.spec.compile import CompileParams, CompileReport, CompileResult, CompileTask from pants.bsp.spec.task import TaskFinishParams, TaskStartParams -from pants.build_graph.address import AddressInput +from pants.bsp.util_rules.targets import BSPBuildTargetInternal from pants.engine.fs import Workspace from pants.engine.internals.native_engine import EMPTY_DIGEST, Digest, MergeDigests -from pants.engine.internals.selectors import Get +from pants.engine.internals.selectors import Get, MultiGet from pants.engine.rules import _uncacheable_rule, collect_rules -from pants.engine.target import FieldSet, WrappedTarget +from pants.engine.target import FieldSet, Targets from pants.engine.unions import UnionMembership, UnionRule, union +from pants.util.ordered_set import FrozenOrderedSet _logger = logging.getLogger(__name__) +_FS = TypeVar("_FS", bound=FieldSet) + @union @dataclass(frozen=True) -class BSPCompileFieldSet(FieldSet): - """FieldSet used to hook BSP compilation.""" +class BSPCompileRequest(Generic[_FS]): + """Hook to allow language backends to compile targets.""" + + field_set_type: ClassVar[Type[_FS]] + + bsp_target: BSPBuildTargetInternal + field_sets: tuple[_FS, ...] @dataclass(frozen=True) @@ -41,56 +54,102 @@ class CompileRequestHandlerMapping(BSPHandlerMapping): response_type = CompileResult +@dataclass(frozen=True) +class CompileOneBSPTargetRequest: + bsp_target: BSPBuildTargetInternal + + # A unique identifier generated by the client to identify this request. + # The server may include this id in triggered notifications or responses. + origin_id: str | None = None + + # Optional arguments to the compilation process. + arguments: tuple[str, ...] | None = () + + @_uncacheable_rule -async def bsp_compile_request( - request: CompileParams, +async def compile_bsp_target( + request: CompileOneBSPTargetRequest, bsp_context: BSPContext, union_membership: UnionMembership, +) -> BSPCompileResult: + targets = await Get(Targets, AddressSpecs, request.bsp_target.specs.address_specs) + compile_request_types: FrozenOrderedSet[Type[BSPCompileRequest]] = union_membership.get( + BSPCompileRequest + ) + field_sets_by_request_type: dict[Type[BSPCompileRequest], set[FieldSet]] = defaultdict(set) + for target in targets: + for compile_request_type in compile_request_types: + field_set_type = compile_request_type.field_set_type + if field_set_type.is_applicable(target): + field_set = field_set_type.create(target) + field_sets_by_request_type[compile_request_type].add(field_set) + + task_id = TaskId(id=uuid.uuid4().hex) + + bsp_context.notify_client( + TaskStartParams( + task_id=task_id, + event_time=int(time.time() * 1000), + data=CompileTask(target=request.bsp_target.bsp_target_id), + ) + ) + + compile_results = await MultiGet( + Get( + BSPCompileResult, + BSPCompileRequest, + compile_request_type(bsp_target=request.bsp_target, field_sets=tuple(field_sets)), + ) + for compile_request_type, field_sets in field_sets_by_request_type.items() + ) + + status = StatusCode.OK + if any(r.status != StatusCode.OK for r in compile_results): + status = StatusCode.ERROR + + bsp_context.notify_client( + TaskFinishParams( + task_id=task_id, + event_time=int(time.time() * 1000), + status=status, + data=CompileReport( + target=request.bsp_target.bsp_target_id, + origin_id=request.origin_id, + errors=0, + warnings=0, + ), + ) + ) + + output_digest = await Get(Digest, MergeDigests([r.output_digest for r in compile_results])) + + return BSPCompileResult( + status=status, + output_digest=output_digest, + ) + + +@_uncacheable_rule +async def bsp_compile_request( + request: CompileParams, workspace: Workspace, ) -> CompileResult: - compile_field_sets = union_membership.get(BSPCompileFieldSet) - compile_results = [] - for bsp_target_id in request.targets: - # TODO: MultiGet these all. - - wrapped_tgt = await Get(WrappedTarget, AddressInput, bsp_target_id.address_input) - tgt = wrapped_tgt.target - _logger.info(f"tgt = {tgt}") - applicable_field_set_impls = [] - for impl in compile_field_sets: - if impl.is_applicable(tgt): - applicable_field_set_impls.append(impl) - _logger.info(f"applicable_field_sets = {applicable_field_set_impls}") - if len(applicable_field_set_impls) == 0: - raise ValueError(f"no applicable field set for: {tgt.address}") - elif len(applicable_field_set_impls) > 1: - raise ValueError(f"ambiguous field set mapping, >1 for: {tgt.address}") - - field_set = applicable_field_set_impls[0].create(tgt) - - task_id = TaskId(id=request.origin_id or uuid.uuid4().hex) - - bsp_context.notify_client( - TaskStartParams( - task_id=task_id, - event_time=int(time.time() * 1000), - data=CompileTask(target=bsp_target_id), - ) - ) + bsp_targets = await MultiGet( + Get(BSPBuildTargetInternal, BuildTargetIdentifier, bsp_target_id) + for bsp_target_id in request.targets + ) - compile_result = await Get(BSPCompileResult, BSPCompileFieldSet, field_set) - compile_results.append(compile_result) - - bsp_context.notify_client( - TaskFinishParams( - task_id=task_id, - event_time=int(time.time() * 1000), - status=compile_result.status, - data=CompileReport( - target=bsp_target_id, origin_id=request.origin_id, errors=0, warnings=0 - ), - ) + compile_results = await MultiGet( + Get( + BSPCompileResult, + CompileOneBSPTargetRequest( + bsp_target=bsp_target, + origin_id=request.origin_id, + arguments=request.arguments, + ), ) + for bsp_target in bsp_targets + ) output_digest = await Get(Digest, MergeDigests([r.output_digest for r in compile_results])) if output_digest != EMPTY_DIGEST: diff --git a/src/python/pants/bsp/util_rules/targets.py b/src/python/pants/bsp/util_rules/targets.py index 9331b9a748b..38f57e5b114 100644 --- a/src/python/pants/bsp/util_rules/targets.py +++ b/src/python/pants/bsp/util_rules/targets.py @@ -79,6 +79,10 @@ class BSPBuildTargetInternal: name: str specs: Specs + @property + def bsp_target_id(self) -> BuildTargetIdentifier: + return BuildTargetIdentifier(f"pants:{self.name}") + @dataclass(frozen=True) class BSPBuildTargetSourcesInfo: