From b83868562a6393e1bf7db33a78a4762790c6a1bf Mon Sep 17 00:00:00 2001
From: Liam DeVoe <orionldevoe@gmail.com>
Date: Sun, 19 Jan 2025 01:43:05 -0500
Subject: [PATCH] move to ir cache

---
 .../hypothesis/internal/conjecture/data.py    |  21 +-
 .../hypothesis/internal/conjecture/engine.py  | 186 +++---------------
 .../tests/conjecture/test_engine.py           |  15 +-
 3 files changed, 40 insertions(+), 182 deletions(-)

diff --git a/hypothesis-python/src/hypothesis/internal/conjecture/data.py b/hypothesis-python/src/hypothesis/internal/conjecture/data.py
index e7de1110c7..6053dbf74f 100644
--- a/hypothesis-python/src/hypothesis/internal/conjecture/data.py
+++ b/hypothesis-python/src/hypothesis/internal/conjecture/data.py
@@ -1737,19 +1737,12 @@ def _draw(self, ir_type, kwargs, *, observe, forced, fake_forced):
             debug_report(f"overrun because hit {self.max_length_ir=}")
             self.mark_overrun()
 
-        if self.ir_prefix is not None and observe:
-            if self.index_ir < len(self.ir_prefix):
-                choice = self._pop_choice(ir_type, kwargs, forced=forced)
-            else:
-                try:
-                    choice = (
-                        forced
-                        if forced is not None
-                        else draw_choice(ir_type, kwargs, random=self.__random)
-                    )
-                except StopTest:
-                    debug_report("overrun because draw_choice overran")
-                    self.mark_overrun()
+        if (
+            observe
+            and self.ir_prefix is not None
+            and self.index_ir < len(self.ir_prefix)
+        ):
+            choice = self._pop_choice(ir_type, kwargs, forced=forced)
 
             if forced is None:
                 forced = choice
