Skip to content

Commit

Permalink
Move CompilerEnvState to package top level.
Browse files Browse the repository at this point in the history
This splits the CompilerEnvState definition out of compiler_env.py and
into a top level package module.
  • Loading branch information
ChrisCummins committed Feb 24, 2021
1 parent 9539dd9 commit ad28768
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 207 deletions.
6 changes: 6 additions & 0 deletions compiler_gym/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ py_library(
],
)

py_library(
name = "compiler_env_state",
srcs = ["compiler_env_state.py"],
visibility = ["//compiler_gym/envs:__subpackages__"],
)

py_library(
name = "random_replay",
srcs = ["random_replay.py"],
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"compiler_gym` will work."
) from e

from compiler_gym.compiler_env_state import CompilerEnvState
from compiler_gym.envs import COMPILER_GYM_ENVS, CompilerEnv, observation_t, step_t
from compiler_gym.random_search import random_search
from compiler_gym.util.download import download
Expand All @@ -47,6 +48,7 @@
"cache_path",
"transient_cache_path",
"CompilerEnv",
"CompilerEnvState",
"COMPILER_GYM_ENVS",
"observation_t",
"step_t",
Expand Down
118 changes: 118 additions & 0 deletions compiler_gym/compiler_env_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module defines a class to represent a compiler environment state."""
import csv
from io import StringIO
from typing import Any, Dict, Iterable, NamedTuple, Optional


def _to_csv(*columns) -> str:
buf = StringIO()
writer = csv.writer(buf)
writer.writerow(columns)
return buf.getvalue().rstrip()


class CompilerEnvState(NamedTuple):
"""The representation of a compiler environment state.
The state of an environment is defined as a benchmark and a sequence of
actions that has been applied to it. For a given environment, the state
contains the information required to reproduce the result.
"""

benchmark: str
"""The name of the benchmark used for this episode."""

commandline: str
"""The list of actions that produced this state, as a commandline."""

walltime: float
"""The walltime of the episode."""

reward: Optional[float] = None
"""The cumulative reward for this episode."""

@staticmethod
def csv_header() -> str:
"""Return the header string for the CSV-format.
:return: A comma-separated string.
"""
return _to_csv("benchmark", "reward", "walltime", "commandline")

def json(self):
"""Return the state as JSON."""
return self._asdict() # pylint: disable=no-member

def to_csv(self) -> str:
"""Serialize a state to a comma separated list of values.
:return: A comma-separated string.
"""
return _to_csv(self.benchmark, self.reward, self.walltime, self.commandline)

@classmethod
def from_json(cls, data: Dict[str, Any]) -> "CompilerEnvState":
"""Construct a state from a JSON dictionary."""
return cls(**data)

@classmethod
def from_csv(cls, csv_string: str) -> "CompilerEnvState":
"""Construct a state from a comma separated list of values."""
reader = csv.reader(StringIO(csv_string))
for line in reader:
try:
benchmark, reward, walltime, commandline = line
break
except ValueError as e:
raise ValueError(f"Failed to parse input: `{csv_string}`: {e}") from e
else:
raise ValueError(f"Failed to parse input: `{csv_string}`")
return cls(
benchmark=benchmark,
reward=None if reward == "" else float(reward),
walltime=float(walltime),
commandline=commandline,
)

@classmethod
def read_csv_file(cls, in_file) -> Iterable["CompilerEnvState"]:
"""Read states from a CSV file.
:param in_file: A file object.
:returns: A generator of :class:`CompilerEnvState` instances.
:raises ValueError: If input parsing fails.
"""
data = in_file.readlines()
for line in csv.DictReader(data):
try:
line["reward"] = float(line["reward"]) if line.get("reward") else None
line["walltime"] = (
float(line["walltime"]) if line.get("walltime") else None
)
yield CompilerEnvState(**line)
except (TypeError, KeyError) as e:
raise ValueError(f"Failed to parse input: `{e}`") from e

def __eq__(self, rhs) -> bool:
if not isinstance(rhs, CompilerEnvState):
return False
epsilon = 1e-5
# If only one benchmark has a reward the states cannot be equal.
if (self.reward is None) != (rhs.reward is None):
return False
if (self.reward is None) and (rhs.reward is None):
reward_equal = True
else:
reward_equal = abs(self.reward - rhs.reward) < epsilon
# Note that walltime is excluded from equivalence checks as two states
# are equivalent if they define the same point in the optimization space
# irrespective of how long it took to get there.
return (
self.benchmark == rhs.benchmark
and reward_equal
and self.commandline == rhs.commandline
)
1 change: 1 addition & 0 deletions compiler_gym/envs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ py_library(
srcs = ["compiler_env.py"],
visibility = ["//compiler_gym:__subpackages__"],
deps = [
"//compiler_gym:compiler_env_state",
"//compiler_gym/datasets:dataset",
"//compiler_gym/service",
"//compiler_gym/service/proto",
Expand Down
9 changes: 1 addition & 8 deletions compiler_gym/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,12 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from compiler_gym.envs.compiler_env import (
CompilerEnv,
CompilerEnvState,
info_t,
observation_t,
step_t,
)
from compiler_gym.envs.compiler_env import CompilerEnv, info_t, observation_t, step_t
from compiler_gym.envs.llvm.llvm_env import LlvmEnv
from compiler_gym.util.registration import COMPILER_GYM_ENVS

__all__ = [
"CompilerEnv",
"CompilerEnvState",
"LlvmEnv",
"observation_t",
"info_t",
Expand Down
150 changes: 26 additions & 124 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module defines the OpenAI gym interface for compilers."""
import csv
import logging
import os
import sys
import warnings
from copy import deepcopy
from io import StringIO
from pathlib import Path
from time import time
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import fasteners
import gym
import numpy as np
from gym.spaces import Space

