Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor prior_formatter_result and its usage #14987

Merged
merged 2 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 9 additions & 34 deletions src/python/pants/backend/codegen/protobuf/lint/buf/format_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from pants.core.goals.fmt import FmtRequest, FmtResult
from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.core.util_rules.system_binaries import (
BinaryShims,
BinaryShimsRequest,
Expand Down Expand Up @@ -45,14 +44,8 @@ class BufFormatRequest(FmtRequest):
name = "buf-format"


@dataclass(frozen=True)
class Setup:
process: Process
original_snapshot: Snapshot


@rule(level=LogLevel.DEBUG)
async def setup_buf_format(request: BufFormatRequest, buf: BufSubsystem) -> Setup:
async def setup_buf_format(request: BufFormatRequest, buf: BufSubsystem) -> Process:
diff_binary = await Get(DiffBinary, DiffBinaryRequest())
download_buf_get = Get(
DownloadedExternalTool, ExternalToolRequest, buf.get_request(Platform.current)
Expand All @@ -66,23 +59,11 @@ async def setup_buf_format(request: BufFormatRequest, buf: BufSubsystem) -> Setu
output_directory=".bin",
),
)
source_files_get = Get(
SourceFiles,
SourceFilesRequest(field_set.sources for field_set in request.field_sets),
)
downloaded_buf, binary_shims, source_files = await MultiGet(
download_buf_get, binary_shims_get, source_files_get
)

source_files_snapshot = (
source_files.snapshot
if request.prior_formatter_result is None
else request.prior_formatter_result
)
downloaded_buf, binary_shims = await MultiGet(download_buf_get, binary_shims_get)

input_digest = await Get(
Digest,
MergeDigests((source_files_snapshot.digest, downloaded_buf.digest, binary_shims.digest)),
MergeDigests((request.snapshot.digest, downloaded_buf.digest, binary_shims.digest)),
)

argv = [
Expand All @@ -91,35 +72,29 @@ async def setup_buf_format(request: BufFormatRequest, buf: BufSubsystem) -> Setu
"-w",
*buf.format_args,
"--path",
",".join(source_files_snapshot.files),
",".join(request.snapshot.files),
]
process = Process(
argv=argv,
input_digest=input_digest,
output_files=source_files_snapshot.files,
output_files=request.snapshot.files,
description=f"Run buf format on {pluralize(len(request.field_sets), 'file')}.",
level=LogLevel.DEBUG,
env={
"PATH": binary_shims.bin_directory,
},
)
return Setup(process, original_snapshot=source_files_snapshot)
return process


@rule(desc="Format with buf format", level=LogLevel.DEBUG)
async def run_buf_format(request: BufFormatRequest, buf: BufSubsystem) -> FmtResult:
if buf.skip_format:
return FmtResult.skip(formatter_name=request.name)
setup = await Get(Setup, BufFormatRequest, request)
result = await Get(ProcessResult, Process, setup.process)
process = await Get(Process, BufFormatRequest, request)
result = await Get(ProcessResult, Process, process)
thejcannon marked this conversation as resolved.
Show resolved Hide resolved
output_snapshot = await Get(Snapshot, Digest, result.output_digest)
return FmtResult(
setup.original_snapshot,
output_snapshot,
stdout=result.stdout.decode(),
stderr=result.stderr.decode(),
formatter_name=request.name,
)
return FmtResult.create(request, result, output_snapshot)


