Skip to content

Commit

Permalink
[datasets] Remove randomness and Union types from Datasets.
Browse files Browse the repository at this point in the history
This patch makes two simplifications to the Datasets API:

1) It removes the random-benchmark selection logic from
`Dataset.benchmark()`. Now, calling `benchmark()` requires a URI. If
you wish to select a benchmark randomly, you can implement this random
selection yourself. The idea is that random benchmark selection is
quite a minor use case that introduces quite a bit of complexity into
the implementation.

2) It removes the `Union[str, Dataset]` types to `Datasets`
methods. Now, only a string is permitted. This is to make it easier to
understand the argument types. If the user has a `Dataset` instance
that they would like to use, they can explicitly pass in
`dataset.name`.

Issue facebookresearch#45.
  • Loading branch information
ChrisCummins authored and bwasti committed Aug 3, 2021
1 parent 8d069c8 commit b27bbb0
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 335 deletions.
30 changes: 3 additions & 27 deletions compiler_gym/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Dict, Iterable, List, NamedTuple, Optional, Union

import fasteners
import numpy as np
from deprecated.sphinx import deprecated

from compiler_gym.datasets.benchmark import DATASET_NAME_RE, Benchmark
Expand Down Expand Up @@ -45,7 +44,6 @@ def __init__(
site_data_base: Path,
benchmark_class=Benchmark,
references: Optional[Dict[str, str]] = None,
random: Optional[np.random.Generator] = None,
hidden: bool = False,
sort_order: int = 0,
logger: Optional[logging.Logger] = None,
Expand All @@ -70,8 +68,6 @@ def __init__(
:param references: A dictionary containing URLs for this dataset, keyed
by their name. E.g. :code:`references["Paper"] = "https://..."`.
:param random: A source of randomness for selecting benchmarks.
:param hidden: Whether the dataset should be excluded from the
:meth:`datasets() <compiler_gym.datasets.Datasets.dataset>` iterator
of any :class:`Datasets <compiler_gym.datasets.Datasets>` container.
Expand Down Expand Up @@ -103,7 +99,6 @@ def __init__(
self._hidden = hidden
self._validatable = validatable

self.random = random or np.random.default_rng()
self._logger = logger
self.sort_order = sort_order
self.benchmark_class = benchmark_class
Expand All @@ -115,17 +110,6 @@ def __init__(
def __repr__(self):
return self.name

def seed(self, seed: int):
"""Set the random state.
Setting a random state will fix the order that
:meth:`dataset.benchmark() <compiler_gym.datasets.Dataset.benchmark>`
returns benchmarks when called without arguments.
:param seed: A number.
"""
self.random = np.random.default_rng(seed)

@property
def logger(self) -> logging.Logger:
"""The logger for this dataset.
Expand Down Expand Up @@ -337,23 +321,15 @@ def benchmark_uris(self) -> Iterable[str]:
"""
raise NotImplementedError("abstract class")

def benchmark(self, uri: Optional[str] = None) -> Benchmark:
def benchmark(self, uri: str) -> Benchmark:
"""Select a benchmark.
If a URI is given, the corresponding :class:`Benchmark
<compiler_gym.datasets.Benchmark>` is returned. Otherwise, a benchmark
is selected uniformly randomly.
Use :meth:`seed() <compiler_gym.datasets.Dataset.seed>` to force a
reproducible order for randomly selected benchmarks.
:param uri: The URI of the benchmark to return. If :code:`None`, select
a benchmark randomly using :code:`self.random`.
:param uri: The URI of the benchmark to return.
:return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>`
instance.
:raise LookupError: If :code:`uri` is provided but does not exist.
:raise LookupError: If :code:`uri` is not found.
"""
raise NotImplementedError("abstract class")

Expand Down
115 changes: 23 additions & 92 deletions compiler_gym/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import deque
from typing import Dict, Iterable, Optional, Set, TypeVar, Union

import numpy as np
from typing import Dict, Iterable, Set, TypeVar

from compiler_gym.datasets.benchmark import (
BENCHMARK_URI_RE,
Expand Down Expand Up @@ -76,46 +74,16 @@ class Datasets(object):
If you want to exclude a dataset, delete it:
>>> del env.datasets["benchmark://b-v0"]
To iterate over the benchmarks in a random order, use :meth:`benchmark()`
and omit the URI:
>>> for i in range(100):
... benchmark = env.datasets.benchmark()
This uses uniform random selection to sample across datasets. For finite
datasets, you could weight the sample by the size of each dataset:
>>> weights = [len(d) for d in env.datasets]
>>> np.random.choice(list(env.datasets), p=weights).benchmark()
"""

def __init__(
self,
datasets: Iterable[Dataset],
random: Optional[np.random.Generator] = None,
):
self._datasets: Dict[str, Dataset] = {d.name: d for d in datasets}
self._visible_datasets: Set[str] = set(
name for name, dataset in self._datasets.items() if not dataset.hidden
)
self.random = random or np.random.default_rng()

def seed(self, seed: Optional[int] = None) -> None:
"""Set the random state.
Setting a random state will fix the order that
:meth:`datasets.benchmark() <compiler_gym.datasets.Datasets.benchmark>`
returns benchmarks when called without arguments.
Calling this method recursively calls :meth:`seed()
<compiler_gym.datasets.Dataset.seed>` on all member datasets.
:param seed: An optional seed value.
"""
self.random = np.random.default_rng(seed)
for dataset in self._datasets.values():
dataset.seed(seed)

def datasets(self, with_deprecated: bool = False) -> Iterable[Dataset]:
"""Enumerate the datasets.
Expand Down Expand Up @@ -147,46 +115,30 @@ def __iter__(self) -> Iterable[Dataset]:
"""
return self.datasets()

def dataset(self, dataset: Optional[Union[str, Dataset]] = None) -> Dataset:
def dataset(self, dataset: str) -> Dataset:
"""Get a dataset.
If a name is given, return the corresponding :meth:`Dataset
<compiler_gym.datasets.Dataset>`. Else, return a dataset uniformly
randomly from the set of available datasets.
Return the corresponding :meth:`Dataset
<compiler_gym.datasets.Dataset>`. Name lookup will succeed whether or
not the dataset is deprecated.
Use :meth:`seed() <compiler_gym.datasets.Dataset.seed>` to force a
reproducible order for randomly selected datasets.
Name lookup will succeed whether or not the dataset is deprecated.
:param dataset: A dataset name, a :class:`Dataset` instance, or
:code:`None` to select a dataset randomly.
:param dataset: A dataset name.
:return: A :meth:`Dataset <compiler_gym.datasets.Dataset>` instance.
:raises LookupError: If :code:`dataset` is not found.
"""
if dataset is None:
if not self._visible_datasets:
raise ValueError("No datasets")

return self._datasets[self.random.choice(list(self._visible_datasets))]

if isinstance(dataset, Dataset):
dataset_name = dataset.name
else:
dataset_name = resolve_uri_protocol(dataset)
dataset_name = resolve_uri_protocol(dataset)

if dataset_name not in self._datasets:
raise LookupError(f"Dataset not found: {dataset_name}")

return self._datasets[dataset_name]

def __getitem__(self, dataset: Union[str, Dataset]) -> Dataset:
def __getitem__(self, dataset: str) -> Dataset:
"""Lookup a dataset.
:param dataset: A dataset name, a :class:`Dataset` instance, or
:code:`None` to select a dataset randomly.
:param dataset: A dataset name.
:return: A :meth:`Dataset <compiler_gym.datasets.Dataset>` instance.
Expand All @@ -195,29 +147,30 @@ def __getitem__(self, dataset: Union[str, Dataset]) -> Dataset:
return self.dataset(dataset)

def __setitem__(self, key: str, dataset: Dataset):
self._datasets[key] = dataset
dataset_name = resolve_uri_protocol(key)

self._datasets[dataset_name] = dataset
if not dataset.hidden:
self._visible_datasets.add(dataset.name)
self._visible_datasets.add(dataset_name)

def __delitem__(self, dataset: Union[str, Dataset]):
def __delitem__(self, dataset: str):
"""Remove a dataset from the collection.
This does not affect any underlying storage used by dataset. See
:meth:`uninstall() <compiler_gym.datasets.Datasets.uninstall>` to clean
up.
:param dataset: A :meth:`Dataset <compiler_gym.datasets.Dataset>`
instance, or the name of a dataset.
:param dataset: The name of a dataset.
:return: :code:`True` if the dataset was removed, :code:`False` if it
was already removed.
"""
dataset_name: str = self.dataset(dataset).name
dataset_name = resolve_uri_protocol(dataset)
if dataset_name in self._visible_datasets:
self._visible_datasets.remove(dataset_name)
del self._datasets[dataset_name]

def __contains__(self, dataset: Union[str, Dataset]) -> bool:
def __contains__(self, dataset: str) -> bool:
"""Returns whether the dataset is contained."""
try:
self.dataset(dataset)
Expand Down Expand Up @@ -261,37 +214,18 @@ def benchmark_uris(self, with_deprecated: bool = False) -> Iterable[str]:
(d.benchmark_uris() for d in self.datasets(with_deprecated=with_deprecated))
)

def benchmark(self, uri: Optional[str] = None) -> Benchmark:
def benchmark(self, uri: str) -> Benchmark:
"""Select a benchmark.
If a benchmark URI is given, the corresponding :class:`Benchmark
<compiler_gym.datasets.Benchmark>` is returned, regardless of whether
the containing dataset is installed or deprecated.
Returns the corresponding :class:`Benchmark
<compiler_gym.datasets.Benchmark>`, regardless of whether the containing
dataset is installed or deprecated.
If no URI is given, a benchmark is selected randomly. First, a dataset
is selected uniformly randomly from the set of available datasets. Then
a benchmark is selected randomly from the chosen dataset.
Calling :code:`benchmark()` will yield benchmarks from all available
datasets with equal probability, regardless of how many benchmarks are
in each dataset. Given a pool of available datasets of differing sizes,
smaller datasets will be overrepresented and large datasets will be
underrepresented.
Use :meth:`seed() <compiler_gym.datasets.Dataset.seed>` to force a
reproducible order for randomly selected benchmarks.
:param uri: The URI of the benchmark to return. If :code:`None`, select
a benchmark randomly using :code:`self.random`.
:param uri: The URI of the benchmark to return.
:return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>`
instance.
"""
if uri is None and not self._visible_datasets:
raise ValueError("No datasets")
elif uri is None:
return self.dataset().benchmark()

uri = resolve_uri_protocol(uri)

match = BENCHMARK_URI_RE.match(uri)
Expand All @@ -301,10 +235,7 @@ def benchmark(self, uri: Optional[str] = None) -> Benchmark:
dataset_name = match.group("dataset")
dataset = self._datasets[dataset_name]

if len(uri) > len(dataset_name) + 1:
return dataset.benchmark(uri)
else:
return dataset.benchmark()
return dataset.benchmark(uri)

@property
def size(self) -> int:
Expand Down
49 changes: 2 additions & 47 deletions compiler_gym/datasets/files_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
import os
from pathlib import Path
from typing import Iterable, List, Optional
from typing import Iterable, List

from compiler_gym.datasets.dataset import Benchmark, Dataset
from compiler_gym.util.decorators import memoized_property
Expand Down Expand Up @@ -108,56 +108,11 @@ def benchmark_uris(self) -> Iterable[str]:
else:
yield from self._benchmark_uris_iter

def benchmark(self, uri: Optional[str] = None) -> Benchmark:
def benchmark(self, uri: str) -> Benchmark:
self.install()
if uri is None or len(uri) <= len(self.name) + 1:
if not self.size:
raise ValueError("No benchmarks")
return self._get_benchmark_by_index(self.random.integers(self.size))

relpath = f"{uri[len(self.name) + 1:]}{self.benchmark_file_suffix}"
abspath = self.dataset_root / relpath
if not abspath.is_file():
raise LookupError(f"Benchmark not found: {uri} (file not found: {abspath})")
return self.benchmark_class.from_file(uri, abspath)

def _get_benchmark_by_index(self, n: int) -> Benchmark:
"""Look up a benchmark using a numeric index into the list of URIs,
without bounds checking.
"""
# If we have memoized the benchmark IDs then just index into the list.
# Otherwise we will scan through the directory hierarchy.
if self.memoize_uris:
if not self._memoized_uris:
self._memoized_uris = self._benchmark_uris
return self.benchmark(self._memoized_uris[n])

i = 0
for root, dirs, files in os.walk(self.dataset_root):
reldir = root[len(str(self.dataset_root)) + 1 :]

# Filter only the files that match the target file suffix.
valid_files = [f for f in files if f.endswith(self.benchmark_file_suffix)]

if i + len(valid_files) <= n:
# There aren't enough files in this directory to bring us up to
# the target file index, so skip this directory and descend into
# subdirectories.
i += len(valid_files)
# Sort the subdirectories so that the iteration order is stable
# and consistent with benchmark_uris().
dirs.sort()
else:
valid_files.sort()
filename = valid_files[n - i]
name_stem = filename
if self.benchmark_file_suffix:
name_stem = filename[: -len(self.benchmark_file_suffix)]
# Use os.path.join() rather than simple '/' concatenation as
# reldir may be empty.
uri = os.path.join(self.name, reldir, name_stem)
return self.benchmark_class.from_file(uri, os.path.join(root, filename))

# "Unreachable". _get_benchmark_by_index() should always be called with
# in-bounds values. Perhaps files have been deleted from site_data_path?
raise IndexError(f"Could not find benchmark with index {n} / {self.size}")
Loading

0 comments on commit b27bbb0

Please sign in to comment.