Skip to content

Commit

Permalink
Merge pull request #589 from ChrisCummins/fix/587
Browse files Browse the repository at this point in the history
Add a test to demonstrate #587
  • Loading branch information
ChrisCummins authored Mar 7, 2022
2 parents 9ffdd06 + fa05914 commit 0b881cb
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/llvm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ py_test(
)

py_test(
name = "gym_interface_compatability",
name = "gym_interface_compatability_test",
timeout = "short",
srcs = ["gym_interface_compatability.py"],
srcs = ["gym_interface_compatability_test.py"],
deps = [
"//compiler_gym/envs/llvm",
"//tests:test_main",
Expand Down
4 changes: 2 additions & 2 deletions tests/llvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ cg_py_test(

cg_py_test(
NAME
gym_interface_compatability
gym_interface_compatability_test
SRCS
"gym_interface_compatability.py"
"gym_interface_compatability_test.py"
DEPS
compiler_gym::envs::llvm::llvm
tests::pytest_plugins::llvm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
"""Test that LlvmEnv is compatible with OpenAI gym interface."""
import gym
import pytest

from compiler_gym.envs.llvm import LlvmEnv
from tests.test_main import main
Expand Down Expand Up @@ -72,5 +73,28 @@ def reward(self, reward):
assert reward == 1


@pytest.mark.xfail(
reason="github.com/facebookresearch/CompilerGym/issues/587", strict=True
)
def test_env_spec_make(env: LlvmEnv):
"""Test that demonstrates a failure in gym compatibility: env.spec does
not encode mutable state like benchmark, reward space, and observation
space.
"""
env.reset(benchmark="cbench-v1/bitcount")
with env.spec.make() as new_env:
assert new_env.benchmark == env.benchmark


def test_env_spec_make_workaround(env: LlvmEnv):
"""Demonstrate how #587 would be fixed, by updating the 'kwargs' dict."""
env.reset(benchmark="cbench-v1/bitcount")
env.spec._kwargs[ # pylint: disable=protected-access
"benchmark"
] = "cbench-v1/bitcount"
with env.spec.make() as new_env:
assert new_env.benchmark == env.benchmark


if __name__ == "__main__":
main()

0 comments on commit 0b881cb

Please sign in to comment.