def rules():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def run_buf(
fmt_result = rule_runner.request(
FmtResult,
[
BufFormatRequest(field_sets, prior_formatter_result=input_sources.snapshot),
BufFormatRequest(field_sets, snapshot=input_sources.snapshot),
],
)

Expand Down
41 changes: 9 additions & 32 deletions src/python/pants/backend/go/lint/gofmt/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pants.backend.go.subsystems.golang import GoRoot
from pants.backend.go.target_types import GoPackageSourcesField
from pants.core.goals.fmt import FmtRequest, FmtResult
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.fs import Digest
from pants.engine.internals.native_engine import Snapshot
from pants.engine.internals.selectors import Get
Expand Down Expand Up @@ -40,53 +39,31 @@ class GofmtRequest(FmtRequest):
name = GofmtSubsystem.options_scope


@dataclass(frozen=True)
class Setup:
process: Process
original_snapshot: Snapshot


@rule(level=LogLevel.DEBUG)
async def setup_gofmt(request: GofmtRequest, goroot: GoRoot) -> Setup:
source_files = await Get(
SourceFiles,
SourceFilesRequest(field_set.sources for field_set in request.field_sets),
)
source_files_snapshot = (
source_files.snapshot
if request.prior_formatter_result is None
else request.prior_formatter_result
)

async def setup_gofmt(request: GofmtRequest, goroot: GoRoot) -> Process:
argv = (
os.path.join(goroot.path, "bin/gofmt"),
"-w",
*source_files_snapshot.files,
*request.snapshot.files,
)
process = Process(
argv=argv,
input_digest=source_files_snapshot.digest,
output_files=source_files_snapshot.files,
description=f"Run gofmt on {pluralize(len(source_files_snapshot.files), 'file')}.",
input_digest=request.snapshot.digest,
output_files=request.snapshot.files,
description=f"Run gofmt on {pluralize(len(request.snapshot.files), 'file')}.",
level=LogLevel.DEBUG,
)
return Setup(process=process, original_snapshot=source_files_snapshot)
return process


@rule(desc="Format with gofmt")
async def gofmt_fmt(request: GofmtRequest, gofmt: GofmtSubsystem) -> FmtResult:
if gofmt.skip:
return FmtResult.skip(formatter_name=request.name)
setup = await Get(Setup, GofmtRequest, request)
result = await Get(ProcessResult, Process, setup.process)
process = await Get(Process, GofmtRequest, request)
result = await Get(ProcessResult, Process, process)
output_snapshot = await Get(Snapshot, Digest, result.output_digest)
return FmtResult(
setup.original_snapshot,
output_snapshot,
stdout=result.stdout.decode(),
stderr=result.stderr.decode(),
formatter_name=request.name,
)
return FmtResult.create(request, result, output_snapshot)


def rules():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def run_gofmt(
fmt_result = rule_runner.request(
FmtResult,
[
GofmtRequest(field_sets, prior_formatter_result=input_sources.snapshot),
GofmtRequest(field_sets, snapshot=input_sources.snapshot),
],
)
return fmt_result
Expand Down
36 changes: 8 additions & 28 deletions src/python/pants/backend/java/lint/google_java_format/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from pants.backend.java.target_types import JavaSourceField
from pants.core.goals.fmt import FmtRequest, FmtResult
from pants.core.goals.generate_lockfiles import GenerateToolLockfileSentinel
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.fs import Digest
from pants.engine.internals.native_engine import Snapshot
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.internals.selectors import Get
from pants.engine.process import ProcessResult
from pants.engine.rules import collect_rules, rule
from pants.engine.target import FieldSet, Target
Expand All @@ -21,7 +20,7 @@
from pants.jvm.resolve.coursier_fetch import ToolClasspath, ToolClasspathRequest
from pants.jvm.resolve.jvm_tool import GenerateJvmLockfileFromTool
from pants.util.logging import LogLevel
from pants.util.strutil import pluralize, strip_v2_chroot_path
from pants.util.strutil import pluralize

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +48,6 @@ class GoogleJavaFormatToolLockfileSentinel(GenerateToolLockfileSentinel):
@dataclass(frozen=True)
class Setup:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to elide this class with the JvmProcess directly causes a graph error :|

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the trouble there. If you want to open a ticket about that linked to #11269 feel free.

process: JvmProcess
original_snapshot: Snapshot


@rule(level=LogLevel.DEBUG)
Expand All @@ -62,19 +60,7 @@ async def setup_google_java_format(
lockfile_request = await Get(
GenerateJvmLockfileFromTool, GoogleJavaFormatToolLockfileSentinel()
)
source_files, tool_classpath = await MultiGet(
Get(
SourceFiles,
SourceFilesRequest(field_set.source for field_set in request.field_sets),
),
Get(ToolClasspath, ToolClasspathRequest(lockfile=lockfile_request)),
)

source_files_snapshot = (
source_files.snapshot
if request.prior_formatter_result is None
else request.prior_formatter_result
)
tool_classpath = await Get(ToolClasspath, ToolClasspathRequest(lockfile=lockfile_request))

toolcp_relpath = "__toolcp"
extra_immutable_input_digests = {
Expand All @@ -96,22 +82,22 @@ async def setup_google_java_format(
"com.google.googlejavaformat.java.Main",
*(["--aosp"] if tool.aosp else []),
"--replace",
*source_files_snapshot.files,
*request.snapshot.files,
]

process = JvmProcess(
jdk=jdk,
argv=args,
classpath_entries=tool_classpath.classpath_entries(toolcp_relpath),
input_digest=source_files_snapshot.digest,
input_digest=request.snapshot.digest,
extra_immutable_input_digests=extra_immutable_input_digests,
extra_nailgun_keys=extra_immutable_input_digests,
output_files=source_files_snapshot.files,
output_files=request.snapshot.files,
description=f"Run Google Java Format on {pluralize(len(request.field_sets), 'file')}.",
level=LogLevel.DEBUG,
)

return Setup(process, original_snapshot=source_files_snapshot)
return Setup(process)


@rule(desc="Format with Google Java Format", level=LogLevel.DEBUG)
Expand All @@ -123,13 +109,7 @@ async def google_java_format_fmt(
setup = await Get(Setup, GoogleJavaFormatRequest, request)
result = await Get(ProcessResult, JvmProcess, setup.process)
output_snapshot = await Get(Snapshot, Digest, result.output_digest)
return FmtResult(
setup.original_snapshot,
output_snapshot,
stdout=strip_v2_chroot_path(result.stdout),
stderr=strip_v2_chroot_path(result.stderr),
formatter_name=request.name,
)
return FmtResult.create(request, result, output_snapshot, strip_chroot_path=True)


@rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def run_google_java_format(rule_runner: RuleRunner, targets: list[Target]) -> Fm
fmt_result = rule_runner.request(
FmtResult,
[
GoogleJavaFormatRequest(field_sets, prior_formatter_result=input_sources.snapshot),
GoogleJavaFormatRequest(field_sets, snapshot=input_sources.snapshot),
],
)
return fmt_result
Expand Down
46 changes: 10 additions & 36 deletions src/python/pants/backend/python/lint/autoflake/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@
from pants.backend.python.util_rules import pex
from pants.backend.python.util_rules.pex import PexRequest, VenvPex, VenvPexProcess
from pants.core.goals.fmt import FmtRequest, FmtResult
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.fs import Digest
from pants.engine.internals.native_engine import Snapshot
from pants.engine.internals.selectors import MultiGet
from pants.engine.process import Process, ProcessResult
from pants.engine.rules import Get, collect_rules, rule
from pants.engine.target import FieldSet, Target
from pants.engine.unions import UnionRule
from pants.util.logging import LogLevel
from pants.util.strutil import pluralize, strip_v2_chroot_path
from pants.util.strutil import pluralize


@dataclass(frozen=True)
Expand All @@ -38,27 +36,9 @@ class AutoflakeRequest(FmtRequest):
name = Autoflake.options_scope


@dataclass(frozen=True)
class Setup:
process: Process
original_snapshot: Snapshot


@rule(level=LogLevel.DEBUG)
async def setup_autoflake(request: AutoflakeRequest, autoflake: Autoflake) -> Setup:
autoflake_pex_get = Get(VenvPex, PexRequest, autoflake.to_pex_request())

source_files_get = Get(
SourceFiles,
SourceFilesRequest(field_set.source for field_set in request.field_sets),
)

source_files, autoflake_pex = await MultiGet(source_files_get, autoflake_pex_get)
source_files_snapshot = (
source_files.snapshot
if request.prior_formatter_result is None
else request.prior_formatter_result
)
async def setup_autoflake(request: AutoflakeRequest, autoflake: Autoflake) -> Process:
autoflake_pex = await Get(VenvPex, PexRequest, autoflake.to_pex_request())

process = await Get(
Process,
Expand All @@ -68,31 +48,25 @@ async def setup_autoflake(request: AutoflakeRequest, autoflake: Autoflake) -> Se
"--in-place",
"--remove-all-unused-imports",
*autoflake.args,
*source_files_snapshot.files,
*request.snapshot.files,
),
input_digest=source_files_snapshot.digest,
output_files=source_files_snapshot.files,
input_digest=request.snapshot.digest,
output_files=request.snapshot.files,
description=f"Run Autoflake on {pluralize(len(request.field_sets), 'file')}.",
level=LogLevel.DEBUG,
),
)
return Setup(process, original_snapshot=source_files_snapshot)
return process


@rule(desc="Format with Autoflake", level=LogLevel.DEBUG)
async def autoflake_fmt(request: AutoflakeRequest, autoflake: Autoflake) -> FmtResult:
if autoflake.skip:
return FmtResult.skip(formatter_name=request.name)
setup = await Get(Setup, AutoflakeRequest, request)
result = await Get(ProcessResult, Process, setup.process)
process = await Get(Process, AutoflakeRequest, request)
result = await Get(ProcessResult, Process, process)
output_snapshot = await Get(Snapshot, Digest, result.output_digest)
return FmtResult(
setup.original_snapshot,
output_snapshot,
stdout=strip_v2_chroot_path(result.stdout),
stderr=strip_v2_chroot_path(result.stderr),
formatter_name=request.name,
)
return FmtResult.create(request, result, output_snapshot, strip_chroot_path=True)


def rules():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run_autoflake(
fmt_result = rule_runner.request(
FmtResult,
[
AutoflakeRequest(field_sets, prior_formatter_result=input_sources.snapshot),
AutoflakeRequest(field_sets, snapshot=input_sources.snapshot),
],
)
return fmt_result
Expand Down
Loading