Skip to content

Commit

Permalink
Merge pull request #627 from facebookresearch/feature/deprecate-actio…
Browse files Browse the repository at this point in the history
…n-list

Split CompilerEnv.step() into two methods for singular or lists of actions (take 2)
  • Loading branch information
ChrisCummins authored Mar 17, 2022
2 parents 755ba67 + 8cd9679 commit 6bf0209
Show file tree
Hide file tree
Showing 34 changed files with 408 additions and 177 deletions.
11 changes: 8 additions & 3 deletions compiler_gym/bin/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
from compiler_gym.datasets import Dataset
from compiler_gym.envs import CompilerEnv
from compiler_gym.service.connection import ConnectionOpts
from compiler_gym.spaces import Commandline
from compiler_gym.spaces import Commandline, NamedDiscrete
from compiler_gym.util.flags.env_from_flags import env_from_flags
from compiler_gym.util.tabulate import tabulate
from compiler_gym.util.truncate import truncate
Expand Down Expand Up @@ -249,12 +249,17 @@ def print_service_capabilities(env: CompilerEnv):
],
headers=("Action", "Description"),
)
else:
print(table)
elif isinstance(action_space, NamedDiscrete):
table = tabulate(
[(a,) for a in sorted(action_space.names)],
headers=("Action",),
)
print(table)
print(table)
else:
raise NotImplementedError(
"Only Commandline and NamedDiscrete are supported."
)


