diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index ffb1e4058d..651b406aa9 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -927,14 +927,23 @@ def _distribute_output( pair = (receiver_name, receiver) is_greedy = getattr(receiver, "__haystack_is_greedy__", False) - if receiver_socket.is_variadic and is_greedy: - # If the receiver is greedy, we can run it as soon as possible. - # First we remove it from the status lists it's in if it's there or we risk running it multiple times. - if pair in run_queue: - run_queue.remove(pair) - if pair in waiting_queue: - waiting_queue.remove(pair) - run_queue.append(pair) + if receiver_socket.is_variadic: + if is_greedy: + # If the receiver is greedy, we can run it as soon as possible. + # First we remove it from the status lists it's in if it's there or + # we risk running it multiple times. + if pair in run_queue: + run_queue.remove(pair) + if pair in waiting_queue: + waiting_queue.remove(pair) + run_queue.append(pair) + else: + # If the receiver Component has a variadic input that is not greedy + # we put it in the waiting queue. + # This make sure that we don't run it earlier than necessary and we can collect + # as many inputs as we can before running it. + if pair not in waiting_queue: + waiting_queue.append(pair) if pair not in waiting_queue and pair not in run_queue: # Queue up the Component that received this input to run, only if it's not already waiting @@ -1027,12 +1036,15 @@ def _find_next_runnable_lazy_variadic_or_default_component( # The loop detection will be handled later on. return name, comp - def _find_components_that_received_no_input( + def _find_components_that_will_receive_no_input( self, component_name: str, component_result: Dict[str, Any] ) -> Set[Tuple[str, Component]]: """ Find all the Components that are connected to component_name and didn't receive any input from it. + This includes the descendants of the Components that didn't receive any input from component_name. + That is necessary to avoid getting stuck into infinite loops waiting for inputs that will never arrive. + :param component_name: Name of the Component that created the output :param component_result: Output of the Component :return: A set of Components that didn't receive any input from component_name @@ -1045,6 +1057,13 @@ def _find_components_that_received_no_input( for receiver in socket.receivers: receiver_instance: Component = self.graph.nodes[receiver]["instance"] components.add((receiver, receiver_instance)) + # Get the descendants too. When we remove a Component that received no input + # it's extremely likely that its descendants will receive no input as well. + # This is fine even if the Pipeline will merge back into a single Component + # at a certain point. The merging Component will be put back into the run + # queue at a later stage. + components |= {(d, self.graph.nodes[d]["instance"]) for d in networkx.descendants(self.graph, receiver)} + return components def _is_stuck_in_a_loop(self, waiting_queue: List[Tuple[str, Component]]) -> bool: diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index fc0d36f1e1..6ae9df4914 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -245,7 +245,7 @@ def run(self, word: str): # This happens when a component was put in the waiting list but we reached it from another edge. _dequeue_waiting_component((name, comp), waiting_queue) - for pair in self._find_components_that_received_no_input(name, res): + for pair in self._find_components_that_will_receive_no_input(name, res): _dequeue_component(pair, run_queue, waiting_queue) res = self._distribute_output(name, res, components_inputs, run_queue, waiting_queue) diff --git a/releasenotes/notes/fix-run-loop-63bf0ffc26887e66.yaml b/releasenotes/notes/fix-run-loop-63bf0ffc26887e66.yaml new file mode 100644 index 0000000000..9ec3525cd8 --- /dev/null +++ b/releasenotes/notes/fix-run-loop-63bf0ffc26887e66.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fix a bug in `Pipeline.run()` that would cause it to get stuck in an infinite loop and never return. + + This was caused by Components waiting forever for their inputs when parts of the Pipeline graph are skipped + cause of a "decision" Component not returning outputs for that side of the Pipeline. diff --git a/test/core/pipeline/features/pipeline_run.feature b/test/core/pipeline/features/pipeline_run.feature index 0ea749c647..52d4c46413 100644 --- a/test/core/pipeline/features/pipeline_run.feature +++ b/test/core/pipeline/features/pipeline_run.feature @@ -38,6 +38,7 @@ Feature: Pipeline running | that has a component with default inputs that doesn't receive anything from its sender but receives input from user | | that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user | | that has multiple components with only default inputs and are added in a different order from the order of execution | + | that is linear with conditional branching and multiple joins | Scenario Outline: Running a bad Pipeline Given a pipeline diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index 8524f20065..e37460df85 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -9,7 +9,7 @@ from haystack.components.builders import PromptBuilder, AnswerBuilder from haystack.components.retrievers.in_memory import InMemoryBM25Retriever from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack.components.joiners import BranchJoiner +from haystack.components.joiners import BranchJoiner, DocumentJoiner from haystack.testing.sample_components import ( Accumulate, AddFixedValue, @@ -1489,3 +1489,93 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): ) ], ) + + +@given("a pipeline that is linear with conditional branching and multiple joins", target_fixture="pipeline_data") +def that_is_linear_with_conditional_branching_and_multiple_joins(): + pipeline = Pipeline() + + @component + class FakeRouter: + @component.output_types(LEGIT=str, INJECTION=str) + def run(self, query: str): + if "injection" in query: + return {"INJECTION": query} + return {"LEGIT": query} + + @component + class FakeEmbedder: + @component.output_types(embeddings=List[float]) + def run(self, text: str): + return {"embeddings": [1.0, 2.0, 3.0]} + + @component + class FakeRanker: + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document]): + return {"documents": documents} + + @component + class FakeRetriever: + @component.output_types(documents=List[Document]) + def run(self, query: str): + if "injection" in query: + return {"documents": []} + return {"documents": [Document(content="This is a document")]} + + @component + class FakeEmbeddingRetriever: + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float]): + return {"documents": [Document(content="This is another document")]} + + pipeline.add_component(name="router", instance=FakeRouter()) + pipeline.add_component(name="text_embedder", instance=FakeEmbedder()) + pipeline.add_component(name="retriever", instance=FakeEmbeddingRetriever()) + pipeline.add_component(name="emptyretriever", instance=FakeRetriever()) + pipeline.add_component(name="joinerfinal", instance=DocumentJoiner()) + pipeline.add_component(name="joinerhybrid", instance=DocumentJoiner()) + pipeline.add_component(name="ranker", instance=FakeRanker()) + pipeline.add_component(name="bm25retriever", instance=FakeRetriever()) + + pipeline.connect("router.INJECTION", "emptyretriever.query") + pipeline.connect("router.LEGIT", "text_embedder.text") + pipeline.connect("text_embedder", "retriever.query_embedding") + pipeline.connect("router.LEGIT", "ranker.query") + pipeline.connect("router.LEGIT", "bm25retriever.query") + pipeline.connect("bm25retriever", "joinerhybrid.documents") + pipeline.connect("retriever", "joinerhybrid.documents") + pipeline.connect("joinerhybrid.documents", "ranker.documents") + pipeline.connect("ranker", "joinerfinal.documents") + pipeline.connect("emptyretriever", "joinerfinal.documents") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"router": {"query": "I'm a legit question"}}, + expected_outputs={ + "joinerfinal": { + "documents": [ + Document(content="This is a document"), + Document(content="This is another document"), + ] + } + }, + expected_run_order=[ + "router", + "text_embedder", + "bm25retriever", + "retriever", + "joinerhybrid", + "ranker", + "joinerfinal", + ], + ), + PipelineRunData( + inputs={"router": {"query": "I'm a nasty prompt injection"}}, + expected_outputs={"joinerfinal": {"documents": []}}, + expected_run_order=["router", "emptyretriever", "joinerfinal"], + ), + ], + ) diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 106533603d..c0ce448d5f 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -1136,23 +1136,26 @@ def test__component_has_enough_inputs_to_run(self): "sentence_builder", {"sentence_builder": {"words": ["blah blah"]}} ) - def test__find_components_that_received_no_input(self): + def test__find_components_that_will_receive_no_input(self): sentence_builder = component_class( "SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"} )() document_builder = component_class( "DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")} )() + document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() pipe = Pipeline() pipe.add_component("sentence_builder", sentence_builder) pipe.add_component("document_builder", document_builder) + pipe.add_component("document_joiner", document_joiner) pipe.connect("sentence_builder.text", "document_builder.text") + pipe.connect("document_builder.doc", "document_joiner.docs") - res = pipe._find_components_that_received_no_input("sentence_builder", {}) - assert res == {("document_builder", document_builder)} + res = pipe._find_components_that_will_receive_no_input("sentence_builder", {}) + assert res == {("document_builder", document_builder), ("document_joiner", document_joiner)} - res = pipe._find_components_that_received_no_input("sentence_builder", {"text": "some text"}) + res = pipe._find_components_that_will_receive_no_input("sentence_builder", {"text": "some text"}) assert res == set() def test__distribute_output(self):