Skip to content

Commit

Permalink
Second Attempt - Add concurrent insertion of vector rows in the Cassa…
Browse files Browse the repository at this point in the history
…ndra Vector Store (#7017)

Retrying with the same improvements as in #6772, this time trying not to
mess up with branches.

@rlancemartin doing a fresh new PR from a branch with a new name. This
should do. Thank you for your help!

---------

Co-authored-by: Jonathan Ellis <jbellis@datastax.com>
Co-authored-by: rlm <pexpresss31@gmail.com>
  • Loading branch information
3 people committed Jul 1, 2023
1 parent 3bfe7cf commit 8d2281a
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 580 deletions.
16 changes: 13 additions & 3 deletions docs/extras/ecosystem/integrations/cassandra.mdx
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Cassandra

>[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column
>[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column
> store, NoSQL database management system designed to handle large amounts of data across many commodity servers,
> providing high availability with no single point of failure. Cassandra offers support for clusters spanning
> providing high availability with no single point of failure. Cassandra offers support for clusters spanning
> multiple datacenters, with asynchronous masterless replication allowing low latency operations for all clients.
> Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication
> Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication
> techniques combined with _Google's Bigtable_ data and storage engine model.
## Installation and Setup
Expand All @@ -16,6 +16,16 @@ pip install cassio



## Vector Store

See a [usage example](/docs/modules/data_connection/vectorstores/integrations/cassandra.html).

```python
from langchain.memory import CassandraChatMessageHistory
```



## Memory

See a [usage example](/docs/modules/memory/integrations/cassandra_chat_message_history.html).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.5\""
"!pip install \"cassio>=0.0.7\""
]
},
{
Expand All @@ -44,14 +44,16 @@
"import os\n",
"import getpass\n",
"\n",
"database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n",
"database_mode = (input('\\n(C)assandra or (A)stra DB? ')).upper()\n",
"\n",
"keyspace_name = input('\\nKeyspace name? ')\n",
"\n",
"if database_mode == 'A':\n",
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
" #\n",
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')"
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')\n",
"elif database_mode == 'C':\n",
" CASSANDRA_CONTACT_POINTS = input('Contact points? (comma-separated, empty for localhost) ').strip()"
]
},
{
Expand All @@ -72,8 +74,15 @@
"from cassandra.cluster import Cluster\n",
"from cassandra.auth import PlainTextAuthProvider\n",
"\n",
"if database_mode == 'L':\n",
" cluster = Cluster()\n",
"if database_mode == 'C':\n",
" if CASSANDRA_CONTACT_POINTS:\n",
" cluster = Cluster([\n",
" cp.strip()\n",
" for cp in CASSANDRA_CONTACT_POINTS.split(',')\n",
" if cp.strip()\n",
" ])\n",
" else:\n",
" cluster = Cluster()\n",
" session = cluster.connect()\n",
"elif database_mode == 'A':\n",
" ASTRA_DB_CLIENT_ID = \"token\"\n",
Expand Down Expand Up @@ -261,7 +270,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
77 changes: 44 additions & 33 deletions langchain/vectorstores/cassandra.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Wrapper around Cassandra vector-store capabilities, based on cassIO."""
from __future__ import annotations

import hashlib
import typing
import uuid
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar

import numpy as np
Expand All @@ -17,14 +17,6 @@

CVST = TypeVar("CVST", bound="Cassandra")

# a positive number of seconds to expire entries, or None for no expiration.
CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS = None


def _hash(_input: str) -> str:
"""Use a deterministic hashing approach."""
return hashlib.md5(_input.encode()).hexdigest()


class Cassandra(VectorStore):
"""Wrapper around Cassandra embeddings platform.
Expand All @@ -46,7 +38,7 @@ class Cassandra(VectorStore):

_embedding_dimension: int | None

def _getEmbeddingDimension(self) -> int:
def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
self._embedding_dimension = len(
self.embedding.embed_query("This is a sample sentence.")
Expand All @@ -59,7 +51,7 @@ def __init__(
session: Session,
keyspace: str,
table_name: str,
ttl_seconds: int | None = CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS,
ttl_seconds: Optional[int] = None,
) -> None:
try:
from cassio.vector import VectorTable
Expand All @@ -81,8 +73,8 @@ def __init__(
session=session,
keyspace=keyspace,
table=table_name,
embedding_dimension=self._getEmbeddingDimension(),
auto_id=False, # the `add_texts` contract admits user-provided ids
embedding_dimension=self._get_embedding_dimension(),
primary_key_type="TEXT",
)

def delete_collection(self) -> None:
Expand All @@ -99,11 +91,27 @@ def clear(self) -> None:
def delete_by_document_id(self, document_id: str) -> None:
return self.table.delete(document_id)

def delete(self, ids: List[str]) -> Optional[bool]:
"""Delete by vector ID.
Args:
ids: List of ids to delete.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
for document_id in ids:
self.delete_by_document_id(document_id)
return True

def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Expand All @@ -112,33 +120,39 @@ def add_texts(
texts (Iterable[str]): Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs.
batch_size (int): Number of concurrent requests to send to the server.
ttl_seconds (Optional[int], optional): Optional time-to-live
for the added texts.
Returns:
List[str]: List of IDs of the added texts.
"""
_texts = list(texts) # lest it be a generator or something
if ids is None:
# unless otherwise specified, we have deterministic IDs:
# re-inserting an existing document will not create a duplicate.
# (and effectively update the metadata)
ids = [_hash(text) for text in _texts]
ids = [uuid.uuid4().hex for _ in _texts]
if metadatas is None:
metadatas = [{} for _ in _texts]
#
ttl_seconds = kwargs.get("ttl_seconds", self.ttl_seconds)
ttl_seconds = ttl_seconds or self.ttl_seconds
#
embedding_vectors = self.embedding.embed_documents(_texts)
for text, embedding_vector, text_id, metadata in zip(
_texts, embedding_vectors, ids, metadatas
):
self.table.put(
document=text,
embedding_vector=embedding_vector,
document_id=text_id,
metadata=metadata,
ttl_seconds=ttl_seconds,
)
#
for i in range(0, len(_texts), batch_size):
batch_texts = _texts[i : i + batch_size]
batch_embedding_vectors = embedding_vectors[i : i + batch_size]
batch_ids = ids[i : i + batch_size]
batch_metadatas = metadatas[i : i + batch_size]

futures = [
self.table.put_async(
text, embedding_vector, text_id, metadata, ttl_seconds
)
for text, embedding_vector, text_id, metadata in zip(
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
)
]
for future in futures:
future.result()
return ids

# id-returning search facilities
Expand Down Expand Up @@ -181,7 +195,6 @@ def similarity_search_with_score_id(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float, str]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
Expand Down Expand Up @@ -219,12 +232,10 @@ def similarity_search(
k: int = 4,
**kwargs: Any,
) -> List[Document]:
#
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k,
**kwargs,
)

def similarity_search_by_vector(
Expand All @@ -245,7 +256,6 @@ def similarity_search_with_score(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
Expand All @@ -266,7 +276,6 @@ def _similarity_search_with_relevance_scores(
return self.similarity_search_with_score(
query,
k,
**kwargs,
)

def max_marginal_relevance_search_by_vector(
Expand Down Expand Up @@ -352,6 +361,7 @@ def from_texts(
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
batch_size: int = 16,
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
Expand All @@ -378,6 +388,7 @@ def from_documents(
cls: Type[CVST],
documents: List[Document],
embedding: Embeddings,
batch_size: int = 16,
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.
Expand Down
Loading

0 comments on commit 8d2281a

Please sign in to comment.