def main(argv):
Expand Down
159 changes: 114 additions & 45 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class CompilerEnv(gym.Env):
:ivar actions: The list of actions that have been performed since the
previous call to :func:`reset`.
:vartype actions: List[int]
:vartype actions: List[ActionType]
:ivar reward_range: A tuple indicating the range of reward values. Default
range is (-inf, +inf).
Expand Down Expand Up @@ -321,7 +321,7 @@ def __init__(
self.reward_range: Tuple[float, float] = (-np.inf, np.inf)
self.episode_reward: Optional[float] = None
self.episode_start_time: float = time()
self.actions: List[int] = []
self.actions: List[ActionType] = []

# Initialize the default observation/reward spaces.
self.observation_space_spec: Optional[ObservationSpaceSpec] = None
Expand Down Expand Up @@ -375,7 +375,7 @@ def commandline(self) -> str:
"""
raise NotImplementedError("abstract method")

def commandline_to_actions(self, commandline: str) -> List[int]:
def commandline_to_actions(self, commandline: str) -> List[ActionType]:
"""Interface for :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`
subclasses to convert from a commandline invocation to a sequence of
actions.
Expand Down Expand Up @@ -409,7 +409,7 @@ def state(self) -> CompilerEnvState:
)

@property
def action_space(self) -> NamedDiscrete:
def action_space(self) -> Space:
"""The current action space.
:getter: Get the current action space.
Expand Down Expand Up @@ -587,7 +587,7 @@ def fork(self) -> "CompilerEnv":
self.reset()
if actions:
logger.warning("Parent service of fork() has died, replaying state")
_, _, done, _ = self.step(actions)
_, _, done, _ = self.multistep(actions)
assert not done, "Failed to replay action sequence"

request = ForkSessionRequest(session_id=self._session_id)
Expand Down Expand Up @@ -620,7 +620,7 @@ def fork(self) -> "CompilerEnv":
# replay the state.
new_env = type(self)(**self._init_kwargs())
new_env.reset()
_, _, done, _ = new_env.step(self.actions)
_, _, done, _ = new_env.multistep(self.actions)
assert not done, "Failed to replay action sequence in forked environment"

# Create copies of the mutable reward and observation spaces. This
Expand Down Expand Up @@ -885,9 +885,9 @@ def _call_with_error(

def raw_step(
self,
actions: Iterable[int],
observations: Iterable[ObservationSpaceSpec],
rewards: Iterable[Reward],
actions: Iterable[ActionType],
observation_spaces: List[ObservationSpaceSpec],
reward_spaces: List[Reward],
) -> StepType:
"""Take a step.
Expand All @@ -908,26 +908,23 @@ def raw_step(
.. warning::
Prefer :meth:`step() <compiler_gym.envs.CompilerEnv.step>` to
:meth:`raw_step() <compiler_gym.envs.CompilerEnv.step>`.
:meth:`step() <compiler_gym.envs.CompilerEnv.step>` has equivalent
functionality, and is less likely to change in the future.
Don't call this method directly, use :meth:`step()
<compiler_gym.envs.CompilerEnv.step>` or :meth:`multistep()
<compiler_gym.envs.CompilerEnv.multistep>` instead. The
:meth:`raw_step() <compiler_gym.envs.CompilerEnv.step>` method is an
implementation detail.
"""
if not self.in_episode:
raise SessionNotFound("Must call reset() before step()")

# Build the list of observations that must be computed by the backend
user_observation_spaces: List[ObservationSpaceSpec] = list(observations)
reward_spaces: List[Reward] = list(rewards)

reward_observation_spaces: List[ObservationSpaceSpec] = []
for reward_space in reward_spaces:
reward_observation_spaces += [
self.observation.spaces[obs] for obs in reward_space.observation_spaces
]

observations_to_compute: List[ObservationSpaceSpec] = list(
set(user_observation_spaces).union(set(reward_observation_spaces))
set(observation_spaces).union(set(reward_observation_spaces))
)
observation_space_index_map: Dict[ObservationSpaceSpec, int] = {
observation_space: i
Expand Down Expand Up @@ -974,7 +971,7 @@ def raw_step(

default_observations = [
observation_space.default_value
for observation_space in user_observation_spaces
for observation_space in observation_spaces
]
default_rewards = [
float(reward_space.reward_on_error(self.episode_reward))
Expand Down Expand Up @@ -1002,7 +999,7 @@ def raw_step(
# Get the user-requested observation.
observations: List[ObservationType] = [
computed_observations[observation_space_index_map[observation_space]]
for observation_space in user_observation_spaces
for observation_space in observation_spaces
]

# Update and compute the rewards.
Expand All @@ -1029,25 +1026,83 @@ def raw_step(

return observations, rewards, reply.end_of_session, info

def step(
def step( # pylint: disable=arguments-differ
self,
action: Union[ActionType, Iterable[ActionType]],
action: ActionType,
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
) -> StepType:
"""Take a step.
:param action: An action, or a sequence of actions. When multiple
actions are provided the observation and reward are returned after
running all of the actions.
:param action: An action.
:param observation_spaces: A list of observation spaces to compute
observations from. If provided, this changes the :code:`observation`
element of the return tuple to be a list of observations from the
requested spaces. The default :code:`env.observation_space` is not
returned.
:param reward_spaces: A list of reward spaces to compute rewards from. If
provided, this changes the :code:`reward` element of the return
tuple to be a list of rewards from the requested spaces. The default
:code:`env.reward_space` is not returned.
:return: A tuple of observation, reward, done, and info. Observation and
reward are None if default observation/reward is not set.
:raises SessionNotFound: If :meth:`reset()
<compiler_gym.envs.CompilerEnv.reset>` has not been called.
"""
if isinstance(action, IterableType):
warnings.warn(
"Argument `action` of CompilerEnv.step no longer accepts a list "
" of actions. Please use CompilerEnv.multistep instead",
category=DeprecationWarning,
)
return self.multistep(
action,
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
observations=observations,
rewards=rewards,
)
if observations is not None:
warnings.warn(
"Argument `observations` of CompilerEnv.step has been "
"renamed `observation_spaces`. Please update your code",
category=DeprecationWarning,
)
observation_spaces = observations
if rewards is not None:
warnings.warn(
"Argument `rewards` of CompilerEnv.step has been renamed "
"`reward_spaces`. Please update your code",
category=DeprecationWarning,
)
reward_spaces = rewards
return self.multistep([action], observation_spaces, reward_spaces)

def multistep(
self,
actions: Iterable[ActionType],
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
):
"""Take a sequence of steps and return the final observation and reward.
:param action: A sequence of actions to apply in order.
:param observations: A list of observation spaces to compute
:param observation_spaces: A list of observation spaces to compute
observations from. If provided, this changes the :code:`observation`
element of the return tuple to be a list of observations from the
requested spaces. The default :code:`env.observation_space` is not
returned.
:param rewards: A list of reward spaces to compute rewards from. If
:param reward_spaces: A list of reward spaces to compute rewards from. If
provided, this changes the :code:`reward` element of the return
tuple to be a list of rewards from the requested spaces. The default
:code:`env.reward_space` is not returned.
Expand All @@ -1058,52 +1113,64 @@ def step(
:raises SessionNotFound: If :meth:`reset()
<compiler_gym.envs.CompilerEnv.reset>` has not been called.
"""
# Coerce actions into a list.
actions = action if isinstance(action, IterableType) else [action]
if observations is not None:
warnings.warn(
"Argument `observations` of CompilerEnv.multistep has been "
"renamed `observation_spaces`. Please update your code",
category=DeprecationWarning,
)
observation_spaces = observations
if rewards is not None:
warnings.warn(
"Argument `rewards` of CompilerEnv.multistep has been renamed "
"`reward_spaces`. Please update your code",
category=DeprecationWarning,
)
reward_spaces = rewards

# Coerce observation spaces into a list of ObservationSpaceSpec instances.
if observations:
observation_spaces: List[ObservationSpaceSpec] = [
if observation_spaces:
observation_spaces_to_compute: List[ObservationSpaceSpec] = [
obs
if isinstance(obs, ObservationSpaceSpec)
else self.observation.spaces[obs]
for obs in observations
for obs in observation_spaces
]
elif self.observation_space_spec:
observation_spaces: List[ObservationSpaceSpec] = [
observation_spaces_to_compute: List[ObservationSpaceSpec] = [
self.observation_space_spec
]
else:
observation_spaces: List[ObservationSpaceSpec] = []
observation_spaces_to_compute: List[ObservationSpaceSpec] = []

# Coerce reward spaces into a list of Reward instances.
if rewards:
reward_spaces: List[Reward] = [
if reward_spaces:
reward_spaces_to_compute: List[Reward] = [
rew if isinstance(rew, Reward) else self.reward.spaces[rew]
for rew in rewards
for rew in reward_spaces
]
elif self.reward_space:
reward_spaces: List[Reward] = [self.reward_space]
reward_spaces_to_compute: List[Reward] = [self.reward_space]
else:
reward_spaces: List[Reward] = []
reward_spaces_to_compute: List[Reward] = []

# Perform the underlying environment step.
observation_values, reward_values, done, info = self.raw_step(
actions, observation_spaces, reward_spaces
actions, observation_spaces_to_compute, reward_spaces_to_compute
)

# Translate observations lists back to the appropriate types.
if observations is None and self.observation_space_spec:
if observation_spaces is None and self.observation_space_spec:
observation_values = observation_values[0]
elif not observation_spaces:
elif not observation_spaces_to_compute:
observation_values = None

# Translate reward lists back to the appropriate types.
if rewards is None and self.reward_space:
if reward_spaces is None and self.reward_space:
reward_values = reward_values[0]
# Update the cumulative episode reward
self.episode_reward += reward_values
elif not reward_spaces:
elif not reward_spaces_to_compute:
reward_values = None

return observation_values, reward_values, done, info
Expand Down Expand Up @@ -1176,7 +1243,9 @@ def apply(self, state: CompilerEnvState) -> None: # noqa
)

actions = self.commandline_to_actions(state.commandline)
_, _, done, info = self.step(actions)
done = False
for action in actions:
_, _, done, info = self.step(action)
if done:
raise ValueError(
f"Environment terminated with error: `{info.get('error_details')}`"
Expand Down
6 changes: 3 additions & 3 deletions compiler_gym/envs/llvm/llvm_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from compiler_gym.datasets import Benchmark
from compiler_gym.spaces.reward import Reward
from compiler_gym.util.gym_type_hints import ObservationType, RewardType
from compiler_gym.util.gym_type_hints import ActionType, ObservationType, RewardType
from compiler_gym.views.observation import ObservationView


Expand Down Expand Up @@ -44,7 +44,7 @@ def reset(self, benchmark: Benchmark, observation_view: ObservationView) -> None

def update(
self,
actions: List[int],
actions: List[ActionType],
observations: List[ObservationType],
observation_view: ObservationView,
) -> RewardType:
Expand Down Expand Up @@ -81,7 +81,7 @@ def reset(self, benchmark: str, observation_view: ObservationView) -> None:

def update(
self,
actions: List[int],
actions: List[ActionType],
observations: List[ObservationType],
observation_view: ObservationView,
) -> RewardType:
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/random_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


@deprecated(version="0.2.1", reason="Use env.step(actions) instead")
@deprecated(version="0.2.1", reason="Use env.step(action) instead")
def replay_actions(env: CompilerEnv, action_names: List[str], outdir: Path):
return replay_actions_(env, action_names, outdir)

Expand Down
Loading

0 comments on commit 6bf0209

Please sign in to comment.