From b27bbb0137bcefc6790bd535e6e8e1a3262d0f97 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Mon, 26 Apr 2021 15:59:16 +0100 Subject: [PATCH] [datasets] Remove randomness and Union types from Datasets. This patch makes two simplifications to the Datasets API: 1) It removes the random-benchmark selection logic from `Dataset.benchmark()`. Now, calling `benchmark()` requires a URI. If you wish to select a benchmark randomly, you can implement this random selection yourself. The idea is that random benchmark selection is quite a minor use case that introduces quite a bit of complexity into the implementation. 2) It removes the `Union[str, Dataset]` types to `Datasets` methods. Now, only a string is permitted. This is to make it easier to understand the argument types. If the user has a `Dataset` instance that they would like to use, they can explicitly pass in `dataset.name`. Issue #45. --- compiler_gym/datasets/dataset.py | 30 +------ compiler_gym/datasets/datasets.py | 115 +++++-------------------- compiler_gym/datasets/files_dataset.py | 49 +---------- tests/datasets/datasets_test.py | 81 ----------------- tests/datasets/files_dataset_test.py | 88 ------------------- 5 files changed, 28 insertions(+), 335 deletions(-) diff --git a/compiler_gym/datasets/dataset.py b/compiler_gym/datasets/dataset.py index 2ec9de638c..7f1dad4f96 100644 --- a/compiler_gym/datasets/dataset.py +++ b/compiler_gym/datasets/dataset.py @@ -13,7 +13,6 @@ from typing import Dict, Iterable, List, NamedTuple, Optional, Union import fasteners -import numpy as np from deprecated.sphinx import deprecated from compiler_gym.datasets.benchmark import DATASET_NAME_RE, Benchmark @@ -45,7 +44,6 @@ def __init__( site_data_base: Path, benchmark_class=Benchmark, references: Optional[Dict[str, str]] = None, - random: Optional[np.random.Generator] = None, hidden: bool = False, sort_order: int = 0, logger: Optional[logging.Logger] = None, @@ -70,8 +68,6 @@ def __init__( :param references: A dictionary containing URLs for this dataset, keyed by their name. E.g. :code:`references["Paper"] = "https://..."`. - :param random: A source of randomness for selecting benchmarks. - :param hidden: Whether the dataset should be excluded from the :meth:`datasets() ` iterator of any :class:`Datasets ` container. @@ -103,7 +99,6 @@ def __init__( self._hidden = hidden self._validatable = validatable - self.random = random or np.random.default_rng() self._logger = logger self.sort_order = sort_order self.benchmark_class = benchmark_class @@ -115,17 +110,6 @@ def __init__( def __repr__(self): return self.name - def seed(self, seed: int): - """Set the random state. - - Setting a random state will fix the order that - :meth:`dataset.benchmark() ` - returns benchmarks when called without arguments. - - :param seed: A number. - """ - self.random = np.random.default_rng(seed) - @property def logger(self) -> logging.Logger: """The logger for this dataset. @@ -337,23 +321,15 @@ def benchmark_uris(self) -> Iterable[str]: """ raise NotImplementedError("abstract class") - def benchmark(self, uri: Optional[str] = None) -> Benchmark: + def benchmark(self, uri: str) -> Benchmark: """Select a benchmark. - If a URI is given, the corresponding :class:`Benchmark - ` is returned. Otherwise, a benchmark - is selected uniformly randomly. - - Use :meth:`seed() ` to force a - reproducible order for randomly selected benchmarks. - - :param uri: The URI of the benchmark to return. If :code:`None`, select - a benchmark randomly using :code:`self.random`. + :param uri: The URI of the benchmark to return. :return: A :class:`Benchmark ` instance. - :raise LookupError: If :code:`uri` is provided but does not exist. + :raise LookupError: If :code:`uri` is not found. """ raise NotImplementedError("abstract class") diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py index 8c3d7bf7b9..e62502777c 100644 --- a/compiler_gym/datasets/datasets.py +++ b/compiler_gym/datasets/datasets.py @@ -3,9 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from collections import deque -from typing import Dict, Iterable, Optional, Set, TypeVar, Union - -import numpy as np +from typing import Dict, Iterable, Set, TypeVar from compiler_gym.datasets.benchmark import ( BENCHMARK_URI_RE, @@ -76,46 +74,16 @@ class Datasets(object): If you want to exclude a dataset, delete it: >>> del env.datasets["benchmark://b-v0"] - - To iterate over the benchmarks in a random order, use :meth:`benchmark()` - and omit the URI: - - >>> for i in range(100): - ... benchmark = env.datasets.benchmark() - - This uses uniform random selection to sample across datasets. For finite - datasets, you could weight the sample by the size of each dataset: - - >>> weights = [len(d) for d in env.datasets] - >>> np.random.choice(list(env.datasets), p=weights).benchmark() """ def __init__( self, datasets: Iterable[Dataset], - random: Optional[np.random.Generator] = None, ): self._datasets: Dict[str, Dataset] = {d.name: d for d in datasets} self._visible_datasets: Set[str] = set( name for name, dataset in self._datasets.items() if not dataset.hidden ) - self.random = random or np.random.default_rng() - - def seed(self, seed: Optional[int] = None) -> None: - """Set the random state. - - Setting a random state will fix the order that - :meth:`datasets.benchmark() ` - returns benchmarks when called without arguments. - - Calling this method recursively calls :meth:`seed() - ` on all member datasets. - - :param seed: An optional seed value. - """ - self.random = np.random.default_rng(seed) - for dataset in self._datasets.values(): - dataset.seed(seed) def datasets(self, with_deprecated: bool = False) -> Iterable[Dataset]: """Enumerate the datasets. @@ -147,46 +115,30 @@ def __iter__(self) -> Iterable[Dataset]: """ return self.datasets() - def dataset(self, dataset: Optional[Union[str, Dataset]] = None) -> Dataset: + def dataset(self, dataset: str) -> Dataset: """Get a dataset. - If a name is given, return the corresponding :meth:`Dataset - `. Else, return a dataset uniformly - randomly from the set of available datasets. + Return the corresponding :meth:`Dataset + `. Name lookup will succeed whether or + not the dataset is deprecated. - Use :meth:`seed() ` to force a - reproducible order for randomly selected datasets. - - Name lookup will succeed whether or not the dataset is deprecated. - - :param dataset: A dataset name, a :class:`Dataset` instance, or - :code:`None` to select a dataset randomly. + :param dataset: A dataset name. :return: A :meth:`Dataset ` instance. :raises LookupError: If :code:`dataset` is not found. """ - if dataset is None: - if not self._visible_datasets: - raise ValueError("No datasets") - - return self._datasets[self.random.choice(list(self._visible_datasets))] - - if isinstance(dataset, Dataset): - dataset_name = dataset.name - else: - dataset_name = resolve_uri_protocol(dataset) + dataset_name = resolve_uri_protocol(dataset) if dataset_name not in self._datasets: raise LookupError(f"Dataset not found: {dataset_name}") return self._datasets[dataset_name] - def __getitem__(self, dataset: Union[str, Dataset]) -> Dataset: + def __getitem__(self, dataset: str) -> Dataset: """Lookup a dataset. - :param dataset: A dataset name, a :class:`Dataset` instance, or - :code:`None` to select a dataset randomly. + :param dataset: A dataset name. :return: A :meth:`Dataset ` instance. @@ -195,29 +147,30 @@ def __getitem__(self, dataset: Union[str, Dataset]) -> Dataset: return self.dataset(dataset) def __setitem__(self, key: str, dataset: Dataset): - self._datasets[key] = dataset + dataset_name = resolve_uri_protocol(key) + + self._datasets[dataset_name] = dataset if not dataset.hidden: - self._visible_datasets.add(dataset.name) + self._visible_datasets.add(dataset_name) - def __delitem__(self, dataset: Union[str, Dataset]): + def __delitem__(self, dataset: str): """Remove a dataset from the collection. This does not affect any underlying storage used by dataset. See :meth:`uninstall() ` to clean up. - :param dataset: A :meth:`Dataset ` - instance, or the name of a dataset. + :param dataset: The name of a dataset. :return: :code:`True` if the dataset was removed, :code:`False` if it was already removed. """ - dataset_name: str = self.dataset(dataset).name + dataset_name = resolve_uri_protocol(dataset) if dataset_name in self._visible_datasets: self._visible_datasets.remove(dataset_name) del self._datasets[dataset_name] - def __contains__(self, dataset: Union[str, Dataset]) -> bool: + def __contains__(self, dataset: str) -> bool: """Returns whether the dataset is contained.""" try: self.dataset(dataset) @@ -261,37 +214,18 @@ def benchmark_uris(self, with_deprecated: bool = False) -> Iterable[str]: (d.benchmark_uris() for d in self.datasets(with_deprecated=with_deprecated)) ) - def benchmark(self, uri: Optional[str] = None) -> Benchmark: + def benchmark(self, uri: str) -> Benchmark: """Select a benchmark. - If a benchmark URI is given, the corresponding :class:`Benchmark - ` is returned, regardless of whether - the containing dataset is installed or deprecated. + Returns the corresponding :class:`Benchmark + `, regardless of whether the containing + dataset is installed or deprecated. - If no URI is given, a benchmark is selected randomly. First, a dataset - is selected uniformly randomly from the set of available datasets. Then - a benchmark is selected randomly from the chosen dataset. - - Calling :code:`benchmark()` will yield benchmarks from all available - datasets with equal probability, regardless of how many benchmarks are - in each dataset. Given a pool of available datasets of differing sizes, - smaller datasets will be overrepresented and large datasets will be - underrepresented. - - Use :meth:`seed() ` to force a - reproducible order for randomly selected benchmarks. - - :param uri: The URI of the benchmark to return. If :code:`None`, select - a benchmark randomly using :code:`self.random`. + :param uri: The URI of the benchmark to return. :return: A :class:`Benchmark ` instance. """ - if uri is None and not self._visible_datasets: - raise ValueError("No datasets") - elif uri is None: - return self.dataset().benchmark() - uri = resolve_uri_protocol(uri) match = BENCHMARK_URI_RE.match(uri) @@ -301,10 +235,7 @@ def benchmark(self, uri: Optional[str] = None) -> Benchmark: dataset_name = match.group("dataset") dataset = self._datasets[dataset_name] - if len(uri) > len(dataset_name) + 1: - return dataset.benchmark(uri) - else: - return dataset.benchmark() + return dataset.benchmark(uri) @property def size(self) -> int: diff --git a/compiler_gym/datasets/files_dataset.py b/compiler_gym/datasets/files_dataset.py index e4df8c10f8..3fd9af4b21 100644 --- a/compiler_gym/datasets/files_dataset.py +++ b/compiler_gym/datasets/files_dataset.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import os from pathlib import Path -from typing import Iterable, List, Optional +from typing import Iterable, List from compiler_gym.datasets.dataset import Benchmark, Dataset from compiler_gym.util.decorators import memoized_property @@ -108,56 +108,11 @@ def benchmark_uris(self) -> Iterable[str]: else: yield from self._benchmark_uris_iter - def benchmark(self, uri: Optional[str] = None) -> Benchmark: + def benchmark(self, uri: str) -> Benchmark: self.install() - if uri is None or len(uri) <= len(self.name) + 1: - if not self.size: - raise ValueError("No benchmarks") - return self._get_benchmark_by_index(self.random.integers(self.size)) relpath = f"{uri[len(self.name) + 1:]}{self.benchmark_file_suffix}" abspath = self.dataset_root / relpath if not abspath.is_file(): raise LookupError(f"Benchmark not found: {uri} (file not found: {abspath})") return self.benchmark_class.from_file(uri, abspath) - - def _get_benchmark_by_index(self, n: int) -> Benchmark: - """Look up a benchmark using a numeric index into the list of URIs, - without bounds checking. - """ - # If we have memoized the benchmark IDs then just index into the list. - # Otherwise we will scan through the directory hierarchy. - if self.memoize_uris: - if not self._memoized_uris: - self._memoized_uris = self._benchmark_uris - return self.benchmark(self._memoized_uris[n]) - - i = 0 - for root, dirs, files in os.walk(self.dataset_root): - reldir = root[len(str(self.dataset_root)) + 1 :] - - # Filter only the files that match the target file suffix. - valid_files = [f for f in files if f.endswith(self.benchmark_file_suffix)] - - if i + len(valid_files) <= n: - # There aren't enough files in this directory to bring us up to - # the target file index, so skip this directory and descend into - # subdirectories. - i += len(valid_files) - # Sort the subdirectories so that the iteration order is stable - # and consistent with benchmark_uris(). - dirs.sort() - else: - valid_files.sort() - filename = valid_files[n - i] - name_stem = filename - if self.benchmark_file_suffix: - name_stem = filename[: -len(self.benchmark_file_suffix)] - # Use os.path.join() rather than simple '/' concatenation as - # reldir may be empty. - uri = os.path.join(self.name, reldir, name_stem) - return self.benchmark_class.from_file(uri, os.path.join(root, filename)) - - # "Unreachable". _get_benchmark_by_index() should always be called with - # in-bounds values. Perhaps files have been deleted from site_data_path? - raise IndexError(f"Could not find benchmark with index {n} / {self.size}") diff --git a/tests/datasets/datasets_test.py b/tests/datasets/datasets_test.py index 8f574029b3..b2a0bcb298 100644 --- a/tests/datasets/datasets_test.py +++ b/tests/datasets/datasets_test.py @@ -17,7 +17,6 @@ class MockDataset: def __init__(self, name): self.name = name self.installed = False - self.seed_value = None self.hidden = False self.benchmark_values = [] self.sort_order = 0 @@ -28,9 +27,6 @@ def install(self): def uninstall(self): self.installed = False - def seed(self, seed): - self.seed_value = seed - def benchmark_uris(self): return (b.uri for b in self.benchmark_values) @@ -59,21 +55,6 @@ def __repr__(self): return str(self.name) -def test_seed_datasets_value(): - """Test that random seed is propagated to datasets.""" - da = MockDataset("a") - db = MockDataset("b") - datasets = Datasets((da, db)) - - datasets.seed(123) - - for dataset in datasets: - assert dataset.seed_value == 123 - - assert da.seed_value == 123 - assert db.seed_value == 123 - - def test_enumerate_datasets_empty(): datasets = Datasets([]) @@ -173,24 +154,6 @@ def test_datasets_get_item_lookup_miss(): assert str(e_ctx.value) == "Dataset not found: benchmark://bar-v0" -def test_dataset_empty(): - datasets = Datasets([]) - - with pytest.raises(ValueError) as e_ctx: - datasets.dataset() - - assert str(e_ctx.value) == "No datasets" - - -def test_benchmark_empty(): - datasets = Datasets([]) - - with pytest.raises(ValueError) as e_ctx: - datasets.benchmark() - - assert str(e_ctx.value) == "No datasets" - - def test_benchmark_lookup_by_uri(): da = MockDataset("benchmark://foo-v0") db = MockDataset("benchmark://bar-v0") @@ -280,49 +243,5 @@ def test_benchmarks_iter_deprecated(): ] -def test_benchmark_select_randomly(): - da = MockDataset("benchmark://foo-v0") - db = MockDataset("benchmark://bar-v0") - ba = MockBenchmark(uri="benchmark://foo-v0/abc") - bb = MockBenchmark(uri="benchmark://bar-v0/abc") - da.benchmark_values.append(ba) - db.benchmark_values.append(bb) - datasets = Datasets([da, db]) - - # Create three lists of randomly selected benchmarks. Two using the same - # seed, the third using a different seed. It is unlikely that the two - # different seeds will produce the same lists. - datasets.seed(1) - benchmarks_a = [datasets.benchmark() for i in range(50)] - datasets.seed(1) - benchmarks_b = [datasets.benchmark() for i in range(50)] - datasets.seed(2) - benchmarks_c = [datasets.benchmark() for i in range(50)] - - assert benchmarks_a == benchmarks_b - assert benchmarks_a != benchmarks_c - assert len(set(benchmarks_a)) == 2 - - -def test_dataset_select_randomly(): - da = MockDataset("benchmark://foo-v0") - db = MockDataset("benchmark://bar-v0") - datasets = Datasets([da, db]) - - # Create three lists of randomly selected datasets. Two using the same seed, - # the third using a different seed. It is unlikely that the two different - # seeds will produce the same lists. - datasets.seed(1) - datasets_a = [datasets.dataset() for i in range(50)] - datasets.seed(1) - datasets_b = [datasets.dataset() for i in range(50)] - datasets.seed(2) - datasets_c = [datasets.dataset() for i in range(50)] - - assert datasets_a == datasets_b - assert datasets_a != datasets_c - assert len(set(datasets_a)) == 2 - - if __name__ == "__main__": main() diff --git a/tests/datasets/files_dataset_test.py b/tests/datasets/files_dataset_test.py index 049f42bbc0..e6d4b4233a 100644 --- a/tests/datasets/files_dataset_test.py +++ b/tests/datasets/files_dataset_test.py @@ -5,7 +5,6 @@ """Unit tests for //compiler_gym/datasets:files_dataset_test.""" import tempfile from pathlib import Path -from typing import Optional import pytest @@ -64,13 +63,6 @@ def test_empty_dataset(empty_dataset: FilesDataset): assert list(empty_dataset.benchmarks()) == [] -def test_empty_dataset_benchmark(empty_dataset: FilesDataset): - with pytest.raises(ValueError) as e_ctx: - empty_dataset.benchmark() - - assert str(e_ctx.value) == "No benchmarks" - - def test_populated_dataset(populated_dataset: FilesDataset): for _ in range(2): assert list(populated_dataset.benchmark_uris()) == [ @@ -119,85 +111,5 @@ def test_populated_dataset_with_file_extension_filter(populated_dataset: FilesDa assert populated_dataset.size == 2 -@pytest.mark.parametrize( - "requested_uri", (None, "benchmark://test-v0", "benchmark://test-v0/") -) -def test_populated_dataset_random_benchmark( - populated_dataset: FilesDataset, requested_uri: Optional[str] -): - populated_dataset.benchmark_file_suffix = ".jpg" - - populated_dataset.seed(1) - benchmarks_a = [populated_dataset.benchmark(requested_uri).uri for _ in range(50)] - populated_dataset.seed(1) - benchmarks_b = [populated_dataset.benchmark(requested_uri).uri for _ in range(50)] - populated_dataset.seed(2) - benchmarks_c = [populated_dataset.benchmark(requested_uri).uri for _ in range(50)] - - assert benchmarks_a == benchmarks_b - assert benchmarks_b != benchmarks_c - - assert set(benchmarks_a) == set(populated_dataset.benchmark_uris()) - assert set(benchmarks_c) == set(populated_dataset.benchmark_uris()) - - -def test_populated_dataset_get_benchmark_by_index(populated_dataset: FilesDataset): - # pylint: disable=protected-access - - i = 0 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/e.txt" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - i += 1 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/f.txt" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - i += 1 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/g.jpg" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - i += 1 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/a/a.txt" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - i += 1 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/a/b.txt" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - i += 1 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/b/a.txt" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - i += 1 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/b/b.txt" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - i += 1 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/b/c.txt" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - i += 1 - benchmark = populated_dataset._get_benchmark_by_index(i) - assert benchmark.uri == "benchmark://test-v0/b/d.jpg" - assert Path(benchmark.proto.program.uri[len("file:///") :]).is_file() - - -def test_populated_dataset_get_benchmark_by_index_out_of_range( - populated_dataset: FilesDataset, -): - # pylint: disable=protected-access - with pytest.raises(IndexError): - populated_dataset._get_benchmark_by_index(-1) - populated_dataset._get_benchmark_by_index(10) - - if __name__ == "__main__": main()