Skip to content

Commit

Permalink
Add a CompilerEnv.validate() method.
Browse files Browse the repository at this point in the history
Add a new CompilerEnv.validate() method that replaces the previous
validate_state(env, state) call. This is a stepping stone to enabling
a more flexible API for custom benchmark validation routines.

github.com//issues/45
  • Loading branch information
ChrisCummins committed Feb 25, 2021
1 parent 5559a5f commit b32a2d4
Show file tree
Hide file tree
Showing 12 changed files with 304 additions and 138 deletions.
10 changes: 10 additions & 0 deletions compiler_gym/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,19 @@ py_library(
srcs = ["validate.py"],
visibility = ["//compiler_gym:__subpackages__"],
deps = [
":validation_result",
"//compiler_gym/envs:compiler_env",
"//compiler_gym/envs/llvm",
"//compiler_gym/spaces",
"//compiler_gym/util",
],
)

py_library(
name = "validation_result",
srcs = ["validation_result.py"],
visibility = ["//compiler_gym:__subpackages__"],
deps = [
":compiler_env_state",
],
)
4 changes: 2 additions & 2 deletions compiler_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
site_data_path,
transient_cache_path,
)
from compiler_gym.validate import ValidationResult, validate_state, validate_states
from compiler_gym.validate import validate_states
from compiler_gym.validation_result import ValidationResult

# The top-level compiler_gym API.
__all__ = [
Expand All @@ -54,6 +55,5 @@
"step_t",
"random_search",
"ValidationResult",
"validate_state",
"validate_states",
]
2 changes: 2 additions & 0 deletions compiler_gym/envs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ py_library(
visibility = ["//compiler_gym:__subpackages__"],
deps = [
"//compiler_gym:compiler_env_state",
"//compiler_gym:validation_result",
"//compiler_gym/datasets:dataset",
"//compiler_gym/service",
"//compiler_gym/service/proto",
"//compiler_gym/spaces",
"//compiler_gym/util",
"//compiler_gym/views",
],
)
83 changes: 82 additions & 1 deletion compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import sys
import warnings
from copy import deepcopy
from math import isclose
from pathlib import Path
from time import time
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import fasteners
import gym
Expand Down Expand Up @@ -41,6 +42,8 @@
StepRequest,
)
from compiler_gym.spaces import NamedDiscrete, Reward
from compiler_gym.util.timer import Timer
from compiler_gym.validation_result import ValidationResult
from compiler_gym.views import ObservationSpaceSpec, ObservationView, RewardView

# Type hints.
Expand Down Expand Up @@ -927,3 +930,81 @@ def apply(self, state: CompilerEnvState) -> None: # noqa
raise ValueError(
f"Environment terminated with error: `{info.get('error_details')}`"
)

def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult:
in_place = state is not None
state = state or self.state

error_messages = []
validation = {
"state": state,
"actions_replay_failed": False,
"reward_validated": False,
"reward_validation_failed": False,
"benchmark_semantics_validated": False,
"benchmark_semantics_validation_failed": False,
}

fkd = self.fork()
try:
with Timer() as walltime:
replay_target = self if in_place else fkd
replay_target.reset(benchmark=state.benchmark)
# Use a while loop here so that we can `break` early out of the
# validation process in case a step fails.
while True:
try:
replay_target.apply(state)
except (ValueError, OSError) as e:
validation["actions_replay_failed"] = True
error_messages.append(str(e))
break

if self.reward_space and self.reward_space.deterministic:
validation["reward_validated"] = True
# If reward deviates from the expected amount record the
# error but continue with the remainder of the validation.
if not isclose(
state.reward,
replay_target.episode_reward,
rel_tol=1e-5,
abs_tol=1e-10,
):
validation["reward_validation_failed"] = True
error_messages.append(
f"Expected reward {state.reward:.4f} but "
f"received reward {replay_target.episode_reward:.4f}"
)

# TODO(https://github.com/facebookresearch/CompilerGym/issues/45):
# Call the new self.benchmark.validation_callback() method
# once implemented.
validate_semantics = self.get_benchmark_validation_callback()
if validate_semantics:
validation["benchmark_semantics_validated"] = True
semantics_error = validate_semantics(self)
if semantics_error:
validation["benchmark_semantics_validation_failed"] = True
error_messages.append(semantics_error)

# Finished all checks, break the loop.
break
finally:
fkd.close()

return ValidationResult(
walltime=walltime.time,
error_details="\n".join(error_messages),
**validation,
)

def get_benchmark_validation_callback(
self,
) -> Optional[Callable[["CompilerEnv"], Optional[str]]]:
"""Return a callback that validates benchmark semantics, if available.
TODO(https://github.com/facebookresearch/CompilerGym/issues/45): This is
a temporary placeholder for what will eventually become a method on a
new Benchmark class.
"""
return None
21 changes: 19 additions & 2 deletions compiler_gym/envs/llvm/llvm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
import os
import shutil
from pathlib import Path
from typing import Iterable, List, Optional, Union, cast
from typing import Callable, Iterable, List, Optional, Union, cast

import numpy as np
from gym.spaces import Dict as DictSpace