from compiler_gym.compiler_env_state import CompilerEnvState
from compiler_gym.datasets.dataset import Dataset, require
from compiler_gym.service import (
CompilerGymServiceConnection,
Expand Down Expand Up @@ -49,126 +48,6 @@
step_t = Tuple[Optional[observation_t], Optional[float], bool, info_t]


def _to_csv(*columns) -> str:
buf = StringIO()
writer = csv.writer(buf)
writer.writerow(columns)
return buf.getvalue().rstrip()


class CompilerEnvState(NamedTuple):
"""The representation of a compiler environment state.
The state of an environment is defined as a benchmark and a sequence of
actions that has been applied to it. For a given environment, the state
contains the information required to reproduce the result.
"""

benchmark: str
"""The name of the benchmark used for this episode."""

commandline: str
"""The list of actions that produced this state, as a commandline."""

walltime: float
"""The walltime of the episode."""

reward: Optional[float] = None
"""The cumulative reward for this episode."""

@staticmethod
def csv_header() -> str:
"""Return the header string for the CSV-format.
:return: A comma-separated string.
"""
return _to_csv("benchmark", "reward", "walltime", "commandline")

def json(self):
"""Return the state as JSON."""
return self._asdict() # pylint: disable=no-member

def to_csv(self) -> str:
"""Serialize a state to a comma separated list of values.
:return: A comma-separated string.
"""
return _to_csv(self.benchmark, self.reward, self.walltime, self.commandline)

@classmethod
def from_json(cls, data: Dict[str, Any]) -> "CompilerEnvState":
"""Construct a state from a JSON dictionary."""
return cls(**data)

@classmethod
def from_csv(cls, csv_string: str) -> "CompilerEnvState":
"""Construct a state from a comma separated list of values."""
reader = csv.reader(StringIO(csv_string))
for line in reader:
try:
benchmark, reward, walltime, commandline = line
break
except ValueError as e:
raise ValueError(f"Failed to parse input: `{csv_string}`: {e}") from e
else:
raise ValueError(f"Failed to parse input: `{csv_string}`")
return cls(
benchmark=benchmark,
reward=None if reward == "" else float(reward),
walltime=float(walltime),
commandline=commandline,
)

@classmethod
def read_csv_file(cls, in_file) -> Iterable["CompilerEnvState"]:
"""Read states from a CSV file.
:param in_file: A file object.
:returns: A generator of :class:`CompilerEnvState` instances.
:raises ValueError: If input parsing fails.
"""
data = in_file.readlines()
for line in csv.DictReader(data):
try:
line["reward"] = float(line["reward"]) if line.get("reward") else None
line["walltime"] = (
float(line["walltime"]) if line.get("walltime") else None
)
yield CompilerEnvState(**line)
except (TypeError, KeyError) as e:
raise ValueError(f"Failed to parse input: `{e}`") from e

def apply(self, env: "CompilerEnv") -> None:
"""Replay the sequence of actions given by a commandline."""
actions = env.commandline_to_actions(self.commandline)
for action in actions:
_, _, done, info = env.step(action)
if done:
raise OSError(
f"Environment terminated with error: `{info.get('error_details')}`"
)

def __eq__(self, rhs) -> bool:
if not isinstance(rhs, CompilerEnvState):
return False
epsilon = 1e-5
# If only one benchmark has a reward the states cannot be equal.
if (self.reward is None) != (rhs.reward is None):
return False
if (self.reward is None) and (rhs.reward is None):
reward_equal = True
else:
reward_equal = abs(self.reward - rhs.reward) < epsilon
# Note that walltime is excluded from equivalence checks as two states
# are equivalent if they define the same point in the optimization space
# irrespective of how long it took to get there.
return (
self.benchmark == rhs.benchmark
and reward_equal
and self.commandline == rhs.commandline
)


class CompilerEnv(gym.Env):
"""An OpenAI gym environment for compiler optimizations.
Expand Down Expand Up @@ -592,7 +471,7 @@ def fork(self) -> "CompilerEnv":
if self.actions:
logging.warning("Parent service of fork() has died, replaying state")
self.reset()
state_to_replay.apply(self)
self.apply(state_to_replay)

request = ForkSessionRequest(session_id=self._session_id)
reply: ForkSessionReply = self.service(self.service.stub.ForkSession, request)
Expand Down Expand Up @@ -1025,3 +904,26 @@ def _add_custom_benchmarks(self, benchmarks: List[Benchmark]) -> None:
self.service.stub.AddBenchmark,
AddBenchmarkRequest(benchmark=benchmarks),
)

def apply(self, state: CompilerEnvState) -> None: # noqa
"""Replay this state on the given an environment.
:param env: A :class:`CompilerEnv` instance.
:raises ValueError: If this state cannot be applied.
"""
if not self.in_episode:
self.reset(benchmark=state.benchmark)

if self.benchmark != state.benchmark:
warnings.warn(
f"Applying state from environment for benchmark '{state.benchmark}' "
f"to environment for benchmark '{self.benchmark}'"
)

actions = self.commandline_to_actions(state.commandline)
for action in actions:
_, _, done, info = self.step(action)
if done:
raise ValueError(
f"Environment terminated with error: `{info.get('error_details')}`"
)
7 changes: 4 additions & 3 deletions compiler_gym/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import gym

from compiler_gym.envs.compiler_env import CompilerEnv, CompilerEnvState
from compiler_gym.compiler_env_state import CompilerEnvState
from compiler_gym.envs.compiler_env import CompilerEnv
from compiler_gym.envs.llvm import LlvmEnv
from compiler_gym.envs.llvm.datasets import get_llvm_benchmark_validation_callback
from compiler_gym.util.timer import Timer
Expand Down Expand Up @@ -65,7 +66,7 @@ def __repr__(self):
return f"✅ {benchmark} {self.state.reward:.4f}"

def json(self):
data = self._asdict()
data = self._asdict() # pylint: disable=no-member
data["state"] = self.state.json()
return data

Expand Down Expand Up @@ -96,7 +97,7 @@ def validate_state(env: CompilerEnv, state: CompilerEnvState) -> ValidationResul
# validation process in case a step fails.
while True:
try:
state.apply(env)
env.apply(state)
reward = env.episode_reward
except (ValueError, OSError) as e:
validation["actions_replay_failed"] = True
Expand Down
Loading

0 comments on commit ad28768

Please sign in to comment.