Skip to content

Commit

Permalink
Merge pull request #285 from facebookresearch/custom-env
Browse files Browse the repository at this point in the history
[examples] Port the explore script to the new wrappers classes.
  • Loading branch information
ChrisCummins authored Jun 3, 2021
2 parents 3f40356 + b618d23 commit 0926f15
Showing 1 changed file with 14 additions and 56 deletions.
70 changes: 14 additions & 56 deletions examples/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@

from compiler_gym.util.flags.benchmark_from_flags import benchmark_from_flags
from compiler_gym.util.flags.env_from_flags import env_from_flags
from compiler_gym.wrappers import ConstrainedCommandline

flags.DEFINE_list(
"actions",
[],
"A list of action names to enumerate. If not provided, all actions are used.",
"A list of flag names to enumerate. If not provided, all actions are used.",
)
flags.DEFINE_integer("episode_length", 5, "The number of steps in each episode.")
flags.DEFINE_integer(
Expand All @@ -54,54 +55,11 @@
FLAGS = flags.FLAGS


class CustomEnv:
"""A wrapper for an LLVM env that takes a subset of the actions.
Taking a subset in the env avoids the easy error to make to pass in
i as an action instead of actions[i] where actions is the subset.
"""

def __init__(self):
self._env = env_from_flags(benchmark_from_flags())
try:
# Project onto the subset of transformations that have
# been specified to be used.
if not FLAGS.actions:
self._action_indices = list(range(len(self._env.action_space.names)))
else:
self._action_indices = [
self._env.action_space.flags.index(a) for a in FLAGS.actions
]
self._action_names = [
self._env.action_space.names[a] for a in self._action_indices
]

finally:
# The program will not terminate until the environment is
# closed, not even if there is an exception.
self._env.close()

def action_names(self, actions):
return [self._action_names[a] for a in actions]

def step(self, action):
return self._env.step(self._action_indices[action])

def reset(self):
self._env.reset()

def close(self):
self._env.close()

def action_count(self):
return len(self._action_indices)

def actions(self):
return range(self.action_count())

@property
def observation(self):
return self._env.observation
def make_env():
env = env_from_flags(benchmark=benchmark_from_flags())
if FLAGS.actions:
env = ConstrainedCommandline(env, flags=FLAGS.actions)
return env


# Used to determine if two rewards are equal up to a small
Expand Down Expand Up @@ -201,7 +159,7 @@ def env_to_fingerprint(env):

def compute_edges(env, sequence):
edges = []
for action in env.actions():
for action in range(env.action_space.n):
env.reset()
reward_sum = 0.0
for action in sequence + [action]:
Expand Down Expand Up @@ -305,8 +263,8 @@ def number_list(stats):
# not check this in that case.
full_all_sum = sum(self._full_all_stats)
assert full_all_sum > 1e9 or full_all_sum == (
pow(env.action_count(), self._depth + 1) - 1
) / (env.action_count() - 1)
pow(env.action_space.n, self._depth + 1) - 1
) / (env.action_space.n - 1)

depth_time_in_seconds = time() - self._depth_start_time_in_seconds
print()
Expand All @@ -321,7 +279,7 @@ def number_list(stats):
):
print(
f" {graph.reward_sum(n):0.4f} ",
", ".join(env.action_names(graph.node_path(n))),
", ".join(env.action_space.flags[f] for f in graph.node_path(n)),
)

print("\n")
Expand All @@ -337,8 +295,8 @@ def compute_action_graph(envs, episode_length):
env_queue.put(env)
pool = ThreadPool(len(envs))

stats = NodeTypeStats(action_count=env.action_count())
graph = StateGraph(edges_per_node=env.action_count())
stats = NodeTypeStats(action_count=env.action_space.n)
graph = StateGraph(edges_per_node=env.action_space.n)

# Add the empty sequence of actions as the starting state.
envs[0].reset()
Expand Down Expand Up @@ -492,7 +450,7 @@ def main(argv):
try:
envs = []
for _ in range(FLAGS.nproc):
envs.append(CustomEnv())
envs.append(make_env())
compute_action_graph(envs, episode_length=FLAGS.episode_length)
finally:
for env in envs:
Expand Down

0 comments on commit 0926f15

Please sign in to comment.