Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: Add docstring to all methods #187

Merged
merged 17 commits into from
Jul 22, 2024
6 changes: 6 additions & 0 deletions samples/index_tuning_sample/index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@


async def get_vector_store():
"""Get vector store instance."""
engine = await AlloyDBEngine.afrom_instance(
project_id=PROJECT_ID,
region=REGION,
Expand All @@ -71,6 +72,7 @@ async def get_vector_store():


async def query_vector_with_timing(vector_store, query):
"""Query using the vector with timing"""
start_time = time.monotonic() # timer starts
docs = await vector_store.asimilarity_search(k=k, query=query)
end_time = time.monotonic() # timer ends
Expand All @@ -79,6 +81,7 @@ async def query_vector_with_timing(vector_store, query):


async def hnsw_search(vector_store, knn_docs):
"""Create an HNSW index and perform similaity search with the index."""
hnsw_index = HNSWIndex(name="hnsw", m=36, ef_construction=96)
await vector_store.aapply_vector_index(hnsw_index)
assert await vector_store.is_valid_index(hnsw_index.name)
Expand All @@ -99,6 +102,7 @@ async def hnsw_search(vector_store, knn_docs):


async def ivfflat_search(vector_store, knn_docs):
"""Create an IVFFlat index and perform similaity search with the index."""
ivfflat_index = IVFFlatIndex(name="ivfflat")
await vector_store.aapply_vector_index(ivfflat_index)
assert await vector_store.is_valid_index(ivfflat_index.name)
Expand All @@ -119,6 +123,7 @@ async def ivfflat_search(vector_store, knn_docs):


async def knn_search(vector_store):
"""Perform similaity search without index."""
latencies = []
knn_docs = []
for query in queries:
Expand All @@ -130,6 +135,7 @@ async def knn_search(vector_store):


def calculate_recall(base, target):
"""Calculate recall on the target result."""
# size of intersection / total number of times
base = {doc.page_content for doc in base}
target = {doc.page_content for doc in target}
Expand Down
50 changes: 47 additions & 3 deletions src/langchain_google_alloydb_pg/chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
async def _aget_messages(
engine: AlloyDBEngine, session_id: str, table_name: str
) -> List[BaseMessage]:
"""Retrieve the messages from AlloyDB"""
"""Retrieve the messages from AlloyDB."""
query = f"""SELECT data, type FROM "{table_name}" WHERE session_id = :session_id ORDER BY id;"""
results = await engine._afetch(query, {"session_id": session_id})
if not results:
Expand All @@ -51,6 +51,18 @@ def __init__(
table_name: str,
messages: List[BaseMessage],
):
"""AlloyDBChatMessageHistory constructor.

Args:
key (object): Key to prevent direct constructor usage.
engine (AlloyDBEngine): database connection pool.
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.
messages (List[BaseMessage]): Messages to store.

Raises:
Exception: If constructor is directly called by the user.
"""
if key != AlloyDBChatMessageHistory.__create_key:
raise Exception(
"Only create class through 'create' or 'create_sync' methods!"
Expand All @@ -67,6 +79,19 @@ async def create(
session_id: str,
table_name: str,
) -> AlloyDBChatMessageHistory:
"""Create a new AlloyDBChatMessageHistory instance.

Args:
engine (AlloyDBEngine): AlloyDB engine to use.
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.

Raises:
IndexError: If the table provided does not contain required schema.

Returns:
AlloyDBChatMessageHistory: A newly created instance of AlloyDBChatMessageHistory.
"""
table_schema = await engine._aload_table_schema(table_name)
column_names = table_schema.columns.keys()

Expand Down Expand Up @@ -94,11 +119,24 @@ def create_sync(
session_id: str,
table_name: str,
) -> AlloyDBChatMessageHistory:
"""Create a new AlloyDBChatMessageHistory instance.

Args:
engine (AlloyDBEngine): AlloyDB engine to use.
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.

Raises:
IndexError: If the table provided does not contain required schema.

Returns:
AlloyDBChatMessageHistory: A newly created instance of AlloyDBChatMessageHistory.
"""
coro = cls.create(engine, session_id, table_name)
return engine._run_as_sync(coro)

async def aadd_message(self, message: BaseMessage) -> None:
"""Append the message to the record in AlloyDB"""
"""Append the message to the record in AlloyDB."""
query = f"""INSERT INTO "{self.table_name}"(session_id, data, type)
VALUES (:session_id, :data, :type);
"""
Expand All @@ -115,28 +153,34 @@ async def aadd_message(self, message: BaseMessage) -> None:
)

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in AlloyDB."""
self.engine._run_as_sync(self.aadd_message(message))

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Append a list of messages to the record in AlloyDB."""
for message in messages:
await self.aadd_message(message)

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Append a list of messages to the record in AlloyDB."""
self.engine._run_as_sync(self.aadd_messages(messages))

async def aclear(self) -> None:
"""Clear session memory from AlloyDB"""
"""Clear session memory from AlloyDB."""
query = f"""DELETE FROM "{self.table_name}" WHERE session_id = :session_id;"""
await self.engine._aexecute(query, {"session_id": self.session_id})
self.messages = []

def clear(self) -> None:
"""Clear session memory from AlloyDB."""
self.engine._run_as_sync(self.aclear())

async def async_messages(self) -> None:
"""Retrieve the messages from AlloyDB."""
self.messages = await _aget_messages(
self.engine, self.session_id, self.table_name
)

def sync_messages(self) -> None:
"""Retrieve the messages from AlloyDB."""
self.engine._run_as_sync(self.async_messages())
137 changes: 135 additions & 2 deletions src/langchain_google_alloydb_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ class Column:
nullable: bool = True

def __post_init__(self) -> None:
"""Check if initialization parameters are valid.

Raises:
ValueError: If Column name is not string.
ValueError: If data_type is not type string.
"""

if not isinstance(self.name, str):
raise ValueError("Column name must be type string")
if not isinstance(self.data_type, str):
Expand All @@ -111,6 +118,17 @@ def __init__(
loop: Optional[asyncio.AbstractEventLoop],
thread: Optional[Thread],
) -> None:
"""AlloyDBEngine constructor.

Args:
key(object): Prevent direct constructor usage.
engine(AsyncEngine): Async engine connection pool.
loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
thread (Optional[Thread] = None): Thread used to create the engine async.

Raises:
Exception: If the constructor is called directly by the user.
"""

if key != AlloyDBEngine.__create_key:
raise Exception(
Expand All @@ -132,6 +150,22 @@ def from_instance(
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.

Args:
project_id (str): GCP project ID.
region (str): Cloud AlloyDB instance region.
cluster (str): Cloud AlloyDB cluster name.
instance (str): Cloud AlloyDB instance name.
database (str): Database name.
user (Optional[str], optional): Cloud AlloyDB user name. Defaults to None.
password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None.
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.

Returns:
AlloyDBEngine: A newly created AlloyDBEngine instance.
"""
# Running a loop in a background thread allows us to support
# async methods from non-async environments
loop = asyncio.new_event_loop()
Expand Down Expand Up @@ -164,7 +198,29 @@ async def _create(
password: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
thread: Optional[Thread] = None,
iam_account_email: Optional[str] = None,
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.

Args:
project_id (str): GCP project ID.
region (str): Cloud AlloyDB instance region.
cluster (str): Cloud AlloyDB cluster name.
instance (str): Cloud AlloyDB instance name.
database (str): Database name.
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
user (Optional[str], optional): Cloud AlloyDB user name. Defaults to None.
password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None.
loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
thread (Optional[Thread] = None): Thread used to create the engine async.
iam_account_email (Optional[str], optional): IAM service account email.

Raises:
ValueError: Raises error if only one of 'user' or 'password' is specified.

Returns:
AlloyDBEngine: A newly created AlloyDBEngine instance.
"""
# error if only one of user or password is set, must be both or neither
if bool(user) ^ bool(password):
raise ValueError(
Expand Down Expand Up @@ -222,6 +278,22 @@ async def afrom_instance(
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.

Args:
project_id (str): GCP project ID.
region (str): Cloud AlloyDB instance region.
cluster (str): Cloud AlloyDB cluster name.
instance (str): Cloud AlloyDB instance name.
database (str): Cloud AlloyDB database name.
user (Optional[str], optional): Cloud AlloyDB user name. Defaults to None.
password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None.
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.

Returns:
AlloyDBEngine: A newly created AlloyDBEngine instance.
"""
return await cls._create(
project_id,
region,
Expand All @@ -235,6 +307,7 @@ async def afrom_instance(

@classmethod
def from_engine(cls: Type[AlloyDBEngine], engine: AsyncEngine) -> AlloyDBEngine:
"""Create an AlloyDBEngine instance from an AsyncEngine."""
return cls(cls.__create_key, engine, None, None)

async def _aexecute(self, query: str, params: Optional[dict] = None) -> None:
Expand All @@ -252,21 +325,24 @@ async def _aexecute_outside_tx(self, query: str) -> None:
async def _afetch(
self, query: str, params: Optional[dict] = None
) -> Sequence[RowMapping]:
"""Fetch results from a SQL query."""
async with self._engine.connect() as conn:
"""Fetch results from a SQL query."""
result = await conn.execute(text(query), params)
result_map = result.mappings()
result_fetch = result_map.fetchall()

return result_fetch

def _execute(self, query: str, params: Optional[dict] = None) -> None:
"""Execute a SQL query."""
return self._run_as_sync(self._aexecute(query, params))

def _fetch(self, query: str, params: Optional[dict] = None) -> Sequence[RowMapping]:
"""Fetch results from a SQL query."""
return self._run_as_sync(self._afetch(query, params))

def _run_as_sync(self, coro: Awaitable[T]) -> T:
"""Run an async coroutine synchronously"""
if not self._loop:
raise Exception("Engine was initialized async.")
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
Expand All @@ -284,7 +360,7 @@ async def ainit_vectorstore_table(
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of vectors to be used with Alloy DB.
Create a table for saving of vectors to be used with AlloyDB.
If table already exists and overwrite flag is not set, a TABLE_ALREADY_EXISTS error is thrown.

Args:
Expand Down Expand Up @@ -337,6 +413,30 @@ def init_vectorstore_table(
overwrite_existing: bool = False,
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of vectors to be used with AlloyDB.
If table already exists and overwrite flag is not set, a TABLE_ALREADY_EXISTS error is thrown.

Args:
table_name (str): The table name.
vector_size (int): Vector size for the embedding model to be used.
content_column (str): Name of the column to store document content.
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Default: "embedding".
metadata_columns (List[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
id_column (str): Name of the column to store ids.
Default: "langchain_id". Optional,
overwrite_existing (bool): Whether to drop the existing table before insertion.
Default: False.
store_metadata (bool): Whether to store metadata in a JSON column if not specified by `metadata_columns`.
Default: True.
Raises:
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
"""
return self._run_as_sync(
self.ainit_vectorstore_table(
table_name,
Expand All @@ -352,6 +452,15 @@ def init_vectorstore_table(
)

async def ainit_chat_history_table(self, table_name: str) -> None:
"""
Create an AlloyDB table to save chat history messages.

Args:
table_name (str): The table name to store chat history.

Returns:
None
"""
create_table_query = f"""CREATE TABLE IF NOT EXISTS "{table_name}"(
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL,
Expand All @@ -361,6 +470,15 @@ async def ainit_chat_history_table(self, table_name: str) -> None:
await self._aexecute(create_table_query)

def init_chat_history_table(self, table_name: str) -> None:
"""
Create an AlloyDB table to save chat history messages.

Args:
table_name (str): The table name to store chat history.

Returns:
None
"""
return self._run_as_sync(
self.ainit_chat_history_table(
table_name,
Expand Down Expand Up @@ -411,6 +529,21 @@ def init_document_table(
metadata_json_column: str = "langchain_metadata",
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of langchain documents.
If table already exists, a DuplicateTableError error is thrown.

Args:
table_name (str): The PgSQL database table name.
content_column (str): Name of the column to store document content.
Default: "page_content".
metadata_columns (List[Column]): A list of Columns
to create for custom metadata. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
store_metadata (bool): Whether to store extra metadata in a metadata column
if not described in 'metadata' field list (Default: True).
"""
return self._run_as_sync(
self.ainit_document_table(
table_name,
Expand Down
Loading