diff --git a/compiler_gym/views/observation_space_spec.py b/compiler_gym/views/observation_space_spec.py index 4ca0c1964..77b50c9fd 100644 --- a/compiler_gym/views/observation_space_spec.py +++ b/compiler_gym/views/observation_space_spec.py @@ -4,12 +4,11 @@ # LICENSE file in the root directory of this source tree. from typing import Any, Callable, ClassVar, Optional, Union -# import networkx as nx -# import numpy as np from gym.spaces import Space from compiler_gym.service.proto import Event, ObservationSpace, py_converters from compiler_gym.util.gym_type_hints import ObservationType +from compiler_gym.util.shell_format import indent class ObservationSpaceSpec: @@ -95,10 +94,34 @@ def __eq__(self, rhs) -> bool: @classmethod def from_proto(cls, index: int, proto: ObservationSpace): + """Create an observation space from a ObservationSpace protocol buffer. + + :param index: The index of this observation space into the list of + observation spaces that the compiler service supports. + + :param proto: An ObservationSpace protocol buffer. + + :raises ValueError: If protocol buffer is invalid. + """ + try: + spec = ObservationSpaceSpec.message_converter(proto.space) + except ValueError as e: + raise ValueError( + f"Error interpreting description of observation space '{proto.name}'.\n" + f"Error: {e}\n" + f"ObservationSpace message:\n" + f"{indent(proto.space, n=2)}" + ) from e + + # TODO(cummins): Additional validation of the observation space + # specification would be useful here, such as making sure that the size + # of {low, high} tensors for box shapes match. At present, these errors + # tend not to show up until later, making it more difficult to debug. + return cls( id=proto.name, index=index, - space=ObservationSpaceSpec.message_converter(proto.space), + space=spec, translate=ObservationSpaceSpec.message_converter, to_string=str, deterministic=proto.deterministic,