diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index fce35565e..62b7f29fe 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -735,7 +735,14 @@ def reset( # pylint: disable=arguments-differ def _retry(error) -> Optional[ObservationType]: """Abort and retry on error.""" - logger.warning("%s during reset(): %s", type(error).__name__, error) + # Log the error that we are recovering from, but treat + # ServiceIsClosed errors as unimportant since we know what causes + # them. + log_severity = ( + logger.debug if isinstance(error, ServiceIsClosed) else logger.warning + ) + log_severity("%s during reset(): %s", type(error).__name__, error) + if self.service: try: self.service.close() diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index 16257d48b..8f9295781 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -204,7 +204,7 @@ def __call__( except ValueError as e: if str(e) == "Cannot invoke RPC on closed channel!": raise ServiceIsClosed( - f"RPC communication failed with message: {e}" + "RPC communication failed because channel is closed" ) from None raise e except grpc.RpcError as e: diff --git a/tests/llvm/datasets/cbench_validate_test.py b/tests/llvm/datasets/cbench_validate_test.py index 156e47b46..377f720d3 100644 --- a/tests/llvm/datasets/cbench_validate_test.py +++ b/tests/llvm/datasets/cbench_validate_test.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Test for cBench semantics validation.""" +import pytest + from compiler_gym import ValidationResult from compiler_gym.envs.llvm import LlvmEnv from tests.test_main import main @@ -10,6 +12,7 @@ pytest_plugins = ["tests.pytest_plugins.llvm"] +@pytest.mark.timeout(600) def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str): """Run the validation routine on all benchmarks.""" env.reward_space = "IrInstructionCount" @@ -29,6 +32,7 @@ def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str) assert result.okay() +@pytest.mark.timeout(600) def test_non_validatable_benchmark_validate( env: LlvmEnv, non_validatable_cbench_uri: str ):