-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add sqlalchemy vectorizer_relationship (#265)
* feat: add sqlalchemy vectorizer field * chore: add installing extras to ci * chore: simplify interface, add simple docs * feat: allow arbitrary primary keys on parent model * docs: update docs with simplified vectorizer field * chore: rename VectorizerField to Vectorizer * chore: update alembic exclusion mechanism * docs: update docs with review comments * chore: align automatic table name with create_vectorizer * chore: add option for any relationship properties * chore: setup class event based rather than lazy so relationship works on first query * chore: update to embedding_relationship * chore: refactor tests add vcr mocks * chore: rename to vectorizer_model; cleanup * chore: fix uv lock * chore: remove dummy key
- Loading branch information
Showing
19 changed files
with
3,203 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# SQLAlchemy Integration with pgai Vectorizer | ||
|
||
The `vectorizer_relationship` is a SQLAlchemy helper that integrates pgai's vectorization capabilities directly into your SQLAlchemy models. | ||
Think of it as a normal SQLAlchemy [relationship](https://docs.sqlalchemy.org/en/20/orm/basic_relationships.html), but with a preconfigured model instance under the hood. | ||
This allows you to easily query vector embeddings created by pgai using familiar SQLAlchemy patterns. | ||
|
||
## Installation | ||
|
||
To use the SQLAlchemy integration, install pgai with the SQLAlchemy extras: | ||
|
||
```bash | ||
pip install "pgai[sqlalchemy]" | ||
``` | ||
|
||
## Basic Usage | ||
|
||
Here's a basic example of how to use the `vectorizer_relationship`: | ||
|
||
```python | ||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | ||
from pgai.sqlalchemy import vectorizer_relationship | ||
|
||
class Base(DeclarativeBase): | ||
pass | ||
|
||
class BlogPost(Base): | ||
__tablename__ = "blog_posts" | ||
|
||
id: Mapped[int] = mapped_column(primary_key=True) | ||
title: Mapped[str] | ||
content: Mapped[str] | ||
|
||
# Add vector embeddings for the content field | ||
content_embeddings = vectorizer_relationship( | ||
dimensions=768 | ||
) | ||
``` | ||
Note if you work with alembics autogenerate functionality for migrations, also check [Working with alembic](#working-with-alembic). | ||
|
||
### Semantic Search | ||
|
||
You can then perform semantic similarity search on the field using [pgvector-python's](https://github.com/pgvector/pgvector-python) distance functions: | ||
|
||
```python | ||
from sqlalchemy import func, text | ||
|
||
similar_posts = ( | ||
session.query(BlogPost.content_embeddings) | ||
.order_by( | ||
BlogPost.content_embeddings.embedding.cosine_distance( | ||
func.ai.openai_embed( | ||
"text-embedding-3-small", | ||
"search query", | ||
text("dimensions => 768") | ||
) | ||
) | ||
) | ||
.limit(5) | ||
.all() | ||
) | ||
``` | ||
|
||
Or if you already have the embeddings in your application: | ||
|
||
```python | ||
similar_posts = ( | ||
session.query(BlogPost.content_embeddings) | ||
.order_by( | ||
BlogPost.content_embeddings.embedding.cosine_distance( | ||
[3, 1, 2] | ||
) | ||
) | ||
.limit(5) | ||
.all() | ||
) | ||
``` | ||
|
||
## Configuration | ||
|
||
The `vectorizer_relationship` accepts the following parameters: | ||
|
||
- `dimensions` (int): The size of the embedding vector (required) | ||
- `target_schema` (str, optional): Override the schema for the embeddings table. If not provided, inherits from the parent model's schema | ||
- `target_table` (str, optional): Override the table name for embeddings. Default is `{table_name}_embedding_store` | ||
|
||
Additional parameters are simply forwarded to the underlying [SQLAlchemy relationship](https://docs.sqlalchemy.org/en/20/orm/relationships.html) so you can configure it as you desire. | ||
|
||
Think of the `vectorizer_relationship` as a normal SQLAlchemy relationship, but with a preconfigured model instance under the hood. | ||
|
||
|
||
## Setting up the Vectorizer | ||
|
||
After defining your model, you need to create the vectorizer using pgai's SQL functions: | ||
|
||
```sql | ||
SELECT ai.create_vectorizer( | ||
'blog_posts'::regclass, | ||
embedding => ai.embedding_openai('text-embedding-3-small', 768), | ||
chunking => ai.chunking_recursive_character_text_splitter( | ||
'content', | ||
50, -- chunk_size | ||
10 -- chunk_overlap | ||
) | ||
); | ||
``` | ||
|
||
We recommend adding this to a migration script and run it via alembic. | ||
|
||
|
||
## Querying Embeddings | ||
|
||
The `vectorizer_relationship` provides several ways to work with embeddings: | ||
|
||
### 1. Direct Access to Embeddings | ||
|
||
If you access the class proeprty of your model the `vectorizer_relationship` provide a SQLAlchemy model that you can query directly: | ||
|
||
```python | ||
# Get all embeddings | ||
embeddings = session.query(BlogPost.content_embeddings).all() | ||
|
||
# Access embedding properties | ||
for embedding in embeddings: | ||
print(embedding.embedding) # The vector embedding | ||
print(embedding.chunk) # The text chunk | ||
``` | ||
The model will have the primary key fields of the parent model as well as the following fields: | ||
- `chunk` (str): The text chunk that was embedded | ||
- `embedding` (Vector): The vector embedding | ||
- `chunk_seq` (int): The sequence number of the chunk | ||
- `embedding_uuid` (str): The UUID of the embedding | ||
- `parent` (ParentModel): The parent model instance | ||
|
||
### 2. Relationship Access | ||
|
||
|
||
```python | ||
blog_post = session.query(BlogPost).first() | ||
for embedding in blog_post.content_embeddings: | ||
print(embedding.chunk) | ||
``` | ||
Access the original posts through the parent relationship | ||
```python | ||
for embedding in similar_posts: | ||
print(embedding.parent.title) | ||
``` | ||
|
||
### 3. Join Queries | ||
|
||
You can combine embedding queries with regular SQL queries using the relationship: | ||
|
||
```python | ||
results = ( | ||
session.query(BlogPost, BlogPost.content_embeddings) | ||
.join(BlogPost.content_embeddings) | ||
.filter(BlogPost.title.ilike("%search term%")) | ||
.all() | ||
) | ||
|
||
for post, embedding in results: | ||
print(f"Title: {post.title}") | ||
print(f"Chunk: {embedding.chunk}") | ||
``` | ||
|
||
## Working with alembic | ||
|
||
|
||
The `vectorizer_relationship` generates a new SQLAlchemy model, that is available under the attribute that you specify. If you are using alembic's autogenerate functionality to generate migrations, you will need to exclude these models from the autogenerate process. | ||
These are added to a list in your metadata called `pgai_managed_tables` and you can exclude them by adding the following to your `env.py`: | ||
|
||
```python | ||
def include_object(object, name, type_, reflected, compare_to): | ||
if type_ == "table" and name in target_metadata.info.get("pgai_managed_tables", set()): | ||
return False | ||
return True | ||
|
||
context.configure( | ||
connection=connection, | ||
target_metadata=target_metadata, | ||
include_object=include_object | ||
) | ||
``` | ||
|
||
This should now prevent alembic from generating tables for these models when you run `alembic revision --autogenerate`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from typing import Any, Generic, TypeVar, overload | ||
|
||
from pgvector.sqlalchemy import Vector # type: ignore | ||
from sqlalchemy import ForeignKeyConstraint, Integer, Text, event, inspect | ||
from sqlalchemy.orm import ( | ||
DeclarativeBase, | ||
Mapped, | ||
Mapper, | ||
Relationship, | ||
RelationshipProperty, | ||
backref, | ||
mapped_column, | ||
relationship, | ||
) | ||
|
||
# Type variable for the parent model | ||
T = TypeVar("T", bound=DeclarativeBase) | ||
|
||
|
||
def to_pascal_case(text: str): | ||
# Split on any non-alphanumeric character | ||
words = "".join(char if char.isalnum() else " " for char in text).split() | ||
# Capitalize first letter of all words | ||
return "".join(word.capitalize() for word in words) | ||
|
||
|
||
class EmbeddingModel(DeclarativeBase, Generic[T]): | ||
"""Base type for embedding models with required attributes""" | ||
|
||
embedding_uuid: Mapped[str] | ||
chunk: Mapped[str] | ||
embedding: Mapped[Vector] | ||
chunk_seq: Mapped[int] | ||
parent: T # Type of the parent model | ||
|
||
|
||
class _Vectorizer: | ||
def __init__( | ||
self, | ||
dimensions: int, | ||
target_schema: str | None = None, | ||
target_table: str | None = None, | ||
**kwargs: Any, | ||
): | ||
self.dimensions = dimensions | ||
self.target_schema = target_schema | ||
self.target_table = target_table | ||
self.owner: type[DeclarativeBase] | None = None | ||
self.name: str | None = None | ||
self._embedding_class: type[EmbeddingModel[Any]] | None = None | ||
self._relationship: RelationshipProperty[Any] | None = None | ||
self._initialized = False | ||
self.relationship_args = kwargs | ||
event.listen(Mapper, "after_configured", self._initialize_all) | ||
|
||
def _initialize_all(self): | ||
"""Force initialization during mapper configuration""" | ||
if not self._initialized and self.owner is not None: | ||
self.__get__(None, self.owner) | ||
|
||
def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None: | ||
table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema") | ||
self.target_schema = ( | ||
self.target_schema | ||
or table_args_schema_name | ||
or owner.registry.metadata.schema | ||
or "public" | ||
) | ||
|
||
def create_embedding_class( | ||
self, owner: type[DeclarativeBase] | ||
) -> type[EmbeddingModel[Any]]: | ||
assert self.name is not None | ||
table_name = self.target_table or f"{owner.__tablename__}_embedding_store" | ||
self.set_schemas_correctly(owner) | ||
class_name = f"{to_pascal_case(self.name)}Embedding" | ||
registry_instance = owner.registry | ||
base: type[DeclarativeBase] = owner.__base__ # type: ignore | ||
|
||
# Check if table already exists in metadata | ||
# There is probably a better way to do this | ||
# than accessing the internal _class_registry | ||
# Not doing this ends up in a recursion because | ||
# creating the new class reconfigures tha parent mapper | ||
# again triggering the after_configured event | ||
key = f"{self.target_schema}.{table_name}" | ||
if key in owner.metadata.tables: | ||
# Find the mapped class in the registry | ||
for cls in owner.registry._class_registry.values(): # type: ignore | ||
if hasattr(cls, "__table__") and cls.__table__.fullname == key: # type: ignore | ||
return cls # type: ignore | ||
|
||
# Get primary key information from the fully initialized model | ||
mapper = inspect(owner) | ||
pk_cols = mapper.primary_key | ||
|
||
# Create the complete class dictionary | ||
class_dict: dict[str, Any] = { | ||
"__tablename__": table_name, | ||
"registry": registry_instance, | ||
# Add all standard columns | ||
"embedding_uuid": mapped_column(Text, primary_key=True), | ||
"chunk": mapped_column(Text, nullable=False), | ||
"embedding": mapped_column(Vector(self.dimensions), nullable=False), | ||
"chunk_seq": mapped_column(Integer, nullable=False), | ||
} | ||
|
||
# Add primary key columns to the dictionary | ||
for col in pk_cols: | ||
class_dict[col.name] = mapped_column(col.type, nullable=False) | ||
|
||
# Create the table args with foreign key constraint | ||
table_args_dict: dict[str, Any] = dict() | ||
if self.target_schema and self.target_schema != owner.registry.metadata.schema: | ||
table_args_dict["schema"] = self.target_schema | ||
|
||
# Create the composite foreign key constraint | ||
fk_constraint = ForeignKeyConstraint( | ||
[col.name for col in pk_cols], # Local columns | ||
[ | ||
f"{owner.__tablename__}.{col.name}" for col in pk_cols | ||
], # Referenced columns | ||
ondelete="CASCADE", | ||
) | ||
|
||
# Add table args to class dictionary | ||
class_dict["__table_args__"] = (fk_constraint, table_args_dict) | ||
|
||
# Create the class using type() | ||
Embedding = type(class_name, (base,), class_dict) | ||
|
||
return Embedding # type: ignore | ||
|
||
@overload | ||
def __get__( | ||
self, obj: None, objtype: type[DeclarativeBase] | ||
) -> type[EmbeddingModel[Any]]: ... | ||
|
||
@overload | ||
def __get__( | ||
self, obj: DeclarativeBase, objtype: type[DeclarativeBase] | None = None | ||
) -> Relationship[EmbeddingModel[Any]]: ... | ||
|
||
def __get__( | ||
self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None | ||
) -> Relationship[EmbeddingModel[Any]] | type[EmbeddingModel[Any]]: | ||
assert self.name is not None | ||
relationship_name = f"_{self.name}_relationship" | ||
if not self._initialized and objtype is not None: | ||
self._embedding_class = self.create_embedding_class(objtype) | ||
|
||
mapper = inspect(objtype) | ||
assert mapper is not None | ||
pk_cols = mapper.primary_key | ||
if not hasattr(objtype, relationship_name): | ||
self.relationship_instance = relationship( | ||
self._embedding_class, | ||
foreign_keys=[ | ||
getattr(self._embedding_class, col.name) for col in pk_cols | ||
], | ||
backref=self.relationship_args.pop( | ||
"backref", backref("parent", lazy="select") | ||
), | ||
**self.relationship_args, | ||
) | ||
setattr(objtype, f"{self.name}_model", self._embedding_class) | ||
setattr(objtype, relationship_name, self.relationship_instance) | ||
self._initialized = True | ||
if obj is None and self._initialized: | ||
return self._embedding_class # type: ignore | ||
|
||
return getattr(obj, relationship_name) | ||
|
||
def __set_name__(self, owner: type[DeclarativeBase], name: str): | ||
self.owner = owner | ||
self.name = name | ||
|
||
metadata = owner.registry.metadata | ||
if not hasattr(metadata, "info"): | ||
metadata.info = {} | ||
metadata.info.setdefault("pgai_managed_tables", set()).add(self.target_table) | ||
|
||
|
||
vectorizer_relationship = _Vectorizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.