Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add loop component 🎁🎄 #5429

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/backend/base/langflow/components/logic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .data_conditional_router import DataConditionalRouterComponent
from .flow_tool import FlowToolComponent
from .listen import ListenComponent
from .loop import LoopComponent
from .notify import NotifyComponent
from .pass_message import PassMessageComponent
from .run_flow import RunFlowComponent
Expand All @@ -12,6 +13,7 @@
"DataConditionalRouterComponent",
"FlowToolComponent",
"ListenComponent",
"LoopComponent",
"NotifyComponent",
"PassMessageComponent",
"RunFlowComponent",
Expand Down
108 changes: 108 additions & 0 deletions src/backend/base/langflow/components/logic/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from langflow.custom import Component
from langflow.io import DataInput, Output
from langflow.schema import Data


class LoopComponent(Component):
display_name = "Loop"
description = (
"Iterates over a list of Data objects, outputting one item at a time and aggregating results from loop inputs."
)
icon = "infinity"

inputs = [
DataInput(name="data", display_name="Data", info="The initial list of Data objects to iterate over."),
DataInput(name="loop_input", display_name="Loop Input", info="Data to aggregate during the iteration."),
]

outputs = [
Output(display_name="Item", name="item", method="item_output"),
Output(display_name="Done", name="done", method="done_output"),
]

def initialize_data(self) -> None:
"""Initialize the data list, context index, and aggregated list."""
if self.ctx.get(f"{self._id}_initialized", False):
return

# Ensure data is a list of Data objects
data_list = self._validate_data(self.data)

# Store the initial data and context variables
self.update_ctx(
{
f"{self._id}_data": data_list,
f"{self._id}_index": 0,
f"{self._id}_aggregated": [],
f"{self._id}_initialized": True,
}
)

def _validate_data(self, data):
"""Validate and return a list of Data objects."""
if isinstance(data, Data):
return [data]
if isinstance(data, list) and all(isinstance(item, Data) for item in data):
return data
msg = "The 'data' input must be a list of Data objects or a single Data object."
raise TypeError(msg)

def evaluate_stop_loop(self) -> bool:
"""Evaluate whether to stop item or done output."""
current_index = self.ctx.get(f"{self._id}_index", 0)
data_length = len(self.ctx.get(f"{self._id}_data", []))
return current_index > data_length

def item_output(self) -> Data:
"""Output the next item in the list or stop if done."""
self.initialize_data()
current_item = Data(text="")

if self.evaluate_stop_loop():
self.stop("item")
return Data(text="")

# Get data list and current index
data_list, current_index = self.loop_variables()
if current_index < len(data_list):
# Output current item and increment index
try:
current_item = data_list[current_index]
except IndexError:
current_item = Data(text="")
self.aggregated_output()
self.update_ctx({f"{self._id}_index": current_index + 1})
return current_item

def done_output(self) -> Data:
"""Trigger the done output when iteration is complete."""
self.initialize_data()

if self.evaluate_stop_loop():
self.stop("item")
self.start("done")

return self.ctx.get(f"{self._id}_aggregated", [])
self.stop("done")
return Data(text="")

def loop_variables(self):
"""Retrieve loop variables from context."""
return (
self.ctx.get(f"{self._id}_data", []),
self.ctx.get(f"{self._id}_index", 0),
)

def aggregated_output(self) -> Data:
"""Return the aggregated list once all items are processed."""
self.initialize_data()

# Get data list and aggregated list
data_list = self.ctx.get(f"{self._id}_data", [])
aggregated = self.ctx.get(f"{self._id}_aggregated", [])

# Check if loop input is provided and append to aggregated list
if self.loop_input is not None and not isinstance(self.loop_input, str) and len(aggregated) <= len(data_list):
aggregated.append(self.loop_input)
self.update_ctx({f"{self._id}_aggregated": aggregated})
return aggregated
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def stop(self, output_name: str | None = None) -> None:
msg = f"Error stopping {self.display_name}: {e}"
raise ValueError(msg) from e

def start(self, output_name: str | None = None) -> None:
if not output_name and self._vertex and len(self._vertex.outputs) == 1:
output_name = self._vertex.outputs[0]["name"]
elif not output_name:
msg = "You must specify an output name to call start"
raise ValueError(msg)
if not self._vertex:
msg = "Vertex is not set"
raise ValueError(msg)
try:
self.graph.mark_branch(vertex_id=self._vertex.id, output_name=output_name, state="ACTIVE")
except Exception as e:
msg = f"Error starting {self.display_name}: {e}"
raise ValueError(msg) from e

def append_state(self, name: str, value: Any) -> None:
if not self._vertex:
msg = "Vertex is not set"
Expand Down
7 changes: 7 additions & 0 deletions src/backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def pytest_configure(config):
pytest.VECTOR_STORE_PATH = data_path / "Vector_store.json"
pytest.SIMPLE_API_TEST = data_path / "SimpleAPITest.json"
pytest.MEMORY_CHATBOT_NO_LLM = data_path / "MemoryChatbotNoLLM.json"
pytest.LOOP_TEST = data_path / "LoopTest.json"
pytest.CODE_WITH_SYNTAX_ERROR = """
def get_text():
retun "Hello World"
Expand All @@ -121,6 +122,7 @@ def get_text():
pytest.TWO_OUTPUTS,
pytest.VECTOR_STORE_PATH,
pytest.MEMORY_CHATBOT_NO_LLM,
pytest.LOOP_TEST,
]:
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"

Expand Down Expand Up @@ -324,6 +326,11 @@ def json_memory_chatbot_no_llm():
return pytest.MEMORY_CHATBOT_NO_LLM.read_text(encoding="utf-8")


@pytest.fixture
def json_loop_test():
return pytest.LOOP_TEST.read_text(encoding="utf-8")


@pytest.fixture(autouse=True)
def deactivate_tracing(monkeypatch):
monkeypatch.setenv("LANGFLOW_DEACTIVATE_TRACING", "true")
Expand Down
Loading
Loading