Skip to content

Commit

Permalink
[tests] Add a unit test to repro facebookresearch#756.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Aug 22, 2022
1 parent 1c40e5b commit 5216c0e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/llvm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ py_test(
deps = [
"//compiler_gym/envs/llvm",
"//compiler_gym/service:connection",
"//compiler_gym/spaces",
"//compiler_gym/util",
"//tests:test_main",
"//tests/pytest_plugins:llvm",
],
Expand Down
61 changes: 61 additions & 0 deletions tests/llvm/runtime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# LICENSE file in the root directory of this source tree.
"""Integrations tests for LLVM runtime support."""
from pathlib import Path
from typing import List

import numpy as np
import pytest
from flaky import flaky

from compiler_gym.envs.llvm import LlvmEnv, llvm_benchmark
from compiler_gym.spaces.reward import Reward
from compiler_gym.util.gym_type_hints import ActionType, ObservationType
from tests.test_main import main

pytest_plugins = ["tests.pytest_plugins.llvm"]
Expand Down Expand Up @@ -144,5 +147,63 @@ def test_default_runtime_observation_count_fork(env: LlvmEnv):
assert fkd.runtime_warmup_runs_count == wc


class RewardDerivedFromRuntime(Reward):
"""A custom reward space that is derived from the Runtime observation space."""

def __init__(self):
super().__init__(
name="runtimeseries",
observation_spaces=["Runtime"],
default_value=0,
min=None,
max=None,
default_negates_returns=True,
deterministic=False,
platform_dependent=True,
)
self.last_runtime_observation: List[float] = None

def reset(self, benchmark, observation_view) -> None:
self.last_runtime_observation = observation_view["Runtime"]

def update(
self,
actions: List[ActionType],
observations: List[ObservationType],
observation_view,
) -> float:
del actions # unused
del observation_view # unused
self.last_runtime_observation = observations[0]
return 0


@flaky # runtime may fail
@pytest.mark.parametrize("runtime_observation_count", [1, 3, 5])
def test_correct_number_of_observations_during_reset(
env: LlvmEnv, runtime_observation_count: int
):
env.reward.add_space(RewardDerivedFromRuntime())
env.runtime_observation_count = runtime_observation_count
env.reset(reward_space="runtimeseries")
assert env.runtime_observation_count == runtime_observation_count

# Check that the number of observations that you are receive during reset()
# matches the amount that you asked for.
# FIXME(github.com/facebookresearch/CompilerGym/issues/756): This is broken.
# Only a single observation is received, irrespective of how many you ask
# for.
assert len(env.reward.spaces["runtimeseries"].last_runtime_observation) == 1

# Check that the number of observations that you are receive during step()
# matches the amount that you asked for.
env.reward.spaces["runtimeseries"].last_runtime_observation = None
env.step(0)
assert (
len(env.reward.spaces["runtimeseries"].last_runtime_observation)
== runtime_observation_count
)


if __name__ == "__main__":
main()

0 comments on commit 5216c0e

Please sign in to comment.