Skip to content

Commit

Permalink
WIP: New dataset API.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Apr 8, 2021
1 parent 2ead03f commit 80db8cf
Show file tree
Hide file tree
Showing 70 changed files with 3,911 additions and 1,699 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ semantics validation, and improving the datasets. Many thanks to @JD-at-work,
- Added default reward spaces for `CompilerEnv` that are derived from scalar
observations (thanks @bwasti!)
- Added a new Q learning example (thanks @JD-at-work!).
- *Deprecation:* The next release v0.1.5 will introduce a new datasets API that
is easier to use and more flexible. In preparation for this, the `Dataset`
class has been renamed to `LegacyDataset`, the following dataset operations
have been marked deprecated: `activate()`, `deactivate()`, and `delete()`. The
- *Deprecation:* The v0.1.9 release will introduce a new datasets API that is
easier to use and more flexible. In preparation for this, the `Dataset` class
has been renamed to `LegacyDataset`, the following dataset operations have
been marked deprecated: `activate()`, `deactivate()`, and `delete()`. The
`GetBenchmarks()` RPC interface method has also been marked deprecated..
- [llvm] Improved semantics validation using LLVM's memory, thread, address, and
undefined behavior sanitizers.
Expand Down
5 changes: 2 additions & 3 deletions compiler_gym/bin/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ py_binary(
srcs = ["datasets.py"],
visibility = ["//visibility:public"],
deps = [
"//compiler_gym/datasets:dataset",
":service",
"//compiler_gym/envs",
"//compiler_gym/util",
"//compiler_gym/util/flags:env_from_flags",
Expand All @@ -39,7 +39,6 @@ py_binary(
"//compiler_gym/util",
"//compiler_gym/util/flags:benchmark_from_flags",
"//compiler_gym/util/flags:env_from_flags",
"//compiler_gym/util/flags:ls_benchmark",
],
)

Expand All @@ -60,7 +59,6 @@ py_binary(
"//compiler_gym:random_search",
"//compiler_gym/util/flags:benchmark_from_flags",
"//compiler_gym/util/flags:env_from_flags",
"//compiler_gym/util/flags:ls_benchmark",
"//compiler_gym/util/flags:nproc",
"//compiler_gym/util/flags:output_dir",
],
Expand All @@ -83,6 +81,7 @@ py_binary(
srcs = ["service.py"],
visibility = ["//visibility:public"],
deps = [
"//compiler_gym/datasets",
"//compiler_gym/envs",
"//compiler_gym/spaces",
"//compiler_gym/util",
Expand Down
105 changes: 6 additions & 99 deletions compiler_gym/bin/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,6 @@
+-------------------+--------------+-----------------+----------------+
These benchmarks are ready for use. Deactivate them using `--deactivate=<name>`.
+---------------------+-----------+-----------------+----------------+
| Inactive Datasets | License | #. Benchmarks | Size on disk |
+=====================+===========+=================+================+
| Total | | 0 | 0 Bytes |
+---------------------+-----------+-----------------+----------------+
These benchmarks may be activated using `--activate=<name>`.
+------------------------+---------------------------------+-----------------+----------------+
| Downloadable Dataset | License | #. Benchmarks | Size on disk |
+========================+=================================+=================+================+
| blas-v0 | BSD 3-Clause | 300 | 4.0 MB |
+------------------------+---------------------------------+-----------------+----------------+
| polybench-v0 | BSD 3-Clause | 27 | 162.6 kB |
+------------------------+---------------------------------+-----------------+----------------+
These benchmarks may be installed using `--download=<name> --activate=<name>`.
Downloading datasets
--------------------
Expand Down Expand Up @@ -131,23 +114,13 @@
A :code:`--delete_all` flag can be used to delete all of the locally installed
datasets.
"""
import os
import sys
from pathlib import Path
from typing import Tuple

import humanize
from absl import app, flags

from compiler_gym.datasets.dataset import (
LegacyDataset,
activate,
deactivate,
delete,
require,
)
from compiler_gym.bin.service import summarize_datasets
from compiler_gym.datasets.dataset import activate, deactivate, delete, require
from compiler_gym.util.flags.env_from_flags import env_from_flags
from compiler_gym.util.tabulate import tabulate

flags.DEFINE_list(
"download",
Expand Down Expand Up @@ -175,69 +148,34 @@
FLAGS = flags.FLAGS


def get_count_and_size_of_directory_contents(root: Path) -> Tuple[int, int]:
"""Return the number of files and combined size of a directory."""
count, size = 0, 0
for root, _, files in os.walk(str(root)):
count += len(files)
size += sum(os.path.getsize(f"{root}/{file}") for file in files)
return count, size


def enumerate_directory(name: str, path: Path):
rows = []
for path in path.iterdir():
if not path.is_file() or not path.name.endswith(".json"):
continue
dataset = LegacyDataset.from_json_file(path)
rows.append(
(dataset.name, dataset.license, dataset.file_count, dataset.size_bytes)
)
rows.append(("Total", "", sum(r[2] for r in rows), sum(r[3] for r in rows)))
return tabulate(
[(n, l, humanize.intcomma(f), humanize.naturalsize(s)) for n, l, f, s in rows],
headers=(name, "License", "#. Benchmarks", "Size on disk"),
)


def main(argv):
"""Main entry point."""
if len(argv) != 1:
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

env = env_from_flags()
try:
if not env.datasets_site_path:
raise app.UsageError("Environment has no benchmarks site path")

env.datasets_site_path.mkdir(parents=True, exist_ok=True)
env.inactive_datasets_site_path.mkdir(parents=True, exist_ok=True)

invalidated_manifest = False

for name_or_url in FLAGS.download:
require(env, name_or_url)

if FLAGS.download_all:
for dataset in env.available_datasets:
require(env, dataset)
for dataset in env.datasets:
dataset.install()

for name in FLAGS.activate:
activate(env, name)
invalidated_manifest = True

if FLAGS.activate_all:
for path in env.inactive_datasets_site_path.iterdir():
activate(env, path.name)
invalidated_manifest = True

for name in FLAGS.deactivate:
deactivate(env, name)
invalidated_manifest = True

if FLAGS.deactivate_all:
for path in env.datasets_site_path.iterdir():
deactivate(env, path.name)
invalidated_manifest = True

for name in FLAGS.delete:
Expand All @@ -246,41 +184,10 @@ def main(argv):
if invalidated_manifest:
env.make_manifest_file()

print(f"{env.spec.id} benchmarks site dir: {env.datasets_site_path}")
print(f"{env.spec.id} benchmarks site dir: {env.datasets.site_data_path}")
print()
print(
enumerate_directory("Active Datasets", env.datasets_site_path),
)
print(
"These benchmarks are ready for use. Deactivate them using `--deactivate=<name>`."
)
print()
print(enumerate_directory("Inactive Datasets", env.inactive_datasets_site_path))
print("These benchmarks may be activated using `--activate=<name>`.")
print()
print(
tabulate(
sorted(
[
(
d.name,
d.license,
humanize.intcomma(d.file_count),
humanize.naturalsize(d.size_bytes),
)
for d in env.available_datasets.values()
]
),
headers=(
"Downloadable Dataset",
"License",
"#. Benchmarks",
"Size on disk",
),
)
)
print(
"These benchmarks may be installed using `--download=<name> --activate=<name>`."
summarize_datasets(env.datasets),
)
finally:
env.close()
Expand Down
45 changes: 15 additions & 30 deletions compiler_gym/bin/manual_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,10 @@
import random
import readline
import sys
from itertools import islice

from absl import app, flags

import compiler_gym.util.flags.ls_benchmark # noqa Flag definition.
from compiler_gym.envs import CompilerEnv
from compiler_gym.util.flags.benchmark_from_flags import benchmark_from_flags
from compiler_gym.util.flags.env_from_flags import env_from_flags
Expand Down Expand Up @@ -304,7 +304,7 @@ def __init__(self, env: CompilerEnv):
self.init_benchmarks()

# Get the benchmarks
self.benchmarks = sorted(self.env.benchmarks)
self.benchmarks = sorted(islice(self.env.datasets.benchmark_uris(), 100))
# Strip default benchmark:// protocol.
for i, benchmark in enumerate(self.benchmarks):
if benchmark.startswith("benchmark://"):
Expand Down Expand Up @@ -338,7 +338,7 @@ def postloop(self):
def init_benchmarks(self):
"""Initialise the set of benchmarks"""
# Get the benchmarks
self.benchmarks = sorted(self.env.benchmarks)
self.benchmarks = sorted(islice(self.env.datasets.benchmark_uris(), 100))
# Strip default benchmark:// protocol.
for i, benchmark in enumerate(self.benchmarks):
if benchmark.startswith("benchmark://"):
Expand All @@ -364,7 +364,7 @@ def simple_complete(self, text, options):

def get_datasets(self):
"""Get the list of available datasets"""
return sorted([k for k in self.env.available_datasets])
return sorted([k.name for k in self.env.datasets.datasets(inactive=True)])

def do_list_datasets(self, arg):
"""List all of the available datasets"""
Expand All @@ -378,23 +378,17 @@ def do_require_dataset(self, arg):
"""Require dataset
The argument is the name of the dataset to require.
"""
if self.get_datasets().count(arg):
try:
with Timer(f"Downloaded dataset {arg}"):
self.env.require_dataset(arg)
self.env.datasets.require(arg)
self.init_benchmarks()
else:
except LookupError:
print("Unknown dataset, '" + arg + "'")
print("Available datasets are listed with command, list_available_datasets")

def do_list_benchmarks(self, arg):
"""List all of the available benchmarks"""
if not self.benchmarks:
doc_root_url = "https://facebookresearch.github.io/CompilerGym/"
install_url = doc_root_url + "getting_started.html#installing-benchmarks"
print("No benchmarks available. See " + install_url)
print("Datasets can be installed with command, require_dataset")
else:
print(", ".join(self.benchmarks))
print(", ".join(self.benchmarks))

def complete_set_benchmark(self, text, line, begidx, endidx):
"""Complete the set_benchmark argument"""
Expand All @@ -409,27 +403,25 @@ def do_set_benchmark(self, arg):
Use '-' for a random benchmark.
"""
if arg == "-":
arg = random.choice(self.benchmarks)
arg = self.env.datasets.benchmark().uri
print(f"set_benchmark {arg}")

if self.benchmarks.count(arg):
try:
benchmark = self.env.datasets.benchmark(arg)
self.stack.clear()

# Set the current benchmark
with Timer() as timer:
observation = self.env.reset(benchmark=arg)
observation = self.env.reset(benchmark=benchmark)
print(f"Reset {self.env.benchmark} environment in {timer}")

if self.env.observation_space and observation is not None:
print(
f"Observation: {self.env.observation_space.to_string(observation)}"
)
print("Observation:", self.env.observation_space.to_string(observation))

self.set_prompt()

else:
except LookupError:
print("Unknown benchmark, '" + arg + "'")
print("Bencmarks are listed with command, list_benchmarks")
print("Benchmarks are listed with command, list_benchmarks")

def get_actions(self):
"""Get the list of actions"""
Expand Down Expand Up @@ -889,13 +881,6 @@ def main(argv):
if len(argv) != 1:
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

if FLAGS.ls_benchmark:
benchmark = benchmark_from_flags()
env = env_from_flags(benchmark)
print("\n".join(sorted(env.benchmarks)))
env.close()
return

with Timer("Initialized environment"):
# FIXME Chris, I don't seem to actually get a benchmark
benchmark = benchmark_from_flags()
Expand Down
6 changes: 0 additions & 6 deletions compiler_gym/bin/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@

from absl import app, flags

import compiler_gym.util.flags.ls_benchmark # noqa Flag definition.
import compiler_gym.util.flags.nproc # noqa Flag definition.
import compiler_gym.util.flags.output_dir # noqa Flag definition.
from compiler_gym.random_search import random_search
Expand Down Expand Up @@ -93,11 +92,6 @@ def main(argv):
if len(argv) != 1:
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

if FLAGS.ls_benchmark:
env = env_from_flags()
print("\n".join(sorted(env.benchmarks)))
env.close()
return
if FLAGS.ls_reward:
env = env_from_flags()
print("\n".join(sorted(env.reward.indices.keys())))
Expand Down
Loading

0 comments on commit 80db8cf

Please sign in to comment.