diff --git a/compiler_gym/datasets/BUILD b/compiler_gym/datasets/BUILD index 3b816a9fc3..a51006ee87 100644 --- a/compiler_gym/datasets/BUILD +++ b/compiler_gym/datasets/BUILD @@ -10,6 +10,7 @@ py_library( "__init__.py", "benchmark.py", "dataset.py", + "datasets.py", "files_dataset.py", "tar_dataset.py", ], diff --git a/compiler_gym/datasets/__init__.py b/compiler_gym/datasets/__init__.py index 2643a58dd0..ea7aba216c 100644 --- a/compiler_gym/datasets/__init__.py +++ b/compiler_gym/datasets/__init__.py @@ -17,6 +17,7 @@ delete, require, ) +from compiler_gym.datasets.datasets import Datasets from compiler_gym.datasets.files_dataset import FilesDataset from compiler_gym.datasets.tar_dataset import TarDataset, TarDatasetWithManifest @@ -27,6 +28,7 @@ "BenchmarkSource", "Dataset", "DatasetInitError", + "Datasets", "deactivate", "delete", "FilesDataset", diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py new file mode 100644 index 0000000000..8c3d7bf7b9 --- /dev/null +++ b/compiler_gym/datasets/datasets.py @@ -0,0 +1,314 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from collections import deque +from typing import Dict, Iterable, Optional, Set, TypeVar, Union + +import numpy as np + +from compiler_gym.datasets.benchmark import ( + BENCHMARK_URI_RE, + Benchmark, + resolve_uri_protocol, +) +from compiler_gym.datasets.dataset import Dataset + +T = TypeVar("T") + + +def round_robin_iterables(iters: Iterable[Iterable[T]]) -> Iterable[T]: + """Yield from the given iterators in round robin order.""" + # Use a queue of iterators to iterate over. Repeatedly pop an iterator from + # the queue, yield the next value from it, then put it at the back of the + # queue. The iterator is discarded once exhausted. + iters = deque(iters) + while len(iters) > 1: + it = iters.popleft() + try: + yield next(it) + iters.append(it) + except StopIteration: + pass + # Once we have only a single iterator left, return it directly rather + # continuing with the round robin. + if len(iters) == 1: + yield from iters.popleft() + + +class Datasets(object): + """A collection of datasets. + + This class provides a dictionary-like interface for indexing and iterating + over multiple :class:`Dataset ` objects. + Select a dataset by URI using: + + >>> env.datasets["benchmark://cbench-v1"] + + Check whether a dataset exists using: + + >>> "benchmark://cbench-v1" in env.datasets + True + + Or iterate over the datasets using: + + >>> for dataset in env.datasets: + ... print(dataset.name) + benchmark://cbench-v1 + benchmark://github-v0 + benchmark://npb-v0 + + To select a benchmark from the datasets, use :meth:`benchmark()`: + + >>> env.datasets.benchmark("benchmark://a-v0/a") + + Use the :meth:`benchmarks()` method to iterate over every benchmark in the + datasets in a stable round robin order: + + >>> for benchmark in env.datasets.benchmarks(): + ... print(benchmark) + benchmark://cbench-v1/1 + benchmark://github-v0/1 + benchmark://npb-v0/1 + benchmark://cbench-v1/2 + ... + + If you want to exclude a dataset, delete it: + + >>> del env.datasets["benchmark://b-v0"] + + To iterate over the benchmarks in a random order, use :meth:`benchmark()` + and omit the URI: + + >>> for i in range(100): + ... benchmark = env.datasets.benchmark() + + This uses uniform random selection to sample across datasets. For finite + datasets, you could weight the sample by the size of each dataset: + + >>> weights = [len(d) for d in env.datasets] + >>> np.random.choice(list(env.datasets), p=weights).benchmark() + """ + + def __init__( + self, + datasets: Iterable[Dataset], + random: Optional[np.random.Generator] = None, + ): + self._datasets: Dict[str, Dataset] = {d.name: d for d in datasets} + self._visible_datasets: Set[str] = set( + name for name, dataset in self._datasets.items() if not dataset.hidden + ) + self.random = random or np.random.default_rng() + + def seed(self, seed: Optional[int] = None) -> None: + """Set the random state. + + Setting a random state will fix the order that + :meth:`datasets.benchmark() ` + returns benchmarks when called without arguments. + + Calling this method recursively calls :meth:`seed() + ` on all member datasets. + + :param seed: An optional seed value. + """ + self.random = np.random.default_rng(seed) + for dataset in self._datasets.values(): + dataset.seed(seed) + + def datasets(self, with_deprecated: bool = False) -> Iterable[Dataset]: + """Enumerate the datasets. + + Dataset order is consistent across runs. + + :param with_deprecated: If :code:`True`, include datasets that have been + marked as deprecated. + + :return: An iterable sequence of :meth:`Dataset + ` instances. + """ + datasets = self._datasets.values() + if not with_deprecated: + datasets = (d for d in datasets if not d.hidden) + yield from sorted(datasets, key=lambda d: (d.sort_order, d.name)) + + def __iter__(self) -> Iterable[Dataset]: + """Iterate over the datasets. + + Dataset order is consistent across runs. + + Equivalent to :meth:`datasets.datasets() + `, but without the ability to + iterate over the deprecated datasets. + + :return: An iterable sequence of :meth:`Dataset + ` instances. + """ + return self.datasets() + + def dataset(self, dataset: Optional[Union[str, Dataset]] = None) -> Dataset: + """Get a dataset. + + If a name is given, return the corresponding :meth:`Dataset + `. Else, return a dataset uniformly + randomly from the set of available datasets. + + Use :meth:`seed() ` to force a + reproducible order for randomly selected datasets. + + Name lookup will succeed whether or not the dataset is deprecated. + + :param dataset: A dataset name, a :class:`Dataset` instance, or + :code:`None` to select a dataset randomly. + + :return: A :meth:`Dataset ` instance. + + :raises LookupError: If :code:`dataset` is not found. + """ + if dataset is None: + if not self._visible_datasets: + raise ValueError("No datasets") + + return self._datasets[self.random.choice(list(self._visible_datasets))] + + if isinstance(dataset, Dataset): + dataset_name = dataset.name + else: + dataset_name = resolve_uri_protocol(dataset) + + if dataset_name not in self._datasets: + raise LookupError(f"Dataset not found: {dataset_name}") + + return self._datasets[dataset_name] + + def __getitem__(self, dataset: Union[str, Dataset]) -> Dataset: + """Lookup a dataset. + + :param dataset: A dataset name, a :class:`Dataset` instance, or + :code:`None` to select a dataset randomly. + + :return: A :meth:`Dataset ` instance. + + :raises LookupError: If :code:`dataset` is not found. + """ + return self.dataset(dataset) + + def __setitem__(self, key: str, dataset: Dataset): + self._datasets[key] = dataset + if not dataset.hidden: + self._visible_datasets.add(dataset.name) + + def __delitem__(self, dataset: Union[str, Dataset]): + """Remove a dataset from the collection. + + This does not affect any underlying storage used by dataset. See + :meth:`uninstall() ` to clean + up. + + :param dataset: A :meth:`Dataset ` + instance, or the name of a dataset. + + :return: :code:`True` if the dataset was removed, :code:`False` if it + was already removed. + """ + dataset_name: str = self.dataset(dataset).name + if dataset_name in self._visible_datasets: + self._visible_datasets.remove(dataset_name) + del self._datasets[dataset_name] + + def __contains__(self, dataset: Union[str, Dataset]) -> bool: + """Returns whether the dataset is contained.""" + try: + self.dataset(dataset) + return True + except LookupError: + return False + + def benchmarks(self, with_deprecated: bool = False) -> Iterable[Benchmark]: + """Enumerate the (possibly infinite) benchmarks lazily. + + Benchmarks order is consistent across runs. One benchmark from each + dataset is returned in round robin order until all datasets have been + fully enumerated. The order of :meth:`benchmarks() + ` and :meth:`benchmark_uris() + ` is the same. + + :param with_deprecated: If :code:`True`, include benchmarks from + datasets that have been marked deprecated. + + :return: An iterable sequence of :class:`Benchmark + ` instances. + """ + return round_robin_iterables( + (d.benchmarks() for d in self.datasets(with_deprecated=with_deprecated)) + ) + + def benchmark_uris(self, with_deprecated: bool = False) -> Iterable[str]: + """Enumerate the (possibly infinite) benchmark URIs. + + Benchmark URI order is consistent across runs. URIs from datasets are + returned in round robin order. The order of :meth:`benchmarks() + ` and :meth:`benchmark_uris() + ` is the same. + + :param with_deprecated: If :code:`True`, include benchmarks from + datasets that have been marked deprecated. + + :return: An iterable sequence of benchmark URI strings. + """ + return round_robin_iterables( + (d.benchmark_uris() for d in self.datasets(with_deprecated=with_deprecated)) + ) + + def benchmark(self, uri: Optional[str] = None) -> Benchmark: + """Select a benchmark. + + If a benchmark URI is given, the corresponding :class:`Benchmark + ` is returned, regardless of whether + the containing dataset is installed or deprecated. + + If no URI is given, a benchmark is selected randomly. First, a dataset + is selected uniformly randomly from the set of available datasets. Then + a benchmark is selected randomly from the chosen dataset. + + Calling :code:`benchmark()` will yield benchmarks from all available + datasets with equal probability, regardless of how many benchmarks are + in each dataset. Given a pool of available datasets of differing sizes, + smaller datasets will be overrepresented and large datasets will be + underrepresented. + + Use :meth:`seed() ` to force a + reproducible order for randomly selected benchmarks. + + :param uri: The URI of the benchmark to return. If :code:`None`, select + a benchmark randomly using :code:`self.random`. + + :return: A :class:`Benchmark ` + instance. + """ + if uri is None and not self._visible_datasets: + raise ValueError("No datasets") + elif uri is None: + return self.dataset().benchmark() + + uri = resolve_uri_protocol(uri) + + match = BENCHMARK_URI_RE.match(uri) + if not match: + raise ValueError(f"Invalid benchmark URI: '{uri}'") + + dataset_name = match.group("dataset") + dataset = self._datasets[dataset_name] + + if len(uri) > len(dataset_name) + 1: + return dataset.benchmark(uri) + else: + return dataset.benchmark() + + @property + def size(self) -> int: + return len(self._visible_datasets) + + def __len__(self) -> int: + return self.size diff --git a/tests/datasets/BUILD b/tests/datasets/BUILD index 4ff4cc119d..1dfa72a476 100644 --- a/tests/datasets/BUILD +++ b/tests/datasets/BUILD @@ -26,6 +26,16 @@ py_test( ], ) +py_test( + name = "datasets_test", + srcs = ["datasets_test.py"], + deps = [ + "//compiler_gym/datasets", + "//tests:test_main", + "//tests/pytest_plugins:common", + ], +) + py_test( name = "files_dataset_test", timeout = "short", diff --git a/tests/datasets/datasets_test.py b/tests/datasets/datasets_test.py new file mode 100644 index 0000000000..8f574029b3 --- /dev/null +++ b/tests/datasets/datasets_test.py @@ -0,0 +1,328 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Unit tests for //compiler_gym/datasets.""" +import pytest + +from compiler_gym.datasets.datasets import Datasets, round_robin_iterables +from tests.test_main import main + +pytest_plugins = ["tests.pytest_plugins.common"] + + +class MockDataset: + """A mock Dataset class.""" + + def __init__(self, name): + self.name = name + self.installed = False + self.seed_value = None + self.hidden = False + self.benchmark_values = [] + self.sort_order = 0 + + def install(self): + self.installed = True + + def uninstall(self): + self.installed = False + + def seed(self, seed): + self.seed_value = seed + + def benchmark_uris(self): + return (b.uri for b in self.benchmark_values) + + def benchmarks(self): + yield from self.benchmark_values + + def benchmark(self, uri=None): + if uri: + for b in self.benchmark_values: + if b.uri == uri: + return b + raise KeyError(uri) + return self.benchmark_values[0] + + def __repr__(self): + return str(self.name) + + +class MockBenchmark: + """A mock Benchmark class.""" + + def __init__(self, uri): + self.uri = uri + + def __repr__(self): + return str(self.name) + + +def test_seed_datasets_value(): + """Test that random seed is propagated to datasets.""" + da = MockDataset("a") + db = MockDataset("b") + datasets = Datasets((da, db)) + + datasets.seed(123) + + for dataset in datasets: + assert dataset.seed_value == 123 + + assert da.seed_value == 123 + assert db.seed_value == 123 + + +def test_enumerate_datasets_empty(): + datasets = Datasets([]) + + assert list(datasets) == [] + + +def test_enumerate_datasets(): + da = MockDataset("benchmark://a") + db = MockDataset("benchmark://b") + datasets = Datasets((da, db)) + + assert list(datasets) == [da, db] + + +def test_enumerate_datasets_with_custom_sort_order(): + da = MockDataset("benchmark://a") + db = MockDataset("benchmark://b") + db.sort_order = -1 + datasets = Datasets((da, db)) + + assert list(datasets) == [db, da] + + +def test_enumerate_hidden_datasets(): + da = MockDataset("benchmark://a") + db = MockDataset("benchmark://b") + datasets = Datasets((da, db)) + + db.hidden = True + assert list(datasets) == [da] + assert list(datasets.datasets(with_deprecated=True)) == [da, db] + + +def test_enumerate_datasets_hidden_at_construction_time(): + da = MockDataset("benchmark://a") + db = MockDataset("benchmark://b") + db.hidden = True + datasets = Datasets((da, db)) + + assert list(datasets) == [da] + assert list(datasets.datasets(with_deprecated=True)) == [da, db] + + +def test_datasets_add_dataset(): + datasets = Datasets([]) + + da = MockDataset("benchmark://foo-v0") + datasets["benchmark://foo-v0"] = da + + assert list(datasets) == [da] + + +def test_datasets_add_hidden_dataset(): + datasets = Datasets([]) + + da = MockDataset("benchmark://a") + da.hidden = True + datasets["benchmark://foo-v0"] = da + + assert list(datasets) == [] + + +def test_datasets_remove(): + da = MockDataset("benchmark://foo-v0") + datasets = Datasets([da]) + + del datasets["benchmark://foo-v0"] + assert list(datasets) == [] + + +def test_datasets_get_item(): + da = MockDataset("benchmark://foo-v0") + datasets = Datasets([da]) + + assert datasets.dataset("benchmark://foo-v0") == da + assert datasets["benchmark://foo-v0"] == da + + +def test_datasets_get_item_default_protocol(): + da = MockDataset("benchmark://foo-v0") + datasets = Datasets([da]) + + assert datasets.dataset("foo-v0") == da + assert datasets["foo-v0"] == da + + +def test_datasets_get_item_lookup_miss(): + da = MockDataset("benchmark://foo-v0") + datasets = Datasets([da]) + + with pytest.raises(LookupError) as e_ctx: + datasets.dataset("benchmark://bar-v0") + assert str(e_ctx.value) == "Dataset not found: benchmark://bar-v0" + + with pytest.raises(LookupError) as e_ctx: + _ = datasets["benchmark://bar-v0"] + assert str(e_ctx.value) == "Dataset not found: benchmark://bar-v0" + + +def test_dataset_empty(): + datasets = Datasets([]) + + with pytest.raises(ValueError) as e_ctx: + datasets.dataset() + + assert str(e_ctx.value) == "No datasets" + + +def test_benchmark_empty(): + datasets = Datasets([]) + + with pytest.raises(ValueError) as e_ctx: + datasets.benchmark() + + assert str(e_ctx.value) == "No datasets" + + +def test_benchmark_lookup_by_uri(): + da = MockDataset("benchmark://foo-v0") + db = MockDataset("benchmark://bar-v0") + ba = MockBenchmark(uri="benchmark://foo-v0/abc") + da.benchmark_values.append(ba) + datasets = Datasets([da, db]) + + assert datasets.benchmark("benchmark://foo-v0/abc") == ba + + +def test_round_robin(): + iters = iter( + [ + iter([0, 1, 2, 3, 4, 5]), + iter(["a", "b", "c"]), + iter([0.5, 1.0]), + ] + ) + assert list(round_robin_iterables(iters)) == [ + 0, + "a", + 0.5, + 1, + "b", + 1.0, + 2, + "c", + 3, + 4, + 5, + ] + + +def test_benchmark_uris_order(): + da = MockDataset("benchmark://foo-v0") + db = MockDataset("benchmark://bar-v0") + ba = MockBenchmark(uri="benchmark://foo-v0/abc") + bb = MockBenchmark(uri="benchmark://foo-v0/123") + bc = MockBenchmark(uri="benchmark://bar-v0/abc") + bd = MockBenchmark(uri="benchmark://bar-v0/123") + da.benchmark_values.append(ba) + da.benchmark_values.append(bb) + db.benchmark_values.append(bc) + db.benchmark_values.append(bd) + datasets = Datasets([da, db]) + + assert list(datasets.benchmark_uris()) == [b.uri for b in datasets.benchmarks()] + # Datasets are ordered by name, so bar-v0 before foo-v0. + assert list(datasets.benchmark_uris()) == [ + "benchmark://bar-v0/abc", + "benchmark://foo-v0/abc", + "benchmark://bar-v0/123", + "benchmark://foo-v0/123", + ] + + +def test_benchmarks_iter_deprecated(): + da = MockDataset("benchmark://foo-v0") + db = MockDataset("benchmark://bar-v0") + db.hidden = True + ba = MockBenchmark(uri="benchmark://foo-v0/abc") + bb = MockBenchmark(uri="benchmark://foo-v0/123") + bc = MockBenchmark(uri="benchmark://bar-v0/abc") + bd = MockBenchmark(uri="benchmark://bar-v0/123") + da.benchmark_values.append(ba) + da.benchmark_values.append(bb) + db.benchmark_values.append(bc) + db.benchmark_values.append(bd) + datasets = Datasets([da, db]) + + # Iterate over the benchmarks. The deprecated dataset is not included. + assert list(datasets.benchmark_uris()) == [b.uri for b in datasets.benchmarks()] + assert list(datasets.benchmark_uris()) == [ + "benchmark://foo-v0/abc", + "benchmark://foo-v0/123", + ] + + # Repeat the above, but include the deprecated datasets. + assert list(datasets.benchmark_uris(with_deprecated=True)) == [ + b.uri for b in datasets.benchmarks(with_deprecated=True) + ] + assert list(datasets.benchmark_uris(with_deprecated=True)) == [ + "benchmark://bar-v0/abc", + "benchmark://foo-v0/abc", + "benchmark://bar-v0/123", + "benchmark://foo-v0/123", + ] + + +def test_benchmark_select_randomly(): + da = MockDataset("benchmark://foo-v0") + db = MockDataset("benchmark://bar-v0") + ba = MockBenchmark(uri="benchmark://foo-v0/abc") + bb = MockBenchmark(uri="benchmark://bar-v0/abc") + da.benchmark_values.append(ba) + db.benchmark_values.append(bb) + datasets = Datasets([da, db]) + + # Create three lists of randomly selected benchmarks. Two using the same + # seed, the third using a different seed. It is unlikely that the two + # different seeds will produce the same lists. + datasets.seed(1) + benchmarks_a = [datasets.benchmark() for i in range(50)] + datasets.seed(1) + benchmarks_b = [datasets.benchmark() for i in range(50)] + datasets.seed(2) + benchmarks_c = [datasets.benchmark() for i in range(50)] + + assert benchmarks_a == benchmarks_b + assert benchmarks_a != benchmarks_c + assert len(set(benchmarks_a)) == 2 + + +def test_dataset_select_randomly(): + da = MockDataset("benchmark://foo-v0") + db = MockDataset("benchmark://bar-v0") + datasets = Datasets([da, db]) + + # Create three lists of randomly selected datasets. Two using the same seed, + # the third using a different seed. It is unlikely that the two different + # seeds will produce the same lists. + datasets.seed(1) + datasets_a = [datasets.dataset() for i in range(50)] + datasets.seed(1) + datasets_b = [datasets.dataset() for i in range(50)] + datasets.seed(2) + datasets_c = [datasets.dataset() for i in range(50)] + + assert datasets_a == datasets_b + assert datasets_a != datasets_c + assert len(set(datasets_a)) == 2 + + +if __name__ == "__main__": + main() diff --git a/tests/llvm/BUILD b/tests/llvm/BUILD index d6d88a1a09..d04b5380b6 100644 --- a/tests/llvm/BUILD +++ b/tests/llvm/BUILD @@ -85,18 +85,6 @@ py_test( ], ) -py_test( - name = "datasets_test", - srcs = ["datasets_test.py"], - deps = [ - "//compiler_gym", - "//compiler_gym/envs/llvm:legacy_datasets", - "//tests:test_main", - "//tests/pytest_plugins:common", - "//tests/pytest_plugins:llvm", - ], -) - py_test( name = "fork_env_test", timeout = "long", diff --git a/tests/llvm/datasets_test.py b/tests/llvm/datasets_test.py deleted file mode 100644 index 72569b0547..0000000000 --- a/tests/llvm/datasets_test.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -"""Tests for //compiler_gym/envs/llvm:legacy_datasets.""" -import pytest - -from compiler_gym.envs.llvm import LlvmEnv, legacy_datasets -from tests.test_main import main - -pytest_plugins = ["tests.pytest_plugins.common", "tests.pytest_plugins.llvm"] - - -def test_validate_sha_output_okay(): - output = legacy_datasets.BenchmarkExecutionResult( - walltime_seconds=0, - output="1234567890abcdef 1234567890abcd 1234567890abc 1234567890 12345", - ) - assert legacy_datasets.validate_sha_output(output) is None - - -def test_validate_sha_output_invalid(): - output = legacy_datasets.BenchmarkExecutionResult(walltime_seconds=0, output="abcd") - assert legacy_datasets.validate_sha_output(output) - - -def test_cBench_v0_deprecation(env: LlvmEnv): - """Test that cBench-v0 emits a deprecation warning when used.""" - with pytest.deprecated_call( - match=( - "Dataset 'cBench-v0' is deprecated as of CompilerGym release " - "v0.1.4, please update to the latest available version" - ) - ): - env.require_dataset("cBench-v0") - - -if __name__ == "__main__": - main()