Skip to content

Commit

Permalink
feat: adds test cases for loop component compatibility with the APIs,…
Browse files Browse the repository at this point in the history
… Loop component updates to support API (#5615)

* add loop component 🎁🎄

* [autofix.ci] apply automated fixes

* fix: add loop component to init

* [autofix.ci] apply automated fixes

* refactor(loop): rename loop input variable and improve code quality

- Renamed 'loop' input to 'loop_input' for clarity.
- Simplified logic for checking loop input and aggregating results.
- Enhanced type hints for better code readability and maintainability.

* refactor(loop): add type hint to initialize_data method for improved clarity

* adding test

* test cases added

* Update test_loop.py

* adding test

* test cases added

* Update test_loop.py

* update with the new test case method!

* Update test_loop.py

* tests  updates

* Update loop.py

* update fix

* issues loop issues

* reverting debug mode params

* solves lint errors and fix the tests

* fix: mypy error incompatible return value type

* [autofix.ci] apply automated fixes

---------

Co-authored-by: Rodrigo Nader <rodrigosilvanader@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
  • Loading branch information
5 people authored Jan 14, 2025
1 parent cc080fd commit 8ebe580
Show file tree
Hide file tree
Showing 6 changed files with 1,787 additions and 30 deletions.
81 changes: 51 additions & 30 deletions src/backend/base/langflow/components/logic/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,7 @@ def initialize_data(self) -> None:
return

# Ensure data is a list of Data objects
if isinstance(self.data, Data):
data_list: list[Data] = [self.data]
elif isinstance(self.data, list):
if not all(isinstance(item, Data) for item in self.data):
msg = "All items in the data list must be Data objects."
raise TypeError(msg)
data_list = self.data
else:
msg = "The 'data' input must be a list of Data objects or a single Data object."
raise TypeError(msg)
data_list = self._validate_data(self.data)

# Store the initial data and context variables
self.update_ctx(
Expand All @@ -47,25 +38,62 @@ def initialize_data(self) -> None:
}
)

def _validate_data(self, data):
"""Validate and return a list of Data objects."""
if isinstance(data, Data):
return [data]
if isinstance(data, list) and all(isinstance(item, Data) for item in data):
return data
msg = "The 'data' input must be a list of Data objects or a single Data object."
raise TypeError(msg)

def evaluate_stop_loop(self) -> bool:
"""Evaluate whether to stop item or done output."""
current_index = self.ctx.get(f"{self._id}_index", 0)
data_length = len(self.ctx.get(f"{self._id}_data", []))
return current_index > data_length

def item_output(self) -> Data:
"""Output the next item in the list."""
"""Output the next item in the list or stop if done."""
self.initialize_data()
current_item = Data(text="")

# Get data list and current index
data_list: list[Data] = self.ctx.get(f"{self._id}_data", [])
current_index: int = self.ctx.get(f"{self._id}_index", 0)
if self.evaluate_stop_loop():
self.stop("item")
return Data(text="")

# Get data list and current index
data_list, current_index = self.loop_variables()
if current_index < len(data_list):
# Output current item and increment index
current_item: Data = data_list[current_index]
self.update_ctx({f"{self._id}_index": current_index + 1})
return current_item

# No more items to output
self.stop("item")
return None # type: ignore [return-value]
try:
current_item = data_list[current_index]
except IndexError:
current_item = Data(text="")
self.aggregated_output()
self.update_ctx({f"{self._id}_index": current_index + 1})
return current_item

def done_output(self) -> Data:
"""Trigger the done output when iteration is complete."""
self.initialize_data()

if self.evaluate_stop_loop():
self.stop("item")
self.start("done")

return self.ctx.get(f"{self._id}_aggregated", [])
self.stop("done")
return Data(text="")

def loop_variables(self):
"""Retrieve loop variables from context."""
return (
self.ctx.get(f"{self._id}_data", []),
self.ctx.get(f"{self._id}_index", 0),
)

def aggregated_output(self) -> Data:
"""Return the aggregated list once all items are processed."""
self.initialize_data()

Expand All @@ -74,14 +102,7 @@ def done_output(self) -> Data:
aggregated = self.ctx.get(f"{self._id}_aggregated", [])

# Check if loop input is provided and append to aggregated list
if self.loop_input is not None:
if self.loop_input is not None and not isinstance(self.loop_input, str) and len(aggregated) <= len(data_list):
aggregated.append(self.loop_input)
self.update_ctx({f"{self._id}_aggregated": aggregated})

# Check if aggregation is complete
if len(aggregated) >= len(data_list):
return aggregated

# Not all items have been processed yet
self.stop("done")
return None # type: ignore [return-value]
return aggregated
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def stop(self, output_name: str | None = None) -> None:
msg = f"Error stopping {self.display_name}: {e}"
raise ValueError(msg) from e

def start(self, output_name: str | None = None) -> None:
if not output_name and self._vertex and len(self._vertex.outputs) == 1:
output_name = self._vertex.outputs[0]["name"]
elif not output_name:
msg = "You must specify an output name to call start"
raise ValueError(msg)
if not self._vertex:
msg = "Vertex is not set"
raise ValueError(msg)
try:
self.graph.mark_branch(vertex_id=self._vertex.id, output_name=output_name, state="ACTIVE")
except Exception as e:
msg = f"Error starting {self.display_name}: {e}"
raise ValueError(msg) from e

def append_state(self, name: str, value: Any) -> None:
if not self._vertex:
msg = "Vertex is not set"
Expand Down
7 changes: 7 additions & 0 deletions src/backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def pytest_configure(config):
pytest.VECTOR_STORE_PATH = data_path / "Vector_store.json"
pytest.SIMPLE_API_TEST = data_path / "SimpleAPITest.json"
pytest.MEMORY_CHATBOT_NO_LLM = data_path / "MemoryChatbotNoLLM.json"
pytest.LOOP_TEST = data_path / "LoopTest.json"
pytest.CODE_WITH_SYNTAX_ERROR = """
def get_text():
retun "Hello World"
Expand All @@ -121,6 +122,7 @@ def get_text():
pytest.TWO_OUTPUTS,
pytest.VECTOR_STORE_PATH,
pytest.MEMORY_CHATBOT_NO_LLM,
pytest.LOOP_TEST,
]:
assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}"

Expand Down Expand Up @@ -324,6 +326,11 @@ def json_memory_chatbot_no_llm():
return pytest.MEMORY_CHATBOT_NO_LLM.read_text(encoding="utf-8")


@pytest.fixture
def json_loop_test():
return pytest.LOOP_TEST.read_text(encoding="utf-8")


@pytest.fixture(autouse=True)
def deactivate_tracing(monkeypatch):
monkeypatch.setenv("LANGFLOW_DEACTIVATE_TRACING", "true")
Expand Down
Loading

0 comments on commit 8ebe580

Please sign in to comment.