@@ -2261,7 +2254,7 @@ def draw_bits(
         elif self._bytes_drawn < len(self.__prefix):
             index = self._bytes_drawn
             buf = self.__prefix[index : index + n_bytes]
-            if len(buf) < n_bytes:
+            if len(buf) < n_bytes:  # pragma: no cover # removing soon
                 assert self.__random is not None
                 buf += uniform(self.__random, n_bytes - len(buf))
         else:
diff --git a/hypothesis-python/src/hypothesis/internal/conjecture/engine.py b/hypothesis-python/src/hypothesis/internal/conjecture/engine.py
index fdcb23903c..931636f480 100644
--- a/hypothesis-python/src/hypothesis/internal/conjecture/engine.py
+++ b/hypothesis-python/src/hypothesis/internal/conjecture/engine.py
@@ -18,17 +18,7 @@
 from datetime import timedelta
 from enum import Enum
 from random import Random, getrandbits
-from typing import (
-    Callable,
-    Final,
-    List,
-    Literal,
-    NoReturn,
-    Optional,
-    Union,
-    cast,
-    overload,
-)
+from typing import Callable, Final, List, Literal, NoReturn, Optional, Union, cast
 
 import attr
 
@@ -283,7 +273,6 @@ def __init__(
         # shrinking where we need to know about the structure of the
         # executed test case.
         self.__data_cache = LRUReusedCache(CACHE_SIZE)
-        self.__data_cache_ir = LRUReusedCache(CACHE_SIZE)
 
         self.reused_previously_shrunk_test_case = False
 
@@ -359,26 +348,8 @@ def _cache_key(self, choices: Sequence[ChoiceT]) -> tuple[ChoiceKeyT, ...]:
 
     def _cache(self, data: ConjectureData) -> None:
         result = data.as_result()
-        self.__data_cache[data.buffer] = result
-
-        # interesting buffer-based data can mislead the shrinker if we cache them.
-        #
-        #   @given(st.integers())
-        #   def f(n):
-        #     assert n < 100
-        #
-        # may generate two counterexamples, n=101 and n=m > 101, in that order,
-        # where the buffer corresponding to n is large due to eg failed probes.
-        # We shrink m and eventually try n=101, but it is cached to a large buffer
-        # and so the best we can do is n=102, a non-ideal shrink.
-        #
-        # We can cache ir-based buffers fine, which always correspond to the
-        # smallest buffer via forced=. The overhead here is small because almost
-        # all interesting data are ir-based via the shrinker (and that overhead
-        # will tend towards zero as we move generation to the ir).
-        if data.ir_prefix is not None or data.status < Status.INTERESTING:
-            key = self._cache_key(data.choices)
-            self.__data_cache_ir[key] = result
+        key = self._cache_key(data.choices)
+        self.__data_cache[key] = result
 
     def cached_test_function_ir(
         self,
@@ -387,6 +358,14 @@ def cached_test_function_ir(
         error_on_discard: bool = False,
         extend: int = 0,
     ) -> Union[ConjectureResult, _Overrun]:
+        """
+        If ``error_on_discard`` is set to True this will raise ``ContainsDiscard``
+        in preference to running the actual test function. This is to allow us
+        to skip test cases we expect to be redundant in some cases. Note that
+        it may be the case that we don't raise ``ContainsDiscard`` even if the
+        result has discards if we cannot determine from previous runs whether
+        it will have a discard.
+        """
         # node templates represent a not-yet-filled hole and therefore cannot
         # be cached or retrieved from the cache.
         if not any(isinstance(choice, NodeTemplate) for choice in choices):
@@ -395,7 +374,7 @@ def cached_test_function_ir(
             choices = cast(Sequence[ChoiceT], choices)
             key = self._cache_key(choices)
             try:
-                cached = self.__data_cache_ir[key]
+                cached = self.__data_cache[key]
                 # if we have a cached overrun for this key, but we're allowing extensions
                 # of the nodes, it could in fact run to a valid data if we try.
                 if extend == 0 or cached.status is not Status.OVERRUN:
@@ -429,14 +408,19 @@ def kill_branch(self) -> NoReturn:
         else:
             trial_data.freeze()
             key = self._cache_key(trial_data.choices)
-            if trial_data.status is Status.OVERRUN:
+            if trial_data.status > Status.OVERRUN:
+                try:
+                    return self.__data_cache[key]
+                except KeyError:
+                    pass
+            else:
                 # if we simulated to an overrun, then we our result is certainly
                 # an overrun; no need to consult the cache. (and we store this result
                 # for simulation-less lookup later).
-                self.__data_cache_ir[key] = Overrun
+                self.__data_cache[key] = Overrun
                 return Overrun
             try:
-                return self.__data_cache_ir[key]
+                return self.__data_cache[key]
             except KeyError:
                 pass
 
@@ -567,8 +551,8 @@ def test_function(self, data: ConjectureData) -> None:
 
         if data.status == Status.INTERESTING:
             if not self.using_hypothesis_backend:
-                # drive the ir tree through the test function to convert it
-                # to a buffer
+                # replay this failure on the hypothesis backend to ensure it still
+                # finds a failure. otherwise, it is flaky.
                 initial_origin = data.interesting_origin
                 initial_traceback = getattr(
                     data.extra_information, "_expected_traceback", None
@@ -611,13 +595,13 @@ def test_function(self, data: ConjectureData) -> None:
                 if sort_key_ir(data.ir_nodes) < sort_key_ir(existing.ir_nodes):
                     self.shrinks += 1
                     self.downgrade_buffer(ir_to_bytes(existing.choices))
-                    self.__data_cache.unpin(existing.buffer)
+                    self.__data_cache.unpin(self._cache_key(existing.choices))
                     changed = True
 
             if changed:
                 self.save_choices(data.choices)
                 self.interesting_examples[key] = data.as_result()  # type: ignore
-                self.__data_cache.pin(data.buffer, data.as_result())
+                self.__data_cache.pin(self._cache_key(data.choices), data.as_result())
                 self.shrunk_examples.discard(key)
 
             if self.shrinks >= MAX_SHRINKS:
@@ -969,11 +953,13 @@ def generate_new_examples(self) -> None:
         self.debug("Generating new examples")
 
         assert self.should_generate_more()
-        zero_data = self.cached_test_function(bytes(BUFFER_SIZE))
+        zero_data = self.cached_test_function_ir(
+            (NodeTemplate("simplest", size=BUFFER_SIZE),)
+        )
         if zero_data.status > Status.OVERRUN:
             assert isinstance(zero_data, ConjectureResult)
             self.__data_cache.pin(
-                zero_data.buffer, zero_data.as_result()
+                self._cache_key(zero_data.choices), zero_data.as_result()
             )  # Pin forever
 
         if zero_data.status == Status.OVERRUN or (
@@ -1048,7 +1034,7 @@ def generate_new_examples(self) -> None:
             # not whatever is specified by the backend. We can improve this
             # once more things are on the ir.
             if not self.using_hypothesis_backend:
-                data = self.new_conjecture_data(prefix=b"", max_length=BUFFER_SIZE)
+                data = self.new_conjecture_data_ir([], max_length=BUFFER_SIZE)
                 with suppress(BackendCannotProceed):
                     self.test_function(data)
                 continue
@@ -1228,7 +1214,7 @@ def generate_mutations_from(
                     assert isinstance(new_data, ConjectureResult)
                     if (
                         new_data.status >= data.status
-                        and data.buffer != new_data.buffer
+                        and choices_key(data.choices) != choices_key(new_data.choices)
                         and all(
                             k in new_data.target_observations
                             and new_data.target_observations[k] >= v
@@ -1332,32 +1318,6 @@ def new_conjecture_data_ir(
             random=self.random,
         )
 
-    def new_conjecture_data(
-        self,
-        prefix: Union[bytes, bytearray],
-        max_length: int = BUFFER_SIZE,
-        observer: Optional[DataObserver] = None,
-    ) -> ConjectureData:
-        provider = (
-            HypothesisProvider if self._switch_to_hypothesis_provider else self.provider
-        )
-        observer = observer or self.tree.new_observer()
-        if not self.using_hypothesis_backend:
-            observer = DataObserver()
-
-        return ConjectureData(
-            prefix=prefix,
-            max_length=max_length,
-            random=self.random,
-            observer=observer,
-            provider=provider,
-        )
-
-    def new_conjecture_data_for_buffer(
-        self, buffer: Union[bytes, bytearray]
-    ) -> ConjectureData:
-        return self.new_conjecture_data(buffer, max_length=len(buffer))
-
     def shrink_interesting_examples(self) -> None:
         """If we've found interesting examples, try to replace each of them
         with a minimal interesting example with the same interesting_origin.
@@ -1468,88 +1428,6 @@ def new_shrinker(
             in_target_phase=self._current_phase == "target",
         )
 
-    def cached_test_function(
-        self,
-        buffer: Union[bytes, bytearray],
-        *,
-        extend: int = 0,
-    ) -> Union[ConjectureResult, _Overrun]:  # pragma: no cover # removing function soon
-        """Checks the tree to see if we've tested this buffer, and returns the
-        previous result if we have.
-
-        Otherwise we call through to ``test_function``, and return a
-        fresh result.
-
-        If ``error_on_discard`` is set to True this will raise ``ContainsDiscard``
-        in preference to running the actual test function. This is to allow us
-        to skip test cases we expect to be redundant in some cases. Note that
-        it may be the case that we don't raise ``ContainsDiscard`` even if the
-        result has discards if we cannot determine from previous runs whether
-        it will have a discard.
-        """
-        buffer = bytes(buffer)[:BUFFER_SIZE]
-
-        max_length = min(BUFFER_SIZE, len(buffer) + extend)
-
-        @overload
-        def check_result(result: _Overrun) -> _Overrun: ...
-        @overload
-        def check_result(result: ConjectureResult) -> ConjectureResult: ...
-        def check_result(
-            result: Union[_Overrun, ConjectureResult],
-        ) -> Union[_Overrun, ConjectureResult]:
-            assert result is Overrun or (
-                isinstance(result, ConjectureResult) and result.status != Status.OVERRUN
-            )
-            return result
-
-        try:
-            cached = check_result(self.__data_cache[buffer])
-            if cached.status > Status.OVERRUN or extend == 0:
-                return cached
-        except KeyError:
-            pass
-
-        observer = DataObserver()
-        dummy_data = self.new_conjecture_data(
-            prefix=buffer, max_length=max_length, observer=observer
-        )
-
-        if self.using_hypothesis_backend:
-            try:
-                self.tree.simulate_test_function(dummy_data)
-            except PreviouslyUnseenBehaviour:
-                pass
-            else:
-                if dummy_data.status > Status.OVERRUN:
-                    dummy_data.freeze()
-                    try:
-                        return self.__data_cache[dummy_data.buffer]
-                    except KeyError:
-                        pass
-                else:
-                    self.__data_cache[buffer] = Overrun
-                    return Overrun
-
-        # We didn't find a match in the tree, so we need to run the test
-        # function normally. Note that test_function will automatically
-        # add this to the tree so we don't need to update the cache.
-
-        result = None
-
-        data = self.new_conjecture_data(
-            prefix=max((buffer, dummy_data.buffer), key=len), max_length=max_length
-        )
-        self.test_function(data)
-        result = check_result(data.as_result())
-        if extend == 0 or (
-            result is not Overrun
-            and not isinstance(result, _Overrun)
-            and len(result.buffer) <= len(buffer)
-        ):
-            self.__data_cache[buffer] = result
-        return result
-
     def passing_choice_sequences(
         self, prefix: Sequence[IRNode] = ()
     ) -> frozenset[bytes]:
@@ -1558,8 +1436,8 @@ def passing_choice_sequences(
         """
         return frozenset(
             result.ir_nodes
-            for key in self.__data_cache_ir
-            if (result := self.__data_cache_ir[key]).status is Status.VALID
+            for key in self.__data_cache
+            if (result := self.__data_cache[key]).status is Status.VALID
             and startswith(result.ir_nodes, prefix)
         )
 
diff --git a/hypothesis-python/tests/conjecture/test_engine.py b/hypothesis-python/tests/conjecture/test_engine.py
index edef062c12..a912c0c5e9 100644
--- a/hypothesis-python/tests/conjecture/test_engine.py
+++ b/hypothesis-python/tests/conjecture/test_engine.py
@@ -216,7 +216,7 @@ def f(data):
         data.mark_interesting()
 
     runner = ConjectureRunner(f, settings=settings(max_examples=5000, database=None))
-    with buffer_size_limit(2):
+    with buffer_size_limit(4):
         runner.run()
     assert runner.interesting_examples
 
@@ -1505,19 +1505,6 @@ def test(data):
         assert d2.status is Status.VALID
 
 
-def test_draw_bits_partly_from_prefix_and_partly_random():
-    # a draw_bits call which straddles the end of our prefix has a slightly
-    # different code branch.
-    def test(data):
-        # float consumes draw_bits(64)
-        data.draw_float()
-
-    with deterministic_PRNG():
-        runner = ConjectureRunner(test, settings=TEST_SETTINGS)
-        d = runner.cached_test_function(bytes(10), extend=100)
-        assert d.status == Status.VALID
-
-
 def test_can_be_set_to_ignore_limits():
     def test(data):
         data.draw_integer(0, 2**8 - 1)