diff --git a/compiler_gym/envs/llvm/__init__.py b/compiler_gym/envs/llvm/__init__.py index 535c524cc..638fd7617 100644 --- a/compiler_gym/envs/llvm/__init__.py +++ b/compiler_gym/envs/llvm/__init__.py @@ -12,6 +12,9 @@ ClangInvocation, get_system_library_flags, make_benchmark, + make_benchmark_from_source, + merge_benchmarks, + split_benchmark_by_function, ) from compiler_gym.envs.llvm.llvm_env import LlvmEnv @@ -30,8 +33,11 @@ "LLVM_SERVICE_BINARY", "LlvmEnv", "make_benchmark", + "make_benchmark_from_source", + "merge_benchmarks", "observation_spaces", "reward_spaces", + "split_benchmark_by_function", ] LLVM_SERVICE_BINARY = runfiles_path( diff --git a/compiler_gym/envs/llvm/llvm_benchmark.py b/compiler_gym/envs/llvm/llvm_benchmark.py index 454599539..193a961ed 100644 --- a/compiler_gym/envs/llvm/llvm_benchmark.py +++ b/compiler_gym/envs/llvm/llvm_benchmark.py @@ -10,6 +10,7 @@ import sys import tempfile from concurrent.futures import as_completed +from copy import deepcopy from datetime import datetime from functools import lru_cache from pathlib import Path @@ -175,6 +176,27 @@ def command(self, outpath: Path) -> List[str]: return cmd + # NOTE(cummins): There is some discussion about the best way to create a + # bitcode that is unoptimized yet does not hinder downstream + # optimization opportunities. Here we are using a configuration based on + # -O1 in which we prevent the -O1 optimization passes from running. This + # is because LLVM produces different function attributes dependening on + # the optimization level. E.g. "-O0 -Xclang -disable-llvm-optzns -Xclang + # -disable-O0-optnone" will generate code with "noinline" attributes set + # on the functions, wheras "-Oz -Xclang -disable-llvm-optzns" will + # generate functions with "minsize" and "optsize" attributes set. + # + # See also: + # + # + DEFAULT_COPT = [ + "-O1", + "-Xclang", + "-disable-llvm-passes", + "-Xclang", + "-disable-llvm-optzns", + ] + @classmethod def from_c_file( cls, @@ -184,29 +206,8 @@ def from_c_file( timeout: int = 600, ) -> "ClangInvocation": copt = copt or [] - # NOTE(cummins): There is some discussion about the best way to create a - # bitcode that is unoptimized yet does not hinder downstream - # optimization opportunities. Here we are using a configuration based on - # -O1 in which we prevent the -O1 optimization passes from running. This - # is because LLVM produces different function attributes dependening on - # the optimization level. E.g. "-O0 -Xclang -disable-llvm-optzns -Xclang - # -disable-O0-optnone" will generate code with "noinline" attributes set - # on the functions, wheras "-Oz -Xclang -disable-llvm-optzns" will - # generate functions with "minsize" and "optsize" attributes set. - # - # See also: - # - # - DEFAULT_COPT = [ - "-O1", - "-Xclang", - "-disable-llvm-passes", - "-Xclang", - "-disable-llvm-optzns", - ] - return cls( - DEFAULT_COPT + copt + [str(path)], + cls.DEFAULT_COPT + copt + [str(path)], system_includes=system_includes, timeout=timeout, ) @@ -422,3 +423,219 @@ def _add_path(path: Path): timestamp = datetime.now().strftime("%Y%m%HT%H%M%S") uri = f"benchmark://user-v0/{timestamp}-{random.randrange(16**4):04x}" return Benchmark.from_file_contents(uri, bitcode) + + +def make_benchmark_from_source( + source: str, + copt: Optional[List[str]] = None, + lang: str = "c++", + system_includes: bool = True, + timeout: int = 600, +) -> Benchmark: + """Create a benchmark from a string of source code. + + This function takes a string of source code and generates a benchmark that + can be passed to :meth:`compiler_gym.envs.LlvmEnv.reset`. + + Example usage: + + >>> benchmark = make_benchmark_from_source("int A() {return 0;}") + >>> env = gym.make("llvm-v0") + >>> env.reset(benchmark=benchmark) + + The clang invocation used is roughly equivalent to: + + .. code-block:: + + $ clang - -O0 -c -emit-llvm -o benchmark.bc + + Additional compile-time arguments to clang can be provided using the + :code:`copt` argument: + + >>> benchmark = make_benchmark_from_source("...", copt=['-O2']) + + :param source: A string of source code. + + :param copt: A list of command line options to pass to clang when compiling + source files. + + :param lang: The source language, passed to clang via the :code:`-x` + argument. Defaults to C++. + + :param system_includes: Whether to include the system standard libraries + during compilation jobs. This requires a system toolchain. See + :func:`get_system_library_flags`. + + :param timeout: The maximum number of seconds to allow clang to run before + terminating. + + :return: A :code:`Benchmark` instance. + + :raises FileNotFoundError: If any input sources are not found. + + :raises TypeError: If the inputs are of unsupported types. + + :raises OSError: If a suitable compiler cannot be found. + + :raises BenchmarkInitError: If a compilation job fails. + + :raises TimeoutExpired: If a compilation job exceeds :code:`timeout` + seconds. + """ + cmd = [ + str(llvm.clang_path()), + f"-x{lang}", + "-", + "-o", + "-", + "-c", + "-emit-llvm", + *ClangInvocation.DEFAULT_COPT, + ] + if system_includes: + cmd += get_system_library_flags() + cmd += copt or [] + + with Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE + ) as clang: + bitcode, stderr = clang.communicate(source.encode("utf-8"), timeout=timeout) + if clang.returncode: + raise BenchmarkInitError( + f"Failed to make benchmark with compiler error: {stderr.decode('utf-8')}" + ) + + timestamp = datetime.now().strftime("%Y%m%HT%H%M%S") + uri = f"benchmark://user-v0/{timestamp}-{random.randrange(16**4):04x}" + return Benchmark.from_file_contents(uri, bitcode) + + +def split_benchmark_by_function( + benchmark: Benchmark, maximum_function_count: int = 0, timeout: float = 300 +) -> List[Benchmark]: + """Split a benchmark into single-function benchmarks. + + This function takes a benchmark as input and divides it into a set of + independent benchmarks, where each benchmark contains a single function from + the input. + + Under the hood, this uses an extension to `llvm-extract + `__ to pull out + individual parts of programs. + + In pseudo code, this is roughly equivalent to: + + .. code-block::py + + for i in number_of_functions_in_benchmark(benchmark): + yield llvm_extract(benchmark, function_number=i) + + :param benchmark: A benchmark to split. + + :param maximum_function_count: If a positive integer, this specifies the + maximum number of single-function benchmarks to extract from the input. + If the input contains more than this number of functions, the remainder + are ignored. + + :param timeout: The maximum number of seconds to allow llvm-extract to run + before terminating. + + :return: A list of :code:`Benchmark` instances. + + :raises ValueError: If the input benchmark contains no functions, or if + llvm-extract fails. + + :raises TimeoutExpired: If any llvm-extract job exceeds :code:`timeout` + seconds. + """ + original_uri = deepcopy(benchmark.uri) + original_bitcode = benchmark.proto.program.contents + + # Count the number of functions in the benchmark. + with Popen( + [str(llvm.llvm_extract_one_path()), "-", "-count-only", "-o", "-"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + ) as p: + stdout, stderr = p.communicate(original_bitcode, timeout=timeout) + if p.returncode: + raise ValueError( + "Failed to count number of functions in benchmark: " + f"{stderr.decode('utf-8')}" + ) + number_of_functions = int(stdout.decode("utf-8")) + if number_of_functions <= 0: + raise ValueError("No functions found!") + + # Iterate over the number of functions, extracting each one in turn. + split_benchmarks: List[Benchmark] = [] + n = min(number_of_functions, maximum_function_count or number_of_functions) + for i in range(n): + with Popen( + [str(llvm.llvm_extract_one_path()), "-", "-n", str(i), "-o", "-"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + ) as p: + stdout, stderr = p.communicate(original_bitcode, timeout=timeout) + if p.returncode: + raise ValueError( + "Failed to extract function {i}: " f"{stderr.decode('utf-8')}" + ) + + original_uri.params["function"] = str(i) + split_benchmarks.append( + Benchmark.from_file_contents(uri=original_uri, data=stdout) + ) + logger.debug("Extracted %s", original_uri) + + return split_benchmarks + + +def merge_benchmarks(benchmarks: List[Benchmark], timeout: float = 300) -> Benchmark: + """Merge a list of benchmarks into a single benchmark. + + Under the hood, this `llvm-link + `__ to combine each of + the bitcodes of the input benchmarks into a single bitcode. + + :param benchmarks: A list of benchmarks to merge. + + :param timeout: The maximum number of seconds to allow llvm-link to run + before terminating. + + :return: A :code:`Benchmark` instance. + + :raises ValueError: If the input contains no benchmarks, or if llvm-link + fails. + + :raises TimeoutExpired: If llvm-link exceeds :code:`timeout` seconds. + """ + if not benchmarks: + raise ValueError("No benchmarks!") + + transient_cache = transient_cache_path(".") + transient_cache.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory(dir=transient_cache, prefix="llvm-link") as d: + tmpdir = Path(d) + + # Write each of the benchmark bitcodes to a temporary file. + cmd = [str(llvm.llvm_link_path()), "-o", "-", "-f"] + for i, benchmark in enumerate(benchmarks): + bitcode_path = tmpdir / f"{i}.bc" + with open(bitcode_path, "wb") as f: + f.write(benchmark.proto.program.contents) + cmd.append(str(bitcode_path)) + + # Run llvm-link on the temporary files. + with Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p: + stdout, stderr = p.communicate(timeout=timeout) + if p.returncode: + raise ValueError( + f"Failed to merge benchmarks: {stderr.decode('utf-8')}" + ) + + timestamp = datetime.now().strftime("%Y%m%HT%H%M%S") + uri = f"benchmark://llvm-link-v0/{timestamp}-{random.randrange(16**4):04x}" + return Benchmark.from_file_contents(uri=uri, data=stdout) diff --git a/compiler_gym/envs/llvm/service/BUILD b/compiler_gym/envs/llvm/service/BUILD index 0cf699002..24749827a 100644 --- a/compiler_gym/envs/llvm/service/BUILD +++ b/compiler_gym/envs/llvm/service/BUILD @@ -206,6 +206,50 @@ cc_library( ], ) +filegroup( + name = "llvm-extract-one-files", + srcs = [ + ":llvm-extract-one", + ] + select({ + "@llvm//:darwin": [], + "//conditions:default": [ + ":libLLVMPolly", + ], + }), + visibility = ["//visibility:public"], +) + +cc_binary( + name = "llvm-extract-one-prelinked", + srcs = ["LlvmExtractOne.cc"], + copts = [ + "-DGOOGLE_PROTOBUF_NO_RTTI", + "-fno-rtti", + "-std=c++17", + ], + deps = [ + "@llvm//10.0.0", + ], +) + +genrule( + name = "llvm-extract-one-bin", + srcs = [":llvm-extract-one-prelinked"], + outs = ["llvm-extract-one"], + cmd = select({ + "@llvm//:darwin": ( + "cp $(location :llvm-extract-one-prelinked) $@" + ), + "//conditions:default": ( + "cp $(location :llvm-extract-one-prelinked) $@ && " + + "chmod 666 $@ && " + + "patchelf --set-rpath '$$ORIGIN' $@ && " + + "chmod 555 $@" + ), + }), + visibility = ["//visibility:public"], +) + cc_library( name = "LlvmSession", srcs = ["LlvmSession.cc"], diff --git a/compiler_gym/envs/llvm/service/CMakeLists.txt b/compiler_gym/envs/llvm/service/CMakeLists.txt index 8005474e6..ea796f5b6 100644 --- a/compiler_gym/envs/llvm/service/CMakeLists.txt +++ b/compiler_gym/envs/llvm/service/CMakeLists.txt @@ -169,6 +169,20 @@ cg_cc_library( PUBLIC ) +llvm_map_components_to_libnames(_LLVM_LIBS core support irreader) +cg_cc_binary( + NAME llvm-extract-one + SRCS LlvmExtractOne.cc + COPTS + "-fno-rtti" + ABS_DEPS + ${_LLVM_LIBS} + INCLUDES + ${LLVM_INCLUDE_DIRS} + DEFINES + ${LLVM_DEFINITIONS} +) + llvm_map_components_to_libnames(_LLVM_LIBS core analysis coroutines objcarcopts target codegen x86codegen x86asmparser #TODO(boian): can these be found programmatically diff --git a/compiler_gym/envs/llvm/service/LlvmExtractOne.cc b/compiler_gym/envs/llvm/service/LlvmExtractOne.cc new file mode 100644 index 000000000..7d108b6a8 --- /dev/null +++ b/compiler_gym/envs/llvm/service/LlvmExtractOne.cc @@ -0,0 +1,142 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//============================================================================= +// This file is a small tool to extract single functions from LLVM bitcode +#include +#include +#include +#include + +#include "llvm/Bitcode/BitcodeWriterPass.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Transforms/IPO.h" + +using namespace llvm; + +// Command line options +static cl::OptionCategory ExtractOneOptions("llvm-extract-one Options"); +static cl::opt InputFileName(cl::Positional, cl::desc("[input file]"), + cl::value_desc("filename"), cl::init("-")); +static cl::opt OutputFileName("o", + cl::desc("Specify output filename (default=stdout)"), + cl::value_desc("filename"), cl::init("-")); +static cl::opt Seed("seed", cl::desc("Random number seed (default=0)"), + cl::value_desc("unsigned int"), cl::init(0), + cl::cat(ExtractOneOptions)); +static cl::opt Nth("n", cl::desc("Extract the n-th function"), cl::value_desc("unsigned int"), + cl::init(-1), cl::cat(ExtractOneOptions)); +static cl::opt CountOnly("count-only", cl::desc("Only count the number of funtions"), + cl::init(false), cl::cat(ExtractOneOptions)); +static cl::opt Force("f", cl::desc("Enable binary output on terminals"), + cl::cat(ExtractOneOptions)); +static cl::opt OutputAssembly("S", cl::desc("Write output as LLVM assembly"), cl::Hidden, + cl::cat(ExtractOneOptions)); +static cl::opt PreserveBitcodeUseListOrder( + "preserve-bc-uselistorder", cl::desc("Preserve use-list order when writing LLVM bitcode."), + cl::init(true), cl::Hidden, cl::cat(ExtractOneOptions)); +static cl::opt PreserveAssemblyUseListOrder( + "preserve-ll-uselistorder", cl::desc("Preserve use-list order when writing LLVM assembly."), + cl::init(false), cl::Hidden, cl::cat(ExtractOneOptions)); + +/// Reads a module from a file. +/// On error, messages are written to stderr and null is returned. +/// +/// \param Context LLVM Context for the module. +/// \param Name Input file name. +static std::unique_ptr readModule(LLVMContext& Context, StringRef Name) { + SMDiagnostic Diag; + std::unique_ptr Module = parseIRFile(Name, Diag, Context); + + if (!Module) + Diag.print("llvm-extract-one", errs()); + + return Module; +} + +// The main tool +int main(int argc, char** argv) { + cl::ParseCommandLineOptions(argc, argv, + " llvm-extract-one\n\n" + " Extract a single, random function from a bitcode file.\n\n" + " If no input file is given, or it is given as '-', then the input " + "file is read from stdin.\n"); + + if (Seed) + std::srand(Seed); + else + std::srand(std::time(nullptr)); + + LLVMContext Context; + + std::unique_ptr Module = readModule(Context, InputFileName); + + if (!Module) + return 1; + + // Find functions that might be kept + std::vector Functions; + for (auto& Function : *Module) { + if (!Function.empty()) { + Functions.push_back(&Function); + } + } + if (CountOnly) { + std::cout << Functions.size() << std::endl; + return 0; + } + + if (Functions.empty()) { + errs() << "No suitable functions\n"; + return 1; + } + // Choose one + int keeperIndex = (Nth == -1 ? std::rand() : Nth) % Functions.size(); + Function* Keeper = Functions[keeperIndex]; + + // Extract the function + ExitOnError ExitOnErr(std::string(*argv) + ": extracting function : "); + ExitOnErr(Keeper->materialize()); + legacy::PassManager Passes; + std::vector GVs = {Keeper}; + Passes.add(createGVExtractionPass(GVs)); // Extract the one function + Passes.add(createGlobalDCEPass()); // Delete unreachable globals + Passes.add(createStripDeadDebugInfoPass()); // Remove dead debug info + Passes.add(createStripDeadPrototypesPass()); // Remove dead func decls + + if (verifyModule(*Module, &errs())) + return 1; + + std::error_code EC; + ToolOutputFile Out(OutputFileName, EC, sys::fs::OF_None); + + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + + if (OutputAssembly) { + Out.os() << "; KeeperIndex = " << keeperIndex << '\n'; + Passes.add(createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder)); + } else if (Force || !CheckBitcodeOutputToConsole(Out.os())) { + Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); + } + Passes.run(*Module); + + // Declare success. + Out.keep(); + + return 0; +} diff --git a/compiler_gym/third_party/llvm/BUILD b/compiler_gym/third_party/llvm/BUILD index c4d9d269f..c5fe4afd4 100644 --- a/compiler_gym/third_party/llvm/BUILD +++ b/compiler_gym/third_party/llvm/BUILD @@ -8,6 +8,7 @@ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") py_library( name = "llvm", srcs = ["__init__.py"], + data = ["//compiler_gym/envs/llvm/service:llvm-extract-one-files"], visibility = ["//visibility:public"], deps = [ "//compiler_gym/util", diff --git a/compiler_gym/third_party/llvm/CMakeLists.txt b/compiler_gym/third_party/llvm/CMakeLists.txt index b2b90b835..950306377 100644 --- a/compiler_gym/third_party/llvm/CMakeLists.txt +++ b/compiler_gym/third_party/llvm/CMakeLists.txt @@ -14,6 +14,8 @@ cg_py_library( llvm SRCS "__init__.py" + DATA + compiler_gym::envs::llvm::service::llvm-extract-one DEPS compiler_gym::util::util PUBLIC diff --git a/compiler_gym/third_party/llvm/__init__.py b/compiler_gym/third_party/llvm/__init__.py index f82dfaba3..004b31758 100644 --- a/compiler_gym/third_party/llvm/__init__.py +++ b/compiler_gym/third_party/llvm/__init__.py @@ -15,7 +15,7 @@ from fasteners import InterProcessLock from compiler_gym.util.download import download -from compiler_gym.util.runfiles_path import cache_path, site_data_path +from compiler_gym.util.runfiles_path import cache_path, runfiles_path, site_data_path logger = logging.getLogger(__name__) @@ -114,6 +114,11 @@ def llvm_dis_path() -> Path: return download_llvm_files() / "bin/llvm-dis" +def llvm_extract_one_path() -> Path: + """Return the path of llvm-extract-one.""" + return runfiles_path("compiler_gym/envs/llvm/service/llvm-extract-one") + + def llvm_link_path() -> Path: """Return the path of llvm-link.""" return download_llvm_files() / "bin/llvm-link" diff --git a/docs/source/llvm/api.rst b/docs/source/llvm/api.rst index 06af7fffd..0ace44e17 100644 --- a/docs/source/llvm/api.rst +++ b/docs/source/llvm/api.rst @@ -15,6 +15,12 @@ Constructing Benchmarks .. autofunction:: make_benchmark +.. autofunction:: make_benchmark_from_source + +.. autofunction:: split_benchmark_by_function + +.. autofunction:: merge_benchmarks + .. autoclass:: BenchmarkFromCommandLine :members: diff --git a/tests/llvm/llvm_benchmark_test.py b/tests/llvm/llvm_benchmark_test.py index 337ab13c5..87494f139 100644 --- a/tests/llvm/llvm_benchmark_test.py +++ b/tests/llvm/llvm_benchmark_test.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Integrations tests for the LLVM CompilerGym environments.""" +import re import tempfile from pathlib import Path @@ -10,7 +11,8 @@ from compiler_gym.datasets import Benchmark from compiler_gym.envs import CompilerEnv -from compiler_gym.envs.llvm import llvm_benchmark +from compiler_gym.envs.llvm import llvm_benchmark as llvm +from compiler_gym.errors.dataset_errors import BenchmarkInitError from compiler_gym.service.proto import Benchmark as BenchmarkProto from compiler_gym.service.proto import File from tests.pytest_plugins.common import macos_only @@ -45,52 +47,214 @@ def test_add_benchmark_invalid_path(env: CompilerEnv): def test_get_system_library_flags_not_found(): with pytest.raises( - llvm_benchmark.HostCompilerFailure, match="Failed to invoke 'not-a-real-binary'" + llvm.HostCompilerFailure, match="Failed to invoke 'not-a-real-binary'" ): - llvm_benchmark.get_system_library_flags("not-a-real-binary") + llvm.get_system_library_flags("not-a-real-binary") def test_get_system_library_flags_nonzero_exit_status(): """Test that setting the $CXX to an invalid binary raises an error.""" - with pytest.raises( - llvm_benchmark.HostCompilerFailure, match="Failed to invoke 'false'" - ): - llvm_benchmark.get_system_library_flags("false") + with pytest.raises(llvm.HostCompilerFailure, match="Failed to invoke 'false'"): + llvm.get_system_library_flags("false") def test_get_system_library_flags_output_parse_failure(): """Test that setting the $CXX to an invalid binary raises an error.""" with pytest.raises( - llvm_benchmark.UnableToParseHostCompilerOutput, + llvm.UnableToParseHostCompilerOutput, match="Failed to parse '#include <...>' search paths from 'echo'", ): - llvm_benchmark.get_system_library_flags("echo") + llvm.get_system_library_flags("echo") def test_get_system_library_flags(): - flags = llvm_benchmark.get_system_library_flags() + flags = llvm.get_system_library_flags() assert flags assert "-isystem" in flags @macos_only def test_get_system_library_flags_system_libraries(): - flags = llvm_benchmark.get_system_library_flags() + flags = llvm.get_system_library_flags() assert flags assert flags[-1] == "-L/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib" def test_ClangInvocation_system_libs(): - cmd = llvm_benchmark.ClangInvocation(["foo.c"]).command("a.out") + cmd = llvm.ClangInvocation(["foo.c"]).command("a.out") assert "-isystem" in cmd def test_ClangInvocation_no_system_libs(): - cmd = llvm_benchmark.ClangInvocation(["foo.c"], system_includes=False).command( - "a.out" - ) + cmd = llvm.ClangInvocation(["foo.c"], system_includes=False).command("a.out") assert "-isystem" not in cmd +@pytest.mark.parametrize( + "source", + [ + "", + "int A() {return 0;}", + """ +int A() {return 0;} +int B() {return A();} +int C() {return 0;} + """, + ], +) +@pytest.mark.parametrize("system_includes", [False, True]) +def test_make_benchmark_from_source_valid_source( + env: CompilerEnv, source: str, system_includes: bool +): + benchmark = llvm.make_benchmark_from_source(source, system_includes=system_includes) + env.reset(benchmark=benchmark) + + +@pytest.mark.parametrize( + "source", + [ + "@syntax error!!!", # invalid syntax + "int A() {return a;}", # undefined variable + '#include "missing.h"', # missing include + ], +) +@pytest.mark.parametrize("system_includes", [False, True]) +def test_make_benchmark_from_source_invalid_source(source: str, system_includes: bool): + with pytest.raises( + BenchmarkInitError, match="Failed to make benchmark with compiler error:" + ): + llvm.make_benchmark_from_source(source, system_includes=system_includes) + + +def test_make_benchmark_from_source_invalid_copt(): + with pytest.raises( + BenchmarkInitError, match="Failed to make benchmark with compiler error:" + ): + llvm.make_benchmark_from_source( + "int A() {return 0;}", copt=["-invalid-argument!"] + ) + + +def test_make_benchmark_from_source_missing_system_includes(): + with pytest.raises( + BenchmarkInitError, match="Failed to make benchmark with compiler error:" + ): + llvm.make_benchmark_from_source("#include ", system_includes=False) + + +def test_make_benchmark_from_source_with_system_includes(): + assert llvm.make_benchmark_from_source("#include ", system_includes=True) + + +def test_split_benchmark_by_function_no_functions(): + benchmark = llvm.make_benchmark_from_source("") + with pytest.raises(ValueError, match="No functions found"): + llvm.split_benchmark_by_function(benchmark) + + +def is_defined(signature: str, ir: str): + """Return whether the function signature is defined in the IR.""" + return re.search(f"^define .*{signature}", ir, re.MULTILINE) + + +def is_declared(signature: str, ir: str): + """Return whether the function signature is defined in the IR.""" + return re.search(f"^declare .*{signature}", ir, re.MULTILINE) + + +def test_split_benchmark_by_function_repeated_split_single_function(env: CompilerEnv): + benchmark = llvm.make_benchmark_from_source("int A() {return 0;}", lang="c") + for _ in range(10): + benchmarks = llvm.split_benchmark_by_function(benchmark) + assert len(benchmarks) == 1 + env.reset(benchmark=benchmarks[0]) + assert is_defined("i32 @A()", env.ir) + benchmark = benchmarks[0] + + +def test_split_benchmark_by_function_multiple_functions(env: CompilerEnv): + benchmark = llvm.make_benchmark_from_source( + """ +int A() {return 0;} +int B() {return A();} +""", + lang="c", + ) + + benchmarks = llvm.split_benchmark_by_function(benchmark) + assert len(benchmarks) == 2 + A, B = benchmarks + + env.reset(benchmark=A) + assert is_defined("i32 @A()", env.ir) + assert not is_defined("i32 @B()", env.ir) + + assert not is_declared("i32 @A()", env.ir) + assert not is_declared("i32 @B()", env.ir) + + env.reset(benchmark=B) + assert not is_defined("i32 @A()", env.ir) + assert is_defined("i32 @B()", env.ir) + + assert is_declared("i32 @A()", env.ir) + assert not is_declared("i32 @B()", env.ir) + + +def test_split_benchmark_by_function_maximum_function_count(env: CompilerEnv): + benchmark = llvm.make_benchmark_from_source( + """ +int A() {return 0;} +int B() {return A();} +""", + lang="c", + ) + + benchmarks = llvm.split_benchmark_by_function( + benchmark, + maximum_function_count=1, + ) + assert len(benchmarks) == 1 + + env.reset(benchmark=benchmarks[0]) + assert is_defined("i32 @A()", env.ir) + + +def test_merge_benchmarks_single_input(env: CompilerEnv): + A = llvm.make_benchmark_from_source("int A() {return 0;}", lang="c") + + merged = llvm.merge_benchmarks([A]) + env.reset(benchmark=merged) + + assert is_defined("i32 @A()", env.ir) + + +def test_merge_benchmarks_independent(env: CompilerEnv): + A = llvm.make_benchmark_from_source("int A() {return 0;}", lang="c") + B = llvm.make_benchmark_from_source("int B() {return 0;}", lang="c") + + merged = llvm.merge_benchmarks([A, B]) + env.reset(benchmark=merged) + + assert is_defined("i32 @A()", env.ir) + assert is_defined("i32 @B()", env.ir) + + +def test_merge_benchmarks_multiply_defined(): + A = llvm.make_benchmark_from_source("int A() {return 0;}", lang="c") + with pytest.raises(ValueError, match="symbol multiply defined"): + llvm.merge_benchmarks([A, A]) + + +def test_merge_benchmarks_declarations(env: CompilerEnv): + A = llvm.make_benchmark_from_source("int A() {return 0;}", lang="c") + B = llvm.make_benchmark_from_source("int A(); int B() {return A();}", lang="c") + + merged = llvm.merge_benchmarks([A, B]) + env.reset(benchmark=merged) + + assert is_defined("i32 @A()", env.ir) + assert is_defined("i32 @B()", env.ir) + + if __name__ == "__main__": main()