Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Experimental CLI --worker-id / --worker-close support #50

Draft
wants to merge 11 commits into
base: wavewave/extra-libraries
Choose a base branch
from
10 changes: 9 additions & 1 deletion haskell/compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def _dynamic_target_metadata_impl(actions, output, arg, pkg_deps) -> list[Provid
md_args = cmd_args(arg.md_gen)
md_args.add(packages_info.bin_paths)
md_args.add("--ghc", arg.haskell_toolchain.compiler)
if arg.haskell_toolchain.use_persistent_workers:
md_args.add("--worker-target-id", "haskell_metadata")
md_args.add(cmd_args(ghc_args, format="--ghc-arg={}"))
md_args.add(
"--source-prefix",
Expand Down Expand Up @@ -240,6 +242,9 @@ def target_metadata(
md_file = ctx.actions.declare_output(ctx.attrs.name + suffix + ".md.json")
md_gen = ctx.attrs._generate_target_metadata[RunInfo]

libname = repr(ctx.label.path).replace("//", "_").replace("/", "_") + "_" + ctx.label.name
pkgname = libname.replace("_", "-")

haskell_toolchain = ctx.attrs._haskell_toolchain[HaskellToolchainInfo]
toolchain_libs = [dep.name for dep in attr_deps_haskell_toolchain_libraries(ctx)]

Expand Down Expand Up @@ -441,9 +446,12 @@ def _common_compile_module_args(
direct_deps_info: list[HaskellLibraryInfoTSet],
pkgname: str | None = None,
) -> CommonCompileModuleArgs:

command = cmd_args(ghc_wrapper)
command.add("--ghc", haskell_toolchain.compiler)

if haskell_toolchain.use_persistent_workers and pkgname:
worker_target_id = pkgname
command.add("--worker-target-id", worker_target_id)
# Some rules pass in RTS (e.g. `+RTS ... -RTS`) options for GHC, which can't
# be parsed when inside an argsfile.
command.add(haskell_toolchain.compiler_flags)
Expand Down
5 changes: 5 additions & 0 deletions haskell/haskell.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,10 @@ def _dynamic_link_shared_impl(actions, pkg_deps, lib, arg):
link_cmd = cmd_args(link_cmd_args, hidden = link_cmd_hidden)
link_cmd.add("-o", lib)

if arg.haskell_toolchain.use_persistent_workers:
link_cmd.add("--worker-target-id={}".format(arg.worker_target_id))
link_cmd.add("--worker-close")

actions.run(
link_cmd,
category = "haskell_link" + arg.artifact_suffix.replace("-", "_"),
Expand Down Expand Up @@ -693,6 +697,7 @@ def _build_haskell_lib(
objects = objects,
toolchain_libs = toolchain_libs,
use_argsfile_at_link = ctx.attrs.use_argsfile_at_link,
worker_target_id = pkgname,
),
))

Expand Down
1 change: 1 addition & 0 deletions haskell/toolchain.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ HaskellToolchainInfo = provider(
"cache_links": provider_field(typing.Any, default = None),
"script_template_processor": provider_field(typing.Any, default = None),
"packages": provider_field(typing.Any, default = None),
"use_persistent_workers": provider_field(typing.Any, default = None),
},
)

Expand Down
17 changes: 13 additions & 4 deletions haskell/tools/generate_target_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def main():
required=True,
type=argparse.FileType("w"),
help="Write package metadata to this file in JSON format.")
parser.add_argument(
"--worker-target-id",
required=False,
type=str,
help="Worker id")
parser.add_argument(
"--ghc",
required=True,
Expand Down Expand Up @@ -86,7 +91,7 @@ def json_default_handler(o):

def obtain_target_metadata(args):
paths = [str(binpath) for binpath in args.bin_path if binpath.is_dir()]
ghc_depends = run_ghc_depends(args.ghc, args.ghc_arg, args.source, paths)
ghc_depends = run_ghc_depends(args.ghc, args.ghc_arg, args.source, paths, args.worker_target_id)
th_modules = determine_th_modules(ghc_depends)
module_mapping = determine_module_mapping(ghc_depends, args.source_prefix)
module_graph = determine_module_graph(ghc_depends)
Expand Down Expand Up @@ -185,20 +190,24 @@ def determine_package_deps(ghc_depends):
return package_deps


def run_ghc_depends(ghc, ghc_args, sources, aux_paths):
def run_ghc_depends(ghc, ghc_args, sources, aux_paths, worker_target_id):
with tempfile.TemporaryDirectory() as dname:
json_fname = os.path.join(dname, "depends.json")
make_fname = os.path.join(dname, "depends.make")
haskell_sources = list(filter(is_haskell_src, sources))

haskell_boot_sources = list(filter (is_haskell_boot, sources))
if worker_target_id:
worker_args = ["--worker-target-id={}".format(worker_target_id)]
else:
worker_args = []
args = [
ghc, "-M", "-include-pkg-deps",
# Note: `-outputdir '.'` removes the prefix of all targets:
# backend/src/Foo/Util.<ext> => Foo/Util.<ext>
"-outputdir", ".",
"-dep-json", json_fname,
"-dep-makefile", make_fname,
] + ghc_args + haskell_sources
] + worker_args + ghc_args + haskell_sources + haskell_boot_sources

env = os.environ.copy()
path = env.get("PATH", "")
Expand Down
25 changes: 20 additions & 5 deletions haskell/tools/ghc_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def main():
default=[],
help="Path to a package db that is used during the module compilation",
)
parser.add_argument(
"--worker-target-id", required=False, type=str, help="worker target id",
)
parser.add_argument(
"--worker-close", required=False, type=bool, default=False, help="worker close",
)
parser.add_argument(
"--ghc", required=True, type=str, help="Path to the Haskell compiler GHC."
)
Expand Down Expand Up @@ -75,8 +81,13 @@ def main():
)

args, ghc_args = parser.parse_known_args()

cmd = [args.ghc] + ghc_args
if args.worker_target_id:
worker_args = ["--worker-target-id={}".format(args.worker_target_id)] + (["--worker-close"] if args.worker_close else [])
use_persistent_workers = True
else:
worker_args = []
use_persistent_workers = False
cmd = [args.ghc] + worker_args + ghc_args

aux_paths = [str(binpath) for binpath in args.bin_path if binpath.is_dir()] + [str(os.path.dirname(binexepath)) for binexepath in args.bin_exe]
env = os.environ.copy()
Expand All @@ -99,7 +110,7 @@ def main():
if returncode != 0:
return returncode

recompute_abi_hash(args.ghc, args.abi_out)
recompute_abi_hash(args.ghc, args.abi_out, use_persistent_workers)

# write an empty dep file, to signal that all tagged files are unused
try:
Expand Down Expand Up @@ -127,11 +138,15 @@ def main():
return 0


def recompute_abi_hash(ghc, abi_out):
def recompute_abi_hash(ghc, abi_out, use_persistent_workers):
"""Call ghc on the hi file and write the ABI hash to abi_out."""
hi_file = abi_out.with_suffix("")
if use_persistent_workers:
worker_args = ["--worker-target-id=show-iface-abi-hash"]
else:
worker_args = []

cmd = [ghc, "-v0", "-package-env=-", "--show-iface-abi-hash", hi_file]
cmd = [ghc, "-v0", "-package-env=-", "--show-iface-abi-hash", hi_file] + worker_args

hash = subprocess.check_output(cmd, text=True).split(maxsplit=1)[0]

Expand Down