from compiler_gym.envs.compiler_env import CompilerEnv
from compiler_gym.envs.llvm.benchmarks import make_benchmark
from compiler_gym.envs.llvm.datasets import LLVM_DATASETS
from compiler_gym.envs.llvm.datasets import (
LLVM_DATASETS,
get_llvm_benchmark_validation_callback,
)
from compiler_gym.envs.llvm.llvm_rewards import (
BaselineImprovementNormalizedReward,
CostFunctionReward,
Expand Down Expand Up @@ -323,3 +326,17 @@ def render(
print(self.ir)
else:
return super().render(mode)

def get_benchmark_validation_callback(
self,
) -> Optional[Callable[[CompilerEnv], Optional[str]]]:
"""Return a callback for validating a given environment state.
If there is no valid callback, returns :code:`None`.
:param env: An :class:`LlvmEnv` instance.
:return: An optional callback that takes an :class:`LlvmEnv` instance as
argument and returns an optional string containing a validation error
message.
"""
return get_llvm_benchmark_validation_callback(self)
123 changes: 4 additions & 119 deletions compiler_gym/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,138 +3,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Validate environment states."""
import math
import multiprocessing
import multiprocessing.pool
import re
from typing import Callable, Iterable, List, NamedTuple, Optional
from typing import Callable, Iterable, List, Optional

import gym

from compiler_gym.compiler_env_state import CompilerEnvState
from compiler_gym.envs.compiler_env import CompilerEnv
from compiler_gym.envs.llvm import LlvmEnv
from compiler_gym.envs.llvm.datasets import get_llvm_benchmark_validation_callback
from compiler_gym.util.timer import Timer


class ValidationResult(NamedTuple):
"""The result of validating a compiler state."""

state: CompilerEnvState
"""The compiler environment state that was validated."""

reward_validated: bool
"""Whether the reward that was recorded in the original state was validated."""

actions_replay_failed: bool
"""Whether the commandline was unable to be reproduced."""

reward_validation_failed: bool
"""Whether the validated reward differed from the original state."""

benchmark_semantics_validated: bool
"""Whether the semantics of the benchmark were validated."""

benchmark_semantics_validation_failed: bool
"""Whether the semantics of the benchmark were found to have changed."""

walltime: float
"""The wall time in seconds that the validation took."""

error_details: str = ""
"""A description of any validation errors."""

def okay(self) -> bool:
"""Whether validation succeeded."""
return not (
self.actions_replay_failed
or self.reward_validation_failed
or self.benchmark_semantics_validation_failed
)

def __repr__(self):
# Remove default-protocol prefix to improve output readability.
benchmark = re.sub(r"^benchmark://", "", self.state.benchmark)

if not self.okay():
msg = ", ".join(self.error_details.strip().split("\n"))
return f"❌ {benchmark} {msg}"
elif self.state.reward is None:
return f"✅ {benchmark}"
else:
return f"✅ {benchmark} {self.state.reward:.4f}"

def json(self):
data = self._asdict() # pylint: disable=no-member
data["state"] = self.state.json()
return data


def validate_state(env: CompilerEnv, state: CompilerEnvState) -> ValidationResult:
"""Validate a :class:`CompilerEnvState <compiler_gym.envs.CompilerEnvState>`.
:param env: A compiler environment.
:param state: The environment state to validate.
:return: A :class:`ValidationResult <compiler_gym.ValidationResult>` instance.
"""
error_messages = []
validation = {
"state": state,
"actions_replay_failed": False,
"reward_validated": False,
"reward_validation_failed": False,
"benchmark_semantics_validated": False,
"benchmark_semantics_validation_failed": False,
}

if state.reward is not None and env.reward_space is None:
raise ValueError("Reward space not specified")

with Timer() as walltime:
env.reset(benchmark=state.benchmark)
# Use a while loop here so that we can `break` early out of the
# validation process in case a step fails.
while True:
try:
env.apply(state)
reward = env.episode_reward
except (ValueError, OSError) as e:
validation["actions_replay_failed"] = True
error_messages.append(str(e))
break

if state.reward is not None and env.reward_space.deterministic:
validation["reward_validated"] = True
# If reward deviates from the expected amount record the
# error but continue with the remainder of the validation.
if not math.isclose(reward, state.reward, rel_tol=1e-5, abs_tol=1e-10):
validation["reward_validation_failed"] = True
error_messages.append(
f"Expected reward {state.reward:.4f} but received reward {reward:.4f}"
)

validate_semantics = get_llvm_benchmark_validation_callback(env)
if validate_semantics:
validation["benchmark_semantics_validated"] = True
semantics_error = validate_semantics(env)
if semantics_error:
validation["benchmark_semantics_validation_failed"] = True
error_messages.append(semantics_error)

# Finished all checks, break the loop.
break

return ValidationResult(
walltime=walltime.time, error_details="\n".join(error_messages), **validation
)
from compiler_gym.validation_result import ValidationResult


def _validate_states_worker(args) -> ValidationResult:
reward_space, state = args
env = gym.make("llvm-v0", reward_space=reward_space)
try:
result = validate_state(env, state)
result = env.validate(state)
finally:
env.close()
return result
Expand All @@ -148,7 +33,7 @@ def validate_states(
inorder: bool = False,
) -> Iterable[ValidationResult]:
"""A parallelized implementation of
:func:`validate_state() <compiler_gym.validate_state>` for batched
:meth:`env.validate() <compiler_gym.envs.CompilerEnv.validate>` for batched
validation.
:param make_env: A callback which instantiates a compiler environment.
Expand Down
Loading

0 comments on commit b32a2d4

Please sign in to comment.