Skip to content

Commit

Permalink
Merge pull request #257 from ChrisCummins/gym-make
Browse files Browse the repository at this point in the history
Add a compiler_gym.make() wrapper.
  • Loading branch information
ChrisCummins authored May 11, 2021
2 parents 77d28bc + 756472e commit 9248260
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 0 deletions.
2 changes: 2 additions & 0 deletions compiler_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
set_debug_level,
)
from compiler_gym.util.download import download
from compiler_gym.util.registration import make
from compiler_gym.util.runfiles_path import (
cache_path,
site_data_path,
Expand All @@ -56,6 +57,7 @@
"__version__",
"cache_path",
"COMPILER_GYM_ENVS",
"make",
"CompilerEnv",
"CompilerEnvState",
"CompilerEnvStateWriter",
Expand Down
6 changes: 6 additions & 0 deletions compiler_gym/util/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
# LICENSE file in the root directory of this source tree.
from typing import List

import gym
from gym.envs.registration import register as gym_register

# A list of gym environment names defined by CompilerGym.
COMPILER_GYM_ENVS: List[str] = []


def make(id: str, **kwargs):
"""Equivalent to :code:`gym.make()`."""
return gym.make(id, **kwargs)


def register(id: str, **kwargs):
COMPILER_GYM_ENVS.append(id)
gym_register(id=id, **kwargs)
10 changes: 10 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ py_test(
],
)

py_test(
name = "make_test",
timeout = "short",
srcs = ["make_test.py"],
deps = [
"//compiler_gym",
"//tests:test_main",
],
)

py_test(
name = "random_search_test",
timeout = "short",
Expand Down
17 changes: 17 additions & 0 deletions tests/make_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import compiler_gym
from compiler_gym.envs import LlvmEnv
from tests.test_main import main


def test_compiler_gym_make():
"""Test that compiler_gym.make() is equivalent to gym.make()."""
with compiler_gym.make("llvm-v0") as env:
assert isinstance(env, LlvmEnv)


if __name__ == "__main__":
main()

0 comments on commit 9248260

Please sign in to comment.