diff --git a/.github/actions/install-build-dependencies/action.yml b/.github/actions/install-build-dependencies/action.yml
new file mode 100644
index 000000000..de7053750
--- /dev/null
+++ b/.github/actions/install-build-dependencies/action.yml
@@ -0,0 +1,28 @@
+---
+name: Install the Linux dependencies
+description: Install build dependencies
+runs:
+ using: composite
+ steps:
+ - name: Install dependencies (linux)
+ run: |
+ if [ "$(uname)" != "Darwin" ]; then
+ curl -L "https://github.com/bazelbuild/bazelisk/releases/download/v1.6.1/bazelisk-linux-amd64" > bazel
+ chmod +x bazel
+ sudo mv bazel /usr/local/bin/bazel
+ sudo apt-get install clang-9 patchelf
+ python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
+ fi
+ shell: bash
+
+ - name: Install dependencies (macos)
+ run: |
+ if [ "$(uname)" = "Darwin" ]; then
+ brew install bazelisk zlib
+ python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
+ fi
+ shell: bash
+ env:
+ LDFLAGS: -L/usr/local/opt/zlib/lib
+ CPPFLAGS: -I/usr/local/opt/zlib/include
+ PKG_CONFIG_PATH: /usr/local/opt/zlib/lib/pkgconfig
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 88d9ef93f..257eacf44 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -18,13 +18,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python: [3.6, 3.7, 3.8, 3.9]
- exclude:
- # Only test recent python versions on macOS.
- - os: macos-latest
- python: 3.6
- - os: macos-latest
- python: 3.7
+ python: [3.9]
steps:
- uses: actions/checkout@v2
@@ -34,33 +28,17 @@ jobs:
with:
python-version: ${{ matrix.python }}
- - name: Install dependencies (linux)
- run: |
- curl -L "https://github.com/bazelbuild/bazelisk/releases/download/v1.6.1/bazelisk-linux-amd64" > bazel
- chmod +x bazel
- sudo mv bazel /usr/local/bin/bazel
- sudo apt install clang-9 patchelf
- python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
- if: matrix.os == 'ubuntu-latest'
-
- - name: Install dependencies (macOS)
- run: |
- brew install bazelisk zlib
- python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
- env:
- LDFLAGS: -L/usr/local/opt/zlib/lib
- CPPFLAGS: -I/usr/local/opt/zlib/include
- PKG_CONFIG_PATH: /usr/local/opt/zlib/lib/pkgconfig
- if: matrix.os == 'macos-latest'
+ - name: Install build dependencies
+ uses: ./.github/actions/install-build-dependencies
- name: Test
run: make test
env:
CC: clang
CXX: clang++
+ BAZEL_OPTS: --batch
BAZEL_TEST_OPTS: --config=ci --test_timeout=300,900,1800,7200
-
install_test:
runs-on: ${{ matrix.os }}
@@ -68,7 +46,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python: [3.8]
+ python: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
@@ -78,30 +56,15 @@ jobs:
with:
python-version: ${{ matrix.python }}
- - name: Install dependencies (linux)
- run: |
- curl -L "https://github.com/bazelbuild/bazelisk/releases/download/v1.6.1/bazelisk-linux-amd64" > bazel
- chmod +x bazel
- sudo mv bazel /usr/local/bin/bazel
- sudo apt install clang-9 patchelf
- python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
- if: matrix.os == 'ubuntu-latest'
-
- - name: Install dependencies (macos)
- run: |
- brew install bazelisk zlib
- python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
- env:
- LDFLAGS: -L/usr/local/opt/zlib/lib
- CPPFLAGS: -I/usr/local/opt/zlib/include
- PKG_CONFIG_PATH: /usr/local/opt/zlib/lib/pkgconfig
- if: matrix.os == 'macos-latest'
+ - name: Install build dependencies
+ uses: ./.github/actions/install-build-dependencies
- name: Install
run: make install
env:
CC: clang
CXX: clang++
+ BAZEL_OPTS: --batch
BAZEL_BUILD_OPTS: --config=ci
- name: Test
@@ -109,22 +72,6 @@ jobs:
env:
CC: clang
CXX: clang++
- BAZEL_BUILD_OPTS: --config=ci
- if: matrix.os == 'macos-latest'
-
- - name: Test with coverage
- run: make install-test-cov
- env:
- CC: clang
- CXX: clang++
- BAZEL_BUILD_OPTS: --config=ci
- if: matrix.os == 'ubuntu-latest'
-
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v1
- with:
- files: ./coverage.xml
- if: matrix.os == 'ubuntu-latest'
- name: Uninstall
run: make purge
diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml
new file mode 100644
index 000000000..1027a9cdc
--- /dev/null
+++ b/.github/workflows/coverage.yaml
@@ -0,0 +1,46 @@
+---
+name: Test Coverage
+
+on:
+ push:
+ branches:
+ - development
+ - stable
+ pull_request:
+
+jobs:
+ pytest-cov:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Setup python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Install build dependencies
+ uses: ./.github/actions/install-build-dependencies
+
+ - name: Install
+ run: make install
+ env:
+ CC: clang
+ CXX: clang++
+ BAZEL_OPTS: --batch
+ BAZEL_BUILD_OPTS: --config=ci
+
+ - name: Test
+ # Note the `|| true` as we don't care about failing tests in this
+ # job.
+ run: make install-test-cov || true
+ env:
+ CC: clang
+ CXX: clang++
+
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v1
+ with:
+ files: ./coverage.xml
+ if: ${{ always() }}
diff --git a/.github/workflows/fuzz.yaml b/.github/workflows/fuzz.yaml
index c5e8f238a..1ae439a14 100644
--- a/.github/workflows/fuzz.yaml
+++ b/.github/workflows/fuzz.yaml
@@ -13,7 +13,7 @@ on:
- cron: 0 9 * * 1-5 # every weekday at 9am
jobs:
- test:
+ fuzz:
runs-on: ${{ matrix.os }}
strategy:
@@ -30,27 +30,16 @@ jobs:
with:
python-version: ${{ matrix.python }}
- - name: Install build dependencies (linux)
- run: |
- curl -L "https://github.com/bazelbuild/bazelisk/releases/download/v1.6.1/bazelisk-linux-amd64" > bazel.tmp
- sudo mv bazel.tmp /usr/local/bin/bazel
- chmod +x /usr/local/bin/bazel
- sudo apt install clang-9 patchelf
- python -m pip install -r compiler_gym/requirements.txt -r tests/requirements.txt
- if: matrix.os == 'ubuntu-latest'
-
- - name: Install build dependencies (macOS)
- run: |
- brew install bazelisk
- python -m pip install -r compiler_gym/requirements.txt -r tests/requirements.txt
- if: matrix.os == 'macos-latest'
+ - name: Install build dependencies
+ uses: ./.github/actions/install-build-dependencies
- name: Install
run: make install
env:
CC: clang
CXX: clang++
+ BAZEL_OPTS: --batch
BAZEL_TEST_OPTS: --config=ci
- name: Test
- run: FUZZ_TIME=600 make fuzz
+ run: FUZZ_TIME=600 make install-fuzz
diff --git a/.github/workflows/pre_commit.yaml b/.github/workflows/pre_commit.yaml
index 227f61a00..b77f50302 100644
--- a/.github/workflows/pre_commit.yaml
+++ b/.github/workflows/pre_commit.yaml
@@ -22,27 +22,33 @@ jobs:
sudo rm -f /usr/bin/clang-format
sudo ln -s /usr/bin/clang-format-10 /usr/bin/clang-format
clang-format --version
+
- name: Install go
uses: actions/setup-go@v2
with:
go-version: ^1.13.1
+
- name: Install buildifier
run: |
go get github.com/bazelbuild/buildtools/buildifier
buildifier --version
+
- name: Install prototool
run: |
GO111MODULE=on go get github.com/uber/prototool/cmd/prototool@dev
prototool version
+
- name: Install hadolint
run: |
wget -O hadolint https://github.com/hadolint/hadolint/releases/download/v1.19.0/hadolint-Linux-x86_64
chmod +x hadolint
sudo mv hadolint /usr/local/bin
- - name: Install Python 3.9
+
+ - name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
+
- name: Install Python dependencies
run: |
sudo apt-get install python3-setuptools
@@ -51,6 +57,7 @@ jobs:
python3 -m isort --version
python3 -m black --version
python3 -m pre_commit --version
+
- name: Run pre-commit checks
# TODO(github.com/facebookresearch/CompilerGym/issues/1): Disable
# isort due to inconsistent results when run locally versus in
diff --git a/.github/workflows/release_test.yaml b/.github/workflows/release_test.yaml
index f5066205f..f225689ab 100644
--- a/.github/workflows/release_test.yaml
+++ b/.github/workflows/release_test.yaml
@@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python: [3.9]
+ python: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
@@ -26,8 +26,16 @@ jobs:
with:
python-version: ${{ matrix.python }}
+ - name: Install runtime dependencies (macos)
+ run: brew install zlib
+ if: matrix.os == 'macos-latest'
+
- name: Install python wheel
run: python -m pip install compiler_gym
+ env:
+ LDFLAGS: -L/usr/local/opt/zlib/lib
+ CPPFLAGS: -I/usr/local/opt/zlib/include
+ PKG_CONFIG_PATH: /usr/local/opt/zlib/lib/pkgconfig
- name: Install python test dependencies
run: python -m pip install -r tests/requirements.txt
diff --git a/.github/workflows/sanitizers.yaml b/.github/workflows/sanitizers.yaml
new file mode 100644
index 000000000..2774b59e6
--- /dev/null
+++ b/.github/workflows/sanitizers.yaml
@@ -0,0 +1,50 @@
+---
+name: LLVM Sanitizers
+
+on:
+ push:
+ branches:
+ - development
+ - stable
+ pull_request:
+
+jobs:
+ llvm-service-asan:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Setup python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Install build dependencies
+ uses: ./.github/actions/install-build-dependencies
+
+ - name: Build pip package
+ run: make bazel-build
+ env:
+ CC: clang
+ CXX: clang++
+ BAZEL_OPTS: --batch
+ BAZEL_BUILD_OPTS: --config=ci
+
+ - name: Build address sanitized LLVM compiler service
+ run: make bazel-build BAZEL_BUILD_OPTS=--config=asan BUILD_TARGET=//compiler_gym/envs/llvm/service:compiler_gym-llvm-service
+ env:
+ CC: clang
+ CXX: clang++
+ BAZEL_OPTS: --batch
+ BAZEL_BUILD_OPTS: --config=ci
+
+ - name: Install pip package
+ run: make pip-install
+
+ - name: Test
+ run: make install-test TEST_TARGET=tests/llvm
+ env:
+ ASAN_OPTIONS: detect_leaks=1
+ CC: clang
+ CXX: clang++
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 6e6e3f75b..249ab6d03 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -25,9 +25,9 @@
#
# $ pre-commit run --all-files
#
-# Install the pre-commit hook using:
+# The pre-commit git hook is installed using:
#
-# $ pre-commit install
+# $ make init
#
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
diff --git a/BUILD.bazel b/BUILD.bazel
index 4558ba5cc..4ba330cd7 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -23,6 +23,7 @@ py_library(
"//compiler_gym/datasets",
"//compiler_gym/envs",
"//compiler_gym/service",
+ "//compiler_gym/service/runtime",
"//compiler_gym/spaces",
"//compiler_gym/views",
"//examples/sensitivity_analysis:action_sensitivity_analysis",
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 73776f516..ca9117607 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,39 @@
+## Release 0.1.9 (2021-06-03)
+
+This release of CompilerGym focuses on **backend extensibility** and adds a
+bunch of new features to make it easier to add support for new compilers:
+
+- Adds a new `CompilationSession` class encapsulates a single incremental
+ compilation session
+ ([#261](https://github.com/facebookresearch/CompilerGym/pull/261)).
+- Adds a common runtime for CompilerGym services that takes a
+ `CompilationSession` subclass and handles all the RPC wrangling for you
+ ([#270](https://github.com/facebookresearch/CompilerGym/pull/270)).
+- Ports the LLVM service and example services to the new runtime
+ ([#277](https://github.com/facebookresearch/CompilerGym/pull/277)). This
+ provides a net performance win with fewer lines of code.
+
+Other highlights of this release include:
+
+- [Core API] Adds a new `compiler_gym.wrappers` module that makes it easy to
+ apply modular transformations to CompilerGym environments without modifying
+ the environment code
+ ([#272](https://github.com/facebookresearch/CompilerGym/pull/272)).
+- [Core API] Adds a new `Datasets.random_benchmark()` method for selecting a
+ uniform random benchmark from one or more datasets
+ ([#247](https://github.com/facebookresearch/CompilerGym/pull/247)).
+- [Core API] Adds a new `compiler_gym.make()` function, equivalent to
+ `gym.make()`
+ ([#257](https://github.com/facebookresearch/CompilerGym/pull/257)).
+- [LLVM] Adds a new `IrSha1` observation space that uses a fast, service-side
+ C++ implementation to compute a checksum of the environment state
+ ([#267](https://github.com/facebookresearch/CompilerGym/pull/267)).
+- [LLVM] Adds 12 new C programs from the CHStone benchmark suite
+ ([#284](https://github.com/facebookresearch/CompilerGym/pull/284)).
+- [LLVM] Adds the `anghabench-v1` dataset and deprecated `anghabench-v0`
+ ([#242](https://github.com/facebookresearch/CompilerGym/pull/242)).
+- Numerous bug fixes and improvements.
+
## Release 0.1.8 (2021-04-30)
This release introduces some significant changes to the way that benchmarks are
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 0d327c510..8ead4ef3e 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -41,7 +41,7 @@ We actively welcome your pull requests.
1. Fork [the repo](https://github.com/facebookresearch/CompilerGym) and create
your branch from `development`.
2. Follow the instructions for
- [building from source](https://github.com/facebookresearch/CompilerGym#building-from-source)
+ [building from source](https://github.com/facebookresearch/CompilerGym/blob/development/INSTALL.md)
to set up your environment.
3. If you've added code that should be tested, add tests.
4. If you've changed APIs, update the [documentation](/docs/source).
diff --git a/INSTALL.md b/INSTALL.md
new file mode 100644
index 000000000..875bff712
--- /dev/null
+++ b/INSTALL.md
@@ -0,0 +1,84 @@
+# Installation
+
+Install the latest CompilerGym release using:
+
+ pip install -U compiler_gym
+
+CompilerGym requires Python >= 3.6. The binary works on macOS and Linux (on
+Ubuntu 18.04, Fedora 28, Debian 10 or newer equivalents).
+
+## Building from Source
+
+If you prefer, you may build from source. This requires a modern C++ toolchain
+and bazel.
+
+### macOS
+
+On macOS the required dependencies can be installed using
+[homebrew](https://docs.brew.sh/Installation):
+
+```sh
+brew install bazelisk zlib
+export LDFLAGS="-L/usr/local/opt/zlib/lib"
+export CPPFLAGS="-I/usr/local/opt/zlib/include"
+export PKG_CONFIG_PATH="/usr/local/opt/zlib/lib/pkgconfig"
+```
+
+Now proceed to [All platforms](#all-platforms) below.
+
+### Linux
+
+On debian-based linux systems, install the required toolchain using:
+
+```sh
+sudo apt install clang-9 libtinfo5 libjpeg-dev zlib1g-dev
+wget https://github.com/bazelbuild/bazelisk/releases/download/v1.7.5/bazelisk-linux-amd64 -O bazel
+chmod +x bazel && mkdir -p ~/.local/bin && mv -v bazel ~/.local/bin
+export PATH="$HOME/.local/bin:$PATH"
+export CC=clang
+export CXX=clang++
+```
+
+### All platforms
+
+We recommend using
+[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/)
+to manage the remaining build dependencies. First create a conda environment
+with the required dependencies:
+
+ conda create -n compiler_gym python=3.9 cmake pandoc patchelf
+ conda activate compiler_gym
+ conda install -c conda-forge doxygen
+
+Then clone the CompilerGym source code using:
+
+ git clone https://github.com/facebookresearch/CompilerGym.git
+ cd CompilerGym
+
+There are two primary git branches: `stable` tracks the latest release;
+`development` is for bleeding edge features that may not yet be mature. Checkout
+your preferred branch and install the python development dependencies using:
+
+ git checkout stable
+ make init
+
+The `make init` target only needs to be run once on initial setup, or when
+pulling remote changes to the CompilerGym repository.
+
+Run the test suite to confirm that everything is working:
+
+ make test
+
+To build and install the `compiler_gym` python package, run:
+
+ make install
+
+**NOTE:** To use the `compiler_gym` package that is installed by `make install`
+you must leave the root directory of this repository. Attempting to import
+`compiler_gym` while in the root of this repository will cause import errors.
+
+When you are finished, you can deactivate and delete the conda
+environment using:
+
+ conda deactivate
+ conda env remove -n compiler_gym
diff --git a/Makefile b/Makefile
index 9d852972d..15e3ec046 100644
--- a/Makefile
+++ b/Makefile
@@ -112,6 +112,7 @@ export HELP
CC ?= clang
CXX ?= clang++
BAZEL ?= bazel
+DOXYGEN ?= doxygen
IBAZEL ?= ibazel
PANDOC ?= pandoc
PYTHON ?= python3
@@ -141,6 +142,7 @@ help:
init:
$(PYTHON) -m pip install -r requirements.txt
+ pre-commit install
############
@@ -150,8 +152,10 @@ init:
# Files and directories generated by python disttools.
DISTTOOLS_OUTS := dist build compiler_gym.egg-info
+BUILD_TARGET ?= //:package
+
bazel-build:
- $(BAZEL) $(BAZEL_OPTS) build $(BAZEL_BUILD_OPTS) //:package
+ $(BAZEL) $(BAZEL_OPTS) build $(BAZEL_BUILD_OPTS) $(BUILD_TARGET)
bdist_wheel: bazel-build
$(PYTHON) setup.py bdist_wheel
@@ -162,20 +166,21 @@ bdist_wheel-linux-rename:
bdist_wheel-linux:
rm -rf build
docker build -t chriscummins/compiler_gym-linux-build packaging
- docker run -v $(ROOT):/CompilerGym --rm chriscummins/compiler_gym-linux-build:latest /bin/sh -c 'cd /CompilerGym && pip3 install gym numpy requests networkx && make bdist_wheel'
+ docker run -v $(ROOT):/CompilerGym --workdir /CompilerGym --rm --shm-size=8g chriscummins/compiler_gym-linux-build:latest /bin/sh -c './packaging/container_init.sh && make bdist_wheel'
mv dist/compiler_gym-$(VERSION)-py3-none-linux_x86_64.whl dist/compiler_gym-$(VERSION)-py3-none-manylinux2014_x86_64.whl
rm -rf build
bdist_wheel-linux-shell:
- docker run -v $(ROOT):/CompilerGym --rm -it --entrypoint "/bin/bash" chriscummins/compiler_gym-linux-build:latest
+ docker run -v $(ROOT):/CompilerGym --workdir /CompilerGym --rm --shm-size=8g -it --entrypoint "/bin/bash" chriscummins/compiler_gym-linux-build:latest
bdist_wheel-linux-test:
- docker run -v $(ROOT):/CompilerGym --rm chriscummins/compiler_gym-linux-build:latest /bin/sh -c 'cd /CompilerGym && pip3 install -U pip && pip3 install dist/compiler_gym-$(VERSION)-py3-none-manylinux2014_x86_64.whl && pip install -r tests/requirements.txt && make install-test'
+ docker run -v $(ROOT):/CompilerGym --workdir /CompilerGym --rm --shm-size=8g chriscummins/compiler_gym-linux-build:latest /bin/sh -c 'cd /CompilerGym && pip3 install -U pip && pip3 install dist/compiler_gym-$(VERSION)-py3-none-manylinux2014_x86_64.whl && pip install -r tests/requirements.txt && make install-test'
all: docs bdist_wheel bdist_wheel-linux
.PHONY: bazel-build bdist_wheel bdist_wheel-linux bdist_wheel-linux-shell bdist_wheel-linux-test
+
#################
# Documentation #
#################
@@ -189,25 +194,29 @@ docs/source/contributing.rst: CONTRIBUTING.md
echo "..\n Generated from $<. Do not edit!\n" > $@
$(PANDOC) --from=markdown --to=rst $< >> $@
-docs/source/installation.rst: README.md
- echo "..\n Generated from $<. Do not edit!\n" > $@
- sed -n '/^## Installation/,$$p' $< | sed -n '/^### Building/q;p' | $(PANDOC) --from=markdown --to=rst >> $@
-
GENERATED_DOCS := \
docs/source/changelog.rst \
docs/source/contributing.rst \
- docs/source/installation.rst \
$(NULL)
gendocs: $(GENERATED_DOCS)
-docs: gendocs bazel-build
+doxygen:
+ cd docs && $(DOXYGEN) Doxyfile
+
+doxygen-rst:
+ cd docs && $(PYTHON) generate_cc_rst.py
+
+docs: gendocs bazel-build doxygen
PYTHONPATH=$(ROOT)/bazel-bin/package.runfiles/CompilerGym $(MAKE) -C docs html
-livedocs: gendocs
+livedocs: gendocs doxygen
PYTHONPATH=$(ROOT)/bazel-bin/package.runfiles/CompilerGym $(MAKE) -C docs livehtml
+.PHONY: doxygen doxygen-rst
+
+
###########
# Testing #
###########
@@ -215,40 +224,54 @@ livedocs: gendocs
COMPILER_GYM_SITE_DATA ?= "/tmp/compiler_gym/tests/site_data"
COMPILER_GYM_CACHE ?= "/tmp/compiler_gym/tests/cache"
+# A directory that is used as the working directory for running pytest tests
+# by symlinking the tests directory into it.
+INSTALL_TEST_ROOT ?= "/tmp/compiler_gym/install_tests"
+
+# The target to use. If not provided, all tests will be run. For `make test` and
+# related, this is a bazel target pattern, with default value '//...'. For `make
+# install-test` and related, this is a relative file path of the directory or
+# file to test, with default value 'tests'.
+TEST_TARGET ?=
+
+# Extra command line arguments for pytest.
+PYTEST_ARGS ?=
+
test:
- $(BAZEL) $(BAZEL_OPTS) test $(BAZEL_TEST_OPTS) //...
+ $(BAZEL) $(BAZEL_OPTS) test $(BAZEL_TEST_OPTS) $(if $(TEST_TARGET),$(TEST_TARGET),//...)
itest:
- $(IBAZEL) $(BAZEL_OPTS) test $(BAZEL_TEST_OPTS) //...
-
+ $(IBAZEL) $(BAZEL_OPTS) test $(BAZEL_TEST_OPTS) $(if $(TEST_TARGET),$(TEST_TARGET),//...)
# Since we can't run compiler_gym from the project root we need to jump through
# some hoops to run pytest "out of tree" by creating an empty directory and
# symlinking the test directory into it so that pytest can be invoked.
-define run_pytest_suite
- mkdir -p /tmp/compiler_gym/wheel_tests
- rm -f /tmp/compiler_gym/wheel_tests/tests /tmp/compiler_gym/wheel_tests/tox.ini
- ln -s $(ROOT)/tests /tmp/compiler_gym/wheel_tests
- ln -s $(ROOT)/tox.ini /tmp/compiler_gym/wheel_tests
- cd /tmp/compiler_gym/wheel_tests && pytest tests $(1) --benchmark-disable -n auto -k "not fuzz"
+install-test-setup:
+ mkdir -p "$(INSTALL_TEST_ROOT)"
+ rm -f "$(INSTALL_TEST_ROOT)/tests" "$(INSTALL_TEST_ROOT)/tox.ini"
+ ln -s "$(ROOT)/tests" "$(INSTALL_TEST_ROOT)"
+ ln -s "$(ROOT)/tox.ini" "$(INSTALL_TEST_ROOT)"
+
+define pytest
+ cd "$(INSTALL_TEST_ROOT)" && pytest $(if $(TEST_TARGET),$(TEST_TARGET),tests) $(1) $(PYTEST_ARGS)
endef
-install-test:
- $(call run_pytest_suite,)
+install-test: install-test-setup
+ $(call pytest,--benchmark-disable -n auto -k "not fuzz" --durations=5)
-install-test-cov:
- $(call run_pytest_suite,--cov=compiler_gym --cov-report=xml)
- @mv /tmp/compiler_gym/wheel_tests/coverage.xml .
+# Note we export $CI=1 so that the tests always run as if within the CI
+# environement. This is to ensure that the reported coverage matches that of
+# the value on: https://codecov.io/gh/facebookresearch/CompilerGym
+install-test-cov: install-test-setup
+ export CI=1; $(call pytest,--benchmark-disable -n auto -k "not fuzz" --durations=5 --cov=compiler_gym --cov-report=xml --cov-report=term)
+ @mv "$(INSTALL_TEST_ROOT)/coverage.xml" .
# The minimum number of seconds to run the fuzz tests in a loop for. Override
# this at the commandline, e.g. `FUZZ_SECONDS=1800 make fuzz`.
FUZZ_SECONDS ?= 300
-install-fuzz:
- mkdir -p /tmp/compiler_gym/wheel_fuzz_tests
- rm -f /tmp/compiler_gym/wheel_fuzz_tests/tests
- ln -s $(ROOT)/tests /tmp/compiler_gym/wheel_fuzz_tests
- cd /tmp/compiler_gym/wheel_fuzz_tests && pytest tests -p no:sugar -x -vv -k fuzz --seconds=$(FUZZ_SECONDS)
+install-fuzz: install-test-setup
+ $(call pytest,-p no:sugar -x -vv -k fuzz --seconds=$(FUZZ_SECONDS))
post-install-test:
$(MAKE) -C examples/makefile_integration clean
@@ -261,10 +284,12 @@ post-install-test:
# Installation #
################
-install: bazel-build
+pip-install:
$(PYTHON) setup.py install
-.PHONY: install
+install: | bazel-build pip-install
+
+.PHONY: pip-install install
##############
diff --git a/README.md b/README.md
index f6ee38e96..25b184611 100644
--- a/README.md
+++ b/README.md
@@ -1,93 +1,67 @@

----
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-CompilerGym is a toolkit for exposing compiler optimization problems
-for reinforcement learning. It allows machine learning researchers to
-experiment with program optimization techniques without requiring any
-experience in compilers, and provides a framework for compiler
-developers to expose new optimization problems for AI.
-
-
-**Table of Contents**
-
-- [Features](#features)
-- [Getting Started](#getting-started)
- - [Installation](#installation)
- - [Building from Source](#building-from-source)
- - [Trying it out](#trying-it-out)
-- [Leaderboards](#leaderboards)
- - [LLVM Instruction Count](#llvm-instruction-count)
-- [Contributing](#contributing)
-- [Citation](#citation)
-
-
-# Features
-
-With CompilerGym, building ML models for compiler research problems is as easy
-as building ML models to play video games. Here are some highlights of key
-features:
-
-* **API:** uses the popular [Gym](https://gym.openai.com/) interface from OpenAI
- — use Python to write your agent.
-
-* **Datasets:** wraps real world programs (C++ programs, TensorFlow programs,
- programs from Github, etc.) and a mainstream compiler
- ([LLVM](https://llvm.org/)), providing millions of programs for training.
-
-* **Tasks and Actions:** interfaces the [LLVM](https://llvm.org/) compiler for
- one compiler research problem: phase ordering (more to come). It has a large
- discrete action space.
-
-* **Representations:** provides raw representations of programs, as well as
- multiple kinds of pre-computed features: you can focus on end-to-end deep
- learning or features + boosted trees, all the way up to graph models.
-
-* **Rewards:** provides appropriate reward functions and loss functions out of
- the box.
-
-* **Testing:** provides a validation process for correctness of results.
-
-* **Baselines:** provides some baselines and reports their performance.
-
-* **Competition:** provides [leaderboards](#leaderboards) for you to submit your
- results.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Reinforcement learning environments for compiler optimization tasks.
+
+
+
+ Check
+ the website
+ for more information.
+
+
+
+
+## Introduction
+
+CompilerGym is a library of easy to use and performant reinforcement learning
+environments for compiler tasks. It allows ML researchers to interact with
+important compiler optimization problems in a language and vocabulary with which
+they are comfortable, and provides a toolkit for systems developers to expose
+new compiler tasks for ML research. We aim to act as a catalyst for making
+compilers faster using ML. Key features include:
+
+* **Ease of use:** built on the the popular [Gym](https://gym.openai.com/)
+ interface - use Python to write your agent. With CompilerGym, building ML
+ models for compiler research problems is as easy as building ML models to play
+ video games.
+
+* **Batteries included:** includes everything required to get started. Wraps
+ real world programs and compilers to provide millions of instances for
+ training. Provides multiple kinds of pre-computed program representations: you
+ can focus on end-to-end deep learning or features + boosted trees, all the way
+ up to graph models. Appropriate reward functions and loss functions for
+ optimization targets are provided out of the box.
+
+* **Reproducible:** provides validation for correctness of results, common
+ baselines, and [leaderboards](#leaderboards) for you to submit your results.
For a glimpse of what's to come, check out [our
roadmap](https://github.com/facebookresearch/CompilerGym/projects/1).
-# Getting Started
-
-Starting with CompilerGym is simple. If you not already familiar with the gym
-interface, refer to the
-[getting started guide](http://facebookresearch.github.io/CompilerGym/getting_started.html)
-for an overview of the key concepts.
-
## Installation
@@ -95,115 +69,45 @@ Install the latest CompilerGym release using:
pip install -U compiler_gym
-The binary works on macOS and Linux (on Ubuntu 18.04, Fedora 28, Debian 10 or
-newer equivalents).
-
-### Building from Source
-
-If you prefer, you may build from source. This requires a modern C++ toolchain
-and bazel.
-
-#### macOS
-
-On macOS the required dependencies can be installed using
-[homebrew](https://docs.brew.sh/Installation):
-
-```sh
-brew install bazelisk zlib
-export LDFLAGS="-L/usr/local/opt/zlib/lib"
-export CPPFLAGS="-I/usr/local/opt/zlib/include"
-export PKG_CONFIG_PATH="/usr/local/opt/zlib/lib/pkgconfig"
-```
-
-Now proceed to [All platforms](#all-platforms) below.
-
-#### Linux
-
-On debian-based linux systems, install the required toolchain using:
-
-```sh
-sudo apt install clang-9 libtinfo5 libjpeg-dev patchelf
-wget https://github.com/bazelbuild/bazelisk/releases/download/v1.7.5/bazelisk-linux-amd64 -O bazel
-chmod +x bazel && mkdir -p ~/.local/bin && mv -v bazel ~/.local/bin
-export PATH="$HOME/.local/bin:$PATH"
-export CC=clang
-export CXX=clang++
-```
-
-#### All platforms
+See [INSTALL.md](INSTALL.md) for further details.
-We recommend using
-[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/)
-to manage the remaining build dependencies. First create a conda environment
-with the required dependencies:
- conda create -n compiler_gym python=3.9 cmake pandoc
- conda activate compiler_gym
+## Usage
-Then clone the CompilerGym source code using:
-
- git clone https://github.com/facebookresearch/CompilerGym.git
- cd CompilerGym
-
-There are two primary git branches: `stable` tracks the latest release;
-`development` is for bleeding edge features that may not yet be mature. Checkout
-your preferred branch and install the python development dependencies using:
-
- git checkout stable
- make init
-
-The `make init` target only needs to be run once on initial setup, or when
-pulling remote changes to the CompilerGym repository.
-
-Run the test suite to confirm that everything is working:
-
- make test
-
-To build and install the `compiler_gym` python package, run:
-
- make install
-
-**NOTE:** To use the `compiler_gym` package that is installed by `make install`
-you must leave the root directory of this repository. Attempting to import
-`compiler_gym` while in the root of this repository will cause import errors.
-
-When you are finished, you can deactivate and delete the conda
-environment using:
-
- conda deactivate
- conda env remove -n compiler_gym
-
-
-## Trying it out
+Starting with CompilerGym is simple. If you not already familiar with the gym
+interface, refer to the [getting started
+guide](http://facebookresearch.github.io/CompilerGym/getting_started.html) for
+an overview of the key concepts.
In Python, import `compiler_gym` to use the environments:
```py
>>> import gym
->>> import compiler_gym # imports the CompilerGym environments
->>> env = gym.make("llvm-autophase-ic-v0") # starts a new environment
->>> env.benchmark = "benchmark://cbench-v1/qsort" # select a program to compile
->>> env.reset() # starts a new compilation session
->>> env.render() # prints the IR of the program
->>> env.step(env.action_space.sample()) # applies a random optimization, updates state/reward/actions
+>>> import compiler_gym # imports the CompilerGym environments
+>>> env = gym.make( # creates a new environment
+... "llvm-v0", # selects the compiler to use
+... benchmark="cbench-v1/qsort", # selects the program to compile
+... observation_space="Autophase", # selects the observation space
+... reward_space="IrInstructionCountOz", # selects the optimization target
+... )
+>>> env.reset() # starts a new compilation session
+>>> env.render() # prints the IR of the program
+>>> env.step(env.action_space.sample()) # applies a random optimization, updates state/reward/actions
```
See the [documentation website](http://facebookresearch.github.io/CompilerGym/)
-for tutorials, further details, and API reference. Our
-[roadmap](https://facebookresearch.github.io/CompilerGym/about.html#roadmap) of
-planned features is public, and the
-[changelog](https://github.com/facebookresearch/CompilerGym/blob/development/CHANGELOG.md)
-summarizes shipped features.
+for tutorials, further details, and API reference. See the [examples](/examples)
+directory for pytorch integration, agent implementations, etc.
-# Leaderboards
+## Leaderboards
These leaderboards track the performance of user-submitted algorithms for
CompilerGym tasks. To submit a result please see
[this document](https://github.com/facebookresearch/CompilerGym/blob/development/CONTRIBUTING.md#leaderboard-submissions).
-## LLVM Instruction Count
+### LLVM Instruction Count
LLVM is a popular open source compiler used widely in industry and research. The
`llvm-ic-v0` environment exposes LLVM's optimizing passes as a set of actions
@@ -228,13 +132,13 @@ environment on the 23 benchmarks in the `cbench-v1` dataset.
| Jiadong Guo | Tabular Q (N=2000, H=5) | [write-up](leaderboard/llvm_instcount/tabular_q/README.md), [results](leaderboard/llvm_instcount/tabular_q/results-H5-N2000.csv) | 2021-04 | 694.105 | 0.988× |
-# Contributing
+## Contributing
We welcome contributions to CompilerGym. If you are interested in contributing please see
[this document](https://github.com/facebookresearch/CompilerGym/blob/development/CONTRIBUTING.md).
-# Citation
+## Citation
If you use CompilerGym in any of your work, please cite:
diff --git a/VERSION b/VERSION
index 699c6c6d4..1a030947e 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.1.8
+0.1.9
diff --git a/WORKSPACE b/WORKSPACE
index 0e113c78a..fe6c01526 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -339,9 +339,9 @@ boost_deps()
http_archive(
name = "programl",
- sha256 = "5a8fc6f71f1971b265cfe0c4b224b45e87ef08e40b49e693a96022063345a9c8",
- strip_prefix = "ProGraML-0476300f2db3724c7b6ecd10970b4525fccc8628",
- urls = ["https://github.com/ChrisCummins/ProGraML/archive/0476300f2db3724c7b6ecd10970b4525fccc8628.tar.gz"],
+ sha256 = "c56360aade351eda1c138a594177fcb7cd2cda2a0a6c5c0d9aa62c7f856194bd",
+ strip_prefix = "ProGraML-4f0981d7a0d27aecef3d6e918c886642b231562d",
+ urls = ["https://github.com/ChrisCummins/ProGraML/archive/4f0981d7a0d27aecef3d6e918c886642b231562d.tar.gz"],
)
load("@programl//tools:bzl/deps.bzl", "programl_deps")
diff --git a/compiler_gym/BUILD b/compiler_gym/BUILD
index 8b3da5205..059223b29 100644
--- a/compiler_gym/BUILD
+++ b/compiler_gym/BUILD
@@ -17,6 +17,7 @@ py_library(
"//compiler_gym/envs",
"//compiler_gym/leaderboard",
"//compiler_gym/util",
+ "//compiler_gym/wrappers",
],
)
diff --git a/compiler_gym/__init__.py b/compiler_gym/__init__.py
index 0110aa9fd..087f344e3 100644
--- a/compiler_gym/__init__.py
+++ b/compiler_gym/__init__.py
@@ -34,7 +34,7 @@
CompilerEnvStateReader,
CompilerEnvStateWriter,
)
-from compiler_gym.envs import COMPILER_GYM_ENVS, CompilerEnv, observation_t, step_t
+from compiler_gym.envs import COMPILER_GYM_ENVS, CompilerEnv
from compiler_gym.random_search import random_search
from compiler_gym.util.debug_util import (
get_debug_level,
@@ -42,6 +42,7 @@
set_debug_level,
)
from compiler_gym.util.download import download
+from compiler_gym.util.registration import make
from compiler_gym.util.runfiles_path import (
cache_path,
site_data_path,
@@ -56,6 +57,7 @@
"__version__",
"cache_path",
"COMPILER_GYM_ENVS",
+ "make",
"CompilerEnv",
"CompilerEnvState",
"CompilerEnvStateWriter",
@@ -63,11 +65,9 @@
"download",
"get_debug_level",
"get_logging_level",
- "observation_t",
"random_search",
"set_debug_level",
"site_data_path",
- "step_t",
"transient_cache_path",
"validate_states",
"ValidationError",
diff --git a/compiler_gym/bin/manual_env.py b/compiler_gym/bin/manual_env.py
index 45df0b0a1..babe31ebc 100644
--- a/compiler_gym/bin/manual_env.py
+++ b/compiler_gym/bin/manual_env.py
@@ -367,7 +367,7 @@ def do_set_benchmark(self, arg):
Use '-' for a random benchmark.
"""
if arg == "-":
- arg = self.env.datasets.benchmark().uri
+ arg = self.env.datasets.random_benchmark().uri
print(f"set_benchmark {arg}")
try:
diff --git a/compiler_gym/datasets/benchmark.py b/compiler_gym/datasets/benchmark.py
index 7a50c21c2..e3b93ab00 100644
--- a/compiler_gym/datasets/benchmark.py
+++ b/compiler_gym/datasets/benchmark.py
@@ -38,7 +38,7 @@ def __repr__(self) -> str:
return str(self.filename)
-class Benchmark(object):
+class Benchmark:
"""A benchmark represents a particular program that is being compiled.
A benchmark is a program that can be used by a :class:`CompilerEnv
@@ -98,6 +98,9 @@ def __init__(
def __repr__(self) -> str:
return str(self.uri)
+ def __hash__(self) -> int:
+ return hash(self.uri)
+
@property
def uri(self) -> str:
"""The URI of the benchmark.
diff --git a/compiler_gym/datasets/dataset.py b/compiler_gym/datasets/dataset.py
index 2045a42b5..0e034a0c9 100644
--- a/compiler_gym/datasets/dataset.py
+++ b/compiler_gym/datasets/dataset.py
@@ -9,6 +9,7 @@
from pathlib import Path
from typing import Dict, Iterable, Optional, Union
+import numpy as np
from deprecated.sphinx import deprecated as mark_deprecated
from compiler_gym.datasets.benchmark import Benchmark
@@ -16,7 +17,7 @@
from compiler_gym.util.debug_util import get_logging_level
-class Dataset(object):
+class Dataset:
"""A dataset is a collection of benchmarks.
The Dataset class has methods for installing and managing groups of
@@ -358,6 +359,29 @@ def benchmark(self, uri: str) -> Benchmark:
"""
raise NotImplementedError("abstract class")
+ def random_benchmark(
+ self, random_state: Optional[np.random.Generator] = None
+ ) -> Benchmark:
+ """Select a benchmark randomly.
+
+ :param random_state: A random number generator. If not provided, a
+ default :code:`np.random.default_rng()` is used.
+
+ :return: A :class:`Benchmark `
+ instance.
+ """
+ random_state = random_state or np.random.default_rng()
+ return self._random_benchmark(random_state)
+
+ def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
+ """Private implementation of the random benchmark getter.
+
+ Subclasses must implement this method so that it selects a benchmark
+ from the available benchmarks with uniform probability, using only
+ :code:`random_state` as a source of randomness.
+ """
+ raise NotImplementedError("abstract class")
+
def __getitem__(self, uri: str) -> Benchmark:
"""Select a benchmark by URI.
diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py
index adc97818e..fd6c285e0 100644
--- a/compiler_gym/datasets/datasets.py
+++ b/compiler_gym/datasets/datasets.py
@@ -3,7 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import deque
-from typing import Dict, Iterable, Set, TypeVar
+from typing import Dict, Iterable, Optional, Set, TypeVar
+
+import numpy as np
from compiler_gym.datasets.benchmark import Benchmark
from compiler_gym.datasets.dataset import Dataset
@@ -31,7 +33,7 @@ def round_robin_iterables(iters: Iterable[Iterable[T]]) -> Iterable[T]:
yield from iters.popleft()
-class Datasets(object):
+class Datasets:
"""A collection of datasets.
This class provides a dictionary-like interface for indexing and iterating
@@ -251,6 +253,42 @@ def benchmark(self, uri: str) -> Benchmark:
return dataset.benchmark(uri)
+ def random_benchmark(
+ self, random_state: Optional[np.random.Generator] = None
+ ) -> Benchmark:
+ """Select a benchmark randomly.
+
+ First, a dataset is selected uniformly randomly using
+ :code:`random_state.choice(list(datasets))`. The
+ :meth:`random_benchmark()
+ ` method of that dataset
+ is then called to select a benchmark.
+
+ Note that the distribution of benchmarks selected by this method is not
+ biased by the size of each dataset, since datasets are selected
+ uniformly. This means that datasets with a small number of benchmarks
+ will be overrepresented compared to datasets with many benchmarks. To
+ correct for this bias, use the number of benchmarks in each dataset as
+ a weight for the random selection:
+
+ >>> rng = np.random.default_rng()
+ >>> finite_datasets = [d for d in env.datasets if len(d) != math.inf]
+ >>> dataset = rng.choice(
+ finite_datasets,
+ p=[len(d) for d in finite_datasets]
+ )
+ >>> dataset.random_benchmark(random_state=rng)
+
+ :param random_state: A random number generator. If not provided, a
+ default :code:`np.random.default_rng()` is used.
+
+ :return: A :class:`Benchmark `
+ instance.
+ """
+ random_state = random_state or np.random.default_rng()
+ dataset = random_state.choice(list(self._visible_datasets))
+ return self[dataset].random_benchmark(random_state=random_state)
+
@property
def size(self) -> int:
return len(self._visible_datasets)
diff --git a/compiler_gym/datasets/files_dataset.py b/compiler_gym/datasets/files_dataset.py
index 4c522912c..5b4e31883 100644
--- a/compiler_gym/datasets/files_dataset.py
+++ b/compiler_gym/datasets/files_dataset.py
@@ -6,6 +6,8 @@
from pathlib import Path
from typing import Iterable, List
+import numpy as np
+
from compiler_gym.datasets.dataset import Benchmark, Dataset
from compiler_gym.util.decorators import memoized_property
@@ -117,3 +119,6 @@ def benchmark(self, uri: str) -> Benchmark:
if not abspath.is_file():
raise LookupError(f"Benchmark not found: {uri} (file not found: {abspath})")
return self.benchmark_class.from_file(uri, abspath)
+
+ def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
+ return self.benchmark(random_state.choice(list(self.benchmark_uris())))
diff --git a/compiler_gym/datasets/tar_dataset.py b/compiler_gym/datasets/tar_dataset.py
index 55b15c73b..632ce935e 100644
--- a/compiler_gym/datasets/tar_dataset.py
+++ b/compiler_gym/datasets/tar_dataset.py
@@ -60,17 +60,13 @@ def __init__(
self.tar_compression = tar_compression
self.strip_prefix = strip_prefix
- self._installed = False
self._tar_extracted_marker = self.site_data_path / ".extracted"
self._tar_lock = Lock()
self._tar_lockfile = self.site_data_path / ".install_lock"
@property
def installed(self) -> bool:
- # Fast path for repeated checks to 'installed' without a disk op.
- if not self._installed:
- self._installed = self._tar_extracted_marker.is_file()
- return self._installed
+ return self._tar_extracted_marker.is_file()
def install(self) -> None:
super().install()
diff --git a/compiler_gym/envs/__init__.py b/compiler_gym/envs/__init__.py
index 09eeb255d..3b683d861 100644
--- a/compiler_gym/envs/__init__.py
+++ b/compiler_gym/envs/__init__.py
@@ -2,15 +2,12 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from compiler_gym.envs.compiler_env import CompilerEnv, info_t, observation_t, step_t
+from compiler_gym.envs.compiler_env import CompilerEnv
from compiler_gym.envs.llvm.llvm_env import LlvmEnv
from compiler_gym.util.registration import COMPILER_GYM_ENVS
__all__ = [
"CompilerEnv",
"LlvmEnv",
- "observation_t",
- "info_t",
- "step_t",
"COMPILER_GYM_ENVS",
]
diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py
index 9db13b4f3..d819512f4 100644
--- a/compiler_gym/envs/compiler_env.py
+++ b/compiler_gym/envs/compiler_env.py
@@ -11,7 +11,7 @@
from math import isclose
from pathlib import Path
from time import time
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import gym
import numpy as np
@@ -27,9 +27,9 @@
ServiceOSError,
ServiceTransportError,
SessionNotFound,
- observation_t,
)
from compiler_gym.service.proto import (
+ Action,
AddBenchmarkRequest,
EndSessionReply,
EndSessionRequest,
@@ -43,15 +43,17 @@
)
from compiler_gym.spaces import DefaultRewardFromObservation, NamedDiscrete, Reward
from compiler_gym.util.debug_util import get_logging_level
+from compiler_gym.util.gym_type_hints import (
+ ActionType,
+ ObservationType,
+ RewardType,
+ StepType,
+)
from compiler_gym.util.timer import Timer
from compiler_gym.validation_error import ValidationError
from compiler_gym.validation_result import ValidationResult
from compiler_gym.views import ObservationSpaceSpec, ObservationView, RewardView
-# Type hints.
-info_t = Dict[str, Any]
-step_t = Tuple[Optional[observation_t], Optional[float], bool, info_t]
-
def _wrapped_step(
service: CompilerGymServiceConnection, request: StepRequest
@@ -78,11 +80,11 @@ class CompilerEnv(gym.Env):
:doc:`/compiler_gym/service` for more details on connecting to services):
>>> env = CompilerEnv(
- service="localhost:8080",
- observation_space="features",
- reward_space="runtime",
- rewards=[env_reward_spaces],
- )
+ ... service="localhost:8080",
+ ... observation_space="features",
+ ... reward_space="runtime",
+ ... rewards=[env_reward_spaces],
+ ... )
Once constructed, an environment can be used in exactly the same way as a
regular :code:`gym.Env`, e.g.
@@ -265,7 +267,7 @@ def __init__(
for space in self.service.action_spaces
]
self.observation = self._observation_view_type(
- get_observation=lambda req: _wrapped_step(self.service, req),
+ raw_step=self.raw_step,
spaces=self.service.observation_spaces,
)
self.reward = self._reward_view_type(rewards, self.observation)
@@ -522,39 +524,57 @@ def fork(self) -> "CompilerEnv":
:return: A new environment instance.
"""
if not self.in_episode:
- if self.actions and not self.in_episode:
+ actions = self.actions.copy()
+ self.reset()
+ if actions:
self.logger.warning(
"Parent service of fork() has died, replaying state"
)
- self.apply(self.state)
- else:
- self.reset()
+ _, _, done, _ = self.step(actions)
+ assert not done, "Failed to replay action sequence"
request = ForkSessionRequest(session_id=self._session_id)
- reply: ForkSessionReply = self.service(self.service.stub.ForkSession, request)
-
- # Create a new environment that shares the connection.
- new_env = type(self)(
- service=self._service_endpoint,
- action_space=self.action_space,
- connection_settings=self._connection_settings,
- service_connection=self.service,
- )
-
- # Set the session ID.
- new_env._session_id = reply.session_id # pylint: disable=protected-access
- new_env.observation.session_id = reply.session_id
+ try:
+ reply: ForkSessionReply = self.service(
+ self.service.stub.ForkSession, request
+ )
- # Now that we have initialized the environment with the current state,
- # set the benchmark so that calls to new_env.reset() will correctly
- # revert the environment to the initial benchmark state.
- #
- # pylint: disable=protected-access
- new_env._next_benchmark = self._benchmark_in_use
+ # Create a new environment that shares the connection.
+ new_env = type(self)(
+ service=self._service_endpoint,
+ action_space=self.action_space,
+ connection_settings=self._connection_settings,
+ service_connection=self.service,
+ )
- # Set the "visible" name of the current benchmark to hide the fact that
- # we loaded from a custom bitcode file.
- new_env._benchmark_in_use = self._benchmark_in_use
+ # Set the session ID.
+ new_env._session_id = reply.session_id # pylint: disable=protected-access
+ new_env.observation.session_id = reply.session_id
+
+ # Now that we have initialized the environment with the current state,
+ # set the benchmark so that calls to new_env.reset() will correctly
+ # revert the environment to the initial benchmark state.
+ #
+ # pylint: disable=protected-access
+ new_env._next_benchmark = self._benchmark_in_use
+
+ # Set the "visible" name of the current benchmark to hide the fact that
+ # we loaded from a custom bitcode file.
+ new_env._benchmark_in_use = self._benchmark_in_use
+ except NotImplementedError:
+ # Fallback implementation. If the compiler service does not support
+ # the Fork() operator then we create a new independent environment
+ # and apply the sequence of actions in the current environment to
+ # replay the state.
+ new_env = type(self)(
+ service=self._service_endpoint,
+ action_space=self.action_space,
+ benchmark=self.benchmark,
+ connection_settings=self._connection_settings,
+ )
+ new_env.reset()
+ _, _, done, _ = new_env.step(self.actions)
+ assert not done, "Failed to replay action sequence in forked environment"
# Create copies of the mutable reward and observation spaces. This
# is required to correctly calculate incremental updates.
@@ -611,8 +631,12 @@ def close(self):
# not kill it.
if reply.remaining_sessions:
close_service = False
- except: # noqa pylint: disable=bare-except
- pass # Don't feel bad, computer, you tried ;-)
+ except Exception as e:
+ self.logger.warning(
+ "Failed to end active compiler session on close(): %s (%s)",
+ e,
+ type(e).__name__,
+ )
self._session_id = None
if self.service and close_service:
@@ -632,7 +656,7 @@ def reset( # pylint: disable=arguments-differ
benchmark: Optional[Union[str, Benchmark]] = None,
action_space: Optional[str] = None,
retry_count: int = 0,
- ) -> Optional[observation_t]:
+ ) -> Optional[ObservationType]:
"""Reset the environment state.
This method must be called before :func:`step()`.
@@ -740,7 +764,7 @@ def reset( # pylint: disable=arguments-differ
self.reward.reset(benchmark=self.benchmark)
if self.reward_space:
- self.episode_reward = 0
+ self.episode_reward = 0.0
if self.observation_space:
if len(reply.observation) != 1:
@@ -751,41 +775,56 @@ def reset( # pylint: disable=arguments-differ
reply.observation[0]
)
- def step(self, action: Union[int, Iterable[int]]) -> step_t:
+ def raw_step(
+ self,
+ actions: Iterable[int],
+ observations: Iterable[ObservationSpaceSpec],
+ rewards: Iterable[Reward],
+ ) -> StepType:
"""Take a step.
- :param action: An action, or a sequence of actions. When multiple
- actions are provided the observation and reward are returned after
- running all of the actions.
+ :param actions: A list of actions to be applied.
- :return: A tuple of observation, reward, done, and info. Observation and
- reward are None if default observation/reward is not set. If done is
- True, observation and reward may also be None (e.g. because the
- service failed).
+ :param observations: A list of observations spaces to compute
+ observations from. These are evaluated after the actions are
+ applied.
+
+ :param rewards: A list of reward spaces to compute rewards from. These
+ are evaluated after the actions are applied.
+
+ :return: A tuple of observations, rewards, done, and info. Observations
+ and rewards are lists.
:raises SessionNotFound: If :meth:`reset()
` has not been called.
+
+ .. warning::
+
+ Prefer :meth:`step() ` to
+ :meth:`raw_step() `.
+ :meth:`step() ` has equivalent
+ functionality, and is less likely to change in the future.
"""
if not self.in_episode:
raise SessionNotFound("Must call reset() before step()")
- actions = action if isinstance(action, IterableType) else [action]
- observation, reward = None, None
# Build the list of observations that must be computed by the backend
- # service to generate the user-requested observation and reward.
- # TODO(cummins): We could de-duplicate this list to improve efficiency
- # when multiple redundant copies of the same observation space are
- # requested.
- observation_indices, observation_spaces = [], []
- if self.observation_space:
- observation_indices.append(self.observation_space_spec.index)
- observation_spaces.append(self.observation_space_spec.id)
- if self.reward_space:
- observation_indices += [
- self.observation.spaces[obs].index
- for obs in self.reward_space.observation_spaces
+ user_observation_spaces: List[ObservationSpaceSpec] = list(observations)
+ reward_spaces: List[Reward] = list(rewards)
+
+ reward_observation_spaces: List[ObservationSpaceSpec] = []
+ for reward_space in reward_spaces:
+ reward_observation_spaces += [
+ self.observation.spaces[obs] for obs in reward_space.observation_spaces
]
- observation_spaces += self.reward_space.observation_spaces
+
+ observations_to_compute: List[ObservationSpaceSpec] = list(
+ set(user_observation_spaces).union(set(reward_observation_spaces))
+ )
+ observation_space_index_map: Dict[ObservationSpaceSpec, int] = {
+ observation_space: i
+ for i, observation_space in enumerate(observations_to_compute)
+ }
# Record the actions.
self.actions += actions
@@ -793,8 +832,10 @@ def step(self, action: Union[int, Iterable[int]]) -> step_t:
# Send the request to the backend service.
request = StepRequest(
session_id=self._session_id,
- action=actions,
- observation_space=observation_indices,
+ action=[Action(action=a) for a in actions],
+ observation_space=[
+ observation_space.index for observation_space in observations_to_compute
+ ],
)
try:
reply = _wrapped_step(self.service, request)
@@ -809,15 +850,20 @@ def step(self, action: Union[int, Iterable[int]]) -> step_t:
# end the current episode and provide some diagnostic information to
# the user through the `info` dict.
self.close()
+
info = {
"error_type": type(e).__name__,
"error_details": str(e),
}
- if self.reward_space:
- reward = self.reward_space.reward_on_error(self.episode_reward)
- if self.observation_space:
- observation = self.observation_space_spec.default_value
- return observation, reward, True, info
+ default_observations = [
+ observation_space.default_value
+ for observation_space in user_observation_spaces
+ ]
+ default_rewards = [
+ float(reward_space.reward_on_error(self.episode_reward))
+ for reward_space in reward_spaces
+ ]
+ return default_observations, default_rewards, True, info
# If the action space has changed, update it.
if reply.HasField("new_action_space"):
@@ -826,32 +872,126 @@ def step(self, action: Union[int, Iterable[int]]) -> step_t:
)
# Translate observations to python representations.
- if len(reply.observation) != len(observation_indices):
+ if len(reply.observation) != len(observations_to_compute):
raise ServiceError(
- f"Requested {observation_indices} observations "
+ f"Requested {len(observations_to_compute)} observations "
f"but received {len(reply.observation)}"
)
- observations = [
- self.observation.spaces[obs].translate(val)
- for obs, val in zip(observation_spaces, reply.observation)
+ computed_observations = [
+ observation_space.translate(value)
+ for observation_space, value in zip(
+ observations_to_compute, reply.observation
+ )
]
- # Pop the requested observation.
- if self.observation_space:
- observation, observations = observations[0], observations[1:]
+ # Get the user-requested observation.
+ observations: List[ObservationType] = [
+ computed_observations[observation_space_index_map[observation_space]]
+ for observation_space in user_observation_spaces
+ ]
- # Compute reward.
- self.reward.previous_action = action
- if self.reward_space:
- reward = self.reward_space.update(action, observations, self.observation)
- self.episode_reward += reward
+ # Update and compute the rewards.
+ rewards: List[RewardType] = []
+ for reward_space in reward_spaces:
+ reward_observations = [
+ computed_observations[
+ observation_space_index_map[
+ self.observation.spaces[observation_space]
+ ]
+ ]
+ for observation_space in reward_space.observation_spaces
+ ]
+ rewards.append(
+ float(
+ reward_space.update(actions, reward_observations, self.observation)
+ )
+ )
info = {
"action_had_no_effect": reply.action_had_no_effect,
"new_action_space": reply.HasField("new_action_space"),
}
- return observation, reward, reply.end_of_session, info
+ return observations, rewards, reply.end_of_session, info
+
+ def step(
+ self,
+ action: Union[ActionType, Iterable[ActionType]],
+ observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
+ rewards: Optional[Iterable[Union[str, Reward]]] = None,
+ ) -> StepType:
+ """Take a step.
+
+ :param action: An action, or a sequence of actions. When multiple
+ actions are provided the observation and reward are returned after
+ running all of the actions.
+
+ :param observations: A list of observation spaces to compute
+ observations from. If provided, this changes the :code:`observation`
+ element of the return tuple to be a list of observations from the
+ requested spaces. The default :code:`env.observation_space` is not
+ returned.
+
+ :param rewards: A list of reward spaces to compute rewards from. If
+ provided, this changes the :code:`reward` element of the return
+ tuple to be a list of rewards from the requested spaces. The default
+ :code:`env.reward_space` is not returned.
+
+ :return: A tuple of observation, reward, done, and info. Observation and
+ reward are None if default observation/reward is not set.
+
+ :raises SessionNotFound: If :meth:`reset()
+ ` has not been called.
+ """
+ # Coerce actions into a list.
+ actions = action if isinstance(action, IterableType) else [action]
+
+ # Coerce observation spaces into a list of ObservationSpaceSpec instances.
+ if observations:
+ observation_spaces: List[ObservationSpaceSpec] = [
+ obs
+ if isinstance(obs, ObservationSpaceSpec)
+ else self.observation.spaces[obs]
+ for obs in observations
+ ]
+ elif self.observation_space_spec:
+ observation_spaces: List[ObservationSpaceSpec] = [
+ self.observation_space_spec
+ ]
+ else:
+ observation_spaces: List[ObservationSpaceSpec] = []
+
+ # Coerce reward spaces into a list of Reward instances.
+ if rewards:
+ reward_spaces: List[Reward] = [
+ rew if isinstance(rew, Reward) else self.reward.spaces[rew]
+ for rew in rewards
+ ]
+ elif self.reward_space:
+ reward_spaces: List[Reward] = [self.reward_space]
+ else:
+ reward_spaces: List[Reward] = []
+
+ # Perform the underlying environment step.
+ observation_values, reward_values, done, info = self.raw_step(
+ actions, observation_spaces, reward_spaces
+ )
+
+ # Translate observations lists back to the appropriate types.
+ if observations is None and self.observation_space_spec:
+ observation_values = observation_values[0]
+ elif not observation_spaces:
+ observation_values = None
+
+ # Translate reward lists back to the appropriate types.
+ if rewards is None and self.reward_space:
+ reward_values = reward_values[0]
+ # Update the cumulative episode reward
+ self.episode_reward += reward_values
+ elif not reward_spaces:
+ reward_values = None
+
+ return observation_values, reward_values, done, info
def render(
self,
diff --git a/compiler_gym/envs/llvm/BUILD b/compiler_gym/envs/llvm/BUILD
index e819ac523..de2604763 100644
--- a/compiler_gym/envs/llvm/BUILD
+++ b/compiler_gym/envs/llvm/BUILD
@@ -57,6 +57,7 @@ py_library(
deps = [
"//compiler_gym/service",
"//compiler_gym/spaces",
+ "//compiler_gym/util",
"//compiler_gym/views",
],
)
diff --git a/compiler_gym/envs/llvm/datasets/BUILD b/compiler_gym/envs/llvm/datasets/BUILD
index cd18149a0..b2129ed3c 100644
--- a/compiler_gym/envs/llvm/datasets/BUILD
+++ b/compiler_gym/envs/llvm/datasets/BUILD
@@ -10,6 +10,7 @@ py_library(
"__init__.py",
"anghabench.py",
"cbench.py",
+ "chstone.py",
"clgen.py",
"csmith.py",
"llvm_stress.py",
diff --git a/compiler_gym/envs/llvm/datasets/__init__.py b/compiler_gym/envs/llvm/datasets/__init__.py
index e83cf0d6a..b46c8d799 100644
--- a/compiler_gym/envs/llvm/datasets/__init__.py
+++ b/compiler_gym/envs/llvm/datasets/__init__.py
@@ -9,6 +9,7 @@
from compiler_gym.datasets import Dataset, TarDatasetWithManifest
from compiler_gym.envs.llvm.datasets.anghabench import AnghaBenchDataset
from compiler_gym.envs.llvm.datasets.cbench import CBenchDataset, CBenchLegacyDataset
+from compiler_gym.envs.llvm.datasets.chstone import CHStoneDataset
from compiler_gym.envs.llvm.datasets.clgen import CLgenDataset
from compiler_gym.envs.llvm.datasets.csmith import CsmithBenchmark, CsmithDataset
from compiler_gym.envs.llvm.datasets.llvm_stress import LlvmStressDataset
@@ -212,6 +213,25 @@ def get_llvm_datasets(site_data_base: Optional[Path] = None) -> Iterable[Dataset
site_data_base = site_data_base or site_data_path("llvm-v0")
yield AnghaBenchDataset(site_data_base=site_data_base, sort_order=0)
+ # Add legacy version of Anghabench using an old manifest.
+ anghabench_v0_manifest_url, anghabench_v0_manifest_sha256 = {
+ "darwin": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-anghabench-v0-macos-manifest.bz2",
+ "39464256405aacefdb7550a7f990c9c578264c132804eec3daac091fa3c21bd1",
+ ),
+ "linux": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-anghabench-v0-linux-manifest.bz2",
+ "a038d25d39ee9472662a9704dfff19c9e3512ff6a70f1067af85c5cb3784b477",
+ ),
+ }[sys.platform]
+ yield AnghaBenchDataset(
+ name="benchmark://anghabench-v0",
+ site_data_base=site_data_base,
+ sort_order=0,
+ manifest_url=anghabench_v0_manifest_url,
+ manifest_sha256=anghabench_v0_manifest_sha256,
+ deprecated="Please use anghabench-v1",
+ )
yield BlasDataset(site_data_base=site_data_base, sort_order=0)
yield CLgenDataset(site_data_base=site_data_base, sort_order=0)
yield CBenchDataset(site_data_base=site_data_base, sort_order=-1)
@@ -229,6 +249,7 @@ def get_llvm_datasets(site_data_base: Optional[Path] = None) -> Iterable[Dataset
sort_order=100,
)
yield CBenchLegacyDataset(site_data_base=site_data_base)
+ yield CHStoneDataset(site_data_base=site_data_base)
yield CsmithDataset(site_data_base=site_data_base, sort_order=0)
yield GitHubDataset(site_data_base=site_data_base, sort_order=0)
yield LinuxDataset(site_data_base=site_data_base, sort_order=0)
diff --git a/compiler_gym/envs/llvm/datasets/anghabench.py b/compiler_gym/envs/llvm/datasets/anghabench.py
index bfcb46a65..ecee29f6a 100644
--- a/compiler_gym/envs/llvm/datasets/anghabench.py
+++ b/compiler_gym/envs/llvm/datasets/anghabench.py
@@ -6,6 +6,7 @@
import sys
from concurrent.futures import as_completed
from pathlib import Path
+from typing import Optional
from compiler_gym.datasets import Benchmark, TarDatasetWithManifest
from compiler_gym.datasets.benchmark import BenchmarkWithSource
@@ -38,19 +39,27 @@ class AnghaBenchDataset(TarDatasetWithManifest):
overhead of compiling it from C to bitcode. This is a one-off cost.
"""
- def __init__(self, site_data_base: Path, sort_order: int = 0):
- manifest_url, manifest_sha256 = {
+ def __init__(
+ self,
+ site_data_base: Path,
+ sort_order: int = 0,
+ manifest_url: Optional[str] = None,
+ manifest_sha256: Optional[str] = None,
+ deprecated: Optional[str] = None,
+ name: Optional[str] = None,
+ ):
+ manifest_url_, manifest_sha256_ = {
"darwin": (
- "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-anghabench-v0-macos-manifest.bz2",
- "39464256405aacefdb7550a7f990c9c578264c132804eec3daac091fa3c21bd1",
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-anghabench-v1-macos-manifest.bz2",
+ "96ead63da5f8efa07fd0370f0c6e452b59bed840828b8b19402102b1ce3ee109",
),
"linux": (
- "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-anghabench-v0-linux-manifest.bz2",
- "a038d25d39ee9472662a9704dfff19c9e3512ff6a70f1067af85c5cb3784b477",
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-anghabench-v1-linux-manifest.bz2",
+ "14df85f650199498cf769715e9f0d7841d09f9fa62a95b8ecc242bdaf227f33a",
),
}[sys.platform]
super().__init__(
- name="benchmark://anghabench-v0",
+ name=name or "benchmark://anghabench-v1",
description="Compile-only C/C++ functions extracted from GitHub",
references={
"Paper": "https://homepages.dcc.ufmg.br/~fernando/publications/papers/FaustinoCGO21.pdf",
@@ -58,8 +67,8 @@ def __init__(self, site_data_base: Path, sort_order: int = 0):
},
license="Unknown. See: https://github.com/brenocfg/AnghaBench/issues/1",
site_data_base=site_data_base,
- manifest_urls=[manifest_url],
- manifest_sha256=manifest_sha256,
+ manifest_urls=[manifest_url or manifest_url_],
+ manifest_sha256=manifest_sha256 or manifest_sha256_,
tar_urls=[
"https://github.com/brenocfg/AnghaBench/archive/d8034ac8562b8c978376008f4b33df01b8887b19.tar.gz"
],
@@ -68,6 +77,7 @@ def __init__(self, site_data_base: Path, sort_order: int = 0):
tar_compression="gz",
benchmark_file_suffix=".bc",
sort_order=sort_order,
+ deprecated=deprecated,
)
def benchmark(self, uri: str) -> Benchmark:
diff --git a/compiler_gym/envs/llvm/datasets/chstone.py b/compiler_gym/envs/llvm/datasets/chstone.py
new file mode 100644
index 000000000..2ef3ca97d
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/chstone.py
@@ -0,0 +1,135 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import subprocess
+from concurrent.futures import as_completed
+from pathlib import Path
+from typing import Iterable
+
+from compiler_gym.datasets import Benchmark, TarDatasetWithManifest
+from compiler_gym.datasets.benchmark import BenchmarkWithSource
+from compiler_gym.envs.llvm.llvm_benchmark import ClangInvocation
+from compiler_gym.util import thread_pool
+from compiler_gym.util.filesystem import atomic_file_write
+
+URIS = [
+ "benchmark://chstone-v0/adpcm",
+ "benchmark://chstone-v0/aes",
+ "benchmark://chstone-v0/blowfish",
+ "benchmark://chstone-v0/dfadd",
+ "benchmark://chstone-v0/dfdiv",
+ "benchmark://chstone-v0/dfmul",
+ "benchmark://chstone-v0/dfsin",
+ "benchmark://chstone-v0/gsm",
+ "benchmark://chstone-v0/jpeg",
+ "benchmark://chstone-v0/mips",
+ "benchmark://chstone-v0/motion",
+ "benchmark://chstone-v0/sha",
+]
+
+
+class CHStoneDataset(TarDatasetWithManifest):
+ """A dataset of C programs curated from GitHub source code.
+
+ The dataset is from:
+
+ Hara, Yuko, Hiroyuki Tomiyama, Shinya Honda, Hiroaki Takada, and Katsuya
+ Ishii. "Chstone: A benchmark program suite for practical c-based
+ high-level synthesis." In 2008 IEEE International Symposium on Circuits
+ and Systems, pp. 1192-1195. IEEE, 2008.
+
+ And is available at:
+
+ http://www.ertl.jp/chstone/
+ """
+
+ def __init__(
+ self,
+ site_data_base: Path,
+ sort_order: int = 0,
+ ):
+ super().__init__(
+ name="benchmark://chstone-v0",
+ description="Benchmarks for C-based High-Level Synthesis",
+ references={
+ "Paper": "http://www.yxi.com/applications/iscas2008-300_1027.pdf",
+ "Homepage": "http://www.ertl.jp/chstone/",
+ },
+ license="Mixture of open source and public domain licenses",
+ site_data_base=site_data_base,
+ tar_urls=[
+ "https://github.com/ChrisCummins/patmos_HLS/archive/e62d878ceb91e5a18007ca2e0a9602ee44ff7d59.tar.gz"
+ ],
+ tar_sha256="f7acab9d3c3dc7b971e62c8454bc909d84bddb6d0a96378e41beb94231739acb",
+ strip_prefix="patmos_HLS-e62d878ceb91e5a18007ca2e0a9602ee44ff7d59/benchmarks/CHStone",
+ tar_compression="gz",
+ benchmark_file_suffix=".bc",
+ sort_order=sort_order,
+ # We provide our own manifest.
+ manifest_urls=[],
+ manifest_sha256="",
+ )
+
+ def benchmark_uris(self) -> Iterable[str]:
+ yield from URIS
+
+ def benchmark(self, uri: str) -> Benchmark:
+ self.install()
+
+ benchmark_name = uri[len(self.name) + 1 :]
+ if not benchmark_name:
+ raise LookupError(f"No benchmark specified: {uri}")
+
+ bitcode_abspath = self.dataset_root / f"{benchmark_name}.bc"
+
+ # Most of the source files are named after the parent directory, but not
+ # all.
+ c_file_name = {
+ "blowfish": "bf.c",
+ "motion": "mpeg2.c",
+ "sha": "sha_driver.c",
+ "jpeg": "main.c",
+ }.get(benchmark_name, f"{benchmark_name}.c")
+ c_file_abspath = self.dataset_root / benchmark_name / c_file_name
+
+ # If the file does not exist, compile it on-demand.
+ if not bitcode_abspath.is_file():
+ if not c_file_abspath.is_file():
+ raise LookupError(
+ f"Benchmark not found: {uri} (file not found: {c_file_abspath})"
+ )
+
+ with atomic_file_write(bitcode_abspath) as tmp_path:
+ compile_cmd = ClangInvocation.from_c_file(
+ c_file_abspath,
+ copt=[
+ "-ferror-limit=1", # Stop on first error.
+ "-w", # No warnings.
+ ],
+ ).command(outpath=tmp_path)
+ subprocess.check_call(compile_cmd, timeout=300)
+
+ return BenchmarkWithSource.create(
+ uri, bitcode_abspath, "function.c", c_file_abspath
+ )
+
+ @property
+ def size(self) -> int:
+ return len(URIS)
+
+ def compile_all(self):
+ n = self.size
+ executor = thread_pool.get_thread_pool_executor()
+ # Since the dataset is lazily compiled, simply iterating over the full
+ # set of URIs will compile everything. Do this in parallel.
+ futures = (
+ executor.submit(self.benchmark, uri) for uri in self.benchmark_uris()
+ )
+ for i, future in enumerate(as_completed(futures), start=1):
+ future.result()
+ print(
+ f"\r\033[KCompiled {i} of {n} programs ({i/n:.1%} complete)",
+ flush=True,
+ end="",
+ )
diff --git a/compiler_gym/envs/llvm/datasets/csmith.py b/compiler_gym/envs/llvm/datasets/csmith.py
index dfe51e435..184f28aa4 100644
--- a/compiler_gym/envs/llvm/datasets/csmith.py
+++ b/compiler_gym/envs/llvm/datasets/csmith.py
@@ -12,6 +12,7 @@
from threading import Lock
from typing import Iterable, List
+import numpy as np
from fasteners import InterProcessLock
from compiler_gym.datasets import Benchmark, BenchmarkSource, Dataset
@@ -21,7 +22,6 @@
from compiler_gym.util.decorators import memoized_property
from compiler_gym.util.download import download
from compiler_gym.util.runfiles_path import transient_cache_path
-from compiler_gym.util.truncate import truncate
# The maximum value for the --seed argument to csmith.
UINT_MAX = (2 ** 32) - 1
@@ -227,6 +227,10 @@ def benchmark_uris(self) -> Iterable[str]:
def benchmark(self, uri: str) -> CsmithBenchmark:
return self.benchmark_from_seed(int(uri.split("/")[-1]))
+ def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
+ seed = random_state.integers(UINT_MAX)
+ return self.benchmark_from_seed(seed)
+
def benchmark_from_seed(self, seed: int) -> CsmithBenchmark:
"""Get a benchmark from a uint32 seed.
@@ -246,30 +250,25 @@ def benchmark_from_seed(self, seed: int) -> CsmithBenchmark:
)
# Generate the C source.
- src, stderr = csmith.communicate(timeout=300)
+ src, _ = csmith.communicate(timeout=300)
if csmith.returncode:
- error = truncate(stderr.decode("utf-8"), max_lines=20, max_line_len=100)
- raise OSError(f"Csmith failed with seed {seed}\nError: {error}")
+ raise OSError(f"Csmith failed with seed {seed}")
# Compile to IR.
clang = subprocess.Popen(
self.clang_compile_command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
)
- stdout, stderr = clang.communicate(src, timeout=300)
+ stdout, _ = clang.communicate(src, timeout=300)
- if csmith.returncode:
- raise OSError(f"Csmith failed with seed {seed}")
if clang.returncode:
compile_cmd = " ".join(self.clang_compile_command)
- error = truncate(stderr.decode("utf-8"), max_lines=20, max_line_len=100)
raise BenchmarkInitError(
f"Compilation job failed!\n"
f"Csmith seed: {seed}\n"
f"Command: {compile_cmd}\n"
- f"Error: {error}"
)
return self.benchmark_class.create(f"{self.name}/{seed}", stdout, src)
diff --git a/compiler_gym/envs/llvm/datasets/llvm_stress.py b/compiler_gym/envs/llvm/datasets/llvm_stress.py
index 02d948bb2..6c9aeca8f 100644
--- a/compiler_gym/envs/llvm/datasets/llvm_stress.py
+++ b/compiler_gym/envs/llvm/datasets/llvm_stress.py
@@ -6,6 +6,8 @@
from pathlib import Path
from typing import Iterable
+import numpy as np
+
from compiler_gym.datasets import Benchmark, Dataset
from compiler_gym.datasets.benchmark import BenchmarkInitError
from compiler_gym.third_party import llvm
@@ -56,6 +58,10 @@ def benchmark_uris(self) -> Iterable[str]:
def benchmark(self, uri: str) -> Benchmark:
return self.benchmark_from_seed(int(uri.split("/")[-1]))
+ def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
+ seed = random_state.integers(UINT_MAX)
+ return self.benchmark_from_seed(seed)
+
def benchmark_from_seed(self, seed: int) -> Benchmark:
"""Get a benchmark from a uint32 seed.
diff --git a/compiler_gym/envs/llvm/legacy_datasets.py b/compiler_gym/envs/llvm/legacy_datasets.py
deleted file mode 100644
index fe1090b88..000000000
--- a/compiler_gym/envs/llvm/legacy_datasets.py
+++ /dev/null
@@ -1,932 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-"""This module defines the available LLVM datasets."""
-import enum
-import io
-import logging
-import os
-import re
-import shutil
-import subprocess
-import sys
-import tarfile
-import tempfile
-from collections import defaultdict
-from concurrent.futures import as_completed
-from pathlib import Path
-from threading import Lock
-from typing import Callable, Dict, Iterable, List, NamedTuple, Optional
-
-import fasteners
-
-from compiler_gym.datasets.dataset import LegacyDataset
-from compiler_gym.third_party import llvm
-from compiler_gym.util import thread_pool
-from compiler_gym.util.download import download
-from compiler_gym.util.runfiles_path import cache_path, site_data_path
-from compiler_gym.util.timer import Timer
-from compiler_gym.validation_error import ValidationError
-
-_CBENCH_DATA_URL = (
- "https://dl.fbaipublicfiles.com/compiler_gym/cBench-v0-runtime-data.tar.bz2"
-)
-_CBENCH_DATA_SHA256 = "a1b5b5d6b115e5809ccaefc2134434494271d184da67e2ee43d7f84d07329055"
-
-
-if sys.platform == "darwin":
- _COMPILE_ARGS = [
- "-L",
- "/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib",
- ]
-else:
- _COMPILE_ARGS = []
-
-LLVM_DATASETS = [
- LegacyDataset(
- name="blas-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-blas-v0.tar.bz2",
- license="BSD 3-Clause",
- description="https://github.com/spcl/ncc/tree/master/data",
- compiler="llvm-10.0.0",
- file_count=300,
- size_bytes=3969036,
- sha256="e724a8114709f8480adeb9873d48e426e8d9444b00cddce48e342b9f0f2b096d",
- ),
- # The difference between cBench-v0 and cBench-v1 is the arguments passed to
- # clang when preparing the LLVM bitcodes:
- #
- # - v0: `-O0 -Xclang -disable-O0-optnone`.
- # - v1: `-O1 -Xclang -Xclang -disable-llvm-passes`.
- #
- # The key difference with is that in v0, the generated IR functions were
- # annotated with a `noinline` attribute that prevented inline. In v1 that is
- # no longer the case.
- LegacyDataset(
- name="cBench-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v0-macos.tar.bz2",
- license="BSD 3-Clause",
- description="https://github.com/ctuning/ctuning-programs",
- compiler="llvm-10.0.0",
- file_count=23,
- size_bytes=7154448,
- sha256="072a730c86144a07bba948c49afe543e4f06351f1cb17f7de77f91d5c1a1b120",
- platforms=["macos"],
- deprecated_since="v0.1.4",
- ),
- LegacyDataset(
- name="cBench-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v0-linux.tar.bz2",
- license="BSD 3-Clause",
- description="https://github.com/ctuning/ctuning-programs",
- compiler="llvm-10.0.0",
- file_count=23,
- size_bytes=6940416,
- sha256="9b5838a90895579aab3b9375e8eeb3ed2ae58e0ad354fec7eb4f8b31ecb4a360",
- platforms=["linux"],
- deprecated_since="v0.1.4",
- ),
- LegacyDataset(
- name="cBench-v1",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v1-macos.tar.bz2",
- license="BSD 3-Clause",
- description="https://github.com/ctuning/ctuning-programs",
- compiler="llvm-10.0.0",
- file_count=23,
- size_bytes=10292032,
- sha256="90b312b40317d9ee9ed09b4b57d378879f05e8970bb6de80dc8581ad0e36c84f",
- platforms=["macos"],
- ),
- LegacyDataset(
- name="cBench-v1",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v1-linux.tar.bz2",
- license="BSD 3-Clause",
- description="https://github.com/ctuning/ctuning-programs",
- compiler="llvm-10.0.0",
- file_count=23,
- size_bytes=10075608,
- sha256="601fff3944c866f6617e653b6eb5c1521382c935f56ca1f36a9f5cf1a49f3de5",
- platforms=["linux"],
- ),
- LegacyDataset(
- name="github-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-github-v0.tar.bz2",
- license="CC BY 4.0",
- description="https://zenodo.org/record/4122437",
- compiler="llvm-10.0.0",
- file_count=50708,
- size_bytes=725974100,
- sha256="880269dd7a5c2508ea222a2e54c318c38c8090eb105c0a87c595e9dd31720764",
- ),
- LegacyDataset(
- name="linux-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-linux-v0.tar.bz2",
- license="GPL-2.0",
- description="https://github.com/spcl/ncc/tree/master/data",
- compiler="llvm-10.0.0",
- file_count=13920,
- size_bytes=516031044,
- sha256="a1ae5c376af30ab042c9e54dc432f89ce75f9ebaee953bc19c08aff070f12566",
- ),
- LegacyDataset(
- name="mibench-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-mibench-v0.tar.bz2",
- license="BSD 3-Clause",
- description="https://github.com/ctuning/ctuning-programs",
- compiler="llvm-10.0.0",
- file_count=40,
- size_bytes=238480,
- sha256="128c090c40b955b99fdf766da167a5f642018fb35c16a1d082f63be2e977eb13",
- ),
- LegacyDataset(
- name="npb-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-npb-v0.tar.bz2",
- license="NASA Open Source Agreement v1.3",
- description="https://github.com/spcl/ncc/tree/master/data",
- compiler="llvm-10.0.0",
- file_count=122,
- size_bytes=2287444,
- sha256="793ac2e7a4f4ed83709e8a270371e65b724da09eaa0095c52e7f4209f63bb1f2",
- ),
- LegacyDataset(
- name="opencv-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-opencv-v0.tar.bz2",
- license="Apache 2.0",
- description="https://github.com/spcl/ncc/tree/master/data",
- compiler="llvm-10.0.0",
- file_count=442,
- size_bytes=21903008,
- sha256="003df853bd58df93572862ca2f934c7b129db2a3573bcae69a2e59431037205c",
- ),
- LegacyDataset(
- name="poj104-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-poj104-v0.tar.bz2",
- license="BSD 3-Clause",
- description="https://sites.google.com/site/treebasedcnn/",
- compiler="llvm-10.0.0",
- file_count=49628,
- size_bytes=304207752,
- sha256="6254d629887f6b51efc1177788b0ce37339d5f3456fb8784415ed3b8c25cce27",
- ),
- # FIXME(github.com/facebookresearch/CompilerGym/issues/55): Polybench
- # dataset has `optnone` function attribute set, requires rebuild.
- # LegacyDataset(
- # name="polybench-v0",
- # url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-polybench-v0.tar.bz2",
- # license="BSD 3-Clause",
- # description="https://github.com/ctuning/ctuning-programs",
- # compiler="llvm-10.0.0",
- # file_count=27,
- # size_bytes=162624,
- # sha256="968087e68470e5b44dc687dae195143000c7478a23d6631b27055bb3bb3116b1",
- # ),
- LegacyDataset(
- name="tensorflow-v0",
- url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-tensorflow-v0.tar.bz2",
- license="Apache 2.0",
- description="https://github.com/spcl/ncc/tree/master/data",
- compiler="llvm-10.0.0",
- file_count=1985,
- size_bytes=299697312,
- sha256="f77dd1988c772e8359e1303cc9aba0d73d5eb27e0c98415ac3348076ab94efd1",
- ),
-]
-
-
-class BenchmarkExecutionResult(NamedTuple):
- """The result of running a benchmark."""
-
- walltime_seconds: float
- """The execution time in seconds."""
-
- error: Optional[ValidationError] = None
- """An error."""
-
- output: Optional[str] = None
- """The output generated by the benchmark."""
-
- def json(self):
- return self._asdict()
-
-
-class LlvmSanitizer(enum.IntEnum):
- """The LLVM sanitizers."""
-
- ASAN = 1
- TSAN = 2
- MSAN = 3
- UBSAN = 4
-
-
-# Compiler flags that are enabled by sanitizers.
-_SANITIZER_FLAGS = {
- LlvmSanitizer.ASAN: ["-O1", "-g", "-fsanitize=address", "-fno-omit-frame-pointer"],
- LlvmSanitizer.TSAN: ["-O1", "-g", "-fsanitize=thread"],
- LlvmSanitizer.MSAN: ["-O1", "-g", "-fsanitize=memory"],
- LlvmSanitizer.UBSAN: ["-fsanitize=undefined"],
-}
-
-
-def _compile_and_run_bitcode_file(
- bitcode_file: Path,
- cmd: str,
- cwd: Path,
- linkopts: List[str],
- env: Dict[str, str],
- num_runs: int,
- logger: logging.Logger,
- sanitizer: Optional[LlvmSanitizer] = None,
- timeout_seconds: float = 300,
- compilation_timeout_seconds: float = 60,
-) -> BenchmarkExecutionResult:
- """Run the given cBench benchmark."""
- # cBench benchmarks expect that a file _finfo_dataset exists in the
- # current working directory and contains the number of benchmark
- # iterations in it.
- with open(cwd / "_finfo_dataset", "w") as f:
- print(num_runs, file=f)
-
- # Create a barebones execution environment for the benchmark.
- run_env = {
- "TMPDIR": os.environ.get("TMPDIR", ""),
- "HOME": os.environ.get("HOME", ""),
- "USER": os.environ.get("USER", ""),
- # Disable all logging from GRPC. In the past I have had false-positive
- # "Wrong output" errors caused by GRPC error messages being logged to
- # stderr.
- "GRPC_VERBOSITY": "NONE",
- }
- run_env.update(env)
-
- error_data = {}
-
- if sanitizer:
- clang_path = llvm.clang_path()
- binary = cwd / "a.out"
- error_data["run_cmd"] = cmd.replace("$BIN", "./a.out")
- # Generate the a.out binary file.
- compile_cmd = (
- [clang_path.name, str(bitcode_file), "-o", str(binary)]
- + _COMPILE_ARGS
- + list(linkopts)
- + _SANITIZER_FLAGS.get(sanitizer, [])
- )
- error_data["compile_cmd"] = compile_cmd
- logger.debug("compile: %s", compile_cmd)
- assert not binary.is_file()
- clang = subprocess.Popen(
- compile_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- universal_newlines=True,
- env={"PATH": f"{clang_path.parent}:{os.environ.get('PATH', '')}"},
- )
- try:
- output, _ = clang.communicate(timeout=compilation_timeout_seconds)
- except subprocess.TimeoutExpired:
- clang.kill()
- error_data["timeout"] = compilation_timeout_seconds
- return BenchmarkExecutionResult(
- walltime_seconds=timeout_seconds,
- error=ValidationError(
- type="Compilation timeout",
- data=error_data,
- ),
- )
- if clang.returncode:
- error_data["output"] = output
- return BenchmarkExecutionResult(
- walltime_seconds=timeout_seconds,
- error=ValidationError(
- type="Compilation failed",
- data=error_data,
- ),
- )
- assert binary.is_file()
- else:
- lli_path = llvm.lli_path()
- error_data["run_cmd"] = cmd.replace("$BIN", f"{lli_path.name} benchmark.bc")
- run_env["PATH"] = str(lli_path.parent)
-
- try:
- logger.debug("exec: %s", error_data["run_cmd"])
- with Timer() as timer:
- process = subprocess.Popen(
- error_data["run_cmd"],
- shell=True,
- stderr=subprocess.STDOUT,
- stdout=subprocess.PIPE,
- env=run_env,
- cwd=cwd,
- )
-
- stdout, _ = process.communicate(timeout=timeout_seconds)
- except subprocess.TimeoutExpired:
- process.kill()
- error_data["timeout_seconds"] = timeout_seconds
- return BenchmarkExecutionResult(
- walltime_seconds=timeout_seconds,
- error=ValidationError(
- type="Execution timeout",
- data=error_data,
- ),
- )
- finally:
- if sanitizer:
- binary.unlink()
-
- try:
- output = stdout.decode("utf-8")
- except UnicodeDecodeError:
- output = ""
-
- if process.returncode:
- # Runtime error.
- if sanitizer == LlvmSanitizer.ASAN and "LeakSanitizer" in output:
- error_type = "Memory leak"
- elif sanitizer == LlvmSanitizer.ASAN and "AddressSanitizer" in output:
- error_type = "Memory error"
- elif sanitizer == LlvmSanitizer.MSAN and "MemorySanitizer" in output:
- error_type = "Memory error"
- elif "Segmentation fault" in output:
- error_type = "Segmentation fault"
- elif "Illegal Instruction" in output:
- error_type = "Illegal Instruction"
- else:
- error_type = f"Runtime error ({process.returncode})"
-
- error_data["return_code"] = process.returncode
- error_data["output"] = output
- return BenchmarkExecutionResult(
- walltime_seconds=timer.time,
- error=ValidationError(
- type=error_type,
- data=error_data,
- ),
- )
- return BenchmarkExecutionResult(walltime_seconds=timer.time, output=output)
-
-
-def download_cBench_runtime_data() -> bool:
- """Download and unpack the cBench runtime dataset."""
- cbench_data = site_data_path("llvm/cBench-v1-runtime-data/runtime_data")
- if (cbench_data / "unpacked").is_file():
- return False
- else:
- # Clean up any partially-extracted data directory.
- if cbench_data.is_dir():
- shutil.rmtree(cbench_data)
-
- tar_contents = io.BytesIO(
- download(_CBENCH_DATA_URL, sha256=_CBENCH_DATA_SHA256)
- )
- with tarfile.open(fileobj=tar_contents, mode="r:bz2") as tar:
- cbench_data.parent.mkdir(parents=True, exist_ok=True)
- tar.extractall(cbench_data.parent)
- assert cbench_data.is_dir()
- # Create the marker file to indicate that the directory is unpacked
- # and ready to go.
- (cbench_data / "unpacked").touch()
- return True
-
-
-# Thread lock to prevent race on download_cBench_runtime_data() from
-# multi-threading. This works in tandem with the inter-process file lock - both
-# are required.
-_CBENCH_DOWNLOAD_THREAD_LOCK = Lock()
-
-
-def _make_cBench_validator(
- cmd: str,
- linkopts: List[str],
- os_env: Dict[str, str],
- num_runs: int = 1,
- compare_output: bool = True,
- input_files: Optional[List[Path]] = None,
- output_files: Optional[List[Path]] = None,
- validate_result: Optional[
- Callable[[BenchmarkExecutionResult], Optional[str]]
- ] = None,
- pre_execution_callback: Optional[Callable[[Path], None]] = None,
- sanitizer: Optional[LlvmSanitizer] = None,
- flakiness: int = 5,
-) -> Callable[["LlvmEnv"], Optional[ValidationError]]: # noqa: F821
- """Construct a validation callback for a cBench benchmark. See validator() for usage."""
- input_files = input_files or []
- output_files = output_files or []
-
- def validator_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
- """The validation callback."""
- with _CBENCH_DOWNLOAD_THREAD_LOCK:
- with fasteners.InterProcessLock(cache_path(".cBench-v1-runtime-data.lock")):
- download_cBench_runtime_data()
-
- cbench_data = site_data_path("llvm/cBench-v1-runtime-data/runtime_data")
- for input_file_name in input_files:
- path = cbench_data / input_file_name
- if not path.is_file():
- raise FileNotFoundError(f"Required benchmark input not found: {path}")
-
- # Create a temporary working directory to execute the benchmark in.
- with tempfile.TemporaryDirectory(dir=env.service.connection.working_dir) as d:
- cwd = Path(d)
-
- # Expand shell variable substitutions in the benchmark command.
- expanded_command = cmd.replace("$D", str(cbench_data))
-
- # Translate the output file names into paths inside the working
- # directory.
- output_paths = [cwd / o for o in output_files]
-
- if pre_execution_callback:
- pre_execution_callback(cwd)
-
- # Produce a gold-standard output using a reference version of
- # the benchmark.
- if compare_output or output_files:
- gs_env = env.fork()
- try:
- # Reset to the original benchmark state and compile it.
- gs_env.reset(benchmark=env.benchmark)
- gs_env.write_bitcode(cwd / "benchmark.bc")
- gold_standard = _compile_and_run_bitcode_file(
- bitcode_file=cwd / "benchmark.bc",
- cmd=expanded_command,
- cwd=cwd,
- num_runs=1,
- # Use default optimizations for gold standard.
- linkopts=linkopts + ["-O2"],
- # Always assume safe.
- sanitizer=None,
- logger=env.logger,
- env=os_env,
- )
- if gold_standard.error:
- return ValidationError(
- type=f"Gold standard: {gold_standard.error.type}",
- data=gold_standard.error.data,
- )
- finally:
- gs_env.close()
-
- # Check that the reference run produced the expected output
- # files.
- for path in output_paths:
- if not path.is_file():
- try:
- output = gold_standard.output
- except UnicodeDecodeError:
- output = ""
- raise FileNotFoundError(
- f"Expected file '{path.name}' not generated\n"
- f"Benchmark: {env.benchmark}\n"
- f"Command: {cmd}\n"
- f"Output: {output}"
- )
- path.rename(f"{path}.gold_standard")
-
- # Serialize the benchmark to a bitcode file that will then be
- # compiled to a binary.
- env.write_bitcode(cwd / "benchmark.bc")
- outcome = _compile_and_run_bitcode_file(
- bitcode_file=cwd / "benchmark.bc",
- cmd=expanded_command,
- cwd=cwd,
- num_runs=num_runs,
- linkopts=linkopts,
- sanitizer=sanitizer,
- logger=env.logger,
- env=os_env,
- )
-
- if outcome.error:
- return outcome.error
-
- # Run a user-specified validation hook.
- if validate_result:
- validate_result(outcome)
-
- # Difftest the console output.
- if compare_output and gold_standard.output != outcome.output:
- return ValidationError(
- type="Wrong output",
- data={"expected": gold_standard.output, "actual": outcome.output},
- )
-
- # Difftest the output files.
- for i, path in enumerate(output_paths, start=1):
- if not path.is_file():
- return ValidationError(
- type="Output not generated",
- data={"path": path.name, "command": cmd},
- )
- diff = subprocess.Popen(
- ["diff", str(path), f"{path}.gold_standard"],
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- )
- stdout, _ = diff.communicate()
- if diff.returncode:
- try:
- stdout = stdout.decode("utf-8")
- return ValidationError(
- type="Wrong output (file)",
- data={"path": path.name, "diff": stdout},
- )
- except UnicodeDecodeError:
- return ValidationError(
- type="Wrong output (file)",
- data={"path": path.name, "diff": ""},
- )
-
- def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
- """Wrap the validation callback in a flakiness retry loop."""
- for i in range(1, max(flakiness, 1) + 1):
- try:
- error = validator_cb(env)
- if not error:
- return
- except TimeoutError:
- # Timeout errors can be raised by the environment in case of a
- # slow step / observation, and should be retried.
- pass
- env.logger.warning(
- "Validation callback failed, attempt=%d/%d", i, flakiness
- )
- return error
-
- return flaky_wrapped_cb
-
-
-# A map from benchmark name to validation callbacks. Defined below.
-VALIDATORS: Dict[
- str, List[Callable[["LlvmEnv"], Optional[str]]] # noqa: F821
-] = defaultdict(list)
-
-
-def validator(
- benchmark: str,
- cmd: str,
- data: Optional[List[str]] = None,
- outs: Optional[List[str]] = None,
- platforms: Optional[List[str]] = None,
- compare_output: bool = True,
- validate_result: Optional[
- Callable[[BenchmarkExecutionResult], Optional[str]]
- ] = None,
- linkopts: Optional[List[str]] = None,
- env: Optional[Dict[str, str]] = None,
- pre_execution_callback: Optional[Callable[[], None]] = None,
- sanitizers: Optional[List[LlvmSanitizer]] = None,
-) -> bool:
- """Declare a new benchmark validator.
-
- TODO(cummins): Pull this out into a public API.
-
- :param benchmark: The name of the benchmark that this validator supports.
- :cmd: The shell command to run the validation. Variable substitution is
- applied to this value as follows: :code:`$BIN` is replaced by the path
- of the compiled binary and :code:`$D` is replaced with the path to the
- benchmark's runtime data directory.
- :data: A list of paths to input files.
- :outs: A list of paths to output files.
- :return: :code:`True` if the new validator was registered, else :code:`False`.
- """
- platforms = platforms or ["linux", "macos"]
- if {"darwin": "macos"}.get(sys.platform, sys.platform) not in platforms:
- return False
- infiles = data or []
- outfiles = [Path(p) for p in outs or []]
- linkopts = linkopts or []
- env = env or {}
- if sanitizers is None:
- sanitizers = LlvmSanitizer
-
- VALIDATORS[benchmark].append(
- _make_cBench_validator(
- cmd=cmd,
- input_files=infiles,
- output_files=outfiles,
- compare_output=compare_output,
- validate_result=validate_result,
- linkopts=linkopts,
- os_env=env,
- pre_execution_callback=pre_execution_callback,
- )
- )
-
- # Register additional validators using the sanitizers.
- if sys.platform.startswith("linux"):
- for sanitizer in sanitizers:
- VALIDATORS[benchmark].append(
- _make_cBench_validator(
- cmd=cmd,
- input_files=infiles,
- output_files=outfiles,
- compare_output=compare_output,
- validate_result=validate_result,
- linkopts=linkopts,
- os_env=env,
- pre_execution_callback=pre_execution_callback,
- sanitizer=sanitizer,
- )
- )
-
- return True
-
-
-def get_llvm_benchmark_validation_callback(
- env: "LlvmEnv", # noqa: F821
-) -> Optional[Callable[["LlvmEnv"], Iterable[ValidationError]]]: # noqa: F821
- """Return a callback for validating a given environment state.
-
- If there is no valid callback, returns :code:`None`.
-
- :param env: An :class:`LlvmEnv ` instance.
-
- :return: An optional callback that takes an :class:`LlvmEnv
- ` instance as argument and returns an
- optional string containing a validation error message.
- """
- validators = VALIDATORS.get(env.benchmark)
-
- # No match.
- if not validators:
- return None
-
- def composed(env):
- # Validation callbacks are read-only on the environment so it is
- # safe to run validators simultaneously in parallel threads.
- executor = thread_pool.get_thread_pool_executor()
- futures = (executor.submit(validator, env) for validator in validators)
- for future in as_completed(futures):
- result = future.result()
- if result is not None:
- yield result
-
- return None
-
- return composed
-
-
-# ===============================
-# Definition of cBench validators
-# ===============================
-
-
-def validate_sha_output(result: BenchmarkExecutionResult) -> Optional[str]:
- """SHA benchmark prints 5 random hex strings. Normally these hex strings are
- 16 characters but occasionally they are less (presumably becuase of a
- leading zero being omitted).
- """
- try:
- if not re.match(
- r"[0-9a-f]{0,16} [0-9a-f]{0,16} [0-9a-f]{0,16} [0-9a-f]{0,16} [0-9a-f]{0,16}",
- result.output.rstrip(),
- ):
- return "Failed to parse hex output"
- except UnicodeDecodeError:
- return "Failed to parse unicode output"
-
-
-def setup_ghostscript_library_files(dataset_id: int) -> Callable[[Path], None]:
- """Make a pre-execution setup hook for ghostscript."""
-
- def setup(cwd: Path):
- cbench_data = site_data_path("llvm/cBench-v1-runtime-data/runtime_data")
- # Copy the input data file into the current directory since ghostscript
- # doesn't like long input paths.
- shutil.copyfile(
- cbench_data / "office_data" / f"{dataset_id}.ps", cwd / "input.ps"
- )
- # Ghostscript doesn't like the library files being symlinks so copy them
- # into the working directory as regular files.
- for path in (cbench_data / "ghostscript").iterdir():
- if path.name.endswith(".ps"):
- shutil.copyfile(path, cwd / path.name)
-
- return setup
-
-
-validator(
- benchmark="benchmark://cBench-v1/bitcount",
- cmd="$BIN 1125000",
-)
-
-validator(
- benchmark="benchmark://cBench-v1/bitcount",
- cmd="$BIN 512",
-)
-
-for i in range(1, 21):
-
- # NOTE(cummins): Disabled due to timeout errors, further investigation
- # needed.
- #
- # validator(
- # benchmark="benchmark://cBench-v1/adpcm",
- # cmd=f"$BIN $D/telecom_data/{i}.adpcm",
- # data=[f"telecom_data/{i}.adpcm"],
- # )
- #
- # validator(
- # benchmark="benchmark://cBench-v1/adpcm",
- # cmd=f"$BIN $D/telecom_data/{i}.pcm",
- # data=[f"telecom_data/{i}.pcm"],
- # )
-
- validator(
- benchmark="benchmark://cBench-v1/blowfish",
- cmd=f"$BIN d $D/office_data/{i}.benc output.txt 1234567890abcdeffedcba0987654321",
- data=[f"office_data/{i}.benc"],
- outs=["output.txt"],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/bzip2",
- cmd=f"$BIN -d -k -f -c $D/bzip2_data/{i}.bz2",
- data=[f"bzip2_data/{i}.bz2"],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/crc32",
- cmd=f"$BIN $D/telecom_data/{i}.pcm",
- data=[f"telecom_data/{i}.pcm"],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/dijkstra",
- cmd=f"$BIN $D/network_dijkstra_data/{i}.dat",
- data=[f"network_dijkstra_data/{i}.dat"],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/gsm",
- cmd=f"$BIN -fps -c $D/telecom_gsm_data/{i}.au",
- data=[f"telecom_gsm_data/{i}.au"],
- )
-
- # NOTE(cummins): ispell fails with returncode 1 and no output when run
- # under safe optimizations.
- #
- # validator(
- # benchmark="benchmark://cBench-v1/ispell",
- # cmd=f"$BIN -a -d americanmed+ $D/office_data/{i}.txt",
- # data = [f"office_data/{i}.txt"],
- # )
-
- validator(
- benchmark="benchmark://cBench-v1/jpeg-c",
- cmd=f"$BIN -dct int -progressive -outfile output.jpeg $D/consumer_jpeg_data/{i}.ppm",
- data=[f"consumer_jpeg_data/{i}.ppm"],
- outs=["output.jpeg"],
- # NOTE(cummins): AddressSanitizer disabled because of
- # global-buffer-overflow in regular build.
- sanitizers=[LlvmSanitizer.TSAN, LlvmSanitizer.UBSAN],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/jpeg-d",
- cmd=f"$BIN -dct int -outfile output.ppm $D/consumer_jpeg_data/{i}.jpg",
- data=[f"consumer_jpeg_data/{i}.jpg"],
- outs=["output.ppm"],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/patricia",
- cmd=f"$BIN $D/network_patricia_data/{i}.udp",
- data=[f"network_patricia_data/{i}.udp"],
- env={
- # NOTE(cummins): Benchmark leaks when executed with safe optimizations.
- "ASAN_OPTIONS": "detect_leaks=0",
- },
- )
-
- validator(
- benchmark="benchmark://cBench-v1/qsort",
- cmd=f"$BIN $D/automotive_qsort_data/{i}.dat",
- data=[f"automotive_qsort_data/{i}.dat"],
- outs=["sorted_output.dat"],
- linkopts=["-lm"],
- )
-
- # NOTE(cummins): Rijndael benchmark disabled due to memory errors under
- # basic optimizations.
- #
- # validator(benchmark="benchmark://cBench-v1/rijndael", cmd=f"$BIN
- # $D/office_data/{i}.enc output.dec d
- # 1234567890abcdeffedcba09876543211234567890abcdeffedcba0987654321",
- # data=[f"office_data/{i}.enc"], outs=["output.dec"],
- # )
- #
- # validator(benchmark="benchmark://cBench-v1/rijndael", cmd=f"$BIN
- # $D/office_data/{i}.txt output.enc e
- # 1234567890abcdeffedcba09876543211234567890abcdeffedcba0987654321",
- # data=[f"office_data/{i}.txt"], outs=["output.enc"],
- # )
-
- validator(
- benchmark="benchmark://cBench-v1/sha",
- cmd=f"$BIN $D/office_data/{i}.txt",
- data=[f"office_data/{i}.txt"],
- compare_output=False,
- validate_result=validate_sha_output,
- )
-
- validator(
- benchmark="benchmark://cBench-v1/stringsearch",
- cmd=f"$BIN $D/office_data/{i}.txt $D/office_data/{i}.s.txt output.txt",
- data=[f"office_data/{i}.txt"],
- outs=["output.txt"],
- env={
- # NOTE(cummins): Benchmark leaks when executed with safe optimizations.
- "ASAN_OPTIONS": "detect_leaks=0",
- },
- linkopts=["-lm"],
- )
-
- # NOTE(cummins): The stringsearch2 benchmark has a very long execution time.
- # Use only a single input to keep the validation time reasonable. I have
- # also observed Segmentation fault on gold standard using 4.txt and 6.txt.
- if i == 1:
- validator(
- benchmark="benchmark://cBench-v1/stringsearch2",
- cmd=f"$BIN $D/office_data/{i}.txt $D/office_data/{i}.s.txt output.txt",
- data=[f"office_data/{i}.txt"],
- outs=["output.txt"],
- env={
- # NOTE(cummins): Benchmark leaks when executed with safe optimizations.
- "ASAN_OPTIONS": "detect_leaks=0",
- },
- # TSAN disabled because of extremely long execution leading to
- # timeouts.
- sanitizers=[LlvmSanitizer.ASAN, LlvmSanitizer.MSAN, LlvmSanitizer.UBSAN],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/susan",
- cmd=f"$BIN $D/automotive_susan_data/{i}.pgm output_large.corners.pgm -c",
- data=[f"automotive_susan_data/{i}.pgm"],
- outs=["output_large.corners.pgm"],
- linkopts=["-lm"],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/tiff2bw",
- cmd=f"$BIN $D/consumer_tiff_data/{i}.tif output.tif",
- data=[f"consumer_tiff_data/{i}.tif"],
- outs=["output.tif"],
- linkopts=["-lm"],
- env={
- # NOTE(cummins): Benchmark leaks when executed with safe optimizations.
- "ASAN_OPTIONS": "detect_leaks=0",
- },
- )
-
- validator(
- benchmark="benchmark://cBench-v1/tiff2rgba",
- cmd=f"$BIN $D/consumer_tiff_data/{i}.tif output.tif",
- data=[f"consumer_tiff_data/{i}.tif"],
- outs=["output.tif"],
- linkopts=["-lm"],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/tiffdither",
- cmd=f"$BIN $D/consumer_tiff_data/{i}.bw.tif out.tif",
- data=[f"consumer_tiff_data/{i}.bw.tif"],
- outs=["out.tif"],
- linkopts=["-lm"],
- )
-
- validator(
- benchmark="benchmark://cBench-v1/tiffmedian",
- cmd=f"$BIN $D/consumer_tiff_data/{i}.nocomp.tif output.tif",
- data=[f"consumer_tiff_data/{i}.nocomp.tif"],
- outs=["output.tif"],
- linkopts=["-lm"],
- )
-
- # NOTE(cummins): On macOS the following benchmarks abort with an illegal
- # hardware instruction error.
- # if sys.platform != "darwin":
- # validator(
- # benchmark="benchmark://cBench-v1/lame",
- # cmd=f"$BIN $D/consumer_data/{i}.wav output.mp3",
- # data=[f"consumer_data/{i}.wav"],
- # outs=["output.mp3"],
- # compare_output=False,
- # linkopts=["-lm"],
- # )
-
- # NOTE(cummins): Segfault on gold standard.
- #
- # validator(
- # benchmark="benchmark://cBench-v1/ghostscript",
- # cmd="$BIN -sDEVICE=ppm -dNOPAUSE -dQUIET -sOutputFile=output.ppm -- input.ps",
- # data=[f"office_data/{i}.ps"],
- # outs=["output.ppm"],
- # linkopts=["-lm", "-lz"],
- # pre_execution_callback=setup_ghostscript_library_files(i),
- # )
diff --git a/compiler_gym/envs/llvm/llvm_benchmark.py b/compiler_gym/envs/llvm/llvm_benchmark.py
index 31d510c3e..58cc4f10e 100644
--- a/compiler_gym/envs/llvm/llvm_benchmark.py
+++ b/compiler_gym/envs/llvm/llvm_benchmark.py
@@ -107,7 +107,7 @@ def get_system_includes() -> List[Path]:
return _SYSTEM_INCLUDES
-class ClangInvocation(object):
+class ClangInvocation:
"""Class to represent a single invocation of the clang compiler."""
def __init__(
diff --git a/compiler_gym/envs/llvm/llvm_env.py b/compiler_gym/envs/llvm/llvm_env.py
index 02cab0459..12aeda589 100644
--- a/compiler_gym/envs/llvm/llvm_env.py
+++ b/compiler_gym/envs/llvm/llvm_env.py
@@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Extensions to the CompilerEnv environment for LLVM."""
-import hashlib
import os
import shutil
from pathlib import Path
@@ -461,9 +460,7 @@ def ir_sha1(self) -> str:
:return: A 40-character hexadecimal sha1 string.
"""
- # TODO(cummins): Compute this on the service-side and add it as an
- # observation space.
- return hashlib.sha1(self.ir.encode("utf-8")).hexdigest()
+ return self.observation["IrSha1"]
def write_ir(self, path: Union[Path, str]) -> Path:
"""Write the current program state to a file.
diff --git a/compiler_gym/envs/llvm/llvm_rewards.py b/compiler_gym/envs/llvm/llvm_rewards.py
index d746b092a..17d6baf8e 100644
--- a/compiler_gym/envs/llvm/llvm_rewards.py
+++ b/compiler_gym/envs/llvm/llvm_rewards.py
@@ -6,8 +6,8 @@
from typing import List, Optional
from compiler_gym.datasets import Benchmark
-from compiler_gym.service import observation_t
from compiler_gym.spaces.reward import Reward
+from compiler_gym.util.gym_type_hints import ObservationType, RewardType
from compiler_gym.views.observation import ObservationView
@@ -34,7 +34,7 @@ def __init__(self, cost_function: str, init_cost_function: str, **kwargs):
super().__init__(observation_spaces=[cost_function], **kwargs)
self.cost_function: str = cost_function
self.init_cost_function: str = init_cost_function
- self.previous_cost: Optional[observation_t] = None
+ self.previous_cost: Optional[ObservationType] = None
def reset(self, benchmark: Benchmark) -> None:
"""Called on env.reset(). Reset incremental progress."""
@@ -43,15 +43,16 @@ def reset(self, benchmark: Benchmark) -> None:
def update(
self,
- action: int,
- observations: List[observation_t],
+ actions: List[int],
+ observations: List[ObservationType],
observation_view: ObservationView,
- ) -> float:
+ ) -> RewardType:
"""Called on env.step(). Compute and return new reward."""
- cost: float = observations[0]
+ del actions # unused
+ cost: RewardType = observations[0]
if self.previous_cost is None:
self.previous_cost = observation_view[self.init_cost_function]
- reward = float(self.previous_cost - cost)
+ reward = RewardType(self.previous_cost - cost)
self.previous_cost = cost
return reward
@@ -64,7 +65,7 @@ class NormalizedReward(CostFunctionReward):
def __init__(self, **kwargs):
"""Constructor."""
super().__init__(**kwargs)
- self.cost_norm: Optional[observation_t] = None
+ self.cost_norm: Optional[ObservationType] = None
self.benchmark: Benchmark = None
def reset(self, benchmark: str) -> None:
@@ -79,16 +80,16 @@ def reset(self, benchmark: str) -> None:
def update(
self,
- action: int,
- observations: List[observation_t],
+ actions: List[int],
+ observations: List[ObservationType],
observation_view: ObservationView,
- ) -> float:
+ ) -> RewardType:
"""Called on env.step(). Compute and return new reward."""
if self.cost_norm is None:
self.cost_norm = self.get_cost_norm(observation_view)
- return super().update(action, observations, observation_view) / self.cost_norm
+ return super().update(actions, observations, observation_view) / self.cost_norm
- def get_cost_norm(self, observation_view: ObservationView) -> float:
+ def get_cost_norm(self, observation_view: ObservationView) -> RewardType:
"""Return the value used to normalize costs."""
return observation_view[self.init_cost_function]
@@ -104,7 +105,7 @@ def __init__(self, baseline_cost_function: str, **kwargs):
super().__init__(**kwargs)
self.baseline_cost_function: str = baseline_cost_function
- def get_cost_norm(self, observation_view: ObservationView) -> float:
+ def get_cost_norm(self, observation_view: ObservationView) -> RewardType:
"""Return the value used to normalize costs."""
init_cost = observation_view[self.init_cost_function]
baseline_cost = observation_view[self.baseline_cost_function]
diff --git a/compiler_gym/envs/llvm/make_specs.py b/compiler_gym/envs/llvm/make_specs.py
index 2da8952f2..8f686419b 100644
--- a/compiler_gym/envs/llvm/make_specs.py
+++ b/compiler_gym/envs/llvm/make_specs.py
@@ -18,8 +18,7 @@ def main(argv):
assert len(argv) == 3, "Usage: make_specs.py "
service_path, output_path = argv[1:]
- env = LlvmEnv(Path(service_path))
- try:
+ with LlvmEnv(Path(service_path)) as env:
with open(output_path, "w") as f:
print("from enum import Enum", file=f)
print(file=f)
@@ -30,8 +29,6 @@ def main(argv):
print("class reward_spaces(Enum):", file=f)
for name in env.reward.spaces:
print(f' {name} = "{name}"', file=f)
- finally:
- env.close()
if __name__ == "__main__":
diff --git a/compiler_gym/envs/llvm/service/ActionSpace.h b/compiler_gym/envs/llvm/service/ActionSpace.h
index 606c9e416..9a6abe50a 100644
--- a/compiler_gym/envs/llvm/service/ActionSpace.h
+++ b/compiler_gym/envs/llvm/service/ActionSpace.h
@@ -12,21 +12,26 @@ namespace compiler_gym::llvm_service {
// LLVM transforms. Generated by //compiler_gym/envs/llvm/service/passes:action-genfiles.
#include "compiler_gym/envs/llvm/service/passes/ActionEnum.h" // @donotremove
-// The available action spaces for LLVM.
-//
-// NOTE(cummins): Housekeeping rules - to add a new action space:
-// 1. Add a new entry to this LlvmActionSpace enum.
-// 2. Add a new switch case to getLlvmActionSpaceList() to return the
-// ActionSpace.
-// 3. Add a new switch case to LlvmSession::step() to compute
-// the actual action.
-// 4. Run `bazel test //compiler_gym/...` and update the newly failing tests.
+/**
+ * The available action spaces for LLVM.
+ *
+ * \note Implementation housekeeping rules - to add a new action space:
+ * 1. Add a new entry to this LlvmActionSpace enum.
+ * 2. Add a new switch case to getLlvmActionSpaceList() to return the
+ * ActionSpace.
+ * 3. Add a new switch case to LlvmSession::step() to compute
+ * the actual action.
+ * 4. Run `bazel test //compiler_gym/...` and update the newly failing tests.
+ */
enum class LlvmActionSpace {
- // The full set of transform passes for LLVM.
- PASSES_ALL,
+ PASSES_ALL, ///< The full set of transform passes for LLVM.
};
-// Get the list of LLVM action spaces.
+/**
+ * Get the list of LLVM action spaces.
+ *
+ * @return A list of ActionSpace instances.
+ */
std::vector getLlvmActionSpaceList();
} // namespace compiler_gym::llvm_service
diff --git a/compiler_gym/envs/llvm/service/BUILD b/compiler_gym/envs/llvm/service/BUILD
index 4f08d1dd5..ab7cc0b24 100644
--- a/compiler_gym/envs/llvm/service/BUILD
+++ b/compiler_gym/envs/llvm/service/BUILD
@@ -54,8 +54,8 @@ cc_binary(
name = "compiler_gym-llvm-service-prelinked",
srcs = ["RunService.cc"],
deps = [
- ":LlvmService",
- "//compiler_gym/util:RunService",
+ ":LlvmSession",
+ "//compiler_gym/service/runtime:cc_runtime",
],
)
@@ -105,6 +105,7 @@ cc_library(
deps = [
":Benchmark",
":Cost",
+ "//compiler_gym/service/proto:compiler_gym_service_cc",
"//compiler_gym/util:GrpcStatusMacros",
"//compiler_gym/util:RunfilesPath",
"//compiler_gym/util:StrLenConstexpr",
@@ -134,30 +135,9 @@ cc_library(
],
)
-cc_library(
- name = "LlvmService",
- srcs = ["LlvmService.cc"],
- hdrs = ["LlvmService.h"],
- visibility = ["//visibility:public"],
- deps = [
- ":Benchmark",
- ":BenchmarkFactory",
- ":Cost",
- ":LlvmSession",
- ":ObservationSpaces",
- "//compiler_gym/service/proto:compiler_gym_service_cc",
- "//compiler_gym/util:GrpcStatusMacros",
- "//compiler_gym/util:Version",
- "@boost//:filesystem",
- "@llvm//10.0.0",
- ],
-)
-
cc_library(
name = "LlvmSession",
- srcs = [
- "LlvmSession.cc",
- ],
+ srcs = ["LlvmSession.cc"],
hdrs = [
"LlvmSession.h",
"//compiler_gym/envs/llvm/service/passes:ActionHeaders.h",
@@ -167,12 +147,14 @@ cc_library(
"-DGOOGLE_PROTOBUF_NO_RTTI",
"-fno-rtti",
],
- visibility = ["//tests:__subpackages__"],
+ visibility = ["//visibility:public"],
deps = [
":ActionSpace",
":Benchmark",
+ ":BenchmarkFactory",
":Cost",
":ObservationSpaces",
+ "//compiler_gym/service:CompilationSession",
"//compiler_gym/service/proto:compiler_gym_service_cc_grpc",
"//compiler_gym/third_party/autophase:InstCount",
"//compiler_gym/third_party/cpuinfo",
@@ -186,7 +168,7 @@ cc_library(
"@magic_enum",
"@nlohmann_json//:json",
"@programl//programl/graph/format:node_link_graph",
- "@programl//programl/ir/llvm",
+ "@programl//programl/ir/llvm:llvm-10",
"@programl//programl/proto:programl_cc",
"@subprocess",
],
diff --git a/compiler_gym/envs/llvm/service/Benchmark.cc b/compiler_gym/envs/llvm/service/Benchmark.cc
index 9f136dffc..31799e8a4 100644
--- a/compiler_gym/envs/llvm/service/Benchmark.cc
+++ b/compiler_gym/envs/llvm/service/Benchmark.cc
@@ -12,6 +12,8 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Verifier.h"
#include "llvm/Support/SHA1.h"
namespace fs = boost::filesystem;
@@ -86,6 +88,15 @@ std::unique_ptr makeModule(llvm::LLVMContext& context, const Bitco
module->setModuleIdentifier("-");
module->setSourceFileName("-");
+ // Strip module debug info.
+ llvm::StripDebugInfo(*module);
+
+ // Erase module-level named metadata.
+ while (!module->named_metadata_empty()) {
+ llvm::NamedMDNode* nmd = &*module->named_metadata_begin();
+ module->eraseNamedMetadata(nmd);
+ }
+
return module;
} else {
*status = Status(StatusCode::INVALID_ARGUMENT,
@@ -100,7 +111,6 @@ Benchmark::Benchmark(const std::string& name, const Bitcode& bitcode,
: context_(std::make_unique()),
module_(makeModuleOrDie(*context_, bitcode, name)),
baselineCosts_(baselineCosts),
- hash_(getModuleHash(*module_)),
name_(name),
bitcodeSize_(bitcode.size()) {}
@@ -110,7 +120,6 @@ Benchmark::Benchmark(const std::string& name, std::unique_ptr
: context_(std::move(context)),
module_(std::move(module)),
baselineCosts_(baselineCosts),
- hash_(getModuleHash(*module_)),
name_(name),
bitcodeSize_(bitcodeSize) {}
@@ -122,4 +131,16 @@ std::unique_ptr Benchmark::clone(const fs::path& workingDirectory) co
return std::make_unique(name(), bitcode, workingDirectory, baselineCosts());
}
+BenchmarkHash Benchmark::module_hash() const { return getModuleHash(*module_); }
+
+Status Benchmark::verify_module() {
+ std::string errorMessage;
+ llvm::raw_string_ostream rso(errorMessage);
+ if (llvm::verifyModule(module(), &rso)) {
+ rso.flush();
+ return Status(StatusCode::DATA_LOSS, "Failed to verify module: " + errorMessage);
+ }
+ return Status::OK;
+}
+
} // namespace compiler_gym::llvm_service
diff --git a/compiler_gym/envs/llvm/service/Benchmark.h b/compiler_gym/envs/llvm/service/Benchmark.h
index 713c65b99..4074cd788 100644
--- a/compiler_gym/envs/llvm/service/Benchmark.h
+++ b/compiler_gym/envs/llvm/service/Benchmark.h
@@ -17,61 +17,138 @@
namespace compiler_gym::llvm_service {
-// We identify benchmarks using a hash of the LLVM module, which is a
-// 160 bits SHA1.
-//
-// NOTE(cummins): In the future when we extend this to support optimizing for
-// performance, we would need this
+/**
+ * A 160 bits SHA1 that identifies an LLVM module.
+ */
using BenchmarkHash = llvm::ModuleHash;
+/**
+ * A bitcode.
+ */
using Bitcode = llvm::SmallString<0>;
+/**
+ * Read a bitcode file from disk.
+ *
+ * @param path The path of the bitcode file to read.
+ * @param bitcode The destination bitcode.
+ * @return `OK` on success, `NOT_FOUND` if the file is not found, or
+ * `INVALID_ARGUMENT` if the file is invalid.
+ */
grpc::Status readBitcodeFile(const boost::filesystem::path& path, Bitcode* bitcode);
-// Parses the given bitcode into a module and strips the identifying ModuleID
-// and source_filename attributes. Returns nullptr on error and sets status.
+/**
+ * Construct an LLVM module from a bitcode.
+ *
+ * Parses the given bitcode into a module and strips the identifying `ModuleID`
+ * and `source_filename` attributes.
+ *
+ * @param context An LLVM context for the new module.
+ * @param bitcode The bitcode to parse.
+ * @param name The name of the module.
+ * @param status An error status that is set to `OK` on success or
+ * `INVALID_ARGUMENT` if the bitcode cannot be parsed.
+ * @return A unique pointer to an LLVM module, or `nullptr` on error and sets
+ * `status`.
+ */
std::unique_ptr makeModule(llvm::LLVMContext& context, const Bitcode& bitcode,
const std::string& name, grpc::Status* status);
-// A benchmark is an LLVM module and the LLVM context that owns it. A benchmark
-// is mutable and can be changed over the course of a session.
+/**
+ * An LLVM module and the LLVM context that owns it.
+ *
+ * A benchmark is mutable and can be changed over the course of a session.
+ */
class Benchmark {
public:
+ /**
+ * Construct a benchmark from a bitcode.
+ */
Benchmark(const std::string& name, const Bitcode& bitcode,
const boost::filesystem::path& workingDirectory, const BaselineCosts& baselineCosts);
+ /**
+ * Construct a benchmark from an LLVM module.
+ */
Benchmark(const std::string& name, std::unique_ptr context,
std::unique_ptr module, size_t bitcodeSize,
const boost::filesystem::path& workingDirectory, const BaselineCosts& baselineCosts);
- // Make a copy of the benchmark.
+ /**
+ * Make a copy of the benchmark.
+ *
+ * @param workingDirectory The working directory for the new benchmark.
+ * @return A copy of the benchmark.
+ */
std::unique_ptr clone(const boost::filesystem::path& workingDirectory) const;
+ /**
+ * Compute and return a SHA1 hash of the module.
+ *
+ * @return A SHA1 hash of the module.
+ */
+ BenchmarkHash module_hash() const;
+
+ /**
+ * Wrapper around `llvm::verifyModule()` which returns an error status on
+ * failure.
+ *
+ * @return `OK` on success, else `DATA_LOSS` if verification fails.
+ */
+ grpc::Status verify_module();
+
+ /**
+ * The name of the benchmark.
+ */
inline const std::string& name() const { return name_; }
+ /**
+ * The size of the bitcode that was parsed to produce the initial benchmark.
+ */
inline const size_t bitcodeSize() const { return bitcodeSize_; }
+ /**
+ * The underlying LLVM module.
+ */
inline llvm::Module& module() { return *module_; }
+ /**
+ * The underlying LLVM module.
+ */
inline const llvm::Module& module() const { return *module_; }
+ /**
+ * The underlying LLVM context.
+ */
inline llvm::LLVMContext& context() { return *context_; }
+ /**
+ * The underlying LLVM context.
+ */
inline const llvm::LLVMContext& context() const { return *context_; }
inline const BaselineCosts& baselineCosts() const { return baselineCosts_; }
// Accessors for the underlying raw pointers.
+
+ /**
+ * A pointer to the underlying LLVM context.
+ */
inline const llvm::LLVMContext* context_ptr() const { return context_.get(); }
+ /**
+ * A pointer to the underlying LLVM module.
+ */
inline const llvm::Module* module_ptr() const { return module_.get(); }
- inline const BenchmarkHash hash() const { return hash_; }
-
- // Replace the benchmark module with a new one. This is to enable
- // out-of-process modification of the IR by serializing the benchmark to a
- // file, modifying the file, then loading the modified file and updating the
- // module pointer here.
+ /** Replace the benchmark module with a new one.
+ *
+ * This is to enable out-of-process modification of the IR by serializing the
+ * benchmark to a file, modifying the file, then loading the modified file and
+ * updating the module pointer here.
+ *
+ * @param module A new module.
+ */
inline void replaceModule(std::unique_ptr module) { module_ = std::move(module); }
private:
@@ -81,7 +158,6 @@ class Benchmark {
std::unique_ptr context_;
std::unique_ptr module_;
const BaselineCosts baselineCosts_;
- const BenchmarkHash hash_;
const std::string name_;
// The length of the bitcode string for this benchmark.
const size_t bitcodeSize_;
diff --git a/compiler_gym/envs/llvm/service/BenchmarkFactory.cc b/compiler_gym/envs/llvm/service/BenchmarkFactory.cc
index 73509654c..d63af6d32 100644
--- a/compiler_gym/envs/llvm/service/BenchmarkFactory.cc
+++ b/compiler_gym/envs/llvm/service/BenchmarkFactory.cc
@@ -23,6 +23,8 @@ namespace fs = boost::filesystem;
using grpc::Status;
using grpc::StatusCode;
+using BenchmarkProto = compiler_gym::Benchmark;
+
namespace compiler_gym::llvm_service {
BenchmarkFactory::BenchmarkFactory(const boost::filesystem::path& workingDirectory,
@@ -35,16 +37,43 @@ BenchmarkFactory::BenchmarkFactory(const boost::filesystem::path& workingDirecto
VLOG(2) << "BenchmarkFactory initialized";
}
-Status BenchmarkFactory::getBenchmark(const std::string& uri,
+Status BenchmarkFactory::getBenchmark(const BenchmarkProto& benchmarkMessage,
std::unique_ptr* benchmark) {
// Check if the benchmark has already been loaded into memory.
- auto loaded = benchmarks_.find(uri);
+ auto loaded = benchmarks_.find(benchmarkMessage.uri());
if (loaded != benchmarks_.end()) {
*benchmark = loaded->second.clone(workingDirectory_);
return Status::OK;
}
- return Status(StatusCode::NOT_FOUND, "Benchmark not found");
+ // Benchmark not cached, cache it and try again.
+ const auto& programFile = benchmarkMessage.program();
+ switch (programFile.data_case()) {
+ case compiler_gym::File::DataCase::kContents: {
+ RETURN_IF_ERROR(addBitcode(
+ benchmarkMessage.uri(),
+ llvm::SmallString<0>(programFile.contents().begin(), programFile.contents().end())));
+ break;
+ }
+ case compiler_gym::File::DataCase::kUri: {
+ // Check the protocol of the benchmark URI.
+ if (programFile.uri().find("file:///") != 0) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ fmt::format("Invalid benchmark data URI. "
+ "Only the file:/// protocol is supported: \"{}\"",
+ programFile.uri()));
+ }
+
+ const fs::path path(programFile.uri().substr(util::strLen("file:///"), std::string::npos));
+ RETURN_IF_ERROR(addBitcode(benchmarkMessage.uri(), path));
+ break;
+ }
+ case compiler_gym::File::DataCase::DATA_NOT_SET:
+ return Status(StatusCode::INVALID_ARGUMENT, fmt::format("No program set in Benchmark:\n{}",
+ benchmarkMessage.DebugString()));
+ }
+
+ return getBenchmark(benchmarkMessage, benchmark);
}
Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitcode) {
diff --git a/compiler_gym/envs/llvm/service/BenchmarkFactory.h b/compiler_gym/envs/llvm/service/BenchmarkFactory.h
index 8aef66dbd..f274641af 100644
--- a/compiler_gym/envs/llvm/service/BenchmarkFactory.h
+++ b/compiler_gym/envs/llvm/service/BenchmarkFactory.h
@@ -15,51 +15,99 @@
#include "boost/filesystem.hpp"
#include "compiler_gym/envs/llvm/service/Benchmark.h"
+#include "compiler_gym/service/proto/compiler_gym_service.pb.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
namespace compiler_gym::llvm_service {
-// Benchmarks are loaded from disk and cached in-memory so that future uses
-// do not require a disk access. The number of benchmarks that may be
-// simultaneously loaded is limited by the combined size of the bitcodes, in
-// bytes. Once this size is reached, benchmarks are offloaded so that they must
-// be re-read from disk.
+/**
+ * Maximum number of bytes before benchmark cache eviction.
+ *
+ * Benchmarks are loaded from disk and cached in-memory so that future uses do
+ * not require a disk access. The number of benchmarks that may be
+ * simultaneously loaded is limited by the combined size of the bitcodes, in
+ * bytes. Once this size is reached, benchmarks are offloaded so that they must
+ * be re-read from disk.
+ */
constexpr size_t kMaxLoadedBenchmarkSize = 512 * 1024 * 1024;
-// A factory object for instantiating LLVM modules for use in optimization
-// sessions. Example usage:
-//
-// BenchmarkFactory factory;
-// auto benchmark = factory.getBenchmark("file:////tmp/my_bitcode.bc");
-// // ... do fun stuff
+/**
+ * A factory object for instantiating LLVM modules for use in optimization
+ * sessions.
+ *
+ * Example usage:
+ *
+ * \code{.cpp}
+ * BenchmarkFactory factory;
+ * auto benchmark = factory.getBenchmark("file:////tmp/my_bitcode.bc");
+ * // ... do fun stuff
+ * \endcode
+ */
class BenchmarkFactory {
public:
- // Construct a benchmark factory. rand is a random seed used to control the
- // selection of random benchmarks. maxLoadedBenchmarkSize is the maximum
- // combined size of the bitcodes that may be cached in memory. Once this
- // size is reached, benchmarks are offloaded so that they must be re-read from
- // disk.
- BenchmarkFactory(const boost::filesystem::path& workingDirectory,
- std::optional rand = std::nullopt,
- size_t maxLoadedBenchmarkSize = kMaxLoadedBenchmarkSize);
+ /**
+ * Return the global benchmark factory singleton.
+ *
+ * @param workingDirectory The working directory.
+ * @param rand An optional random number generator. This is used for cache
+ * evictions.
+ * @param maxLoadedBenchmarkSize The maximum size in bytes of the benchmark
+ * cache before evictions.
+ * @return The benchmark factory singleton instance.
+ */
+ static BenchmarkFactory& getSingleton(const boost::filesystem::path& workingDirectory,
+ std::optional rand = std::nullopt,
+ size_t maxLoadedBenchmarkSize = kMaxLoadedBenchmarkSize) {
+ static BenchmarkFactory instance(workingDirectory, rand, maxLoadedBenchmarkSize);
+ return instance;
+ }
- // Get the requested named benchmark.
- [[nodiscard]] grpc::Status getBenchmark(const std::string& uri,
+ /**
+ * Get the requested named benchmark.
+ *
+ * @param benchmarkMessage A Benchmark protocol message.
+ * @param benchmark A benchmark instance to assign this benchmark to.
+ * @return `OK` on success, or `INVALID_ARGUMENT` if the protocol message is
+ * invalid.
+ */
+ [[nodiscard]] grpc::Status getBenchmark(const compiler_gym::Benchmark& benchmarkMessage,
std::unique_ptr* benchmark);
+ private:
[[nodiscard]] grpc::Status addBitcode(const std::string& uri, const Bitcode& bitcode);
[[nodiscard]] grpc::Status addBitcode(const std::string& uri,
const boost::filesystem::path& path);
- private:
- // A mapping from URI to benchmarks which have been loaded into memory.
+ /**
+ * Construct a benchmark factory.
+ *
+ * @param workingDirectory A filesystem directory to use for storing temporary
+ * files.
+ * @param rand is a random seed used to control the selection of random
+ * benchmarks.
+ * @param maxLoadedBenchmarkSize is the maximum combined size of the bitcodes
+ * that may be cached in memory. Once this size is reached, benchmarks are
+ * offloaded so that they must be re-read from disk.
+ */
+ BenchmarkFactory(const boost::filesystem::path& workingDirectory,
+ std::optional rand = std::nullopt,
+ size_t maxLoadedBenchmarkSize = kMaxLoadedBenchmarkSize);
+
+ BenchmarkFactory(const BenchmarkFactory&) = delete;
+ BenchmarkFactory& operator=(const BenchmarkFactory&) = delete;
+
+ /**
+ * A mapping from URI to benchmarks which have been loaded into memory.
+ */
std::unordered_map benchmarks_;
const boost::filesystem::path workingDirectory_;
std::mt19937_64 rand_;
- // The current and maximum allowed sizes of the loaded benchmarks.
+ /**
+ * The current and maximum allowed sizes of the loaded benchmarks.
+ */
size_t loadedBenchmarksSize_;
const size_t maxLoadedBenchmarkSize_;
};
diff --git a/compiler_gym/envs/llvm/service/Cost.h b/compiler_gym/envs/llvm/service/Cost.h
index 4b520cc97..c5a467100 100644
--- a/compiler_gym/envs/llvm/service/Cost.h
+++ b/compiler_gym/envs/llvm/service/Cost.h
@@ -14,22 +14,35 @@
namespace compiler_gym::llvm_service {
+/**
+ * A cost function for LLVM benchmarks.
+ */
enum class LlvmCostFunction {
- // The number of instructions in the LLVM-IR module. This is fast to compute
- // and deterministic.
+ /**
+ * The number of instructions in the LLVM-IR module.
+ *
+ * IR instruction count is fast to compute and deterministic.
+ */
IR_INSTRUCTION_COUNT,
- // Returns the size (in bytes) of the .TEXT section of the compiled module.
+ /**
+ * Returns the size (in bytes) of the .TEXT section of the compiled module.
+ */
OBJECT_TEXT_SIZE_BYTES,
#ifdef COMPILER_GYM_EXPERIMENTAL_TEXT_SIZE_COST
- // Returns the size (in bytes) of the .TEXT section of the compiled binary.
+ /**
+ * Returns the size (in bytes) of the .TEXT section of the compiled binary.
+ */
TEXT_SIZE_BYTES,
#endif
};
+/**
+ * LLVM's builtin policies.
+ */
enum class LlvmBaselinePolicy {
- O0, // No optimizations.
- O3, // -O3 optimizations.
- Oz, // -Oz optimizations.
+ O0, ///< No optimizations.
+ O3, ///< `-O3` optimizations.
+ Oz, ///< `-Oz` optimizations.
};
constexpr size_t numCosts = magic_enum::enum_count();
@@ -38,20 +51,41 @@ constexpr size_t numBaselineCosts = magic_enum::enum_count()
using BaselineCosts = std::array;
using PreviousCosts = std::array, numCosts>;
-// TODO(cummins): Refactor cost calculation to allow graceful error handling
-// by returning a grpc::Status.
-
-// Compute the cost using a given cost function. A lower cost is better.
+/**
+ * Compute the cost using a given cost function. A lower cost is better.
+ *
+ * @param costFunction The cost function to use.
+ * @param module The module to compute the cost for.
+ * @param workingDirectory A directory that can be used for temporary file
+ * storage.
+ * @param cost The cost to write.
+ * @return `OK` on success.
+ */
[[nodiscard]] grpc::Status setCost(const LlvmCostFunction& costFunction, llvm::Module& module,
const boost::filesystem::path& workingDirectory, double* cost);
-// Return a baseline cost.
+/**
+ * Return a baseline cost.
+ *
+ * @param baselineCosts The baseline costs list.
+ * @param policy The baseline policy to return the cost of.
+ * @param cost The cost function to use.
+ * @return A cost.
+ */
double getBaselineCost(const BaselineCosts& baselineCosts, LlvmBaselinePolicy policy,
LlvmCostFunction cost);
-// Compute the costs of baseline policies. The unoptimizedModule parameter is
-// unmodified, but is not const because various LLVM API calls require a mutable
-// reference.
+/**
+ * Compute the costs of baseline policies.
+ *
+ * \note The `unoptimizedModule` parameter is unmodified, but is not const
+ * because various LLVM API calls require a mutable reference.
+ *
+ * @param unoptimizedModule The module to compute the baseline costs of.
+ * @param baselineCosts The costs to write.
+ * @param workingDirectory A directory that can be used for temporary file
+ * storage.
+ */
[[nodiscard]] grpc::Status setBaselineCosts(llvm::Module& unoptimizedModule,
BaselineCosts* baselineCosts,
const boost::filesystem::path& workingDirectory);
diff --git a/compiler_gym/envs/llvm/service/LlvmService.cc b/compiler_gym/envs/llvm/service/LlvmService.cc
deleted file mode 100644
index 5627deb8b..000000000
--- a/compiler_gym/envs/llvm/service/LlvmService.cc
+++ /dev/null
@@ -1,196 +0,0 @@
-// Copyright (c) Facebook, Inc. and its affiliates.
-//
-// This source code is licensed under the MIT license found in the
-// LICENSE file in the root directory of this source tree.
-#include "compiler_gym/envs/llvm/service/LlvmService.h"
-
-#include
-
-#include
-#include
-
-#include "compiler_gym/envs/llvm/service/ActionSpace.h"
-#include "compiler_gym/envs/llvm/service/ObservationSpaces.h"
-#include "compiler_gym/service/proto/compiler_gym_service.pb.h"
-#include "compiler_gym/util/EnumUtil.h"
-#include "compiler_gym/util/GrpcStatusMacros.h"
-#include "compiler_gym/util/StrLenConstexpr.h"
-#include "compiler_gym/util/Version.h"
-#include "llvm/ADT/Triple.h"
-#include "llvm/Config/llvm-config.h"
-
-namespace compiler_gym::llvm_service {
-
-using grpc::ServerContext;
-using grpc::Status;
-using grpc::StatusCode;
-namespace fs = boost::filesystem;
-
-LlvmService::LlvmService(const fs::path& workingDirectory)
- : workingDirectory_(workingDirectory), benchmarkFactory_(workingDirectory), nextSessionId_(0) {}
-
-Status LlvmService::GetVersion(ServerContext* /* unused */, const GetVersionRequest* /* unused */,
- GetVersionReply* reply) {
- VLOG(2) << "GetSpaces()";
- reply->set_service_version(COMPILER_GYM_VERSION);
- std::stringstream ss;
- ss << LLVM_VERSION_STRING << " " << llvm::Triple::normalize(LLVM_DEFAULT_TARGET_TRIPLE);
- reply->set_compiler_version(ss.str());
- return Status::OK;
-}
-
-Status LlvmService::GetSpaces(ServerContext* /* unused */, const GetSpacesRequest* /* unused */,
- GetSpacesReply* reply) {
- VLOG(2) << "GetSpaces()";
- const auto actionSpaces = getLlvmActionSpaceList();
- *reply->mutable_action_space_list() = {actionSpaces.begin(), actionSpaces.end()};
- const auto observationSpaces = getLlvmObservationSpaceList();
- *reply->mutable_observation_space_list() = {observationSpaces.begin(), observationSpaces.end()};
-
- return Status::OK;
-}
-
-Status LlvmService::StartSession(ServerContext* /* unused */, const StartSessionRequest* request,
- StartSessionReply* reply) {
- const std::lock_guard lock(sessionsMutex_);
-
- if (!request->benchmark().size()) {
- return Status(StatusCode::INVALID_ARGUMENT, "No benchmark URI set for StartSession()");
- }
-
- std::unique_ptr benchmark;
- RETURN_IF_ERROR(benchmarkFactory_.getBenchmark(request->benchmark(), &benchmark));
-
- reply->set_benchmark(benchmark->name());
- VLOG(1) << "StartSession(" << benchmark->name() << "), [" << nextSessionId_ << "]";
-
- LlvmActionSpace actionSpace;
- RETURN_IF_ERROR(util::intToEnum(request->action_space(), &actionSpace));
-
- // Construct the environment.
- auto session =
- std::make_unique(std::move(benchmark), actionSpace, workingDirectory_);
-
- // Compute the initial observations.
- for (int i = 0; i < request->observation_space_size(); ++i) {
- LlvmObservationSpace observationSpace;
- RETURN_IF_ERROR(util::intToEnum(request->observation_space(i), &observationSpace));
- auto observation = reply->add_observation();
- RETURN_IF_ERROR(session->getObservation(observationSpace, observation));
- }
-
- reply->set_session_id(nextSessionId_);
- sessions_[nextSessionId_] = std::move(session);
- ++nextSessionId_;
-
- return Status::OK;
-}
-
-Status LlvmService::ForkSession(ServerContext* /* unused */, const ForkSessionRequest* request,
- ForkSessionReply* reply) {
- const std::lock_guard lock(sessionsMutex_);
-
- LlvmSession* environment;
- RETURN_IF_ERROR(session(request->session_id(), &environment));
- VLOG(1) << "ForkSession(" << request->session_id() << "), [" << nextSessionId_ << "]";
-
- // Construct the environment.
- reply->set_session_id(nextSessionId_);
- sessions_[nextSessionId_] =
- std::make_unique(environment->benchmark().clone(environment->workingDirectory()),
- environment->actionSpace(), environment->workingDirectory());
-
- ++nextSessionId_;
-
- return Status::OK;
-}
-
-Status LlvmService::EndSession(grpc::ServerContext* /* unused */, const EndSessionRequest* request,
- EndSessionReply* reply) {
- const std::lock_guard lock(sessionsMutex_);
-
- // Note that unlike the other methods, no error is thrown if the requested
- // session does not exist.
- if (sessions_.find(request->session_id()) != sessions_.end()) {
- const LlvmSession* environment;
- RETURN_IF_ERROR(session(request->session_id(), &environment));
- VLOG(1) << "Step " << environment->actionCount() << " EndSession("
- << environment->benchmark().name() << "), [" << request->session_id() << "]";
-
- sessions_.erase(request->session_id());
- }
-
- reply->set_remaining_sessions(sessions_.size());
- return Status::OK;
-}
-
-Status LlvmService::Step(ServerContext* /* unused */, const StepRequest* request,
- StepReply* reply) {
- LlvmSession* environment;
- RETURN_IF_ERROR(session(request->session_id(), &environment));
-
- VLOG(2) << "Step " << environment->actionCount() << " Step()";
- return environment->step(*request, reply);
-}
-
-Status LlvmService::AddBenchmark(ServerContext* /* unused */, const AddBenchmarkRequest* request,
- AddBenchmarkReply* reply) {
- VLOG(2) << "AddBenchmark()";
- for (int i = 0; i < request->benchmark_size(); ++i) {
- RETURN_IF_ERROR(addBenchmark(request->benchmark(i)));
- }
-
- return Status::OK;
-}
-
-Status LlvmService::addBenchmark(const ::compiler_gym::Benchmark& request) {
- const std::string& uri = request.uri();
- if (!uri.size()) {
- return Status(StatusCode::INVALID_ARGUMENT, "Benchmark must have a URI");
- }
-
- const auto& programFile = request.program();
- switch (programFile.data_case()) {
- case ::compiler_gym::File::DataCase::kContents:
- return benchmarkFactory_.addBitcode(
- uri, llvm::SmallString<0>(programFile.contents().begin(), programFile.contents().end()));
- case ::compiler_gym::File::DataCase::kUri: {
- // Check that protocol of the benmchmark URI.
- if (programFile.uri().find("file:///") != 0) {
- return Status(StatusCode::INVALID_ARGUMENT,
- fmt::format("Invalid benchmark data URI. "
- "Only the file:/// protocol is supported: \"{}\"",
- programFile.uri()));
- }
-
- const fs::path path(programFile.uri().substr(util::strLen("file:///"), std::string::npos));
- return benchmarkFactory_.addBitcode(uri, path);
- }
- case ::compiler_gym::File::DataCase::DATA_NOT_SET:
- return Status(StatusCode::INVALID_ARGUMENT, "No program set");
- }
-
- return Status::OK;
-}
-
-Status LlvmService::session(uint64_t id, LlvmSession** environment) {
- auto it = sessions_.find(id);
- if (it == sessions_.end()) {
- return Status(StatusCode::NOT_FOUND, fmt::format("Session not found: {}", id));
- }
-
- *environment = it->second.get();
- return Status::OK;
-}
-
-Status LlvmService::session(uint64_t id, const LlvmSession** environment) const {
- auto it = sessions_.find(id);
- if (it == sessions_.end()) {
- return Status(StatusCode::NOT_FOUND, fmt::format("Session not found: {}", id));
- }
-
- *environment = it->second.get();
- return Status::OK;
-}
-
-} // namespace compiler_gym::llvm_service
diff --git a/compiler_gym/envs/llvm/service/LlvmService.h b/compiler_gym/envs/llvm/service/LlvmService.h
deleted file mode 100644
index 4e321e1f9..000000000
--- a/compiler_gym/envs/llvm/service/LlvmService.h
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright (c) Facebook, Inc. and its affiliates.
-//
-// This source code is licensed under the MIT license found in the
-// LICENSE file in the root directory of this source tree.
-#pragma once
-
-#include
-
-#include
-#include
-
-#include "boost/filesystem.hpp"
-#include "compiler_gym/envs/llvm/service/Benchmark.h"
-#include "compiler_gym/envs/llvm/service/BenchmarkFactory.h"
-#include "compiler_gym/envs/llvm/service/LlvmSession.h"
-#include "compiler_gym/service/proto/compiler_gym_service.grpc.pb.h"
-#include "compiler_gym/service/proto/compiler_gym_service.pb.h"
-
-namespace compiler_gym::llvm_service {
-
-// RPC service for LLVM.
-class LlvmService final : public CompilerGymService::Service {
- public:
- explicit LlvmService(const boost::filesystem::path& workingDirectory);
-
- // RPC endpoints.
- grpc::Status GetVersion(grpc::ServerContext* context, const GetVersionRequest* request,
- GetVersionReply* reply) final override;
-
- grpc::Status GetSpaces(grpc::ServerContext* context, const GetSpacesRequest* request,
- GetSpacesReply* reply) final override;
-
- grpc::Status StartSession(grpc::ServerContext* context, const StartSessionRequest* request,
- StartSessionReply* reply) final override;
-
- grpc::Status ForkSession(grpc::ServerContext* context, const ForkSessionRequest* request,
- ForkSessionReply* reply) final override;
-
- grpc::Status EndSession(grpc::ServerContext* context, const EndSessionRequest* request,
- EndSessionReply* reply) final override;
-
- // NOTE: Step() is not thread safe. The underlying assumption is that each
- // LlvmSession is managed by a single thread, so race conditions between
- // operations that affect the same LlvmSession are not protected against.
- grpc::Status Step(grpc::ServerContext* context, const StepRequest* request,
- StepReply* reply) final override;
-
- grpc::Status AddBenchmark(grpc::ServerContext* context, const AddBenchmarkRequest* request,
- AddBenchmarkReply* reply) final override;
-
- protected:
- grpc::Status session(uint64_t id, LlvmSession** environment);
- grpc::Status session(uint64_t id, const LlvmSession** environment) const;
-
- grpc::Status addBenchmark(const ::compiler_gym::Benchmark& request);
-
- private:
- const boost::filesystem::path workingDirectory_;
- std::unordered_map> sessions_;
- // Mutex used to ensure thread safety of creation and destruction of sessions.
- std::mutex sessionsMutex_;
- BenchmarkFactory benchmarkFactory_;
- uint64_t nextSessionId_;
-};
-
-} // namespace compiler_gym::llvm_service
diff --git a/compiler_gym/envs/llvm/service/LlvmSession.cc b/compiler_gym/envs/llvm/service/LlvmSession.cc
index 5a1c281f3..0052a87c6 100644
--- a/compiler_gym/envs/llvm/service/LlvmSession.cc
+++ b/compiler_gym/envs/llvm/service/LlvmSession.cc
@@ -8,12 +8,16 @@
#include
#include
+#include
#include
#include
#include "boost/filesystem.hpp"
#include "compiler_gym/envs/llvm/service/ActionSpace.h"
+#include "compiler_gym/envs/llvm/service/Benchmark.h"
+#include "compiler_gym/envs/llvm/service/BenchmarkFactory.h"
#include "compiler_gym/envs/llvm/service/Cost.h"
+#include "compiler_gym/envs/llvm/service/ObservationSpaces.h"
#include "compiler_gym/envs/llvm/service/passes/ActionHeaders.h"
#include "compiler_gym/envs/llvm/service/passes/ActionSwitch.h"
#include "compiler_gym/third_party/autophase/InstCount.h"
@@ -23,11 +27,9 @@
#include "compiler_gym/util/RunfilesPath.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/CodeGen/Passes.h"
-#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
-#include "llvm/IR/Verifier.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/TargetSelect.h"
@@ -44,6 +46,9 @@ using grpc::Status;
using grpc::StatusCode;
using nlohmann::json;
+using BenchmarkProto = compiler_gym::Benchmark;
+using ActionSpaceProto = compiler_gym::ActionSpace;
+
namespace {
// Return the target library information for a module.
@@ -52,18 +57,6 @@ llvm::TargetLibraryInfoImpl getTargetLibraryInfo(llvm::Module& module) {
return llvm::TargetLibraryInfoImpl(triple);
}
-// Wrapper around llvm::verifyModule() which raises the given exception type
-// on failure.
-Status verifyModuleStatus(const llvm::Module& module) {
- std::string errorMessage;
- llvm::raw_string_ostream rso(errorMessage);
- if (llvm::verifyModule(module, &rso)) {
- rso.flush();
- return Status(StatusCode::DATA_LOSS, "Failed to verify module: " + errorMessage);
- }
- return Status::OK;
-}
-
void initLlvm() {
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
@@ -118,59 +111,98 @@ Status writeBitcodeToFile(const llvm::Module& module, const fs::path& path) {
} // anonymous namespace
-LlvmSession::LlvmSession(std::unique_ptr benchmark, LlvmActionSpace actionSpace,
- const boost::filesystem::path& workingDirectory)
- : workingDirectory_(workingDirectory),
- benchmark_(std::move(benchmark)),
- actionSpace_(actionSpace),
- tlii_(getTargetLibraryInfo(benchmark_->module())),
- actionCount_(0) {
- // Initialize LLVM.
- initLlvm();
+std::string LlvmSession::getCompilerVersion() const {
+ std::stringstream ss;
+ ss << LLVM_VERSION_STRING << " " << llvm::Triple::normalize(LLVM_DEFAULT_TARGET_TRIPLE);
+ return ss.str();
+}
- // Initialize cpuinfo
+std::vector LlvmSession::getActionSpaces() const { return getLlvmActionSpaceList(); }
+
+std::vector LlvmSession::getObservationSpaces() const {
+ return getLlvmObservationSpaceList();
+}
+
+LlvmSession::LlvmSession(const boost::filesystem::path& workingDirectory)
+ : CompilationSession(workingDirectory),
+ observationSpaceNames_(util::createPascalCaseToEnumLookupTable()) {
+ initLlvm();
cpuinfo_initialize();
+}
- // Strip module debug info.
- llvm::StripDebugInfo(benchmark_->module());
+Status LlvmSession::init(const ActionSpace& actionSpace, const BenchmarkProto& benchmark) {
+ BenchmarkFactory& benchmarkFactory = BenchmarkFactory::getSingleton(workingDirectory());
- // Erase module-level named metadata.
- while (!benchmark_->module().named_metadata_empty()) {
- llvm::NamedMDNode* nmd = &*benchmark_->module().named_metadata_begin();
- benchmark_->module().eraseNamedMetadata(nmd);
- }
+ // Get the benchmark or return an error.
+ std::unique_ptr llvmBenchmark;
+ RETURN_IF_ERROR(benchmarkFactory.getBenchmark(benchmark, &llvmBenchmark));
+
+ // Verify the benchmark now to catch errors early.
+ RETURN_IF_ERROR(llvmBenchmark->verify_module());
+
+ LlvmActionSpace actionSpaceEnum;
+ RETURN_IF_ERROR(util::pascalCaseToEnum(actionSpace.name(), &actionSpaceEnum));
+
+ return init(actionSpaceEnum, std::move(llvmBenchmark));
+}
+
+Status LlvmSession::init(CompilationSession* other) {
+ // TODO: Static cast?
+ auto llvmOther = static_cast(other);
+ return init(llvmOther->actionSpace(), llvmOther->benchmark().clone(workingDirectory()));
+}
+
+Status LlvmSession::init(const LlvmActionSpace& actionSpace, std::unique_ptr benchmark) {
+ benchmark_ = std::move(benchmark);
+ actionSpace_ = actionSpace;
+
+ tlii_ = getTargetLibraryInfo(benchmark_->module());
// Verify the module now to catch any problems early.
- CHECK(verifyModuleStatus(benchmark_->module()).ok());
+ return Status::OK;
}
-Status LlvmSession::step(const StepRequest& request, StepReply* reply) {
- // Apply the requested actions.
- actionCount_ += request.action_size();
+Status LlvmSession::applyAction(const Action& action, bool& endOfEpisode,
+ std::optional& newActionSpace,
+ bool& actionHadNoEffect) {
+ DCHECK(benchmark_) << "Calling applyAction() before init()";
+
+ // Apply the requested action.
switch (actionSpace()) {
case LlvmActionSpace::PASSES_ALL:
- for (int i = 0; i < request.action_size(); ++i) {
- LlvmAction action;
- RETURN_IF_ERROR(util::intToEnum(request.action(i), &action));
- RETURN_IF_ERROR(runAction(action, reply));
- }
+ LlvmAction actionEnum;
+ RETURN_IF_ERROR(util::intToEnum(action.action(), &actionEnum));
+ RETURN_IF_ERROR(applyPassAction(actionEnum, actionHadNoEffect));
}
- // Fail now if we have broken something.
- RETURN_IF_ERROR(verifyModuleStatus(benchmark().module()));
+ return Status::OK;
+}
- // Compute the requested observations.
- for (int i = 0; i < request.observation_space_size(); ++i) {
- LlvmObservationSpace observationSpace;
- RETURN_IF_ERROR(util::intToEnum(request.observation_space(i), &observationSpace));
- auto observation = reply->add_observation();
- RETURN_IF_ERROR(getObservation(observationSpace, observation));
+Status LlvmSession::endOfStep(bool actionHadNoEffect, bool& endOfEpisode,
+ std::optional& newActionSpace) {
+ if (actionHadNoEffect) {
+ return Status::OK;
+ } else {
+ return benchmark().verify_module();
}
+}
+
+Status LlvmSession::computeObservation(const ObservationSpace& observationSpace,
+ Observation& observation) {
+ DCHECK(benchmark_) << "Calling computeObservation() before init()";
+ const auto& it = observationSpaceNames_.find(observationSpace.name());
+ if (it == observationSpaceNames_.end()) {
+ return Status(
+ StatusCode::INVALID_ARGUMENT,
+ fmt::format("Could not interpret observation space name: {}", observationSpace.name()));
+ }
+ const LlvmObservationSpace observationSpaceEnum = it->second;
+ RETURN_IF_ERROR(computeObservation(observationSpaceEnum, observation));
return Status::OK;
}
-Status LlvmSession::runAction(LlvmAction action, StepReply* reply) {
+Status LlvmSession::applyPassAction(LlvmAction action, bool& actionHadNoEffect) {
#ifdef EXPERIMENTAL_UNSTABLE_GVN_SINK_PASS
// NOTE(https://github.com/facebookresearch/CompilerGym/issues/46): The
// -gvn-sink pass has been found to have nondeterministic behavior so has
@@ -178,28 +210,27 @@ Status LlvmSession::runAction(LlvmAction action, StepReply* reply) {
// the command line was found to produce more stable results.
if (action == LlvmAction::GVNSINK_PASS) {
RETURN_IF_ERROR(runOptWithArgs({"-gvn-sink"}));
- reply->set_action_had_no_effect(true);
+ actionHadNoEffect = true;
return Status::OK;
}
#endif
// Use the generated HANDLE_PASS() switch statement to dispatch to runPass().
-#define HANDLE_PASS(pass) runPass(pass, reply);
+#define HANDLE_PASS(pass) actionHadNoEffect = !runPass(pass);
HANDLE_ACTION(action, HANDLE_PASS)
#undef HANDLE_PASS
return Status::OK;
}
-void LlvmSession::runPass(llvm::Pass* pass, StepReply* reply) {
+bool LlvmSession::runPass(llvm::Pass* pass) {
llvm::legacy::PassManager passManager;
setupPassManager(&passManager, pass);
- const bool changed = passManager.run(benchmark().module());
- reply->set_action_had_no_effect(!changed);
+ return passManager.run(benchmark().module());
}
-void LlvmSession::runPass(llvm::FunctionPass* pass, StepReply* reply) {
+bool LlvmSession::runPass(llvm::FunctionPass* pass) {
llvm::legacy::FunctionPassManager passManager(&benchmark().module());
setupPassManager(&passManager, pass);
@@ -208,13 +239,13 @@ void LlvmSession::runPass(llvm::FunctionPass* pass, StepReply* reply) {
changed |= (passManager.run(function) ? 1 : 0);
}
changed |= (passManager.doFinalization() ? 1 : 0);
- reply->set_action_had_no_effect(!changed);
+ return changed;
}
Status LlvmSession::runOptWithArgs(const std::vector& optArgs) {
// Create temporary files for `opt` to read from and write to.
- const auto before_path = fs::unique_path(workingDirectory_ / "module-%%%%%%%%.bc");
- const auto after_path = fs::unique_path(workingDirectory_ / "module-%%%%%%%%.bc");
+ const auto before_path = fs::unique_path(workingDirectory() / "module-%%%%%%%%.bc");
+ const auto after_path = fs::unique_path(workingDirectory() / "module-%%%%%%%%.bc");
RETURN_IF_ERROR(writeBitcodeToFile(benchmark().module(), before_path));
// Build a command line invocation: `opt input.bc -o output.bc `.
@@ -260,31 +291,43 @@ Status LlvmSession::runOptWithArgs(const std::vector& optArgs) {
return Status::OK;
}
-Status LlvmSession::getObservation(LlvmObservationSpace space, Observation* reply) {
+Status LlvmSession::computeObservation(LlvmObservationSpace space, Observation& reply) {
switch (space) {
case LlvmObservationSpace::IR: {
// Serialize the LLVM module to an IR string.
std::string ir;
llvm::raw_string_ostream rso(ir);
benchmark().module().print(rso, /*AAW=*/nullptr);
- reply->set_string_value(ir);
+ reply.set_string_value(ir);
+ break;
+ }
+ case LlvmObservationSpace::IR_SHA1: {
+ std::stringstream ss;
+ const BenchmarkHash hash = benchmark().module_hash();
+ // Hex encode, zero pad, and concatenate the unsigned integers that
+ // contain the hash.
+ for (uint32_t val : hash) {
+ ss << std::setfill('0') << std::setw(sizeof(BenchmarkHash::value_type) * 2) << std::hex
+ << val;
+ }
+ reply.set_string_value(ss.str());
break;
}
case LlvmObservationSpace::BITCODE_FILE: {
// Generate an output path with 16 bits of randomness.
- const auto outpath = fs::unique_path(workingDirectory_ / "module-%%%%%%%%.bc");
+ const auto outpath = fs::unique_path(workingDirectory() / "module-%%%%%%%%.bc");
RETURN_IF_ERROR(writeBitcodeToFile(benchmark().module(), outpath));
- reply->set_string_value(outpath.string());
+ reply.set_string_value(outpath.string());
break;
}
case LlvmObservationSpace::INST_COUNT: {
const auto features = InstCount::getFeatureVector(benchmark().module());
- *reply->mutable_int64_list()->mutable_value() = {features.begin(), features.end()};
+ *reply.mutable_int64_list()->mutable_value() = {features.begin(), features.end()};
break;
}
case LlvmObservationSpace::AUTOPHASE: {
const auto features = autophase::InstCount::getFeatureVector(benchmark().module());
- *reply->mutable_int64_list()->mutable_value() = {features.begin(), features.end()};
+ *reply.mutable_int64_list()->mutable_value() = {features.begin(), features.end()};
break;
}
case LlvmObservationSpace::PROGRAML: {
@@ -302,7 +345,7 @@ Status LlvmSession::getObservation(LlvmObservationSpace space, Observation* repl
if (!status.ok()) {
return Status(StatusCode::INTERNAL, status.error_message());
}
- *reply->mutable_string_value() = nodeLinkGraph.dump();
+ *reply.mutable_string_value() = nodeLinkGraph.dump();
break;
}
case LlvmObservationSpace::CPU_INFO: {
@@ -327,83 +370,83 @@ Status LlvmSession::getObservation(LlvmObservationSpace space, Observation* repl
hwinfo["cores_count"] = cpuinfo_get_cores_count();
auto cpu = cpuinfo_get_packages();
hwinfo["name"] = cpu->name;
- *reply->mutable_string_value() = hwinfo.dump();
+ *reply.mutable_string_value() = hwinfo.dump();
break;
}
case LlvmObservationSpace::IR_INSTRUCTION_COUNT: {
double cost;
RETURN_IF_ERROR(setCost(LlvmCostFunction::IR_INSTRUCTION_COUNT, benchmark().module(),
- workingDirectory_, &cost));
- reply->set_scalar_int64(static_cast(cost));
+ workingDirectory(), &cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::IR_INSTRUCTION_COUNT_O0: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::O0,
LlvmCostFunction::IR_INSTRUCTION_COUNT);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::IR_INSTRUCTION_COUNT_O3: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::O3,
LlvmCostFunction::IR_INSTRUCTION_COUNT);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::IR_INSTRUCTION_COUNT_OZ: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::Oz,
LlvmCostFunction::IR_INSTRUCTION_COUNT);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::OBJECT_TEXT_SIZE_BYTES: {
double cost;
RETURN_IF_ERROR(setCost(LlvmCostFunction::OBJECT_TEXT_SIZE_BYTES, benchmark().module(),
- workingDirectory_, &cost));
- reply->set_scalar_int64(static_cast(cost));
+ workingDirectory(), &cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::OBJECT_TEXT_SIZE_O0: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::O0,
LlvmCostFunction::OBJECT_TEXT_SIZE_BYTES);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::OBJECT_TEXT_SIZE_O3: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::O3,
LlvmCostFunction::OBJECT_TEXT_SIZE_BYTES);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::OBJECT_TEXT_SIZE_OZ: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::Oz,
LlvmCostFunction::OBJECT_TEXT_SIZE_BYTES);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
#ifdef COMPILER_GYM_EXPERIMENTAL_TEXT_SIZE_COST
case LlvmObservationSpace::TEXT_SIZE_BYTES: {
double cost;
RETURN_IF_ERROR(setCost(LlvmCostFunction::TEXT_SIZE_BYTES, benchmark().module(),
- workingDirectory_, &cost));
- reply->set_scalar_int64(static_cast(cost));
+ workingDirectory(), &cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::TEXT_SIZE_O0: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::O0,
LlvmCostFunction::TEXT_SIZE_BYTES);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::TEXT_SIZE_O3: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::O3,
LlvmCostFunction::TEXT_SIZE_BYTES);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
case LlvmObservationSpace::TEXT_SIZE_OZ: {
const auto cost = getBaselineCost(benchmark().baselineCosts(), LlvmBaselinePolicy::Oz,
LlvmCostFunction::TEXT_SIZE_BYTES);
- reply->set_scalar_int64(static_cast(cost));
+ reply.set_scalar_int64(static_cast(cost));
break;
}
#endif
diff --git a/compiler_gym/envs/llvm/service/LlvmSession.h b/compiler_gym/envs/llvm/service/LlvmSession.h
index e9d55e3b1..69e050792 100644
--- a/compiler_gym/envs/llvm/service/LlvmSession.h
+++ b/compiler_gym/envs/llvm/service/LlvmSession.h
@@ -4,16 +4,19 @@
// LICENSE file in the root directory of this source tree.
#pragma once
+#include
#include
#include
#include
#include
+#include
#include "compiler_gym/envs/llvm/service/ActionSpace.h"
#include "compiler_gym/envs/llvm/service/Benchmark.h"
#include "compiler_gym/envs/llvm/service/Cost.h"
#include "compiler_gym/envs/llvm/service/ObservationSpaces.h"
+#include "compiler_gym/service/CompilationSession.h"
#include "compiler_gym/service/proto/compiler_gym_service.grpc.pb.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
@@ -25,51 +28,91 @@
namespace compiler_gym::llvm_service {
-// This class exposes the LLVM optimization pipeline for an LLVM module as an
-// interactive environment.
-//
-// It can be used directly as a C++ API, or it can be accessed through an RPC
-// interface using the compiler_gym::service::LlvmService class.
-class LlvmSession {
+/**
+ * An interactive LLVM compilation session.
+ *
+ * This class exposes the LLVM optimization pipeline for an LLVM module as an
+ * interactive environment. It can be used directly as a C++ API, or it can be
+ * accessed through an RPC interface using the CompilerGym RPC runtime.
+ */
+class LlvmSession final : public CompilationSession {
public:
- // Construct an environment by taking ownership of a benchmark. Throws
- // std::invalid_argument if the benchmark's LLVM module fails verification.
- LlvmSession(std::unique_ptr benchmark, LlvmActionSpace actionSpace,
- const boost::filesystem::path& workingDirectory);
+ LlvmSession(const boost::filesystem::path& workingDirectory);
- inline const Benchmark& benchmark() const { return *benchmark_; }
- inline Benchmark& benchmark() { return *benchmark_; }
+ std::string getCompilerVersion() const final override;
- inline const LlvmActionSpace actionSpace() const { return actionSpace_; }
+ std::vector getActionSpaces() const final override;
+
+ std::vector getObservationSpaces() const final override;
+
+ [[nodiscard]] grpc::Status init(const ActionSpace& actionSpace,
+ const compiler_gym::Benchmark& benchmark) final override;
+
+ [[nodiscard]] grpc::Status init(CompilationSession* other) final override;
+
+ [[nodiscard]] grpc::Status applyAction(const Action& action, bool& endOfEpisode,
+ std::optional& newActionSpace,
+ bool& actionHadNoEffect) final override;
- inline const boost::filesystem::path& workingDirectory() const { return workingDirectory_; }
+ [[nodiscard]] grpc::Status endOfStep(bool actionHadNoEffect, bool& endOfEpisode,
+ std::optional& newActionSpace) final override;
- // Run the requested action(s) then compute the requested observation(s).
- [[nodiscard]] grpc::Status step(const StepRequest& request, StepReply* reply);
+ [[nodiscard]] grpc::Status computeObservation(const ObservationSpace& observationSpace,
+ Observation& observation) final override;
- // Returns the number of actions that have been applied in calls to step()
- // since the start of the session. This is just for logging and has no effect.
- inline int actionCount() const { return actionCount_; }
+ inline const LlvmActionSpace actionSpace() const { return actionSpace_; }
- // Run the requested action.
- [[nodiscard]] grpc::Status runAction(LlvmAction action, StepReply* reply);
+ private:
+ [[nodiscard]] grpc::Status computeObservation(LlvmObservationSpace observationSpace,
+ Observation& observation);
- // Compute the requested observation.
- [[nodiscard]] grpc::Status getObservation(LlvmObservationSpace space, Observation* reply);
+ [[nodiscard]] grpc::Status init(const LlvmActionSpace& actionSpace,
+ std::unique_ptr benchmark);
- protected:
- // Run the given pass, possibly modifying the underlying LLVM module.
- void runPass(llvm::Pass* pass, StepReply* reply);
- void runPass(llvm::FunctionPass* pass, StepReply* reply);
+ inline const Benchmark& benchmark() const {
+ DCHECK(benchmark_) << "Calling benchmark() before init()";
+ return *benchmark_;
+ }
+ inline Benchmark& benchmark() {
+ DCHECK(benchmark_) << "Calling benchmark() before init()";
+ return *benchmark_;
+ }
- // Run the commandline `opt` tool on the current LLVM module with the given
- // arguments, replacing the environment state with the generated output.
+ /**
+ * Run the requested action.
+ *
+ * @param action An action to apply.
+ * @param actionHadNoEffect Set to true if LLVM reported that any passes that
+ * were run made no modifications to the module.
+ * @return `OK` on success.
+ */
+ [[nodiscard]] grpc::Status applyPassAction(LlvmAction action, bool& actionHadNoEffect);
+
+ /**
+ * Run the given pass, possibly modifying the underlying LLVM module.
+ *
+ * @return Whether the module was modified.
+ */
+ bool runPass(llvm::Pass* pass);
+
+ /**
+ * Run the given pass, possibly modifying the underlying LLVM module.
+ *
+ * @return Whether the module was modified.
+ */
+ bool runPass(llvm::FunctionPass* pass);
+
+ /**
+ * Run the commandline `opt` tool on the current LLVM module with the given
+ * arguments, replacing the environment state with the generated output.
+ */
[[nodiscard]] grpc::Status runOptWithArgs(const std::vector& optArgs);
inline const llvm::TargetLibraryInfoImpl& tlii() const { return tlii_; }
- private:
- // Setup pass manager with depdendent passes and the specified pass.
+ /**
+ * Setup pass manager with depdendent passes and the specified pass.
+ */
template
inline void setupPassManager(PassManager* passManager, Pass* pass) {
passManager->add(new llvm::ProfileSummaryInfoWrapperPass());
@@ -78,13 +121,13 @@ class LlvmSession {
passManager->add(pass);
}
- const boost::filesystem::path workingDirectory_;
- const std::unique_ptr benchmark_;
- const LlvmActionSpace actionSpace_;
- const llvm::TargetLibraryInfoImpl tlii_;
+ // Immutable state.
const programl::ProgramGraphOptions programlOptions_;
-
- int actionCount_;
+ const std::unordered_map observationSpaceNames_;
+ // Mutable state initialized in init().
+ LlvmActionSpace actionSpace_;
+ std::unique_ptr benchmark_;
+ llvm::TargetLibraryInfoImpl tlii_;
};
} // namespace compiler_gym::llvm_service
diff --git a/compiler_gym/envs/llvm/service/ObservationSpaces.cc b/compiler_gym/envs/llvm/service/ObservationSpaces.cc
index c2c5bd8c5..fd65d0a29 100644
--- a/compiler_gym/envs/llvm/service/ObservationSpaces.cc
+++ b/compiler_gym/envs/llvm/service/ObservationSpaces.cc
@@ -37,6 +37,14 @@ std::vector getLlvmObservationSpaceList() {
space.set_platform_dependent(false);
break;
}
+ case LlvmObservationSpace::IR_SHA1: {
+ ScalarRange sha1Size;
+ space.mutable_string_size_range()->mutable_min()->set_value(40);
+ space.mutable_string_size_range()->mutable_max()->set_value(40);
+ space.set_deterministic(true);
+ space.set_platform_dependent(false);
+ break;
+ }
case LlvmObservationSpace::BITCODE_FILE: {
ScalarRange pathLength;
space.mutable_string_size_range()->mutable_min()->set_value(0);
diff --git a/compiler_gym/envs/llvm/service/ObservationSpaces.h b/compiler_gym/envs/llvm/service/ObservationSpaces.h
index 23c75f0ce..e15371e3e 100644
--- a/compiler_gym/envs/llvm/service/ObservationSpaces.h
+++ b/compiler_gym/envs/llvm/service/ObservationSpaces.h
@@ -10,58 +10,81 @@
namespace compiler_gym::llvm_service {
-// The available observation spaces for LLVM.
-//
-// NOTE(cummins): Housekeeping rules - to add a new observation space:
-// 1. Add a new entry to this LlvmObservationSpace enum.
-// 2. Add a new switch case to getLlvmObservationSpaceList() to return the
-// ObserverationSpace.
-// 3. Add a new switch case to LlvmSession::getObservation() to compute
-// the actual observation.
-// 4. Run `bazel test //compiler_gym/...` and update the newly failing tests.
+/**
+ * The available observation spaces for LLVM.
+ *
+ * \note Housekeeping rules - to add a new observation space:
+ * 1. Add a new entry to this LlvmObservationSpace enum.
+ * 2. Add a new switch case to getLlvmObservationSpaceList() to return the
+ * ObserverationSpace.
+ * 3. Add a new switch case to LlvmSession::getObservation() to compute
+ * the actual observation.
+ * 4. Run `bazel test //compiler_gym/...` and update the newly failing tests.
+ */
enum class LlvmObservationSpace {
- // The entire LLVM module as an IR string. This allows the user to do its own
- // feature extraction.
+ /**
+ * The entire LLVM module as an IR string.
+ *
+ * This allows the user to do their own feature extraction.
+ */
IR,
- // Write the bitcode to a file. Returns a string, which is the path of the
- // written file.
+ /** The 40-digit hex SHA1 checksum of the LLVM module. */
+ IR_SHA1,
+ /** Write the bitcode to a file and return its path as a string. */
BITCODE_FILE,
- // The counts of all instructions in a program.
+ /** The counts of all instructions in a program. */
INST_COUNT,
- // The Autophase feature vector from:
- //
- // Huang, Q., Haj-Ali, A., Moses, W., Xiang, J., Stoica, I., Asanovic, K., &
- // Wawrzynek, J. (2019). Autophase: Compiler phase-ordering for HLS with
- // deep reinforcement learning. FCCM.
+ /**
+ * The Autophase feature vector.
+ *
+ * From:
+ *
+ * Huang, Q., Haj-Ali, A., Moses, W., Xiang, J., Stoica, I., Asanovic, K.,
+ * & Wawrzynek, J. (2019). Autophase: Compiler phase-ordering for HLS with
+ * deep reinforcement learning. FCCM.
+ */
AUTOPHASE,
- // Returns the graph representation of a program from:
- //
- // Cummins, C., Fisches, Z. V., Ben-Nun, T., Hoefler, T., & Leather, H.
- // (2020). ProGraML: Graph-based Deep Learning for Program Optimization
- // and Analysis. ArXiv:2003.10536. https://arxiv.org/abs/2003.10536
+ /**
+ * Returns the graph representation of a program.
+ *
+ * From:
+ *
+ * Cummins, C., Fisches, Z. V., Ben-Nun, T., Hoefler, T., & Leather, H.
+ * (2020). ProGraML: Graph-based Deep Learning for Program Optimization
+ * and Analysis. ArXiv:2003.10536. https://arxiv.org/abs/2003.10536
+ */
PROGRAML,
- // A JSON dictionary of properties describing the CPU.
+ /** A JSON dictionary of properties describing the CPU. */
CPU_INFO,
- // The number of LLVM-IR instructions in the current module.
+ /** The number of LLVM-IR instructions in the current module. */
IR_INSTRUCTION_COUNT,
+ /** The number of LLVM-IR instructions normalized to `-O0`. */
IR_INSTRUCTION_COUNT_O0,
+ /** The number of LLVM-IR instructions normalized to `-O3`. */
IR_INSTRUCTION_COUNT_O3,
+ /** The number of LLVM-IR instructions normalized to `-Oz`. */
IR_INSTRUCTION_COUNT_OZ,
- // The size of the .text section of the lowered module. Platform dependent.
+ /** The platform-dependent size of the .text section of the lowered module. */
OBJECT_TEXT_SIZE_BYTES,
+ /** The platform-dependent size of the .text section of the lowered module. */
OBJECT_TEXT_SIZE_O0,
+ /** The platform-dependent size of the .text section of the lowered module. */
OBJECT_TEXT_SIZE_O3,
+ /** The platform-dependent size of the .text section of the lowered module. */
OBJECT_TEXT_SIZE_OZ,
#ifdef COMPILER_GYM_EXPERIMENTAL_TEXT_SIZE_COST
- // The size of the .text section of the compiled binary. Platform dependent.
+ /** The platform-dependent size of the .text section of the compiled binary. */
TEXT_SIZE_BYTES,
+ /** The platform-dependent size of the .text section of the compiled binary. */
TEXT_SIZE_O0,
+ /** The platform-dependent size of the .text section of the compiled binary. */
TEXT_SIZE_O3,
+ /** The platform-dependent size of the .text section of the compiled binary. */
TEXT_SIZE_OZ,
#endif
};
-// Return the list of available observation spaces.
+/** Return the list of available observation spaces. */
std::vector getLlvmObservationSpaceList();
} // namespace compiler_gym::llvm_service
diff --git a/compiler_gym/envs/llvm/service/RunService.cc b/compiler_gym/envs/llvm/service/RunService.cc
index b7b6b72f0..6a76ae186 100644
--- a/compiler_gym/envs/llvm/service/RunService.cc
+++ b/compiler_gym/envs/llvm/service/RunService.cc
@@ -2,13 +2,12 @@
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
-#include "compiler_gym/util/RunService.h"
-
-#include "compiler_gym/envs/llvm/service/LlvmService.h"
+#include "compiler_gym/envs/llvm/service/LlvmSession.h"
+#include "compiler_gym/service/runtime/Runtime.h"
const char* usage = R"(LLVM CompilerGym service)";
-using namespace compiler_gym::util;
+using namespace compiler_gym::runtime;
using namespace compiler_gym::llvm_service;
-int main(int argc, char** argv) { return runService(&argc, &argv, usage); }
+int main(int argc, char** argv) { createAndRunCompilerGymService(argc, argv, usage); }
diff --git a/compiler_gym/envs/llvm/service/passes/config.py b/compiler_gym/envs/llvm/service/passes/config.py
index 871a76c64..cb7a4db87 100644
--- a/compiler_gym/envs/llvm/service/passes/config.py
+++ b/compiler_gym/envs/llvm/service/passes/config.py
@@ -47,7 +47,8 @@
"DeadInstElimination": "DeadInstEliminationPass",
"DivRemPairsLegacyPass": "DivRemPairsPass",
"DSELegacyPass": "DeadStoreEliminationPass",
- "EarlyCSEMemSSALegacyPass": "EarlyCSEPass",
+ "EarlyCSELegacyPass": "EarlyCSEPass",
+ "EarlyCSEMemSSALegacyPass": "EarlyCSEMemSSAPass",
"EliminateAvailableExternallyLegacyPass": "EliminateAvailableExternallyPass",
"EntryExitInstrumenter": "EntryExitInstrumenterPass",
"Float2IntLegacyPass": "Float2IntPass",
@@ -174,7 +175,6 @@
"WholeProgramDevirt",
"MakeGuardsExplicitLegacyPass",
"LowerTypeTests",
- "EarlyCSELegacyPass",
# Unneeded debugging passes.
"WriteThinLTOBitcode",
"PredicateInfoPrinterLegacyPass",
diff --git a/compiler_gym/envs/llvm/service/passes/make_action_space_genfiles.py b/compiler_gym/envs/llvm/service/passes/make_action_space_genfiles.py
index a5e07cc1f..1aaac0b7d 100644
--- a/compiler_gym/envs/llvm/service/passes/make_action_space_genfiles.py
+++ b/compiler_gym/envs/llvm/service/passes/make_action_space_genfiles.py
@@ -154,6 +154,19 @@ def make_action_sources(pass_iterator, outpath: Path):
print("#pragma once", file=f)
for header in sorted(headers):
print(f'#include "{header}"', file=f)
+
+ # Inject an ad-hoc workaround for the non-standard constructor of the
+ # EarlyCSEMemSSAPass.
+ print(
+ """
+namespace llvm {
+FunctionPass* createEarlyCSEMemSSAPass() {
+ return createEarlyCSEPass(/*UseMemorySSA=*/true);
+}
+} // namespace llvm
+""",
+ file=f,
+ )
logging.debug("Generated %s", include_path.name)
with open(actions_path, "w") as f:
diff --git a/compiler_gym/leaderboard/llvm_instcount.py b/compiler_gym/leaderboard/llvm_instcount.py
index 9d913bb4a..bbda9baa2 100644
--- a/compiler_gym/leaderboard/llvm_instcount.py
+++ b/compiler_gym/leaderboard/llvm_instcount.py
@@ -131,7 +131,7 @@ def run(self):
state = self.env.state.copy()
state.walltime = timer.time
- writer.write_state(state)
+ writer.write_state(state, flush=True)
self.states.append(state)
if not self.alive:
diff --git a/compiler_gym/random_search.py b/compiler_gym/random_search.py
index 87c60d6de..7cb025443 100644
--- a/compiler_gym/random_search.py
+++ b/compiler_gym/random_search.py
@@ -113,10 +113,9 @@ def random_search(
nproc: int = cpu_count(),
skip_done: bool = False,
) -> Tuple[float, List[int]]:
- env = make_env()
- try:
+ with make_env() as env:
env.reset()
- if not isinstance(env, CompilerEnv):
+ if not isinstance(env.unwrapped, CompilerEnv):
raise TypeError(
f"random_search() requires CompilerEnv. Called with: {type(env).__name__}"
)
@@ -151,8 +150,6 @@ def random_search(
}
with open(str(metadata_path), "w") as f:
json.dump(metadata, f, sort_keys=True, indent=2)
- finally:
- env.close()
workers = [RandomAgentWorker(make_env, patience) for _ in range(nproc)]
for worker in workers:
@@ -164,7 +161,7 @@ def random_search(
last_best_returns = -float("inf")
print(
- f"Started {len(workers)} worker threads for "
+ f"Started {len(workers)} worker threads for {benchmark_uri} "
f"using reward {reward_space_name}."
)
print(f"Writing logs to {outdir}")
diff --git a/compiler_gym/service/BUILD b/compiler_gym/service/BUILD
index 9c9ad2255..280d508eb 100644
--- a/compiler_gym/service/BUILD
+++ b/compiler_gym/service/BUILD
@@ -3,33 +3,46 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
load("@rules_python//python:defs.bzl", "py_library")
+load("@rules_cc//cc:defs.bzl", "cc_library")
py_library(
name = "service",
srcs = ["__init__.py"],
visibility = ["//visibility:public"],
deps = [
+ ":compilation_session",
":connection",
- ":proto2py",
"//compiler_gym/service/proto",
],
)
py_library(
- name = "connection",
- srcs = ["connection.py"],
+ name = "compilation_session",
+ srcs = ["compilation_session.py"],
visibility = ["//visibility:public"],
deps = [
"//compiler_gym/service/proto",
- "//compiler_gym/util",
+ ],
+)
+
+cc_library(
+ name = "CompilationSession",
+ srcs = ["CompilationSession.cc"],
+ hdrs = ["CompilationSession.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//compiler_gym/service/proto:compiler_gym_service_cc",
+ "@boost//:filesystem",
+ "@com_github_grpc_grpc//:grpc++",
],
)
py_library(
- name = "proto2py",
- srcs = ["proto2py.py"],
+ name = "connection",
+ srcs = ["connection.py"],
visibility = ["//visibility:public"],
deps = [
"//compiler_gym/service/proto",
+ "//compiler_gym/util",
],
)
diff --git a/compiler_gym/service/CompilationSession.cc b/compiler_gym/service/CompilationSession.cc
new file mode 100644
index 000000000..5e55c1330
--- /dev/null
+++ b/compiler_gym/service/CompilationSession.cc
@@ -0,0 +1,26 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+#include "compiler_gym/service/CompilationSession.h"
+
+using grpc::Status;
+using grpc::StatusCode;
+
+namespace compiler_gym {
+
+std::string CompilationSession::getCompilerVersion() const { return ""; }
+
+Status CompilationSession::init(CompilationSession* other) {
+ return Status(StatusCode::UNIMPLEMENTED, "CompilationSession::init() not implemented");
+}
+
+Status CompilationSession::endOfStep(bool actionHadNoEffect, bool& endOfEpisode,
+ std::optional& newActionSpace) {
+ return Status::OK;
+}
+
+CompilationSession::CompilationSession(const boost::filesystem::path& workingDirectory)
+ : workingDirectory_(workingDirectory) {}
+
+} // namespace compiler_gym
diff --git a/compiler_gym/service/CompilationSession.h b/compiler_gym/service/CompilationSession.h
new file mode 100644
index 000000000..6bcd77ddc
--- /dev/null
+++ b/compiler_gym/service/CompilationSession.h
@@ -0,0 +1,146 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+#pragma once
+
+#include
+
+#include
+#include
+
+#include "boost/filesystem.hpp"
+#include "compiler_gym/service/proto/compiler_gym_service.pb.h"
+
+namespace compiler_gym {
+
+/**
+ * Base class for encapsulating an incremental compilation session.
+ *
+ * To add support for a new compiler, subclass from this base and provide
+ * implementations of the abstract methods, then call
+ * createAndRunCompilerGymService() and parametrize it with your class type:
+ *
+ * \code{.cpp}
+ * #include "compiler_gym/service/CompilationSession.h"
+ * #include "compiler_gym/service/runtime/Runtime.h"
+ *
+ * using namespace compiler_gym;
+ *
+ * class MyCompilationSession final : public CompilationSession { ... }
+ *
+ * int main(int argc, char** argv) {
+ * runtime::createAndRunCompilerGymService();
+ * }
+ * \endcode
+ */
+class CompilationSession {
+ public:
+ /**
+ * Get the compiler version.
+ *
+ * @return A string indicating the compiler version.
+ */
+ virtual std::string getCompilerVersion() const;
+
+ /**
+ * A list of action spaces describing the capabilities of the compiler.
+ *
+ * @return A list of ActionSpace instances.
+ */
+ virtual std::vector getActionSpaces() const = 0;
+
+ /**
+ * A list of feature vectors that this compiler provides.
+ *
+ * @return A list of ObservationSpace instances.
+ */
+ virtual std::vector getObservationSpaces() const = 0;
+
+ /**
+ * Start a CompilationSession.
+ *
+ * This will be called after construction and before applyAction() or
+ * computeObservation(). This will only be called once.
+ *
+ * @param actionSpace The action space to use.
+ * @param benchmark The benchmark to use.
+ * @return `OK` on success, else an error code and message.
+ */
+ [[nodiscard]] virtual grpc::Status init(const ActionSpace& actionSpace,
+ const Benchmark& benchmark) = 0;
+
+ /**
+ * Initialize a CompilationSession from another CompilerSession.
+ *
+ * Think of this like a copy constructor, except that this method is allowed
+ * to fail.
+ *
+ * This will be called after construction and before applyAction() or
+ * computeObservation(). This will only be called once.
+ *
+ * @param other The CompilationSession to initialize from.
+ * @return `OK` on success, else an errro code and message.
+ */
+ [[nodiscard]] virtual grpc::Status init(CompilationSession* other);
+
+ /**
+ * Apply an action.
+ *
+ * @param action The action to apply.
+ * @param newActionSpace If applying the action mutated the action space, set
+ * this value to the new action space.
+ * @param actionHadNoEffect If the action had no effect, set this to true.
+ * @return `OK` on success, else an errro code and message.
+ */
+ [[nodiscard]] virtual grpc::Status applyAction(const Action& action, bool& endOfEpisode,
+ std::optional& newActionSpace,
+ bool& actionHadNoEffect) = 0;
+
+ /**
+ * Compute an observation.
+ *
+ * @return `OK` on success, else an errro code and message.
+ */
+ [[nodiscard]] virtual grpc::Status computeObservation(const ObservationSpace& observationSpace,
+ Observation& observation) = 0;
+
+ /**
+ * Optional. This will be called after all applyAction() and
+ * computeObservation() in a step. Use this method if you would like to
+ * perform post-transform validation of compiler state.
+ *
+ * @return `OK` on success, else an errro code and message.
+ */
+ [[nodiscard]] virtual grpc::Status endOfStep(bool actionHadNoEffect, bool& endOfEpisode,
+ std::optional& newActionSpace);
+
+ CompilationSession(const boost::filesystem::path& workingDirectory);
+
+ virtual ~CompilationSession() = default;
+
+ protected:
+ /**
+ * Get the working directory.
+ *
+ * The working directory is a local filesystem directory that this
+ * CompilationSession can use to store temporary files such as build
+ * artifacts. The directory exists.
+ *
+ * \note If you need to store very large files for a CompilationSession then
+ * consider using an alternate filesystem path as, when possible, an
+ * in-memory filesystem will be used for the working directory.
+ *
+ * \note A single working directory may be shared by multiple
+ * CompilationSession instances. Do not assume that you have exclusive
+ * access.
+ *
+ * @return A path.
+ */
+ inline const boost::filesystem::path& workingDirectory() { return workingDirectory_; }
+
+ private:
+ const boost::filesystem::path workingDirectory_;
+};
+
+} // namespace compiler_gym
diff --git a/compiler_gym/service/__init__.py b/compiler_gym/service/__init__.py
index d3f61f792..baa8c8ea8 100644
--- a/compiler_gym/service/__init__.py
+++ b/compiler_gym/service/__init__.py
@@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from compiler_gym.service.compilation_session import CompilationSession
from compiler_gym.service.connection import (
CompilerGymServiceConnection,
ConnectionOpts,
@@ -12,13 +13,11 @@
ServiceTransportError,
SessionNotFound,
)
-from compiler_gym.service.proto2py import observation_t, scalar_range2tuple
__all__ = [
"CompilerGymServiceConnection",
+ "CompilationSession",
"ConnectionOpts",
- "observation_t",
- "scalar_range2tuple",
"ServiceError",
"ServiceInitError",
"ServiceIsClosed",
diff --git a/compiler_gym/service/compilation_session.py b/compiler_gym/service/compilation_session.py
new file mode 100644
index 000000000..278dbc396
--- /dev/null
+++ b/compiler_gym/service/compilation_session.py
@@ -0,0 +1,92 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from pathlib import Path
+from typing import List, Optional, Tuple
+
+from compiler_gym.service.proto import (
+ Action,
+ ActionSpace,
+ Benchmark,
+ Observation,
+ ObservationSpace,
+)
+
+
+class CompilationSession:
+ """Base class for encapsulating an incremental compilation session.
+
+ To add support for a new compiler, subclass from this base and provide
+ implementations of the abstract methods, then call
+ :func:`create_and_run_compiler_service
+ ` and pass in
+ your class type:
+
+ .. code-block:: python
+
+ from compiler_gym.service import CompilationSession
+ from compiler_gym.service import runtime
+
+ class MyCompilationSession(CompilationSession):
+ ...
+
+ if __name__ == "__main__":
+ runtime.create_and_run_compiler_service(MyCompilationSession)
+ """
+
+ compiler_version: str = ""
+ """The compiler version."""
+
+ action_spaces: List[ActionSpace] = []
+ """A list of action spaces describing the capabilities of the compiler."""
+
+ observation_spaces: List[ObservationSpace] = []
+ """A list of feature vectors that this compiler provides."""
+
+ def __init__(
+ self, working_dir: Path, action_space: ActionSpace, benchmark: Benchmark
+ ):
+ """Start a CompilationSession.
+
+ Subclasses should initialize the parent class first.
+
+ :param working_dir: A directory on the local filesystem that can be used
+ to store temporary files such as build artifacts.
+
+ :param action_space: The action space to use.
+
+ :param benchmark: The benchmark to use.
+ """
+ del action_space # Subclasses must use this.
+ del benchmark # Subclasses must use this.
+ self.working_dir = working_dir
+
+ def apply_action(self, action: Action) -> Tuple[bool, Optional[ActionSpace], bool]:
+ """Apply an action.
+
+ :param action: The action to apply.
+
+ :return: A tuple: :code:`(end_of_session, new_action_space,
+ action_had_no_effect)`.
+ """
+ raise NotImplementedError
+
+ def get_observation(self, observation_space: ObservationSpace) -> Observation:
+ """Compute an observation.
+
+ :param observation_space: The observation space.
+
+ :return: An observation.
+ """
+ raise NotImplementedError
+
+ def fork(self) -> "CompilationSession":
+ """Create a copy of current session state.
+
+ Implementing this method is optional.
+
+ :return: A new CompilationSession with the same state.
+ """
+ # No need to override this if you are not adding support to fork().
+ raise NotImplementedError("CompilationSession.fork() not supported")
diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py
index ba123cd28..9aebe4aad 100644
--- a/compiler_gym/service/connection.py
+++ b/compiler_gym/service/connection.py
@@ -134,7 +134,7 @@ def __call__(
StubMethod = Callable[[Request], Reply]
-class Connection(object):
+class Connection:
"""Base class for service connections."""
def __init__(self, channel, url: str, logger: logging.Logger):
@@ -252,10 +252,18 @@ def make_working_dir() -> Path:
"""Make a working directory for a service. The calling code is responsible
for removing this directory when done.
"""
- random_hash = random.getrandbits(16)
- service_name = datetime.now().strftime(f"s/%m%dT%H%M%S-%f-{random_hash:04x}")
- working_dir = transient_cache_path(service_name)
- (working_dir / "logs").mkdir(parents=True, exist_ok=False)
+ while True:
+ random_hash = random.getrandbits(16)
+ service_name = datetime.now().strftime(f"s/%m%dT%H%M%S-%f-{random_hash:04x}")
+ working_dir = transient_cache_path(service_name)
+ # Guard against the unlike scenario that there is a collision between
+ # the randomly generated working directories of multiple
+ # make_working_dir() calls.
+ try:
+ (working_dir / "logs").mkdir(parents=True, exist_ok=False)
+ break
+ except FileExistsError:
+ pass
return working_dir
@@ -281,17 +289,12 @@ def __init__(
raise FileNotFoundError(f"File not found: {local_service_binary}")
self.working_dir = make_working_dir()
- # Set environment variable COMPILER_GYM_SERVICE_ARGS to pass
- # additional arguments to the service.
- args = os.environ.get("COMPILER_GYM_SERVICE_ARGS", "")
-
# The command that will be executed. The working directory of this
# command will be set to the local_service_binary's parent, so we can
# use the relpath for a neater `ps aux` view.
cmd = [
f"./{local_service_binary.name}",
f"--working_dir={self.working_dir}",
- args,
]
# Set the root of the runfiles directory.
@@ -299,16 +302,17 @@ def __init__(
env["COMPILER_GYM_RUNFILES"] = str(runfiles_path("."))
env["COMPILER_GYM_SITE_DATA"] = str(site_data_path("."))
- # Set the verbosity of the service. The logging level of the service
- # is the debug level - 1, so that COMPILER_GYM_DEUG=3 will cause VLOG(2)
+ # Set the verbosity of the service. The logging level of the service is
+ # the debug level - 1, so that COMPILER_GYM_DEBUG=3 will cause VLOG(2)
# and lower to be logged to stdout.
debug_level = get_debug_level()
if debug_level > 0:
cmd.append("--alsologtostderr")
cmd.append(f"-v={debug_level - 1}")
# If we are debugging the backend, set the logbuflevel to a low
- # value to disable buffering of logging messages. This makes it
- # easier to `LOG(INFO) << "..."` debug things.
+ # value to disable buffering of logging messages. This removes any
+ # buffering between `LOG(INFO) << "..."` and the message being
+ # emited to stderr.
cmd.append("--logbuflevel=-1")
else:
# Silence the gRPC logs as we will do our own error reporting, but
@@ -317,12 +321,19 @@ def __init__(
if not os.environ.get("GRPC_VERBOSITY"):
os.environ["GRPC_VERBOSITY"] = "NONE"
+ # Set environment variable COMPILER_GYM_SERVICE_ARGS to pass
+ # additional arguments to the service.
+ args = os.environ.get("COMPILER_GYM_SERVICE_ARGS", "")
+ if args:
+ cmd.append(args)
+
logger.debug("Exec %s", cmd)
self.process = subprocess.Popen(
cmd,
env=env,
cwd=local_service_binary.parent,
)
+ self._process_returncode_exception_raised = False
# Read the port from a file generated by the service.
wait_secs = 0.1
@@ -423,14 +434,30 @@ def loglines(self) -> Iterable[str]:
def close(self):
"""Terminate a local subprocess and close the connection."""
try:
- self.process.kill()
+ self.process.terminate()
self.process.communicate(timeout=self.process_exit_max_seconds)
+ if (
+ self.process.returncode
+ and not self._process_returncode_exception_raised
+ ):
+ # You can call close() multiple times but we only want to emit
+ # the exception once.
+ self._process_returncode_exception_raised = True
+ raise ServiceError(
+ f"Service exited with returncode {self.process.returncode}"
+ )
except ProcessLookupError:
self.logger.warning("Service process not found at %s", self.working_dir)
except subprocess.TimeoutExpired:
+ # Try and kill it and then walk away.
+ try:
+ self.process.kill()
+ except: # noqa
+ pass
self.logger.warning("Abandoning orphan service at %s", self.working_dir)
- shutil.rmtree(self.working_dir, ignore_errors=True)
- super().close()
+ finally:
+ shutil.rmtree(self.working_dir, ignore_errors=True)
+ super().close()
def __repr__(self):
alive_or_dead = "alive" if self.process.poll() else "dead"
@@ -477,7 +504,7 @@ def __repr__(self):
return self.url
-class CompilerGymServiceConnection(object):
+class CompilerGymServiceConnection:
"""A connection to a compiler gym service.
There are two types of service connections: managed and unmanaged. The type
diff --git a/compiler_gym/service/proto/__init__.py b/compiler_gym/service/proto/__init__.py
index 2db9c7f91..e818cce8a 100644
--- a/compiler_gym/service/proto/__init__.py
+++ b/compiler_gym/service/proto/__init__.py
@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from compiler_gym.service.proto.compiler_gym_service_pb2 import (
+ Action,
ActionSpace,
AddBenchmarkReply,
AddBenchmarkRequest,
@@ -34,6 +35,7 @@
)
__all__ = [
+ "Action",
"ActionSpace",
"AddBenchmarkReply",
"AddBenchmarkRequest",
diff --git a/compiler_gym/service/proto/compiler_gym_service.proto b/compiler_gym/service/proto/compiler_gym_service.proto
index 7916af741..19ab49bd5 100644
--- a/compiler_gym/service/proto/compiler_gym_service.proto
+++ b/compiler_gym/service/proto/compiler_gym_service.proto
@@ -9,11 +9,14 @@ syntax = "proto3";
package compiler_gym;
+option cc_enable_arenas = true;
option go_package = "compiler_gympb";
option java_multiple_files = true;
option java_outer_classname = "CompilerGymServiceProto";
option java_package = "com.compiler_gym";
+// The CompilerGymService is the interface that exposes the incremental
+// optimization of a program as an interactive environment.
service CompilerGymService {
// Request version strings from the service.
rpc GetVersion(GetVersionRequest) returns (GetVersionReply);
@@ -43,11 +46,10 @@ service CompilerGymService {
rpc AddBenchmark(AddBenchmarkRequest) returns (AddBenchmarkReply);
}
-// ===========================================================================
-// GetVersion().
-
+// A GetVersion() request.
message GetVersionRequest {}
+// The GetVersion() response.
message GetVersionReply {
// The version string for this service.
string service_version = 1;
@@ -55,9 +57,7 @@ message GetVersionReply {
string compiler_version = 2;
}
-// ===========================================================================
-// StartSession().
-
+// A StartSession() request.
message StartSessionRequest {
// The name of the benchmark to use for this session. If not provided, a
// benchmark is chosen randomly by the service.
@@ -70,6 +70,7 @@ message StartSessionRequest {
repeated int32 observation_space = 3;
}
+// A StartSession() reply.
message StartSessionReply {
// The ID that has been assigned to the session. The client must use this ID
// in all subsequent interactions with the service for this session.
@@ -86,19 +87,17 @@ message StartSessionReply {
repeated Observation observation = 4;
}
-// ===========================================================================
-// Step().
-
+// A Step() request.
message StepRequest {
// The ID of the session.
int64 session_id = 1;
- // A list of indices into the ActionSpace.action list. Actions are executed
- // in the order they appear in this list.
- repeated int32 action = 2;
+ // A list of actions to execute, in order.
+ repeated Action action = 2;
// A list of indices into the GetSpacesReply.observation_space_list
repeated int32 observation_space = 3;
}
+// A Step() reply.
message StepReply {
// Indicates that the session has ended. This could be because there are no
// further actions that can be made, or because the action has led to an
@@ -119,23 +118,27 @@ message StepReply {
repeated Observation observation = 4;
}
-// ===========================================================================
-// Actions.
-
+// A description of an action space.
+//
+// \warning This message format is likely to change. This currently only
+// supports flat action spaces of categorical values. In the future we will
+// want to replace this with a more extensible representation that supports
+// parameterized actions, and actions of different types (e.g. optimization
+// passes vs optimization contexts).
message ActionSpace {
// The name of the action space.
string name = 1;
// A list of discrete action names.
- // NOTE(cummins): This currently only supports flat action spaces of
- // categorical values. In the future we will want to replace this with a more
- // extensible representation that supports parameterized actions, and actions
- // of different types (e.g. optimization passes vs optimization contexts).
repeated string action = 2;
}
-// ===========================================================================
-// Observations.
+// An action.
+message Action {
+ // An index into the ActionSpace.action list.
+ int32 action = 1;
+}
+// An observations from a compiler.
message Observation {
// A point in an ObservationSpace is _either_ a scalar or vector of integers
// or real values, a string, or an opaque byte array.
@@ -149,14 +152,17 @@ message Observation {
}
}
+// A list of 64 bit integers.
message Int64List {
repeated int64 value = 1;
}
+// A list of doubles.
message DoubleList {
repeated double value = 1;
}
+// The [min, max] range of a scalar.
message ScalarRange {
// The minimum value (inclusive). If not set, the value is -inf.
ScalarLimit min = 1;
@@ -164,14 +170,17 @@ message ScalarRange {
ScalarLimit max = 2;
}
+// Representation of the upper or lower limit of a scalar.
message ScalarLimit {
double value = 1;
}
+// A list of scalar ranges.
message ScalarRangeList {
repeated ScalarRange range = 1;
}
+// The description of a space of observations.
message ObservationSpace {
// The name of the observation space.
string name = 1;
@@ -207,37 +216,34 @@ message ObservationSpace {
Observation default_value = 9;
}
-// ===========================================================================
-// Fork().
-
+// A Fork() request.
message ForkSessionRequest {
// The ID of the session to fork.
int64 session_id = 1;
}
+// A Fork() reply.
message ForkSessionReply {
// The ID of the newly created session.
int64 session_id = 1;
}
-// ===========================================================================
-// EndSession().
-
+// An EndSession() request.
message EndSessionRequest {
// The ID of the session.
int64 session_id = 1;
}
+// An EndSession() reply.
message EndSessionReply {
// The number of sessions that the service currently has.
int32 remaining_sessions = 1;
}
-// ===========================================================================
-// GetSpaces().
-
+// A GetSpaces() request.
message GetSpacesRequest {}
+// A GetSpaces() reply.
message GetSpacesReply {
// The initial space of actions. Subsequent calls to step() may produce
// a new action space.
@@ -247,11 +253,7 @@ message GetSpacesReply {
repeated ObservationSpace observation_space_list = 2;
}
-// ===========================================================================
-// AddBenchmark().
-
-// A Benchmark message is used to register a new benchmark with a compiler
-// service.
+// Representation of the input to a compiler.
message Benchmark {
// The name of the benchmark to add. In case of conflict with an existing
// benchmark, this new benchmark replaces the existing one.
@@ -274,8 +276,10 @@ message File {
}
}
+// An AddBenchmark() request.
message AddBenchmarkRequest {
repeated Benchmark benchmark = 1;
}
+// An AddBenchmark() reply.
message AddBenchmarkReply {}
diff --git a/compiler_gym/service/proto2py.py b/compiler_gym/service/proto2py.py
deleted file mode 100644
index 6b02986c6..000000000
--- a/compiler_gym/service/proto2py.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-"""Converters from protocol buffers to python-friendly types."""
-from typing import Any, Dict, List, Union
-
-import networkx as nx
-import numpy as np
-
-from compiler_gym.service.proto import ScalarRange
-
-json_t = Union[List[Any], Dict[str, Any]]
-observation_t = Union[np.ndarray, str, bytes, int, float, json_t, nx.DiGraph]
-
-
-def scalar_range2tuple(sr: ScalarRange, defaults=(-np.inf, np.inf)):
- """Convert a ScalarRange to a tuple of (min, max) bounds."""
- return (
- sr.min.value if sr.HasField("min") else defaults[0],
- sr.max.value if sr.HasField("max") else defaults[1],
- )
diff --git a/compiler_gym/service/runtime/BUILD b/compiler_gym/service/runtime/BUILD
new file mode 100644
index 000000000..9432e0dea
--- /dev/null
+++ b/compiler_gym/service/runtime/BUILD
@@ -0,0 +1,113 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+#
+# This package implements the CompilerGym service runtime, which is the utility
+# code that creates RPC servers and dispatches to CompilationServices.
+load("@rules_cc//cc:defs.bzl", "cc_library")
+load("@rules_python//python:defs.bzl", "py_library")
+
+py_library(
+ name = "runtime",
+ srcs = ["__init__.py"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":create_and_run_compiler_gym_service",
+ ],
+)
+
+cc_library(
+ name = "cc_runtime",
+ hdrs = ["Runtime.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":CreateAndRunCompilerGymServiceImpl",
+ ],
+)
+
+py_library(
+ name = "benchmark_cache",
+ srcs = ["benchmark_cache.py"],
+ visibility = ["//tests/service/runtime:__subpackages__"],
+ deps = [
+ "//compiler_gym/service/proto",
+ ],
+)
+
+cc_library(
+ name = "BenchmarkCache",
+ srcs = ["BenchmarkCache.cc"],
+ hdrs = ["BenchmarkCache.h"],
+ visibility = ["//tests/service/runtime:__subpackages__"],
+ deps = [
+ "//compiler_gym/service/proto:compiler_gym_service_cc",
+ "@boost//:filesystem",
+ "@com_github_grpc_grpc//:grpc++",
+ "@glog",
+ ],
+)
+
+py_library(
+ name = "compiler_gym_service",
+ srcs = ["compiler_gym_service.py"],
+ deps = [
+ ":benchmark_cache",
+ "//compiler_gym/service:compilation_session",
+ "//compiler_gym/service/proto",
+ "//compiler_gym/util",
+ ],
+)
+
+cc_library(
+ name = "CompilerGymService",
+ hdrs = [
+ "CompilerGymService.h",
+ "CompilerGymServiceImpl.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":BenchmarkCache",
+ ":CompilerGymServiceImpl",
+ "//compiler_gym/service:CompilationSession",
+ "//compiler_gym/service/proto:compiler_gym_service_cc",
+ "//compiler_gym/service/proto:compiler_gym_service_cc_grpc",
+ "@boost//:filesystem",
+ "@com_github_grpc_grpc//:grpc++",
+ ],
+)
+
+cc_library(
+ name = "CompilerGymServiceImpl",
+ hdrs = ["CompilerGymServiceImpl.h"],
+ deps = [
+ "//compiler_gym/util:GrpcStatusMacros",
+ "//compiler_gym/util:Version",
+ "@fmt",
+ "@glog",
+ ],
+)
+
+py_library(
+ name = "create_and_run_compiler_gym_service",
+ srcs = ["create_and_run_compiler_gym_service.py"],
+ deps = [
+ ":compiler_gym_service",
+ "//compiler_gym/service/proto",
+ "//compiler_gym/util",
+ ],
+)
+
+cc_library(
+ name = "CreateAndRunCompilerGymServiceImpl",
+ srcs = ["CreateAndRunCompilerGymServiceImpl.cc"],
+ hdrs = ["CreateAndRunCompilerGymServiceImpl.h"],
+ deps = [
+ ":CompilerGymService",
+ "//compiler_gym/util:GrpcStatusMacros",
+ "@boost//:filesystem",
+ "@com_github_grpc_grpc//:grpc++",
+ "@gflags",
+ "@glog",
+ ],
+)
diff --git a/compiler_gym/service/runtime/BenchmarkCache.cc b/compiler_gym/service/runtime/BenchmarkCache.cc
new file mode 100644
index 000000000..8cd2cd219
--- /dev/null
+++ b/compiler_gym/service/runtime/BenchmarkCache.cc
@@ -0,0 +1,83 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+#include "compiler_gym/service/runtime/BenchmarkCache.h"
+
+#include
+
+using grpc::Status;
+using grpc::StatusCode;
+
+namespace compiler_gym::runtime {
+
+BenchmarkCache::BenchmarkCache(size_t maxSizeInBytes, std::optional rand)
+ : rand_(rand.has_value() ? *rand : std::mt19937_64(std::random_device()())),
+ maxSizeInBytes_(maxSizeInBytes),
+ sizeInBytes_(0){};
+
+const Benchmark* BenchmarkCache::get(const std::string& uri) const {
+ auto it = benchmarks_.find(uri);
+ if (it == benchmarks_.end()) {
+ return nullptr;
+ }
+
+ return &it->second;
+}
+
+void BenchmarkCache::add(const Benchmark&& benchmark) {
+ VLOG(3) << "Caching benchmark " << benchmark.uri() << ". Cache size = " << sizeInBytes()
+ << " bytes, " << size() << " items";
+
+ // Remove any existing value to keep the cache size consistent.
+ const auto it = benchmarks_.find(benchmark.uri());
+ if (it != benchmarks_.end()) {
+ const size_t replacedSize = it->second.ByteSizeLong();
+ benchmarks_.erase(it);
+ sizeInBytes_ -= replacedSize;
+ }
+
+ const size_t size = benchmark.ByteSizeLong();
+ if (sizeInBytes() + size > maxSizeInBytes()) {
+ if (size > maxSizeInBytes()) {
+ LOG(WARNING) << "Adding new benchmark with size " << size
+ << " bytes exceeds total target cache size of " << maxSizeInBytes() << " bytes";
+ } else {
+ VLOG(3) << "Adding new benchmark with size " << size << " bytes exceeds maximum size "
+ << maxSizeInBytes() << " bytes, " << this->size() << " items";
+ }
+ evictToCapacity();
+ }
+
+ benchmarks_.insert({benchmark.uri(), std::move(benchmark)});
+ sizeInBytes_ += size;
+}
+
+void BenchmarkCache::evictToCapacity(std::optional targetSize) {
+ int evicted = 0;
+ targetSize = targetSize.has_value() ? targetSize : maxSizeInBytes() / 2;
+
+ while (size() && sizeInBytes() > targetSize) {
+ // Select a benchmark randomly.
+ std::uniform_int_distribution distribution(0, benchmarks_.size() - 1);
+ size_t index = distribution(rand_);
+ auto iterator = std::next(std::begin(benchmarks_), index);
+
+ // Evict the benchmark from the pool of loaded benchmarks.
+ ++evicted;
+ sizeInBytes_ -= iterator->second.ByteSizeLong();
+ benchmarks_.erase(iterator);
+ }
+
+ if (evicted) {
+ VLOG(2) << "Evicted " << evicted << " benchmarks from cache. Benchmark cache "
+ << "size now " << sizeInBytes() << " bytes, " << benchmarks_.size() << " items";
+ }
+}
+
+void BenchmarkCache::setMaxSizeInBytes(size_t maxSizeInBytes) {
+ maxSizeInBytes_ = maxSizeInBytes;
+ evictToCapacity(maxSizeInBytes);
+}
+
+} // namespace compiler_gym::runtime
diff --git a/compiler_gym/service/runtime/BenchmarkCache.h b/compiler_gym/service/runtime/BenchmarkCache.h
new file mode 100644
index 000000000..f285e9f1d
--- /dev/null
+++ b/compiler_gym/service/runtime/BenchmarkCache.h
@@ -0,0 +1,103 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+#pragma once
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include "boost/filesystem.hpp"
+#include "compiler_gym/service/proto/compiler_gym_service.pb.h"
+
+namespace compiler_gym::runtime {
+
+constexpr size_t kEvictionSizeInBytes = 512 * 1024 * 1024;
+
+/**
+ * A cache of Benchmark protocol messages.
+ *
+ * This object caches Benchmark messages by URI. Once the cache reaches a
+ * predetermined size, benchmarks are evicted randomly until the capacity is
+ * reduced to 50%.
+ */
+class BenchmarkCache {
+ public:
+ /**
+ * Constructor.
+ *
+ * @param maxSizeInBytes The maximum size of the benchmark buffer before an
+ * automated eviction is run.
+ * @param rand A random start used for selecting benchmarks for random
+ * eviction.
+ */
+ BenchmarkCache(size_t maxSizeInBytes = kEvictionSizeInBytes,
+ std::optional rand = std::nullopt);
+
+ /**
+ * Lookup a benchmark. The pointer set by this method is valid only until the
+ * next call to add().
+ *
+ * @param uri The URI of the benchmark.
+ * @return A Benchmark pointer.
+ */
+ const Benchmark* get(const std::string& uri) const;
+
+ /**
+ * Move-insert the given benchmark to the cache.
+ *
+ * @param benchmark A benchmark to insert.
+ */
+ void add(const Benchmark&& benchmark);
+
+ /**
+ * Get the number of elements in the cache.
+ *
+ * @return A nonnegative integer.
+ */
+ inline size_t size() const { return benchmarks_.size(); };
+
+ /**
+ * Get the size of the cache in bytes.
+ *
+ * @return A nonnegative integer.
+ */
+ inline size_t sizeInBytes() const { return sizeInBytes_; };
+
+ /**
+ * The maximum size of the cache before an eviction.
+ *
+ * @return A nonnegative integer.
+ */
+ inline size_t maxSizeInBytes() const { return maxSizeInBytes_; };
+
+ /**
+ * Set a new maximum size of the cache.
+ *
+ * @param maxSizeInBytes A number of bytes.
+ */
+ void setMaxSizeInBytes(size_t maxSizeInBytes);
+
+ /**
+ * Evict benchmarks randomly to reduce the capacity to the given size.
+ *
+ * If `targetSizeInBytes` is not provided, benchmarks are evicted to 50% of
+ * `maxSizeInBytes`.
+ *
+ * @param targetSizeInBytes A target maximum size in bytes.
+ */
+ void evictToCapacity(std::optional targetSizeInBytes = std::nullopt);
+
+ private:
+ std::unordered_map benchmarks_;
+
+ std::mt19937_64 rand_;
+ size_t maxSizeInBytes_;
+ size_t sizeInBytes_;
+};
+
+} // namespace compiler_gym::runtime
diff --git a/compiler_gym/service/runtime/CompilerGymService.h b/compiler_gym/service/runtime/CompilerGymService.h
new file mode 100644
index 000000000..16672caaa
--- /dev/null
+++ b/compiler_gym/service/runtime/CompilerGymService.h
@@ -0,0 +1,96 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+#pragma once
+
+#include
+
+#include
+#include
+
+#include "boost/filesystem.hpp"
+#include "compiler_gym/service/CompilationSession.h"
+#include "compiler_gym/service/proto/compiler_gym_service.grpc.pb.h"
+#include "compiler_gym/service/proto/compiler_gym_service.pb.h"
+#include "compiler_gym/service/runtime/BenchmarkCache.h"
+
+namespace compiler_gym::runtime {
+
+/**
+ * A default implementation of the CompilerGymService.
+ *
+ * When parametrized by a CompilationSession subclass, this provides the RPC
+ * handling logic to run a gym service. User should call
+ * createAndRunCompilerGymService() rather than interacting with this class
+ * directly.
+ */
+template
+class CompilerGymService final : public compiler_gym::CompilerGymService::Service {
+ public:
+ CompilerGymService(const boost::filesystem::path& workingDirectory,
+ std::unique_ptr benchmarks = nullptr);
+
+ // RPC endpoints.
+ grpc::Status GetVersion(grpc::ServerContext* context, const GetVersionRequest* request,
+ GetVersionReply* reply) final override;
+
+ grpc::Status GetSpaces(grpc::ServerContext* context, const GetSpacesRequest* request,
+ GetSpacesReply* reply) final override;
+
+ grpc::Status StartSession(grpc::ServerContext* context, const StartSessionRequest* request,
+ StartSessionReply* reply) final override;
+
+ grpc::Status ForkSession(grpc::ServerContext* context, const ForkSessionRequest* request,
+ ForkSessionReply* reply) final override;
+
+ grpc::Status EndSession(grpc::ServerContext* context, const EndSessionRequest* request,
+ EndSessionReply* reply) final override;
+
+ // NOTE: Step() is not thread safe. The underlying assumption is that each
+ // CompilationSessionType is managed by a single thread, so race conditions
+ // between operations that affect the same CompilationSessionType are not
+ // protected against.
+ grpc::Status Step(grpc::ServerContext* context, const StepRequest* request,
+ StepReply* reply) final override;
+
+ grpc::Status AddBenchmark(grpc::ServerContext* context, const AddBenchmarkRequest* request,
+ AddBenchmarkReply* reply) final override;
+
+ inline BenchmarkCache& benchmarks() { return *benchmarks_; }
+
+ // Get the number of active sessions.
+ inline int sessionCount() const { return static_cast(sessions_.size()); }
+
+ protected:
+ [[nodiscard]] grpc::Status session(uint64_t id, CompilationSession** environment);
+
+ [[nodiscard]] grpc::Status session(uint64_t id, const CompilationSession** environment) const;
+
+ [[nodiscard]] grpc::Status action_space(const CompilationSession* session, int index,
+ const ActionSpace** actionSpace) const;
+
+ [[nodiscard]] grpc::Status observation_space(const CompilationSession* session, int index,
+ const ObservationSpace** observationSpace) const;
+
+ inline const boost::filesystem::path& workingDirectory() const { return workingDirectory_; }
+
+ // Add the given session and return its ID.
+ uint64_t addSession(std::unique_ptr session);
+
+ private:
+ const boost::filesystem::path workingDirectory_;
+ const std::vector actionSpaces_;
+ const std::vector observationSpaces_;
+
+ std::unordered_map> sessions_;
+ std::unique_ptr benchmarks_;
+
+ // Mutex used to ensure thread safety of creation and destruction of sessions.
+ std::mutex sessionsMutex_;
+ uint64_t nextSessionId_;
+};
+
+} // namespace compiler_gym::runtime
+
+#include "compiler_gym/service/runtime/CompilerGymServiceImpl.h"
diff --git a/compiler_gym/service/runtime/CompilerGymServiceImpl.h b/compiler_gym/service/runtime/CompilerGymServiceImpl.h
new file mode 100644
index 000000000..6bb976514
--- /dev/null
+++ b/compiler_gym/service/runtime/CompilerGymServiceImpl.h
@@ -0,0 +1,249 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the LICENSE file
+// in the root directory of this source tree.
+//
+// Private implementation of the CompilerGymService template class. Do not
+// include this header directly! Use
+// compiler_gym/service/runtimeCompilerGymService.h.
+#pragma once
+
+#include
+
+#include "compiler_gym/util/GrpcStatusMacros.h"
+#include "compiler_gym/util/Version.h"
+
+namespace compiler_gym::runtime {
+
+template
+CompilerGymService::CompilerGymService(
+ const boost::filesystem::path& workingDirectory, std::unique_ptr benchmarks)
+ : workingDirectory_(workingDirectory),
+ actionSpaces_(CompilationSessionType(workingDirectory).getActionSpaces()),
+ observationSpaces_(CompilationSessionType(workingDirectory).getObservationSpaces()),
+ benchmarks_(benchmarks ? std::move(benchmarks) : std::make_unique()),
+ nextSessionId_(0) {}
+
+template
+grpc::Status CompilerGymService::GetVersion(
+ grpc::ServerContext* context, const GetVersionRequest* request, GetVersionReply* reply) {
+ VLOG(2) << "GetVersion()";
+ reply->set_service_version(COMPILER_GYM_VERSION);
+ CompilationSessionType environment(workingDirectory());
+ reply->set_compiler_version(environment.getCompilerVersion());
+ return grpc::Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::GetSpaces(grpc::ServerContext* context,
+ const GetSpacesRequest* request,
+ GetSpacesReply* reply) {
+ VLOG(2) << "GetSpaces()";
+ for (const auto& actionSpace : actionSpaces_) {
+ *reply->add_action_space_list() = actionSpace;
+ }
+ for (const auto& observationSpace : observationSpaces_) {
+ *reply->add_observation_space_list() = observationSpace;
+ }
+ return grpc::Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::StartSession(
+ grpc::ServerContext* context, const StartSessionRequest* request, StartSessionReply* reply) {
+ if (!request->benchmark().size()) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "No benchmark URI set for StartSession()");
+ }
+
+ const std::lock_guard lock(sessionsMutex_);
+ VLOG(1) << "StartSession(" << request->benchmark() << "), " << sessionCount()
+ << " active sessions";
+
+ const Benchmark* benchmark = benchmarks().get(request->benchmark());
+ if (!benchmark) {
+ return grpc::Status(grpc::StatusCode::NOT_FOUND, "Benchmark not found");
+ }
+
+ // Construct the new session.
+ auto environment = std::make_unique(workingDirectory());
+
+ // Resolve the action space.
+ const ActionSpace* actionSpace;
+ RETURN_IF_ERROR(action_space(environment.get(), request->action_space(), &actionSpace));
+
+ // Initialize the session.
+ RETURN_IF_ERROR(environment->init(*actionSpace, *benchmark));
+
+ // Compute the initial observations.
+ for (int i = 0; i < request->observation_space_size(); ++i) {
+ const ObservationSpace* observationSpace;
+ RETURN_IF_ERROR(
+ observation_space(environment.get(), request->observation_space(i), &observationSpace));
+ RETURN_IF_ERROR(environment->computeObservation(*observationSpace, *reply->add_observation()));
+ }
+
+ reply->set_session_id(addSession(std::move(environment)));
+
+ return grpc::Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::ForkSession(
+ grpc::ServerContext* context, const ForkSessionRequest* request, ForkSessionReply* reply) {
+ const std::lock_guard lock(sessionsMutex_);
+
+ CompilationSession* baseSession;
+ RETURN_IF_ERROR(session(request->session_id(), &baseSession));
+ VLOG(1) << "ForkSession(" << request->session_id() << "), [" << nextSessionId_ << "]";
+
+ // Construct the new session.
+ auto forked = std::make_unique(workingDirectory());
+
+ // Initialize from the base environment.
+ RETURN_IF_ERROR(forked->init(baseSession));
+
+ reply->set_session_id(addSession(std::move(forked)));
+
+ return grpc::Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::EndSession(
+ grpc::ServerContext* context, const EndSessionRequest* request, EndSessionReply* reply) {
+ VLOG(1) << "EndSession(" << request->session_id() << "), " << sessionCount() - 1
+ << " sessions remaining";
+
+ const std::lock_guard lock(sessionsMutex_);
+
+ // Note that unlike the other methods, no error is thrown if the requested
+ // session does not exist.
+ if (sessions_.find(request->session_id()) != sessions_.end()) {
+ const CompilationSession* environment;
+ RETURN_IF_ERROR(session(request->session_id(), &environment));
+ sessions_.erase(request->session_id());
+ }
+
+ reply->set_remaining_sessions(sessionCount());
+ return Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::Step(grpc::ServerContext* context,
+ const StepRequest* request,
+ StepReply* reply) {
+ CompilationSession* environment;
+ RETURN_IF_ERROR(session(request->session_id(), &environment));
+
+ VLOG(2) << "Session " << request->session_id() << " Step()";
+
+ bool endOfEpisode = false;
+ std::optional newActionSpace;
+ bool actionsHadNoEffect = true;
+
+ // Apply the actions.
+ for (int i = 0; i < request->action_size(); ++i) {
+ bool actionHadNoEffect = false;
+ std::optional newActionSpaceFromAction;
+ RETURN_IF_ERROR(environment->applyAction(request->action(i), endOfEpisode,
+ newActionSpaceFromAction, actionHadNoEffect));
+ actionsHadNoEffect &= actionHadNoEffect;
+ if (newActionSpaceFromAction.has_value()) {
+ newActionSpace = *newActionSpaceFromAction;
+ }
+ if (endOfEpisode) {
+ break;
+ }
+ }
+
+ // Compute the requested observations.
+ for (int i = 0; i < request->observation_space_size(); ++i) {
+ const ObservationSpace* observationSpace;
+ RETURN_IF_ERROR(
+ observation_space(environment, request->observation_space(i), &observationSpace));
+ DCHECK(observationSpace) << "No observation space set";
+ RETURN_IF_ERROR(environment->computeObservation(*observationSpace, *reply->add_observation()));
+ }
+
+ // Call the end-of-step callback.
+ RETURN_IF_ERROR(environment->endOfStep(actionsHadNoEffect, endOfEpisode, newActionSpace));
+
+ reply->set_action_had_no_effect(actionsHadNoEffect);
+ if (newActionSpace.has_value()) {
+ *reply->mutable_new_action_space() = *newActionSpace;
+ }
+ reply->set_end_of_session(endOfEpisode);
+ return Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::AddBenchmark(
+ grpc::ServerContext* context, const AddBenchmarkRequest* request, AddBenchmarkReply* reply) {
+ // We need to grab the sessions lock here to ensure thread safe access to the
+ // benchmarks cache.
+ const std::lock_guard lock(sessionsMutex_);
+
+ VLOG(2) << "AddBenchmark()";
+ for (int i = 0; i < request->benchmark_size(); ++i) {
+ benchmarks().add(std::move(request->benchmark(i)));
+ }
+
+ return grpc::Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::session(uint64_t id,
+ CompilationSession** environment) {
+ auto it = sessions_.find(id);
+ if (it == sessions_.end()) {
+ return Status(grpc::StatusCode::NOT_FOUND, fmt::format("Session not found: {}", id));
+ }
+
+ *environment = it->second.get();
+ return grpc::Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::session(
+ uint64_t id, const CompilationSession** environment) const {
+ auto it = sessions_.find(id);
+ if (it == sessions_.end()) {
+ return grpc::Status(grpc::StatusCode::NOT_FOUND, fmt::format("Session not found: {}", id));
+ }
+
+ *environment = it->second.get();
+ return grpc::Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::action_space(
+ const CompilationSession* session, int index, const ActionSpace** actionSpace) const {
+ if (index < 0 || index >= static_cast(actionSpaces_.size())) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ fmt::format("Action space index out of range: {}", index));
+ }
+ *actionSpace = &actionSpaces_[index];
+ return Status::OK;
+}
+
+template
+grpc::Status CompilerGymService::observation_space(
+ const CompilationSession* session, int index, const ObservationSpace** observationSpace) const {
+ if (index < 0 || index >= static_cast(observationSpaces_.size())) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ fmt::format("Observation space index out of range: {}", index));
+ }
+ *observationSpace = &observationSpaces_[index];
+ return Status::OK;
+}
+
+template
+uint64_t CompilerGymService::addSession(
+ std::unique_ptr session) {
+ uint64_t id = nextSessionId_;
+ sessions_[id] = std::move(session);
+ ++nextSessionId_;
+ return id;
+}
+
+} // namespace compiler_gym::runtime
diff --git a/compiler_gym/util/RunService.cc b/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.cc
similarity index 61%
rename from compiler_gym/util/RunService.cc
rename to compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.cc
index e484c7109..28319136d 100644
--- a/compiler_gym/util/RunService.cc
+++ b/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.cc
@@ -2,7 +2,7 @@
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
-#include "compiler_gym/util/RunService.h"
+#include "compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h"
DEFINE_string(
working_dir, "",
@@ -10,3 +10,14 @@ DEFINE_string(
DEFINE_string(port, "0",
"The port to listen on. If 0, an unused port will be selected. The selected port is "
"written to /port.txt.");
+
+namespace compiler_gym::runtime {
+
+std::promise shutdownSignal;
+
+void shutdown_handler(int signum) {
+ VLOG(1) << "Service received signal: " << signum;
+ shutdownSignal.set_value();
+}
+
+} // namespace compiler_gym::runtime
diff --git a/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h b/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h
new file mode 100644
index 000000000..09147edac
--- /dev/null
+++ b/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h
@@ -0,0 +1,150 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+//
+// Private implementation of the createAndRunCompilerGymService(). Do not
+// include this header directly! Use compiler_gym/service/runtime/Runtime.h.
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+#include "boost/filesystem.hpp"
+#include "compiler_gym/service/proto/compiler_gym_service.pb.h"
+#include "compiler_gym/service/runtime/CompilerGymService.h"
+
+DECLARE_string(port);
+DECLARE_string(working_dir);
+
+namespace compiler_gym::runtime {
+
+extern std::promise shutdownSignal;
+
+// Increase maximum message size beyond the 4MB default as inbound message
+// may be larger (e.g., in the case of IR strings).
+constexpr size_t kMaxMessageSizeInBytes = 512 * 1024 * 1024;
+
+void shutdown_handler(int signum);
+
+// Create a service, configured using --port and --working_dir flags, and run
+// it. This function never returns.
+//
+// CompilationService must be a valid compiler_gym::CompilationService subclass
+// that implements the abstract methods and takes a single-argument working
+// directory constructor:
+//
+// class MyCompilationService final : public CompilationService {
+// public:
+// ...
+// }
+//
+// Usage:
+//
+// int main(int argc, char** argv) {
+// createAndRunCompilerGymServiceImpl(argc, argv, "usage string");
+// }
+template
+[[noreturn]] void createAndRunCompilerGymServiceImpl(int argc, char** argv, const char* usage) {
+ // Register a signal handler for SIGTERM that will set the shutdown_signal
+ // future value.
+ std::signal(SIGTERM, shutdown_handler);
+
+ gflags::SetUsageMessage(std::string(usage));
+
+ // Parse the command line arguments and die if any are unrecognized.
+ gflags::ParseCommandLineFlags(&argc, &argv, /*remove_flags=*/true);
+ if (argc > 1) {
+ std::cerr << "ERROR: unknown command line argument '" << argv[1] << '\'';
+ exit(1);
+ }
+
+ // Set up the working and logging directories.
+ boost::filesystem::path workingDirectory{FLAGS_working_dir};
+ bool createdWorkingDir = false;
+ if (FLAGS_working_dir.empty()) {
+ // If no working directory was set, create one.
+ workingDirectory = boost::filesystem::unique_path(boost::filesystem::temp_directory_path() /
+ "compiler_gym-service-%%%%-%%%%");
+ boost::filesystem::create_directories(workingDirectory / "logs");
+ FLAGS_working_dir = workingDirectory.string();
+ createdWorkingDir = true;
+ }
+
+ FLAGS_log_dir = workingDirectory.string() + "/logs";
+ if (!createdWorkingDir && !boost::filesystem::is_directory(FLAGS_log_dir)) {
+ std::cerr << "ERROR: logging directory '" << FLAGS_log_dir << "' not found";
+ exit(1);
+ }
+
+ google::InitGoogleLogging(argv[0]);
+
+ CompilerGymService service{workingDirectory};
+
+ grpc::ServerBuilder builder;
+ builder.RegisterService(&service);
+
+ builder.SetMaxMessageSize(kMaxMessageSizeInBytes);
+
+ // Start a channel on the port.
+ int port;
+ std::string serverAddress = "0.0.0.0:" + (FLAGS_port.empty() ? "0" : FLAGS_port);
+ builder.AddListeningPort(serverAddress, grpc::InsecureServerCredentials(), &port);
+
+ // Start the server.
+ std::unique_ptr server(builder.BuildAndStart());
+ CHECK(server) << "Failed to build RPC service";
+
+ {
+ // Write the port to a /port.txt file, which an external
+ // process can read to determine how to get in touch. First write the port
+ // to a temporary file and rename it, since renaming is atomic.
+ const boost::filesystem::path portPath = workingDirectory / "port.txt";
+ std::ofstream out(portPath.string() + ".tmp");
+ out << std::to_string(port) << std::endl;
+ out.close();
+ boost::filesystem::rename(portPath.string() + ".tmp", portPath);
+ }
+
+ {
+ // Write the process ID to a /pid.txt file, which can
+ // external process can later use to determine if this service is still
+ // alive.
+ const boost::filesystem::path pidPath = workingDirectory / "pid.txt";
+ std::ofstream out(pidPath.string() + ".tmp");
+ out << std::to_string(getpid()) << std::endl;
+ out.close();
+ boost::filesystem::rename(pidPath.string() + ".tmp", pidPath);
+ }
+
+ LOG(INFO) << "Service " << workingDirectory << " listening on " << port << ", PID = " << getpid();
+
+ // Block on the RPC service in a separate thread. This enables the current
+ // thread to handle the shutdown routine.
+ std::thread serverThread([&]() { server->Wait(); });
+
+ // Block until this shutdown signal is received.
+ shutdownSignal.get_future().wait();
+ VLOG(2) << "Shutting down the RPC service";
+ server->Shutdown();
+ serverThread.join();
+
+ if (service.sessionCount()) {
+ std::cerr << "ERROR: Killing a service with " << service.sessionCount()
+ << (service.sessionCount() > 1 ? " active sessions!" : " active session!")
+ << std::endl;
+ exit(6);
+ }
+
+ exit(0);
+}
+
+} // namespace compiler_gym::runtime
diff --git a/compiler_gym/service/runtime/Runtime.h b/compiler_gym/service/runtime/Runtime.h
new file mode 100644
index 000000000..ef154bb1c
--- /dev/null
+++ b/compiler_gym/service/runtime/Runtime.h
@@ -0,0 +1,39 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+#pragma once
+
+#include "compiler_gym/service/runtime/CompilerGymService.h"
+#include "compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h"
+
+namespace compiler_gym::runtime {
+
+/**
+ * Create and run an RPC service for the given compilation session.
+ *
+ * This should be called on its own in a self contained script to implement a
+ * compilation service. Example:
+ *
+ * \code{.cpp}
+ * #include "compiler_gym/service/runtime/Runtime.h"
+ * #include "my_compiler_service/MyCompilationSession.h"
+ *
+ * int main(int argc, char** argv) {
+ * createAndRunCompilerGymService(
+ * argc, argc, "My compiler service"
+ * );
+ * }
+ * \endcode
+ *
+ * This function never returns.
+ *
+ * @tparam CompilationSessionType A sublass of CompilationSession that provides
+ * implementations of the abstract methods.
+ */
+template
+[[noreturn]] void createAndRunCompilerGymService(int argc, char** argv, const char* usage) {
+ createAndRunCompilerGymServiceImpl(argc, argv, usage);
+}
+
+} // namespace compiler_gym::runtime
diff --git a/compiler_gym/service/runtime/__init__.py b/compiler_gym/service/runtime/__init__.py
new file mode 100644
index 000000000..579dd8b87
--- /dev/null
+++ b/compiler_gym/service/runtime/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from compiler_gym.service.runtime.create_and_run_compiler_gym_service import (
+ create_and_run_compiler_gym_service,
+)
+
+__all__ = [
+ "create_and_run_compiler_gym_service",
+]
diff --git a/compiler_gym/service/runtime/benchmark_cache.py b/compiler_gym/service/runtime/benchmark_cache.py
new file mode 100644
index 000000000..72a862b75
--- /dev/null
+++ b/compiler_gym/service/runtime/benchmark_cache.py
@@ -0,0 +1,128 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import logging
+from typing import Dict, Optional
+
+import numpy as np
+
+from compiler_gym.service.proto import Benchmark
+
+MAX_SIZE_IN_BYTES = 512 * 104 * 1024
+
+
+class BenchmarkCache:
+ """An in-memory cache of Benchmark messages.
+
+ This object caches Benchmark messages by URI. Once the cache reaches a
+ predetermined size, benchmarks are evicted randomly until the capacity is
+ reduced to 50%.
+ """
+
+ def __init__(
+ self,
+ max_size_in_bytes: int = MAX_SIZE_IN_BYTES,
+ rng: Optional[np.random.Generator] = None,
+ logger: Optional[logging.Logger] = None,
+ ):
+ self._max_size_in_bytes = max_size_in_bytes
+ self.rng = rng or np.random.default_rng()
+ self.logger = logger or logging.getLogger("compiler_gym")
+
+ self._benchmarks: Dict[str, Benchmark] = {}
+ self._size_in_bytes = 0
+
+ def __getitem__(self, uri: str) -> Benchmark:
+ """Get a benchmark by URI. Raises KeyError."""
+ item = self._benchmarks.get(uri)
+ if item is None:
+ raise KeyError(uri)
+ return item
+
+ def __contains__(self, uri: str):
+ """Whether URI is in cache."""
+ return uri in self._benchmarks
+
+ def __setitem__(self, uri: str, benchmark: Benchmark):
+ """Add benchmark to cache."""
+ self.logger.debug(
+ "Caching benchmark %s. Cache size = %d bytes, %d items",
+ uri,
+ self.size_in_bytes,
+ self.size,
+ )
+
+ # Remove any existing value to keep the cache size consistent.
+ if uri in self._benchmarks:
+ self._size_in_bytes -= self._benchmarks[uri].ByteSize()
+ del self._benchmarks[uri]
+
+ size = benchmark.ByteSize()
+ if self.size_in_bytes + size > self.max_size_in_bytes:
+ if size > self.max_size_in_bytes:
+ self.logger.warning(
+ "Adding new benchmark with size %d bytes exceeds total "
+ "target cache size of %d bytes",
+ size,
+ self.max_size_in_bytes,
+ )
+ else:
+ self.logger.debug(
+ "Adding new benchmark with size %d bytes "
+ "exceeds maximum size %d bytes, %d items",
+ size,
+ self.max_size_in_bytes,
+ self.size,
+ )
+ self.evict_to_capacity()
+
+ self._benchmarks[uri] = benchmark
+ self._size_in_bytes += size
+
+ def evict_to_capacity(self, target_size_in_bytes: Optional[int] = None) -> None:
+ """Evict benchmarks randomly to reduce the capacity below 50%."""
+ evicted = 0
+ target_size_in_bytes = (
+ self.max_size_in_bytes // 2
+ if target_size_in_bytes is None
+ else target_size_in_bytes
+ )
+
+ while self.size and self.size_in_bytes > target_size_in_bytes:
+ evicted += 1
+ key = self.rng.choice(list(self._benchmarks.keys()))
+ self._size_in_bytes -= self._benchmarks[key].ByteSize()
+ del self._benchmarks[key]
+
+ if evicted:
+ self.logger.info(
+ "Evicted %d benchmarks from cache. "
+ "Benchmark cache size now %d bytes, %d items",
+ evicted,
+ self.size_in_bytes,
+ self.size,
+ )
+
+ @property
+ def size(self) -> int:
+ """The number of items in the cache."""
+ return len(self._benchmarks)
+
+ @property
+ def size_in_bytes(self) -> int:
+ """The combined size of the elements in the cache, excluding the
+ cache overhead.
+ """
+ return self._size_in_bytes
+
+ @property
+ def max_size_in_bytes(self) -> int:
+ """The maximum size of the cache."""
+ return self._max_size_in_bytes
+
+ @max_size_in_bytes.setter
+ def max_size_in_bytes(self, value: int) -> None:
+ """Set a new maximum cache size."""
+ self._max_size_in_bytes = value
+ self.evict_to_capacity(target_size_in_bytes=value)
diff --git a/compiler_gym/service/runtime/compiler_gym_service.py b/compiler_gym/service/runtime/compiler_gym_service.py
new file mode 100644
index 000000000..2a2bd8a0e
--- /dev/null
+++ b/compiler_gym/service/runtime/compiler_gym_service.py
@@ -0,0 +1,171 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import logging
+from contextlib import contextmanager
+from pathlib import Path
+from threading import Lock
+from typing import Dict
+
+from grpc import StatusCode
+
+from compiler_gym.service.compilation_session import CompilationSession
+from compiler_gym.service.proto import AddBenchmarkReply, AddBenchmarkRequest
+from compiler_gym.service.proto import (
+ CompilerGymServiceServicer as CompilerGymServiceServicerStub,
+)
+from compiler_gym.service.proto import (
+ EndSessionReply,
+ EndSessionRequest,
+ GetSpacesReply,
+ GetSpacesRequest,
+ GetVersionReply,
+ GetVersionRequest,
+ StartSessionReply,
+ StartSessionRequest,
+ StepReply,
+ StepRequest,
+)
+from compiler_gym.service.runtime.benchmark_cache import BenchmarkCache
+from compiler_gym.util.version import __version__
+
+
+@contextmanager
+def exception_to_grpc_status(context):
+ def handle_exception_as(exception, code):
+ context.set_code(code)
+ context.set_details(str(exception))
+
+ try:
+ yield
+ except ValueError as e:
+ handle_exception_as(e, StatusCode.INVALID_ARGUMENT)
+ except LookupError as e:
+ handle_exception_as(e, StatusCode.NOT_FOUND)
+ except NotImplementedError as e:
+ handle_exception_as(e, StatusCode.UNIMPLEMENTED)
+ except FileNotFoundError as e:
+ handle_exception_as(e, StatusCode.UNIMPLEMENTED)
+ except TypeError as e:
+ handle_exception_as(e, StatusCode.FAILED_PRECONDITION)
+ except TimeoutError as e:
+ handle_exception_as(e, StatusCode.DEADLINE_EXCEEDED)
+
+
+class CompilerGymService(CompilerGymServiceServicerStub):
+ def __init__(self, working_directory: Path, compilation_session_type):
+ self.working_directory = working_directory
+ self.benchmarks = BenchmarkCache()
+
+ self.compilation_session_type = compilation_session_type
+ self.sessions: Dict[int, CompilationSession] = {}
+ self.sessions_lock = Lock()
+ self.next_session_id: int = 0
+
+ self.action_spaces = compilation_session_type.action_spaces
+ self.observation_spaces = compilation_session_type.observation_spaces
+
+ def GetVersion(self, request: GetVersionRequest, context) -> GetVersionReply:
+ del context # Unused
+ del request # Unused
+ logging.debug("GetVersion()")
+ return GetVersionReply(
+ service_version=__version__,
+ compiler_version=self.compilation_session_type.compiler_version,
+ )
+
+ def GetSpaces(self, request: GetSpacesRequest, context) -> GetSpacesReply:
+ del request # Unused
+ logging.debug("GetSpaces()")
+ with exception_to_grpc_status(context):
+ return GetSpacesReply(
+ action_space_list=self.action_spaces,
+ observation_space_list=self.observation_spaces,
+ )
+
+ def StartSession(self, request: StartSessionRequest, context) -> StartSessionReply:
+ """Create a new compilation session."""
+ logging.debug("StartSession(%s), [%d]", request.benchmark, self.next_session_id)
+ reply = StartSessionReply()
+
+ if not request.benchmark:
+ context.set_code(StatusCode.INVALID_ARGUMENT)
+ context.set_details("No benchmark URI set for StartSession()")
+ return reply
+
+ with self.sessions_lock, exception_to_grpc_status(context):
+ if request.benchmark not in self.benchmarks:
+ context.set_code(StatusCode.NOT_FOUND)
+ context.set_details("Benchmark not found")
+ return reply
+
+ session = self.compilation_session_type(
+ working_directory=self.working_directory,
+ action_space=self.action_spaces[request.action_space],
+ benchmark=self.benchmarks[request.benchmark],
+ )
+
+ # Generate the initial observations.
+ reply.observation.extend(
+ [
+ session.get_observation(self.observation_spaces[obs])
+ for obs in request.observation_space
+ ]
+ )
+
+ reply.session_id = self.next_session_id
+ self.sessions[reply.session_id] = session
+ self.next_session_id += 1
+
+ return reply
+
+ def EndSession(self, request: EndSessionRequest, context) -> EndSessionReply:
+ del context # Unused
+ logging.debug(
+ "EndSession(%d), %d sessions remaining",
+ request.session_id,
+ len(self.sessions) - 1,
+ )
+
+ with self.sessions_lock:
+ if request.session_id in self.sessions:
+ del self.sessions[request.session_id]
+ return EndSessionReply(remaining_sessions=len(self.sessions))
+
+ def Step(self, request: StepRequest, context) -> StepReply:
+ logging.debug("Step()")
+ reply = StepReply()
+
+ if request.session_id not in self.sessions:
+ context.set_code(StatusCode.NOT_FOUND)
+ context.set_details(f"Session not found: {request.session_id}")
+ return reply
+
+ session = self.sessions[request.session_id]
+
+ reply.action_had_no_effect = True
+
+ with exception_to_grpc_status(context):
+ for action in request.action:
+ reply.end_of_session, nas, ahne = session.apply_action(action)
+ reply.action_had_no_effect &= ahne
+ if nas:
+ reply.new_action_space.CopyFrom(nas)
+
+ reply.observation.extend(
+ [
+ session.get_observation(self.observation_spaces[obs])
+ for obs in request.observation_space
+ ]
+ )
+
+ return reply
+
+ def AddBenchmark(self, request: AddBenchmarkRequest, context) -> AddBenchmarkReply:
+ del context # Unused
+ reply = AddBenchmarkReply()
+ with self.sessions_lock:
+ for benchmark in request.benchmark:
+ self.benchmarks[benchmark.uri] = benchmark
+ return reply
diff --git a/compiler_gym/service/runtime/create_and_run_compiler_gym_service.py b/compiler_gym/service/runtime/create_and_run_compiler_gym_service.py
new file mode 100644
index 000000000..f02f35a03
--- /dev/null
+++ b/compiler_gym/service/runtime/create_and_run_compiler_gym_service.py
@@ -0,0 +1,142 @@
+#! /usr/bin/env python3
+#
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""An example CompilerGym service in python."""
+import os
+import sys
+from concurrent import futures
+from multiprocessing import cpu_count
+from pathlib import Path
+from signal import SIGTERM, signal
+from tempfile import mkdtemp
+from threading import Event, Thread
+from typing import Type
+
+import grpc
+from absl import app, flags, logging
+
+from compiler_gym.service.compilation_session import CompilationSession
+from compiler_gym.service.proto import compiler_gym_service_pb2_grpc
+from compiler_gym.service.runtime.compiler_gym_service import CompilerGymService
+from compiler_gym.util import debug_util as dbg
+from compiler_gym.util.filesystem import atomic_file_write
+from compiler_gym.util.shell_format import plural
+
+flags.DEFINE_string("working_dir", "", "Path to use as service working directory")
+flags.DEFINE_integer("port", 0, "The service listening port")
+flags.DEFINE_integer(
+ "rpc_service_threads", cpu_count(), "The number of server worker threads"
+)
+flags.DEFINE_integer("logbuflevel", 0, "Flag for compatability with C++ service.")
+FLAGS = flags.FLAGS
+
+MAX_MESSAGE_SIZE_IN_BYTES = 512 * 1024 * 1024
+
+
+shutdown_signal = Event()
+
+
+# NOTE(cummins): This script is executed in a subprocess, so code coverage
+# tracking does not work. As such we use "# pragma: no cover" annotation for all
+# functions.
+def _shutdown_handler(signal_number, stack_frame): # pragma: no cover
+ del stack_frame # Unused
+ logging.info("Service received signal: %d", signal_number)
+ shutdown_signal.set()
+
+
+def create_and_run_compiler_gym_service(
+ compilation_session_type: Type[CompilationSession],
+): # pragma: no cover
+ """Create and run an RPC service for the given compilation session.
+
+ This should be called on its own in a self contained script to implement a
+ compilation service. Example:
+
+ .. code-block:: python
+
+ from compiler_gym.service import runtime
+ from my_compiler_service import MyCompilationSession
+
+ if __name__ == "__main__":
+ runtime.create_and_run_compiler_gym_service(MyCompilationSession)
+
+ This function never returns.
+
+ :param compilation_session_type: A sublass of :class:`CompilationSession
+ ` that provides implementations
+ of the abstract methods.
+ """
+
+ def main(argv):
+ # Register a signal handler for SIGTERM that will set the shutdownSignal
+ # future value.
+ signal(SIGTERM, _shutdown_handler)
+
+ argv = [x for x in argv if x.strip()]
+ if len(argv) > 1:
+ print(
+ f"ERROR: Unrecognized command line argument '{argv[1]}'",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+ working_dir = Path(FLAGS.working_dir or mkdtemp(prefix="compiler_gym-service-"))
+ (working_dir / "logs").mkdir(exist_ok=True, parents=True)
+
+ FLAGS.log_dir = str(working_dir / "logs")
+ logging.get_absl_handler().use_absl_log_file()
+ logging.set_verbosity(dbg.get_logging_level())
+
+ # Create the service.
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=FLAGS.rpc_service_threads),
+ options=[
+ ("grpc.max_send_message_length", MAX_MESSAGE_SIZE_IN_BYTES),
+ ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE_IN_BYTES),
+ ],
+ )
+ service = CompilerGymService(
+ working_directory=working_dir,
+ compilation_session_type=compilation_session_type,
+ )
+ compiler_gym_service_pb2_grpc.add_CompilerGymServiceServicer_to_server(
+ service, server
+ )
+ port = server.add_insecure_port("0.0.0.0:0")
+
+ with atomic_file_write(working_dir / "port.txt", fileobj=True, mode="w") as f:
+ f.write(str(port))
+
+ with atomic_file_write(working_dir / "pid.txt", fileobj=True, mode="w") as f:
+ f.write(str(os.getpid()))
+
+ logging.info(
+ "Service %s listening on %d, PID = %d", working_dir, port, os.getpid()
+ )
+
+ server.start()
+
+ # Block on the RPC service in a separate thread. This enables the
+ # current thread to handle the shutdown routine.
+ server_thread = Thread(target=server.wait_for_termination)
+ server_thread.start()
+
+ # Block until the shutdown signal is received.
+ shutdown_signal.wait()
+ logging.info("Shutting down the RPC service")
+ server.stop(60).wait()
+ server_thread.join()
+
+ if len(service.sessions):
+ print(
+ "ERROR: Killing a service with",
+ plural(len(service.session), "active session", "active sessions"),
+ file=sys.stderr,
+ )
+ sys.exit(6)
+
+ app.run(main)
diff --git a/compiler_gym/spaces/BUILD b/compiler_gym/spaces/BUILD
index 303b4c80e..34289a46c 100644
--- a/compiler_gym/spaces/BUILD
+++ b/compiler_gym/spaces/BUILD
@@ -38,6 +38,7 @@ py_library(
deps = [
":scalar",
"//compiler_gym/service",
+ "//compiler_gym/util",
],
)
diff --git a/compiler_gym/spaces/__init__.py b/compiler_gym/spaces/__init__.py
index 72dc9fc4f..6b06ba1a4 100644
--- a/compiler_gym/spaces/__init__.py
+++ b/compiler_gym/spaces/__init__.py
@@ -9,11 +9,11 @@
from compiler_gym.spaces.sequence import Sequence
__all__ = [
- "DefaultRewardFromObservation",
- "Scalar",
- "Sequence",
- "NamedDiscrete",
"Commandline",
"CommandlineFlag",
+ "DefaultRewardFromObservation",
+ "NamedDiscrete",
"Reward",
+ "Scalar",
+ "Sequence",
]
diff --git a/compiler_gym/spaces/named_discrete.py b/compiler_gym/spaces/named_discrete.py
index 604d4d36b..043bd192a 100644
--- a/compiler_gym/spaces/named_discrete.py
+++ b/compiler_gym/spaces/named_discrete.py
@@ -13,6 +13,7 @@ class NamedDiscrete(Discrete):
:ivar name: The name of the space. :code:`None` if the space has no name.
:vartype name: Optional[str]
+
:ivar names: A list of names for each element in the space.
:vartype names: List[str]
diff --git a/compiler_gym/spaces/reward.py b/compiler_gym/spaces/reward.py
index 2518a3971..7c2ea57ea 100644
--- a/compiler_gym/spaces/reward.py
+++ b/compiler_gym/spaces/reward.py
@@ -2,12 +2,12 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from typing import List, Optional, Tuple
+from typing import List, Optional, Tuple, Union
import numpy as np
-from compiler_gym.service import observation_t
from compiler_gym.spaces.scalar import Scalar
+from compiler_gym.util.gym_type_hints import ObservationType, RewardType
class Reward(Scalar):
@@ -39,11 +39,11 @@ def __init__(
self,
id: str,
observation_spaces: Optional[List[str]] = None,
- default_value: float = 0,
- min: Optional[float] = None,
- max: Optional[float] = None,
+ default_value: RewardType = 0,
+ min: Optional[RewardType] = None,
+ max: Optional[RewardType] = None,
default_negates_returns: bool = False,
- success_threshold: Optional[float] = None,
+ success_threshold: Optional[RewardType] = None,
deterministic: bool = False,
platform_dependent: bool = True,
):
@@ -82,7 +82,7 @@ def __init__(
)
self.id = id
self.observation_spaces = observation_spaces or []
- self.default_value: float = default_value
+ self.default_value: RewardType = default_value
self.default_negates_returns: bool = default_negates_returns
self.success_threshold = success_threshold
self.deterministic = deterministic
@@ -99,10 +99,10 @@ def reset(self, benchmark: str) -> None:
def update(
self,
- action: int,
- observations: List[observation_t],
+ actions: List[int],
+ observations: List[ObservationType],
observation_view: "compiler_gym.views.ObservationView", # noqa: F821
- ) -> float:
+ ) -> RewardType:
"""Calculate a reward for the given action.
:param action: The action performed.
@@ -114,7 +114,7 @@ def update(
"""
raise NotImplementedError("abstract class")
- def reward_on_error(self, episode_reward: float) -> float:
+ def reward_on_error(self, episode_reward: RewardType) -> RewardType:
"""Return the reward value for an error condition.
This method should be used to produce the reward value that should be
@@ -130,20 +130,28 @@ def reward_on_error(self, episode_reward: float) -> float:
return self.default_value
@property
- def range(self) -> Tuple[float, float]:
+ def range(self) -> Tuple[RewardType, RewardType]:
"""The lower and upper bounds of the reward."""
return (self.min, self.max)
def __repr__(self):
return self.id
+ def __eq__(self, other: Union["Reward", str]) -> bool:
+ if isinstance(other, str):
+ return self.id == other
+ elif isinstance(other, Reward):
+ return self.id == other.id
+ else:
+ return False
+
class DefaultRewardFromObservation(Reward):
def __init__(self, observation_name: str, **kwargs):
super().__init__(
observation_spaces=[observation_name], id=observation_name, **kwargs
)
- self.previous_value: Optional[observation_t] = None
+ self.previous_value: Optional[ObservationType] = None
def reset(self, benchmark: str) -> None:
"""Called on env.reset(). Reset incremental progress."""
@@ -153,15 +161,15 @@ def reset(self, benchmark: str) -> None:
def update(
self,
action: int,
- observations: List[observation_t],
+ observations: List[ObservationType],
observation_view: "compiler_gym.views.ObservationView", # noqa: F821
- ) -> float:
+ ) -> RewardType:
"""Called on env.step(). Compute and return new reward."""
del action # unused
del observation_view # unused
- value: float = observations[0]
+ value: RewardType = observations[0]
if self.previous_value is None:
self.previous_value = 0
- reward = float(value - self.previous_value)
+ reward = RewardType(value - self.previous_value)
self.previous_value = value
return reward
diff --git a/compiler_gym/third_party/cbench/BUILD b/compiler_gym/third_party/cbench/BUILD
index 4b4aa2cb2..8648ae400 100644
--- a/compiler_gym/third_party/cbench/BUILD
+++ b/compiler_gym/third_party/cbench/BUILD
@@ -222,7 +222,7 @@ genrule(
],
outs = ["cbench-v1/gsm.bc"],
cmd = (
- "mkdir -p $(@D) &&rsync -rL $$(dirname $(location @cBench//:readme))/telecom_gsm/ $(@D)/telecom_gsm_src/ &&" +
+ "mkdir -p $(@D) && rsync -rL $$(dirname $(location @cBench//:readme))/telecom_gsm/ $(@D)/telecom_gsm_src/ &&" +
"patch --quiet --forward $(@D)/telecom_gsm_src/src/add.c < $(location cBench-gsm-add.c.patch);" +
"$(location :make_llvm_module) $(@D)/telecom_gsm_src $@ -DSASR -DSTUPID_COMPILER -DNeedFunctionPrototypes=1"
),
@@ -238,7 +238,7 @@ genrule(
],
outs = ["cbench-v1/ispell.bc"],
cmd = (
- "mkdir -p $(@D) &&rsync -rL $$(dirname $(location @cBench//:readme))/office_ispell/ $(@D)/office_ispell_src/ &&" +
+ "mkdir -p $(@D) && rsync -rL $$(dirname $(location @cBench//:readme))/office_ispell/ $(@D)/office_ispell_src/ &&" +
"patch --quiet --forward $(@D)/office_ispell_src/src/correct.c < $(location cBench-ispell-correct.c.patch);" +
"$(location :make_llvm_module) $(@D)/office_ispell_src $@"
),
diff --git a/compiler_gym/third_party/inst2vec/BUILD b/compiler_gym/third_party/inst2vec/BUILD
index 0586c2a2e..d0b260fa9 100644
--- a/compiler_gym/third_party/inst2vec/BUILD
+++ b/compiler_gym/third_party/inst2vec/BUILD
@@ -33,16 +33,10 @@ py_library(
name = "inst2vec_preprocess",
srcs = ["inst2vec_preprocess.py"],
deps = [
- ":inst2vec_utils",
":rgx_utils",
],
)
-py_library(
- name = "inst2vec_utils",
- srcs = ["inst2vec_utils.py"],
-)
-
py_library(
name = "rgx_utils",
srcs = ["rgx_utils.py"],
diff --git a/compiler_gym/third_party/inst2vec/__init__.py b/compiler_gym/third_party/inst2vec/__init__.py
index cfcef729e..31704fdde 100644
--- a/compiler_gym/third_party/inst2vec/__init__.py
+++ b/compiler_gym/third_party/inst2vec/__init__.py
@@ -15,7 +15,7 @@
)
-class Inst2vecEncoder(object):
+class Inst2vecEncoder:
"""An LLVM encoder for inst2vec."""
def __init__(self):
diff --git a/compiler_gym/third_party/inst2vec/inst2vec_preprocess.py b/compiler_gym/third_party/inst2vec/inst2vec_preprocess.py
index f377c6be6..7610aa95f 100644
--- a/compiler_gym/third_party/inst2vec/inst2vec_preprocess.py
+++ b/compiler_gym/third_party/inst2vec/inst2vec_preprocess.py
@@ -30,485 +30,9 @@
import networkx as nx
-from compiler_gym.third_party.inst2vec import inst2vec_utils as i2v_utils
from compiler_gym.third_party.inst2vec import rgx_utils as rgx
-########################################################################################################################
-# Helper functions: list and stmt handling
-########################################################################################################################
-def string_of_items(dic):
- """
- Return a string containing all keys of a dictionary, separated by a comma
- (Helper function for structure inlining)
- :param dic: dictionary [key=string: value=string]
- :return: string constructed of the dictionaries' keys
- """
- s = ""
- for k, v in dic.items():
- s += k + ": " + v + "\n"
- return s
-
-
-def collapse_into_one_list(data):
- """
- Collapse list of list of strings into one list of strings
- :param data: list of list of strings
- :return: list of strings
- """
- data_ = list()
- for i in range(len(data)):
- for j in range(len(data[i])):
- data_.append(data[i][j])
-
- return data_
-
-
-def string_from_list(l):
- """
- Construct a string from a list of strings
- :param l: list of strings
- :return: string containing elements of list l separated by a comma
- """
- s = l[0]
- if len(l) > 1:
- for i in range(len(l) - 1):
- # only add this string to the list if it is different from the previous strings
- e = l[i + 1]
- if e not in l[0 : i + 1]:
- s += ",\t\t" + e
- return s
-
-
-def create_list_stmts(list_graphs):
- """
- Create a unique list of statements (strings) from a list of graphs in which statements are attributes of edges
- :param list_graphs: list of context-graphs (nodes = ids, edges = statements)
- :return: list_stmts: a unique list of statements (strings)
- """
- list_stmts = list()
- for G in list_graphs:
- edges_list = [e[2]["stmt"] for e in G.edges(data=True)]
- list_stmts += edges_list
-
- return list_stmts
-
-
-########################################################################################################################
-# Counting and statistics
-########################################################################################################################
-def get_stmt_counts(data_set, data_list):
- """
- Get statement counts
- :param data_set: set containing the elements from data_list but without repetitions and ordered
- :param data_list: list of string statements with repetitions and no ordering
- :return: data_count: dictionary with pairs [stmt, number of occurrences in data_list]
- the order of the statements is the same as the one in data_set
- data_operations_count: list of tuples
- [string "tag level 1", "tag level 2", "tag level 3", int "number of occurrences"]
- """
- # Setup variables
- data_count = {x: 0 for x in data_set}
- data_operations_count = list()
-
- # Compute stmt counts (overall)
- print("Counting statement occurrences (overall)...")
- for stmt in data_list:
- data_count[stmt] += 1
-
- # Check that all stmts have been counted (for debugging purposes)
- total_stmt_count = sum(data_count.values())
- assert total_stmt_count == len(data_list), "Not all statements have been counted"
-
- # Compute stmt counts (by family)
- print("Counting statement occurrences (by family) ...")
- total_stmt_count = 0
- stmts_categorized = list()
-
- # Loop over stmt families
- for fam in rgx.llvm_IR_stmt_families:
- op_count = 0
-
- # loop on all stmts in data
- for i in range(len(data_set)):
- # if the regular expression for the family matches
- if re.match(fam[3], data_set[i], re.MULTILINE):
- # add the corresponding number of occurrences to the counter
- op_count += data_count[data_set[i]]
- stmts_categorized.append(i)
-
- # append the count to the list of number of occurrences
- data_operations_count.append([fam[0], fam[1], fam[2], op_count])
-
- # increase the total stmt count
- total_stmt_count += op_count
-
- # Check that all stmts have been categorized once and only once (debugging purposes)
- print("Starting categorization check ...")
- stmts_categorized = sorted(stmts_categorized)
- if stmts_categorized != list(range(len(data_set))):
- print("Tracking down the errors in categorization ... : ")
- for i in range(len(data_set)):
- num = stmts_categorized.count(i)
- if num == 0:
- print(data_set[i], "\n\tappears 0 times")
- if num > 1:
- print(data_set[i], "\n\tappears ", num, " times")
-
- assert stmts_categorized <= list(
- range(len(data_set))
- ), "Not all statements have been categorized"
- assert stmts_categorized >= list(
- range(len(data_set))
- ), "Some statements have been categorized multiple times"
- assert total_stmt_count == len(data_list), "Not all statements have been counted"
-
- return data_count, data_operations_count
-
-
-def data_statistics(data, descr):
- """
- Compute and print some statistics on the data
- :param data: list of lists of statements (strings)
- :param descr: string description of the current step of the pipeline to add to output
- :return: source_data_list: list of statements
- source_data sorted set of statements
- """
- # Create a list of statements (strings) collecting the statements from all files
- source_data_list = collapse_into_one_list(data)
-
- # Create a sorted set of statements appearing in our data set
- source_data = sorted(set(source_data_list))
-
- # Get number of lines and the vocabulary size
- number_lines = len(source_data_list)
- vocabulary_size = len(source_data)
-
- # Construct output
- out = (
- "After "
- + descr
- + ":\n"
- + "--- {:<26}: {:>12,d}\n".format("Number of lines", number_lines)
- + "--- {:<26}: {:>12,d}\n".format("Vocabulary size", vocabulary_size)
- )
- print(out)
-
- # Return
- return source_data_list, source_data
-
-
-########################################################################################################################
-# Reading, writing and dumping files
-########################################################################################################################
-
-
-def read_data_files_from_folder(foldername):
- """
- Read all source files in folder
- Return a list of file contents, whereby each file content is a list of strings, each string representing a line
- :param foldername: name of the folder in which the data files to be read are located
- :return: a list of files where each file is a list of strings
- """
- # Helper variables
- data = list()
- file_names = list()
- file_count = 0
-
- print("Reading data from all files in folder ", foldername)
- listing = os.listdir(foldername + "/")
- to_subtract = file_count
-
- # Loop over files in folder
- for file in listing:
- if file[0] != "." and file[-3:] == ".ll":
- # If this isn't a hidden file and it is an LLVM IR file ('.ll' extension),
- # open file and import content
- f = open(os.path.join(foldername, file), "r")
- data.append(
- f.read().splitlines()
- ) # add this file as an element to the list "data"
- f.close()
-
- # Add file name to dictionary
- file_names.append(file)
-
- # Increment counters
- file_count += 1
-
- print("Number of files read from", foldername, ": ", file_count - to_subtract)
- print("Total number of files read for dataset", foldername, ": ", file_count)
- return data, file_names
-
-
-def print_preprocessed_data(raw_data, foldername, filenames):
- """
- Write pre-processed code to file for future reference
- :param raw_data: a list of files where each file is a list of strings
- :param foldername: folder in which to print
- :param filenames: list of base file names
- :return:
- """
- # Make sure the directory exists - if not, create it
- foldername = os.path.join(foldername, "preprocessed")
- if not os.path.exists(foldername):
- os.makedirs(foldername)
-
- # Write pre-processed code to files
- i = 0
- for file in raw_data:
- filename = os.path.join(foldername, filenames[i][:-3] + "_preprocessed.txt")
- print("Writing pre-processed data to file ", filename)
- with open(filename, "w") as f:
- for l in file:
- f.write(l + "\n")
- i += 1
-
-
-def print_data(data, filename):
- """
- Write pre-processed code to file for future reference
- :param data: a list of strings
- :param filename: name of file to print this to (string)
- :return:
- """
- print("Write data to file ", filename)
- with open(filename, "w") as f:
- for l in data:
- f.write(l + "\n")
-
-
-def sort_key(x):
- """
- Helper function to sort nodes
- :param x: node
- :return: node name, node id type
- """
- id_part = x[0][1:]
-
- if id_part.isdigit():
- return x[0][0], int(x[0][1:])
- else:
- return x[0][0], 1
-
-
-def print_node_family_to_file(G, f, nodetype):
- """
- Helper function for function "print_graph_to_file"
- :param G: graph
- :param f: file handle
- :param nodetype: string corresponding to the "id" of the node family to be printed
- """
-
- # Construct node family
- if nodetype == "root":
- node_family = [
- n for n in G.nodes() if G.out_degree(n) > 0 and G.in_degree(n) == 0
- ]
- node_family = sorted(node_family, key=sort_key)
- elif nodetype == "leaf":
- node_family = [
- n for n in G.nodes() if G.out_degree(n) == 0 and G.in_degree(n) >= 1
- ]
- node_family = sorted(node_family, key=sort_key)
- elif nodetype == "isolated":
- node_family = [n for n in G.nodes() if G.degree(n) == 0]
- node_family = sorted(node_family, key=sort_key)
- else:
- node_family = [
- n[0]
- for n in sorted(list(G.nodes(data=True)), key=sort_key)
- if n[1]["id"] == nodetype
- ]
-
- # Write to file
- f.write("#nodes: " + str(len(node_family)) + "\n")
- f.write("-" * 80 + "\n")
- for n in node_family:
- f.write("{n:<60}\n".format(n=n))
-
-
-def print_graph_to_file(G, multi_edge_dic, folder, filename):
- """
- Print information about a graph to a file
- :param G: graph
- :param multi_edge_dic: dictionary of multi-edges
- = edges for which a parallel edge connecting the same two end-nodes exists
- :param folder: folder in which to write
- :param filename: base name of the graph
- """
- # Print to file
- graph_filename = os.path.join(folder, filename[:-3] + ".txt")
- print("Printing graph to file : ", graph_filename)
-
- with open(graph_filename, "w") as f:
-
- # GENERAL
- f.write("#nodes: " + str(G.number_of_nodes()) + "\n")
- f.write("#edges: " + str(G.number_of_edges()) + "\n\n")
-
- # INFORMATION ON NODES
- # all
- f.write("Nodes (" + str(G.number_of_nodes()) + "):\n")
- f.write("-" * 80 + "\n")
- for n, data in sorted(G.nodes(data=True), key=sort_key):
- f.write("{n:<60}, {w}\n".format(n=n[:60], w=data["id"]))
-
- # local
- f.write("\nLocal identifier nodes: \n")
- print_node_family_to_file(G, f, "local")
-
- # block references
- f.write("\nBlock reference nodes: \n")
- print_node_family_to_file(G, f, "label")
-
- # global
- f.write("\nGlobal nodes: \n")
- print_node_family_to_file(G, f, "global")
-
- # immediate value
- f.write("\nImmediate value nodes: \n")
- print_node_family_to_file(G, f, "imm_val")
-
- # ad_hoc
- f.write("\nAd hoc value nodes: \n")
- print_node_family_to_file(G, f, "ad_hoc")
-
- # leaf
- f.write("\nLeaf nodes: \n")
- print_node_family_to_file(G, f, "leaf")
-
- # root
- f.write("\nRoot nodes: \n")
- print_node_family_to_file(G, f, "root")
-
- # isolated
- f.write("\nIsolated nodes: \n")
- print_node_family_to_file(G, f, "isolated")
- f.write("\n\n")
-
- # INFORMATION ON EDGES
- # all
- f.write("Edges (" + str(G.number_of_edges()) + ")\n")
- f.write("-" * 80 + "\n")
- for a, b, data in sorted(G.edges(data=True), key=sort_key):
- f.write(
- "({a:<30}, {b:<30}) {w}\n".format(a=a[:30], b=b[:30], w=data["stmt"])
- )
-
- # data flow edges
- dataedges = [
- (str(n[0]), str(n[1]), str(n[2]))
- for n in sorted(list(G.edges(data=True)), key=sort_key)
- if n[2]["flow"] == "data"
- ]
- f.write("\nData flow edges: \n")
- f.write(
- "#edges: "
- + str(len(dataedges))
- + " ("
- + str(int(len(dataedges)) / G.number_of_edges() * 100)[:5]
- + "%)\n"
- )
- f.write("-" * 80 + "\n")
- for e in dataedges:
- f.write("({a:<30}, {b:<30}) {c}\n".format(a=e[0][:30], b=e[1][:30], c=e[2]))
-
- # control flow edges
- ctrledges = [
- (str(n[0]), str(n[1]), str(n[2]))
- for n in sorted(list(G.edges(data=True)), key=sort_key)
- if n[2]["flow"] == "ctrl"
- ]
- f.write("\nCtrl flow edges: \n")
- f.write(
- "#edges: "
- + str(len(ctrledges))
- + " ("
- + str(int(len(dataedges)) / G.number_of_edges() * 100)[:5]
- + "%)\n"
- )
- f.write("-" * 80 + "\n")
- for e in ctrledges:
- f.write("({a:<30}, {b:<30}) {c}\n".format(a=e[0][:30], b=e[1][:30], c=e[2]))
-
- # multi-edges
- f.write("\nMulti-edges: \n")
- multi_edge_list = list()
- for k, v in multi_edge_dic.items(): # Compile the multi-edges
- multi_edge_list += v
- f.write(
- "#multi-edges: "
- + str(len(multi_edge_list))
- + " ("
- + str(int(len(multi_edge_list)) / G.number_of_edges() * 100)[:5]
- + "%)\n"
- )
- f.write(
- "#node pairs connected by multi-edges: "
- + str(len(multi_edge_dic.keys()))
- + " ("
- + str(int(len(multi_edge_dic)) / G.number_of_edges() * 100)[:5]
- + "%)\n"
- )
- f.write("-" * 80 + "\n")
- for k, v_ in multi_edge_dic.items():
- n = re.match(r"(.*) \|\|\| (.*)", k)
- assert n is not None, "Could not identify nodes in " + k
- f.write("{m:<60} {p:<60}\n".format(m=n.group(1)[:60], p=n.group(2)[:60]))
- for v in v_:
- f.write("\t{}\n".format(v))
- f.write("\n")
-
-
-def print_structure_dictionary(dic, folder, filename):
- """
- Print the dictionary of structures to a file
- :param dic: dictionary ["structure name", [list of possible values]]
- :param folder: name of folder in which to print dictionary
- :param filename: name of file in which to print dictionary
- :return:
- """
- # Print dictionary in alphabetical order
- dic_filename = os.path.join(folder, filename[:-3] + ".txt")
- print('Printing dictionary to file "', dic_filename)
- with open(dic_filename, "w") as f:
- f.write("{:<70} {}\n\n".format("structure name", "literal value"))
- for key, value in sorted(dic.items()):
- f.write("{:<70} {}\n".format(key, string_from_list(value)))
-
-
-def PrintDualXfgToFile(D, folder, filename):
- """Print dual-XFG graph to file.
-
- :param D: dual-XFG graphs
- :param folder: name of folder in which to print dictionary
- :param filename: name of file in which to print dictionary
- """
- # Print to file
- graph_filename = os.path.join(folder, filename[:-3] + ".txt")
- print("Printing graph to file : ", graph_filename)
-
- with open(graph_filename, "w") as f:
- # GENERAL
- f.write("#nodes: " + str(D.number_of_nodes()) + "\n")
- f.write("#edges: " + str(D.number_of_edges()) + "\n\n")
-
- # INFORMATION ON NODES
- f.write("Nodes (" + str(D.number_of_nodes()) + ")\n")
- f.write("-" * 80 + "\n")
- for n, _ in sorted(D.nodes(data=True), key=sort_key):
- f.write(f"{n:<60}\n")
- f.write("\n")
- # INFORMATION ON EDGES
- f.write("Edges (" + str(D.number_of_edges()) + ")\n")
- f.write("-" * 80 + "\n")
- for a, b, data in sorted(D.edges(data=True), key=sort_key):
- f.write(
- "({a:<37}, {b:<37}) {w}\n".format(a=a[:37], b=b[:37], w=data["weight"])
- )
-
-
########################################################################################################################
# LLVM IR preprocessing
########################################################################################################################
@@ -810,2085 +334,6 @@ def preprocess(data):
return preprocessed_data, functions_declared_in_files
-########################################################################################################################
-# XFG-building
-########################################################################################################################
-def get_identifiers_from_line(line):
- """
- Extract identifiers (local, global and label) from a statement
- :param line: string: (part of) statement
- :return: lists of strings: m_loc, m_glob, m_label, m_label2
- """
- # Find label nodes
- m_label = m_label2 = list()
- if line.find("label") != -1 or re.match(rgx.local_id_no_perc + r":", line):
- m_label1 = re.findall("label (" + rgx.local_id + ")", line)
- if re.match(r";