diff --git a/compiler_gym/datasets/BUILD b/compiler_gym/datasets/BUILD index 6bd60d6902..12e324a5e7 100644 --- a/compiler_gym/datasets/BUILD +++ b/compiler_gym/datasets/BUILD @@ -6,10 +6,16 @@ load("@rules_python//python:defs.bzl", "py_library") py_library( name = "datasets", - srcs = ["__init__.py"], + srcs = [ + "__init__.py", + "benchmark.py", + "dataset.py", + ], visibility = ["//visibility:public"], deps = [ - ":dataset", + "//compiler_gym:validation_result", + "//compiler_gym/service/proto", + "//compiler_gym/util", ], ) diff --git a/compiler_gym/datasets/__init__.py b/compiler_gym/datasets/__init__.py index b0dc9440c5..3798cbfa15 100644 --- a/compiler_gym/datasets/__init__.py +++ b/compiler_gym/datasets/__init__.py @@ -3,6 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Manage datasets of benchmarks.""" +from compiler_gym.datasets.benchmark import ( + Benchmark, + BenchmarkInitError, + BenchmarkSource, +) from compiler_gym.datasets.dataset import ( LegacyDataset, activate, @@ -11,4 +16,13 @@ require, ) -__all__ = ["LegacyDataset", "require", "activate", "deactivate", "delete"] +__all__ = [ + "activate", + "Benchmark", + "BenchmarkInitError", + "BenchmarkSource", + "deactivate", + "delete", + "LegacyDataset", + "require", +] diff --git a/compiler_gym/datasets/benchmark.py b/compiler_gym/datasets/benchmark.py new file mode 100644 index 0000000000..585b241406 --- /dev/null +++ b/compiler_gym/datasets/benchmark.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import re +from concurrent.futures import as_completed +from pathlib import Path +from typing import Callable, Iterable, List, NamedTuple, Optional, Union + +from compiler_gym.service.proto import Benchmark as BenchmarkProto +from compiler_gym.service.proto import File +from compiler_gym.util import thread_pool +from compiler_gym.util.decorators import memoized_property +from compiler_gym.validation_result import ValidationError + +# A validation callback is a function that takes a single CompilerEnv instance +# as its argument and returns an iterable sequence of zero or more +# ValidationError tuples. +ValidationCallback = Callable[["CompilerEnv"], Iterable[ValidationError]] # noqa: F821 + + +# Regular expression that matches the full two-part URI prefix of a dataset: +# {{protocol}}://{{dataset}} +# +# A trailing slash is permitted. +# +# Example matches: "benchmark://foo-v0", "benchmark://foo-v0/". +DATASET_NAME_RE = re.compile( + r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))/?" +) + +# Regular expression that matches the full three-part format of a benchmark URI: +# {{protocol}}://{{dataset}}/{{id}} +# +# The {{id}} is optional. +# +# Example matches: "benchmark://foo-v0/" or "benchmark://foo-v0/program". +BENCHMARK_URI_RE = re.compile( + r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))(/(?P[^\s]*))?$" +) + + +def resolve_uri_protocol(uri: str) -> str: + """Require that the URI has a protocol by applying a default "benchmark" + protocol if none is set.""" + if "://" not in uri: + return f"benchmark://{uri}" + return uri + + +class BenchmarkSource(NamedTuple): + """A source file that is used to generate a benchmark. A benchmark may + comprise many source files. + """ + + filename: str + """The name of the file.""" + + contents: bytes + """The contents of the file as a byte array.""" + + def __repr__(self) -> str: + return str(self.filename) + + +class Benchmark(object): + """A benchmark represents a particular program that is being compiled. + + A benchmark is a program that can be used by a :class:`CompilerEnv + ` as a program to optimize. A benchmark + comprises the data that is fed into the compiler, identified by a URI. + + Benchmarks are not normally instantiated directly. Instead, benchmarks are + instantiated using :meth:`env.datasets.benchmark() + `: + + >>> env.datasets.benchmark() + benchmark://npb-v0/20 + + Calling :meth:`env.datasets.benchmark() + ` with no arguments will select a + benchmark randomly from the available datasets. To select a specific + benchmark, pass the URI as argument: + + >>> env.datasets.benchmark("benchmark://npb-v0/20") + benchmark://npb-v0/20 + + The available benchmark URIs can be queried using + :meth:`env.datasets.benchmark_uris() + `. + + Compiler environments may provide additional helper functions for generating + benchmarks, such as :meth:`env.make_benchmark() + ` for LLVM. + + The data underlying a Benchmark instance should be considered immutable. New + attributes cannot be assigned to Benchmark instances. + + Benchmarks may provide additional functionality such as runtime checks or + methods for validating the semantics of a benchmark. The benchmark for an + environment can be set during :meth:`env.reset() + `. The currently active benchmark can + be queried using :attr:`env.benchmark + `: + + >>> env = gym.make("llvm-v0") + >>> env.reset(benchmark="cbench-v1/crc32") + >>> env.benchmark + cbench-v1/crc32 + + A Benchmark instance wraps an instance of the :code:`Benchmark` protocol + buffer from the `RPC interface + `_ + with additional functionality. + """ + + __slots__ = ["_proto", "_validation_callbacks", "_sources"] + + def __init__( + self, + proto: BenchmarkProto, + validation_callbacks: Optional[List[ValidationCallback]] = None, + sources: Optional[List[BenchmarkSource]] = None, + ): + self._proto = proto + self._validation_callbacks = validation_callbacks or [] + self._sources = list(sources or []) + + def __repr__(self) -> str: + return str(self.uri) + + @property + def uri(self) -> str: + """The URI of the benchmark. + + Benchmark URIs should be unique, that is, that two URIs with the same + value should resolve to the same benchmark. However, URIs do not have + uniquely describe a benchmark. That is, multiple identical benchmarks + could have different URIs. + + :return: A URI string. :type: string + """ + return self._proto.uri + + @property + def proto(self) -> BenchmarkProto: + """The protocol buffer representing the benchmark. + + :return: A Benchmark message. + :type: :code:`Benchmark` + """ + return self._proto + + @property + def sources(self) -> Iterable[BenchmarkSource]: + """The original source code used to produce this benchmark. + + :return: An iterable sequence of :class:`BenchmarkSource + ` tuples, comprising relative + file paths and file contents. + + :type: :code:`Iterable[BenchmarkSource]` + """ + return (BenchmarkSource(*x) for x in self._sources) + + def is_validatable(self) -> bool: + """Whether the benchmark has any validation callbacks registered. + + :return: :code:`True` if the benchmark has at least one validation + callback. + """ + return self._validation_callbacks != [] + + def validate(self, env: "CompilerEnv") -> List[ValidationError]: # noqa: F821 + """Run any validation callbacks and return any errors. + + If no errors are returned, validation has succeeded: + + >>> benchmark.validate(env) + [] + + If an error occurs, a :class:`ValidationError + ` tuple will describe the type of the + error, and optionally contain other data: + + >>> benchmark.validate(env) + [ValidationError(type="RuntimeError")] + + Multiple :class:`ValidationError ` errors + may be returned to indicate multiple errors. + + This is a synchronous version of :meth:`ivalidate() + ` that blocks until all + results are ready: + + >>> benchmark.validate(env) == list(benchmark.ivalidate(env)) + True + + :param env: The :class:`CompilerEnv ` + instance that is being validated. + + :return: A list of zero or more :class:`ValidationError + ` tuples that occurred during + validation. + """ + return list(self.ivalidate(env)) + + def ivalidate(self, env: "CompilerEnv") -> Iterable[ValidationError]: # noqa: F821 + """Run any validation callbacks and return a generator of errors. + + This is an asynchronous version of :meth:`validate() + ` that returns immediately. + + :parameter env: A :class:`CompilerEnv ` + instance to validate. + + :return: A generator of :class:`ValidationError + ` tuples that occur during validation. + """ + executor = thread_pool.get_thread_pool_executor() + futures = ( + executor.submit(validator, env) for validator in self.validation_callbacks() + ) + for future in as_completed(futures): + result: Iterable[ValidationError] = future.result() + if result: + yield from result + + def validation_callbacks( + self, + ) -> List[ValidationCallback]: + """Return the list of registered validation callbacks. + + :return: A list of callables. See :meth:`add_validation_callback() + `. + """ + return self._validation_callbacks + + def add_source(self, source: BenchmarkSource) -> None: + """Register a new source file for this benchmark. + + :param source: The :class:`BenchmarkSource + ` to register. + """ + self._sources.append(source) + + def add_validation_callback( + self, + validation_callback: ValidationCallback, + ) -> None: + """Register a new validation callback that will be executed on + :meth:`validate() `. + + :param validation_callback: A callback that accepts a single + :class:`CompilerEnv ` argument and + returns an iterable sequence of zero or more :class:`ValidationError + ` tuples. Validation callbacks must be + thread safe and must not modify the environment. + """ + self._validation_callbacks.append(validation_callback) + + def write_sources_to_directory(self, directory: Path) -> int: + """Write the source files for this benchmark to the given directory. + + This writes each of the :attr:`benchmark.sources + ` files to disk. + + :param directory: The directory to write results to. If it does not + exist, it is created. + + :return: The number of files written. + """ + directory = Path(directory) + directory.mkdir(exist_ok=True, parents=True) + uniq_paths = set() + for filename, contents in self.sources: + path = directory / filename + uniq_paths.add(path) + path.parent.mkdir(exist_ok=True, parents=True) + with open(path, "wb") as f: + f.write(contents) + + return len(uniq_paths) + + @classmethod + def from_file(cls, uri: str, path: Path): + """Construct a benchmark from the path to a file. + + :param uri: The URI of the benchmark. + + :param path: A filesystem path. + + :raise FileNotFoundError: If the path does not exist. + + :return: A :class:`Benchmark ` instance. + """ + path = Path(path) + if not path.is_file(): + raise FileNotFoundError(path) + return cls( + proto=BenchmarkProto( + uri=uri, program=File(uri=f"file:///{path.absolute()}") + ), + ) + + @classmethod + def from_file_contents(cls, uri: str, data: bytes): + """Construct a benchmark from a raw data array. + + :param uri: The URI of the benchmark. + + :param data: An array of bytes that will be passed to the compiler + service. + """ + return cls(proto=BenchmarkProto(uri=uri, program=File(contents=data))) + + def __eq__(self, other: Union[str, "Benchmark"]): + if isinstance(other, Benchmark): + return self.uri == other.uri + else: + return self.uri == other + + def __ne__(self, other: Union[str, "Benchmark"]): + return not self == other + + +class BenchmarkInitError(OSError): + """Base class for errors raised if a benchmark fails to initialize.""" + + +class BenchmarkWithSource(Benchmark): + """A benchmark which has a single source file.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._src_name = None + self._src_path = None + + @classmethod + def create( + cls, uri: str, input_path: Path, src_name: str, src_path: Path + ) -> Benchmark: + """Create a benchmark from paths.""" + benchmark = cls.from_file(uri, input_path) + benchmark._src_name = src_name # pylint: disable=protected-access + benchmark._src_path = src_path # pylint: disable=protected-access + return benchmark + + @memoized_property + def sources(self) -> Iterable[BenchmarkSource]: + with open(self._src_path, "rb") as f: + return [ + BenchmarkSource(filename=self._src_name, contents=f.read()), + ] + + @property + def source(self) -> str: + """Return the single source file contents as a string.""" + return list(self.sources)[0].contents.decode("utf-8") diff --git a/tests/datasets/BUILD b/tests/datasets/BUILD new file mode 100644 index 0000000000..99214b1df1 --- /dev/null +++ b/tests/datasets/BUILD @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +load("@rules_python//python:defs.bzl", "py_test") + +py_test( + name = "benchmark_test", + timeout = "short", + srcs = ["benchmark_test.py"], + deps = [ + "//compiler_gym/datasets", + "//tests:test_main", + "//tests/pytest_plugins:common", + ], +) diff --git a/tests/datasets/benchmark_test.py b/tests/datasets/benchmark_test.py new file mode 100644 index 0000000000..0953a073c9 --- /dev/null +++ b/tests/datasets/benchmark_test.py @@ -0,0 +1,299 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Unit tests for //compiler_gym/datasets:benchmark.""" +import re +from pathlib import Path + +import pytest + +from compiler_gym.datasets.benchmark import ( + BENCHMARK_URI_RE, + DATASET_NAME_RE, + Benchmark, + BenchmarkSource, +) +from compiler_gym.service.proto import Benchmark as BenchmarkProto +from compiler_gym.validation_result import ValidationError +from tests.test_main import main + +pytest_plugins = ["tests.pytest_plugins.common"] + + +def _rgx_match(regex, groupname, string) -> str: + """Match the regex and return a named group.""" + match = re.match(regex, string) + assert match, f"Failed to match regex '{regex}' using string '{groupname}'" + return match.group(groupname) + + +@pytest.mark.parametrize("regex", (DATASET_NAME_RE, BENCHMARK_URI_RE)) +def test_benchmark_uri_protocol(regex): + assert not regex.match("B?://cbench-v1/") # Invalid characters + assert not regex.match("cbench-v1/") # Missing protocol + + assert ( + _rgx_match(regex, "dataset_protocol", "benchmark://cbench-v1/") == "benchmark" + ) + assert ( + _rgx_match(regex, "dataset_protocol", "Generator13://gen-v11/") == "Generator13" + ) + + +def test_benchmark_uri_dataset(): + assert not BENCHMARK_URI_RE.match("benchmark://cBench?v0/") # Invalid character + assert not BENCHMARK_URI_RE.match("benchmark://cBench/") # Missing version suffix + + assert ( + _rgx_match(BENCHMARK_URI_RE, "dataset_name", "benchmark://cbench-v1/") + == "cbench-v1" + ) + assert ( + _rgx_match(BENCHMARK_URI_RE, "dataset_name", "Generator13://gen-v11/") + == "gen-v11" + ) + + +def test_benchmark_dataset_name(): + assert ( + _rgx_match(BENCHMARK_URI_RE, "dataset", "benchmark://cbench-v1/") + == "benchmark://cbench-v1" + ) + assert ( + _rgx_match(BENCHMARK_URI_RE, "dataset", "Generator13://gen-v11/") + == "Generator13://gen-v11" + ) + + +def test_benchmark_uri_id(): + assert not BENCHMARK_URI_RE.match("benchmark://cbench-v1/ whitespace") # Whitespace + assert not BENCHMARK_URI_RE.match("benchmark://cbench-v1/\t") # Whitespace + + assert ( + _rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1") is None + ) + assert ( + _rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1/") == "" + ) + assert ( + _rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1/foo") + == "foo" + ) + assert ( + _rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cbench-v1/foo/123") + == "foo/123" + ) + assert ( + _rgx_match( + BENCHMARK_URI_RE, + "benchmark_name", + "benchmark://cbench-v1/foo/123?param=true&false", + ) + == "foo/123?param=true&false" + ) + + +def test_benchmark_attribute_outside_init(): + """Test that new attributes cannot be added to Benchmark.""" + benchmark = Benchmark(None) + with pytest.raises(AttributeError): + benchmark.foobar = 123 # noqa + + +def test_benchmark_subclass_attribute_outside_init(): + """Test that new attributes can be added to Benchmark subclass.""" + + class TestBenchmark(Benchmark): + pass + + benchmark = TestBenchmark(None) + benchmark.foobar = 123 + assert benchmark.foobar == 123 + + +def test_benchmark_properties(): + """Test benchmark properties.""" + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + assert benchmark.uri == "benchmark://example-v0/foobar" + assert benchmark.proto == BenchmarkProto(uri="benchmark://example-v0/foobar") + + +def test_benchmark_immutable(): + """Test that benchmark properties are immutable.""" + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + with pytest.raises(AttributeError): + benchmark.uri = 123 + with pytest.raises(AttributeError): + benchmark.proto = 123 + + +def test_add_validation_callbacks_values(): + """Test methods for adding and checking custom validation callbacks.""" + + def a(env): + pass + + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + assert benchmark.validation_callbacks() == [] + assert not benchmark.is_validatable() + + benchmark.add_validation_callback(a) + assert benchmark.validation_callbacks() == [a] + assert benchmark.is_validatable() + + benchmark.add_validation_callback(a) + assert benchmark.validation_callbacks() == [a, a] + + +def test_add_validation_callbacks_call_count(): + """Test that custom validation callbacks are called on validate().""" + a_call_count = 0 + b_call_count = 0 + + def a(env): + nonlocal a_call_count + a_call_count += 1 + + def b(env): + nonlocal b_call_count + b_call_count += 1 + + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + benchmark.add_validation_callback(a) + + errors = benchmark.validate(env=None) + assert errors == [] + assert a_call_count == 1 + assert b_call_count == 0 + + benchmark.add_validation_callback(b) + errors = benchmark.validate(env=None) + assert errors == [] + assert a_call_count == 2 + assert b_call_count == 1 + + +def test_validation_callback_error(): + """Test error propagation from custom validation callback.""" + + def a(env): + yield ValidationError(type="Compilation Error") + yield ValidationError(type="Runtime Error") + + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + benchmark.add_validation_callback(a) + + errors = benchmark.validate(env=None) + assert errors == [ + ValidationError(type="Compilation Error"), + ValidationError(type="Runtime Error"), + ] + + +def test_validation_callback_error_iter(): + """Test error propagation from custom validation callback using iterable.""" + + def a(env): + yield ValidationError(type="Compilation Error") + yield ValidationError(type="Runtime Error") + + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + benchmark.add_validation_callback(a) + + errors = benchmark.ivalidate(env=None) + next(errors) == ValidationError(type="Compilation Error") + next(errors) == ValidationError(type="Runtime Error") + + +def test_validation_callback_flaky(): + """Test error propagation on callback which *may* fail.""" + flaky = False + + def a(env): + nonlocal flaky + if flaky: + yield ValidationError(type="Runtime Error") + + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + benchmark.add_validation_callback(a) + + errors = benchmark.validate(env=None) + assert errors == [] + + flaky = True + errors = benchmark.validate(env=None) + assert errors == [ + ValidationError(type="Runtime Error"), + ] + + +def test_eq_benchmarks(): + a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) + b = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) + + assert a == b + + +def test_eq_strings(): + a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) + b = "benchmark://example-v0/foo" + + assert a == b + + +def test_ne_benchmarks(): + a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) + b = Benchmark(BenchmarkProto(uri="benchmark://example-v0/bar")) + + assert a != b + + +def test_ne_strings(): + a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) + b = "benchmark://example-v0/bar" + + assert a != b + + +def test_benchmark_sources(tmpwd: Path): + a = Benchmark( + BenchmarkProto(uri="benchmark://example-v0/foo"), + sources=[("example.py", "Hello, world!".encode("utf-8"))], + ) + a.add_source(BenchmarkSource(filename="foo.py", contents="Hi".encode("utf-8"))) + + assert list(a.sources) == [ + BenchmarkSource("example.py", "Hello, world!".encode("utf-8")), + BenchmarkSource(filename="foo.py", contents="Hi".encode("utf-8")), + ] + + a.write_sources_to_directory("benchmark_sources") + + with open(tmpwd / "benchmark_sources" / "example.py") as f: + assert f.read() == "Hello, world!" + with open(tmpwd / "benchmark_sources" / "foo.py") as f: + assert f.read() == "Hi" + + +def test_benchmark_from_file(tmpwd: Path): + path = tmpwd / "foo.txt" + path.touch() + benchmark = Benchmark.from_file("benchmark://example-v0/foo", path) + # Use startswith() and endswith() because macOS can add a /private prefix to + # paths. + assert benchmark.proto.program.uri.startswith("file:///") + assert benchmark.proto.program.uri.endswith(str(path)) + + +def test_benchmark_from_file_not_found(tmpwd: Path): + path = tmpwd / "foo.txt" + with pytest.raises(FileNotFoundError) as e_ctx: + Benchmark.from_file("benchmark://example-v0/foo", path) + + # Use endswith() because macOS can add a /private prefix to paths. + assert str(e_ctx.value).endswith(str(path)) + + +if __name__ == "__main__": + main()