diff --git a/hypothesis_trio/stateful.py b/hypothesis_trio/stateful.py index e7272f1..4eedbb4 100644 --- a/hypothesis_trio/stateful.py +++ b/hypothesis_trio/stateful.py @@ -1,6 +1,8 @@ import trio from trio.testing import trio_test +import hypothesis + from hypothesis.stateful import ( # Needed for run_state_machine_as_test copy-paste check_type, @@ -24,13 +26,7 @@ # runner into run_state_machine_as_test -def run_state_machine_as_test(state_machine_factory, settings=None): - """Run a state machine definition as a test, either silently doing nothing - or printing a minimal breaking program and raising an exception. - state_machine_factory is anything which returns an instance of - GenericStateMachine when called with no arguments - it can be a class or a - function. settings will be used to control the execution of the test. - """ +def run_custom_state_machine_as_test(state_machine_factory, settings=None): if settings is None: try: settings = state_machine_factory.TestCase.settings @@ -63,30 +59,7 @@ def run_state_machine(factory, data): or current_verbosity() >= Verbosity.debug ) - # Plug a custom machinery - # This block is not in the original copy-paste - try: - machine._custom_runner - except AttributeError: - pass - else: - return machine._custom_runner(data, print_steps, should_continue) - - try: - if print_steps: - machine.print_start() - machine.check_invariants() - - while should_continue.more(): - value = data.conjecture_data.draw(machine.steps()) - if print_steps: - machine.print_step(value) - machine.execute_step(value) - machine.check_invariants() - finally: - if print_steps: - machine.print_end() - machine.teardown() + return machine._custom_runner(data, print_steps, should_continue) # Use a machine digest to identify stateful tests in the example database run_state_machine.hypothesis.inner_test._hypothesis_internal_add_digest = function_digest( @@ -205,12 +178,27 @@ async def check_invariants(self): def monkey_patch_hypothesis(): - from hypothesis import stateful - stateful._old_run_state_machine_as_test = stateful.run_state_machine_as_test - stateful.run_state_machine_as_test = run_state_machine_as_test + if hasattr(hypothesis.stateful, "original_run_state_machine_as_test"): + return + original = hypothesis.stateful.run_state_machine_as_test + + def run_state_machine_as_test(state_machine_factory, settings=None): + """Run a state machine definition as a test, either silently doing nothing + or printing a minimal breaking program and raising an exception. + state_machine_factory is anything which returns an instance of + GenericStateMachine when called with no arguments - it can be a class or a + function. settings will be used to control the execution of the test. + """ + if hasattr(state_machine_factory, '_custom_runner'): + return run_custom_state_machine_as_test( + state_machine_factory, settings=settings + ) + return original(state_machine_factory, settings=settings) + hypothesis.stateful.original_run_state_machine_as_test = original + hypothesis.stateful.run_state_machine_as_test = run_state_machine_as_test -monkey_patch_hypothesis() -# Expose all objects from original stateful module +# Monkey patch and expose all objects from original stateful module +monkey_patch_hypothesis() from hypothesis.stateful import *