Skip to content

Commit

Permalink
docs: Add docstring to all methods (#187)
Browse files Browse the repository at this point in the history
* add docstrings

* added more docstrings

* add more docstring

* mod

* revert change from other pr

* delete service account impl

* add on

* fix error

* fix typo

* fix typo

* fix typo 2

* more docstring improvement

* Update src/langchain_google_alloydb_pg/chat_message_history.py

Co-authored-by: Averi Kitsch <akitsch@google.com>

* Update src/langchain_google_alloydb_pg/engine.py

Co-authored-by: Averi Kitsch <akitsch@google.com>

* Update src/langchain_google_alloydb_pg/engine.py

Co-authored-by: Averi Kitsch <akitsch@google.com>

* resolve comment

---------

Co-authored-by: Averi Kitsch <akitsch@google.com>
  • Loading branch information
duwenxin99 and averikitsch authored Jul 22, 2024
1 parent 25c310c commit 518581e
Show file tree
Hide file tree
Showing 6 changed files with 446 additions and 12 deletions.
6 changes: 6 additions & 0 deletions samples/index_tuning_sample/index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@


async def get_vector_store():
"""Get vector store instance."""
engine = await AlloyDBEngine.afrom_instance(
project_id=PROJECT_ID,
region=REGION,
Expand All @@ -71,6 +72,7 @@ async def get_vector_store():


async def query_vector_with_timing(vector_store, query):
"""Query using the vector with timing"""
start_time = time.monotonic() # timer starts
docs = await vector_store.asimilarity_search(k=k, query=query)
end_time = time.monotonic() # timer ends
Expand All @@ -79,6 +81,7 @@ async def query_vector_with_timing(vector_store, query):


async def hnsw_search(vector_store, knn_docs):
"""Create an HNSW index and perform similaity search with the index."""
hnsw_index = HNSWIndex(name="hnsw", m=36, ef_construction=96)
await vector_store.aapply_vector_index(hnsw_index)
assert await vector_store.is_valid_index(hnsw_index.name)
Expand All @@ -99,6 +102,7 @@ async def hnsw_search(vector_store, knn_docs):


async def ivfflat_search(vector_store, knn_docs):
"""Create an IVFFlat index and perform similaity search with the index."""
ivfflat_index = IVFFlatIndex(name="ivfflat")
await vector_store.aapply_vector_index(ivfflat_index)
assert await vector_store.is_valid_index(ivfflat_index.name)
Expand All @@ -119,6 +123,7 @@ async def ivfflat_search(vector_store, knn_docs):


async def knn_search(vector_store):
"""Perform similaity search without index."""
latencies = []
knn_docs = []
for query in queries:
Expand All @@ -130,6 +135,7 @@ async def knn_search(vector_store):


def calculate_recall(base, target):
"""Calculate recall on the target result."""
# size of intersection / total number of times
base = {doc.page_content for doc in base}
target = {doc.page_content for doc in target}
Expand Down
50 changes: 47 additions & 3 deletions src/langchain_google_alloydb_pg/chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
async def _aget_messages(
engine: AlloyDBEngine, session_id: str, table_name: str
) -> List[BaseMessage]:
"""Retrieve the messages from AlloyDB"""
"""Retrieve the messages from AlloyDB."""
query = f"""SELECT data, type FROM "{table_name}" WHERE session_id = :session_id ORDER BY id;"""
results = await engine._afetch(query, {"session_id": session_id})
if not results:
Expand All @@ -51,6 +51,18 @@ def __init__(
table_name: str,
messages: List[BaseMessage],
):
"""AlloyDBChatMessageHistory constructor.
Args:
key (object): Key to prevent direct constructor usage.
engine (AlloyDBEngine): database connection pool.
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.
messages (List[BaseMessage]): Messages to store.
Raises:
Exception: If constructor is directly called by the user.
"""
if key != AlloyDBChatMessageHistory.__create_key:
raise Exception(
"Only create class through 'create' or 'create_sync' methods!"
Expand All @@ -67,6 +79,19 @@ async def create(
session_id: str,
table_name: str,
) -> AlloyDBChatMessageHistory:
"""Create a new AlloyDBChatMessageHistory instance.
Args:
engine (AlloyDBEngine): AlloyDB engine to use.
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.
Raises:
IndexError: If the table provided does not contain required schema.
Returns:
AlloyDBChatMessageHistory: A newly created instance of AlloyDBChatMessageHistory.
"""
table_schema = await engine._aload_table_schema(table_name)
column_names = table_schema.columns.keys()

Expand Down Expand Up @@ -94,11 +119,24 @@ def create_sync(
session_id: str,
table_name: str,
) -> AlloyDBChatMessageHistory:
"""Create a new AlloyDBChatMessageHistory instance.
Args:
engine (AlloyDBEngine): AlloyDB engine to use.
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.
Raises:
IndexError: If the table provided does not contain required schema.
Returns:
AlloyDBChatMessageHistory: A newly created instance of AlloyDBChatMessageHistory.
"""
coro = cls.create(engine, session_id, table_name)
return engine._run_as_sync(coro)

async def aadd_message(self, message: BaseMessage) -> None:
"""Append the message to the record in AlloyDB"""
"""Append the message to the record in AlloyDB."""
query = f"""INSERT INTO "{self.table_name}"(session_id, data, type)
VALUES (:session_id, :data, :type);
"""
Expand All @@ -115,28 +153,34 @@ async def aadd_message(self, message: BaseMessage) -> None:
)

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in AlloyDB."""
self.engine._run_as_sync(self.aadd_message(message))

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Append a list of messages to the record in AlloyDB."""
for message in messages:
await self.aadd_message(message)

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Append a list of messages to the record in AlloyDB."""
self.engine._run_as_sync(self.aadd_messages(messages))

async def aclear(self) -> None:
"""Clear session memory from AlloyDB"""
"""Clear session memory from AlloyDB."""
query = f"""DELETE FROM "{self.table_name}" WHERE session_id = :session_id;"""
await self.engine._aexecute(query, {"session_id": self.session_id})
self.messages = []

def clear(self) -> None:
"""Clear session memory from AlloyDB."""
self.engine._run_as_sync(self.aclear())

async def async_messages(self) -> None:
"""Retrieve the messages from AlloyDB."""
self.messages = await _aget_messages(
self.engine, self.session_id, self.table_name
)

def sync_messages(self) -> None:
"""Retrieve the messages from AlloyDB."""
self.engine._run_as_sync(self.async_messages())
137 changes: 135 additions & 2 deletions src/langchain_google_alloydb_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ class Column:
nullable: bool = True

def __post_init__(self) -> None:
"""Check if initialization parameters are valid.
Raises:
ValueError: If Column name is not string.
ValueError: If data_type is not type string.
"""

if not isinstance(self.name, str):
raise ValueError("Column name must be type string")
if not isinstance(self.data_type, str):
Expand All @@ -111,6 +118,17 @@ def __init__(
loop: Optional[asyncio.AbstractEventLoop],
thread: Optional[Thread],
) -> None:
"""AlloyDBEngine constructor.
Args:
key(object): Prevent direct constructor usage.
engine(AsyncEngine): Async engine connection pool.
loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
thread (Optional[Thread] = None): Thread used to create the engine async.
Raises:
Exception: If the constructor is called directly by the user.
"""

if key != AlloyDBEngine.__create_key:
raise Exception(
Expand All @@ -132,6 +150,22 @@ def from_instance(
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.
Args:
project_id (str): GCP project ID.
region (str): Cloud AlloyDB instance region.
cluster (str): Cloud AlloyDB cluster name.
instance (str): Cloud AlloyDB instance name.
database (str): Database name.
user (Optional[str], optional): Cloud AlloyDB user name. Defaults to None.
password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None.
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.
Returns:
AlloyDBEngine: A newly created AlloyDBEngine instance.
"""
# Running a loop in a background thread allows us to support
# async methods from non-async environments
loop = asyncio.new_event_loop()
Expand Down Expand Up @@ -164,7 +198,29 @@ async def _create(
password: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
thread: Optional[Thread] = None,
iam_account_email: Optional[str] = None,
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.
Args:
project_id (str): GCP project ID.
region (str): Cloud AlloyDB instance region.
cluster (str): Cloud AlloyDB cluster name.
instance (str): Cloud AlloyDB instance name.
database (str): Database name.
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
user (Optional[str], optional): Cloud AlloyDB user name. Defaults to None.
password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None.
loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
thread (Optional[Thread] = None): Thread used to create the engine async.
iam_account_email (Optional[str], optional): IAM service account email.
Raises:
ValueError: Raises error if only one of 'user' or 'password' is specified.
Returns:
AlloyDBEngine: A newly created AlloyDBEngine instance.
"""
# error if only one of user or password is set, must be both or neither
if bool(user) ^ bool(password):
raise ValueError(
Expand Down Expand Up @@ -222,6 +278,22 @@ async def afrom_instance(
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.
Args:
project_id (str): GCP project ID.
region (str): Cloud AlloyDB instance region.
cluster (str): Cloud AlloyDB cluster name.
instance (str): Cloud AlloyDB instance name.
database (str): Cloud AlloyDB database name.
user (Optional[str], optional): Cloud AlloyDB user name. Defaults to None.
password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None.
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.
Returns:
AlloyDBEngine: A newly created AlloyDBEngine instance.
"""
return await cls._create(
project_id,
region,
Expand All @@ -235,6 +307,7 @@ async def afrom_instance(

@classmethod
def from_engine(cls: Type[AlloyDBEngine], engine: AsyncEngine) -> AlloyDBEngine:
"""Create an AlloyDBEngine instance from an AsyncEngine."""
return cls(cls.__create_key, engine, None, None)

async def _aexecute(self, query: str, params: Optional[dict] = None) -> None:
Expand All @@ -252,21 +325,24 @@ async def _aexecute_outside_tx(self, query: str) -> None:
async def _afetch(
self, query: str, params: Optional[dict] = None
) -> Sequence[RowMapping]:
"""Fetch results from a SQL query."""
async with self._engine.connect() as conn:
"""Fetch results from a SQL query."""
result = await conn.execute(text(query), params)
result_map = result.mappings()
result_fetch = result_map.fetchall()

return result_fetch

def _execute(self, query: str, params: Optional[dict] = None) -> None:
"""Execute a SQL query."""
return self._run_as_sync(self._aexecute(query, params))

def _fetch(self, query: str, params: Optional[dict] = None) -> Sequence[RowMapping]:
"""Fetch results from a SQL query."""
return self._run_as_sync(self._afetch(query, params))

def _run_as_sync(self, coro: Awaitable[T]) -> T:
"""Run an async coroutine synchronously"""
if not self._loop:
raise Exception("Engine was initialized async.")
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
Expand All @@ -284,7 +360,7 @@ async def ainit_vectorstore_table(
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of vectors to be used with Alloy DB.
Create a table for saving of vectors to be used with AlloyDB.
If table already exists and overwrite flag is not set, a TABLE_ALREADY_EXISTS error is thrown.
Args:
Expand Down Expand Up @@ -337,6 +413,30 @@ def init_vectorstore_table(
overwrite_existing: bool = False,
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of vectors to be used with AlloyDB.
If table already exists and overwrite flag is not set, a TABLE_ALREADY_EXISTS error is thrown.
Args:
table_name (str): The table name.
vector_size (int): Vector size for the embedding model to be used.
content_column (str): Name of the column to store document content.
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Default: "embedding".
metadata_columns (List[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
id_column (str): Name of the column to store ids.
Default: "langchain_id". Optional,
overwrite_existing (bool): Whether to drop the existing table before insertion.
Default: False.
store_metadata (bool): Whether to store metadata in a JSON column if not specified by `metadata_columns`.
Default: True.
Raises:
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
"""
return self._run_as_sync(
self.ainit_vectorstore_table(
table_name,
Expand All @@ -352,6 +452,15 @@ def init_vectorstore_table(
)

async def ainit_chat_history_table(self, table_name: str) -> None:
"""
Create an AlloyDB table to save chat history messages.
Args:
table_name (str): The table name to store chat history.
Returns:
None
"""
create_table_query = f"""CREATE TABLE IF NOT EXISTS "{table_name}"(
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL,
Expand All @@ -361,6 +470,15 @@ async def ainit_chat_history_table(self, table_name: str) -> None:
await self._aexecute(create_table_query)

def init_chat_history_table(self, table_name: str) -> None:
"""
Create an AlloyDB table to save chat history messages.
Args:
table_name (str): The table name to store chat history.
Returns:
None
"""
return self._run_as_sync(
self.ainit_chat_history_table(
table_name,
Expand Down Expand Up @@ -411,6 +529,21 @@ def init_document_table(
metadata_json_column: str = "langchain_metadata",
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of langchain documents.
If table already exists, a DuplicateTableError error is thrown.
Args:
table_name (str): The PgSQL database table name.
content_column (str): Name of the column to store document content.
Default: "page_content".
metadata_columns (List[Column]): A list of Columns
to create for custom metadata. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
store_metadata (bool): Whether to store extra metadata in a metadata column
if not described in 'metadata' field list (Default: True).
"""
return self._run_as_sync(
self.ainit_document_table(
table_name,
Expand Down
Loading

0 comments on commit 518581e

Please sign in to comment.