diff --git a/CHANGELOG.rst b/CHANGELOG.rst index aa5fbf90..5677e6e8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,7 @@ Changelog ========= +- Add tools to handle simulation errors and limit simulation time - Use kernel copy to avoid pickle issue and allow BOLFI parallelisation with non-default kernel - Restrict matplotlib version < 3.9 for compatibility with GPy - Add option to use additive or multiplicative adjustment in any acquisition method diff --git a/docs/api.rst b/docs/api.rst index e58e71d9..3f56aed8 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -126,6 +126,7 @@ Other .. autosummary:: elfi.tools.vectorize elfi.tools.external_operation + elfi.tools.unreliable_operation @@ -338,3 +339,5 @@ Other .. automethod:: elfi.tools.vectorize .. automethod:: elfi.tools.external_operation + +.. automethod:: elfi.tools.unreliable_operation diff --git a/elfi/model/tools.py b/elfi/model/tools.py index 77f3d65c..e83c55d3 100644 --- a/elfi/model/tools.py +++ b/elfi/model/tools.py @@ -1,5 +1,9 @@ """This module contains tools for ELFI graphs.""" +__all__ = ['vectorize', 'external_operation', 'unreliable_operation'] + +import logging +import signal import subprocess from functools import partial @@ -7,7 +11,7 @@ from elfi.utils import get_sub_seed, is_array -__all__ = ['vectorize', 'external_operation'] +logger = logging.getLogger(__name__) def run_vectorized(operation, *inputs, constants=None, dtype=None, batch_size=None, **kwargs): @@ -284,3 +288,111 @@ def external_operation(command, prepare_inputs=prepare_inputs, stdout=stdout, subprocess_kwargs=subprocess_kwargs) + + +def run_with_recovery(operation, known_errors, *inputs, error_output=None, **kwargs): + """Run the operation with error recovery. + + Helper that returns a predetermined output when an accepted error occurs in the operation. + This tool is still experimental and may not work in all cases. + + Parameters + ---------- + operation : callable + Operation to be executed. + known_errors : Exception or tuple + Accepted errors. + inputs + Positional arguments for the operation. + error_output : any, optional + Output to return when an accepted error occurs. Defaults to None. + kwargs + Keyword arguments for the operation. + + Returns + ------- + output : any + Operation output or error_output if operation failed with an accepted error. + + """ + try: + output = operation(*inputs, **kwargs) + except known_errors as e: + logger.warning("Exception occurred: {}".format(e)) + batch_size = kwargs.get('batch_size', None) + output = np.array([error_output] * batch_size) if batch_size else error_output + return output + + +def run_with_time_limit(operation, time_limit, *inputs, error_output=None, **kwargs): + """Run the operation with time limit. + + Helper that terminates the operation at time limit and returns a predetermined output. + This tool is still experimental and may not work in all cases. + + Parameters + ---------- + operation : callable + Operation to be executed. + time_limit : int + Operation time limit in seconds. + inputs + Positional arguments for the operation. + error_output : any, optional + Output to return when the operation exceeds time limit. Defaults to None. + kwargs + Keyword arguments for the operation. + + Returns + ------- + output : any + Operation output or error_output if operation exceeded time limit. + + """ + def timeout_handler(signum, frame): + raise TimeoutError + + try: + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(time_limit) + output = operation(*inputs, **kwargs) + except TimeoutError: + logger.warning("Operation exceeded time limit.") + batch_size = kwargs.get('batch_size', None) + output = np.array([error_output] * batch_size) if batch_size else error_output + finally: + signal.alarm(0) # cancel the alarm + return output + + +def unreliable_operation(operation, + known_errors=None, + time_limit=None, + error_output=None): + """Wrap an operation to run with timeout and recovery options. + + This tool is still experimental and may not work in all cases. + + Parameters + ---------- + operation : callable + Operation to be executed. + known_errors : Exception or tuple + Accepted errors. Defaults to None. + time_limit : int, optional + Operation time limit in seconds. Defaults to None. + error_output : any, optional + Output to return when an accepted error occurs or the operation exceeds time limit. + Defaults to None. + + Returns + ------- + operation : callable + ELFI compatible operation that can be used e.g. as a simulator. + + """ + if time_limit is not None: + operation = partial(run_with_time_limit, operation, time_limit, error_output=error_output) + if known_errors is not None: + operation = partial(run_with_recovery, operation, known_errors, error_output=error_output) + return operation diff --git a/tests/unit/test_tools.py b/tests/unit/test_tools.py index 46ee5b38..2dde7895 100644 --- a/tests/unit/test_tools.py +++ b/tests/unit/test_tools.py @@ -1,4 +1,5 @@ import pickle +import time import numpy as np import pytest @@ -86,6 +87,37 @@ def test_vectorized_and_external_combined(): assert len(np.unique(g[:, 3]) == 1) +def test_unreliable_operation(): + def simulator(param, error=None, sleep=0, random_state=None): + if error is not None: + raise error + time.sleep(sleep) + return param * np.linspace(0, 1, 5) + + errors = RuntimeError + sim = elfi.tools.unreliable_operation(simulator, known_errors=errors) + assert(np.all(sim(2, error=None) == simulator(2))) + assert(sim(2, error=RuntimeError("Example runtime error.")) == None) + + sim = elfi.tools.unreliable_operation(simulator, known_errors=errors, error_output=np.zeros(5)) + assert(np.all(sim(2, error=RuntimeError("Example runtime error.")) == np.zeros(5))) + + errors = (RuntimeError, ArithmeticError) + sim = elfi.tools.unreliable_operation(simulator, known_errors=errors) + assert(sim(2, error=RuntimeError("Example runtime error.")) == None) + assert(sim(2, error=ZeroDivisionError("Example arithmetic error.")) == None) + + errors = Exception + sim = elfi.tools.unreliable_operation(simulator, known_errors=errors) + assert(sim(2, error=RuntimeError("Example runtime error.")) == None) + with pytest.raises(KeyboardInterrupt): + sim(2, error=KeyboardInterrupt) + + sim = elfi.tools.unreliable_operation(simulator, time_limit=1) + assert(np.all(sim(2, sleep=0) == simulator(2))) + assert(sim(2, sleep=2) == None) + + def test_progress_bar(ma2): thresholds = [.5, .2] N = 1000