Skip to content

Commit

Permalink
chore: Update persist_directory parameter to handle None value in Chr…
Browse files Browse the repository at this point in the history
…omaSearch and Chroma components (#2157)

* chore: Update persist_directory parameter to handle None value in ChromaSearch and Chroma components

* 🐛 (test_endpoints.py): fix assertion to check for correct key name in output results for chat and any input types
  • Loading branch information
ogabrielluiz authored Jun 13, 2024
1 parent 84df4fd commit 34b6153
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import chromadb
from chromadb.config import Settings
from langchain_chroma import Chroma

from langflow.components.vectorstores.base.model import LCVectorStoreComponent
from langflow.field_typing import Embeddings, Text
from langflow.schema import Record
Expand Down Expand Up @@ -104,10 +103,11 @@ def build(
client = chromadb.HttpClient(settings=chroma_settings)
if index_directory:
index_directory = self.resolve_path(index_directory)

vector_store = Chroma(
embedding_function=embedding,
collection_name=collection_name,
persist_directory=index_directory,
persist_directory=index_directory or None,
client=client,
)

Expand Down
3 changes: 1 addition & 2 deletions src/backend/base/langflow/components/vectorstores/Chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore

from langflow.base.vectorstores.utils import chroma_collection_to_records
from langflow.custom import CustomComponent
from langflow.schema import Record
Expand Down Expand Up @@ -107,7 +106,7 @@ def build(
index_directory = self.resolve_path(index_directory)

chroma = Chroma(
persist_directory=index_directory,
persist_directory=index_directory or None,
client=client,
embedding_function=embedding,
collection_name=collection_name,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def test_successful_run_with_input_type_chat(client, starter_project, created_ap
chat_input_outputs = [output for output in outputs_dict.get("outputs") if "ChatInput" in output.get("component_id")]
assert len(chat_input_outputs) == 1
# Now we check if the input_value is correct
assert all([output.get("results").get("result") == "value1" for output in chat_input_outputs]), chat_input_outputs
assert all([output.get("results").get("text") == "value1" for output in chat_input_outputs]), chat_input_outputs


def test_successful_run_with_input_type_any(client, starter_project, created_api_key):
Expand Down Expand Up @@ -631,7 +631,7 @@ def test_successful_run_with_input_type_any(client, starter_project, created_api
]
assert len(any_input_outputs) == 1
# Now we check if the input_value is correct
assert all([output.get("results").get("result") == "value1" for output in any_input_outputs]), any_input_outputs
assert all([output.get("results").get("text") == "value1" for output in any_input_outputs]), any_input_outputs


@pytest.mark.api_key_required
Expand Down

0 comments on commit 34b6153

Please sign in to comment.