Skip to content

Commit

Permalink
feat: add sqlalchemy vectorizer_relationship (#265)
Browse files Browse the repository at this point in the history
* feat: add sqlalchemy vectorizer field

* chore: add installing extras to ci

* chore: simplify interface, add simple docs

* feat: allow arbitrary primary keys on parent model

* docs: update docs with simplified vectorizer field

* chore: rename VectorizerField to Vectorizer

* chore: update alembic exclusion mechanism

* docs: update docs with review comments

* chore: align automatic table name with create_vectorizer

* chore: add option for any relationship properties

* chore: setup class event based rather than lazy so relationship works on first query

* chore: update to embedding_relationship

* chore: refactor tests add vcr mocks

* chore: rename to vectorizer_model; cleanup

* chore: fix uv lock

* chore: remove dummy key
  • Loading branch information
Askir authored Dec 19, 2024
1 parent 39fcf93 commit 0230509
Show file tree
Hide file tree
Showing 19 changed files with 3,203 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ jobs:

- name: Install dependencies
working-directory: ./projects/pgai
run: uv sync
run: uv sync --all-extras

- name: Lint
run: just pgai lint
Expand Down
184 changes: 184 additions & 0 deletions docs/python-integration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# SQLAlchemy Integration with pgai Vectorizer

The `vectorizer_relationship` is a SQLAlchemy helper that integrates pgai's vectorization capabilities directly into your SQLAlchemy models.
Think of it as a normal SQLAlchemy [relationship](https://docs.sqlalchemy.org/en/20/orm/basic_relationships.html), but with a preconfigured model instance under the hood.
This allows you to easily query vector embeddings created by pgai using familiar SQLAlchemy patterns.

## Installation

To use the SQLAlchemy integration, install pgai with the SQLAlchemy extras:

```bash
pip install "pgai[sqlalchemy]"
```

## Basic Usage

Here's a basic example of how to use the `vectorizer_relationship`:

```python
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from pgai.sqlalchemy import vectorizer_relationship

class Base(DeclarativeBase):
pass

class BlogPost(Base):
__tablename__ = "blog_posts"

id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str]
content: Mapped[str]

# Add vector embeddings for the content field
content_embeddings = vectorizer_relationship(
dimensions=768
)
```
Note if you work with alembics autogenerate functionality for migrations, also check [Working with alembic](#working-with-alembic).

### Semantic Search

You can then perform semantic similarity search on the field using [pgvector-python's](https://github.com/pgvector/pgvector-python) distance functions:

```python
from sqlalchemy import func, text

similar_posts = (
session.query(BlogPost.content_embeddings)
.order_by(
BlogPost.content_embeddings.embedding.cosine_distance(
func.ai.openai_embed(
"text-embedding-3-small",
"search query",
text("dimensions => 768")
)
)
)
.limit(5)
.all()
)
```

Or if you already have the embeddings in your application:

```python
similar_posts = (
session.query(BlogPost.content_embeddings)
.order_by(
BlogPost.content_embeddings.embedding.cosine_distance(
[3, 1, 2]
)
)
.limit(5)
.all()
)
```

## Configuration

The `vectorizer_relationship` accepts the following parameters:

- `dimensions` (int): The size of the embedding vector (required)
- `target_schema` (str, optional): Override the schema for the embeddings table. If not provided, inherits from the parent model's schema
- `target_table` (str, optional): Override the table name for embeddings. Default is `{table_name}_embedding_store`

Additional parameters are simply forwarded to the underlying [SQLAlchemy relationship](https://docs.sqlalchemy.org/en/20/orm/relationships.html) so you can configure it as you desire.

Think of the `vectorizer_relationship` as a normal SQLAlchemy relationship, but with a preconfigured model instance under the hood.


## Setting up the Vectorizer

After defining your model, you need to create the vectorizer using pgai's SQL functions:

```sql
SELECT ai.create_vectorizer(
'blog_posts'::regclass,
embedding => ai.embedding_openai('text-embedding-3-small', 768),
chunking => ai.chunking_recursive_character_text_splitter(
'content',
50, -- chunk_size
10 -- chunk_overlap
)
);
```

We recommend adding this to a migration script and run it via alembic.


## Querying Embeddings

The `vectorizer_relationship` provides several ways to work with embeddings:

### 1. Direct Access to Embeddings

If you access the class proeprty of your model the `vectorizer_relationship` provide a SQLAlchemy model that you can query directly:

```python
# Get all embeddings
embeddings = session.query(BlogPost.content_embeddings).all()

# Access embedding properties
for embedding in embeddings:
print(embedding.embedding) # The vector embedding
print(embedding.chunk) # The text chunk
```
The model will have the primary key fields of the parent model as well as the following fields:
- `chunk` (str): The text chunk that was embedded
- `embedding` (Vector): The vector embedding
- `chunk_seq` (int): The sequence number of the chunk
- `embedding_uuid` (str): The UUID of the embedding
- `parent` (ParentModel): The parent model instance

### 2. Relationship Access


```python
blog_post = session.query(BlogPost).first()
for embedding in blog_post.content_embeddings:
print(embedding.chunk)
```
Access the original posts through the parent relationship
```python
for embedding in similar_posts:
print(embedding.parent.title)
```

### 3. Join Queries

You can combine embedding queries with regular SQL queries using the relationship:

```python
results = (
session.query(BlogPost, BlogPost.content_embeddings)
.join(BlogPost.content_embeddings)
.filter(BlogPost.title.ilike("%search term%"))
.all()
)

for post, embedding in results:
print(f"Title: {post.title}")
print(f"Chunk: {embedding.chunk}")
```

## Working with alembic


The `vectorizer_relationship` generates a new SQLAlchemy model, that is available under the attribute that you specify. If you are using alembic's autogenerate functionality to generate migrations, you will need to exclude these models from the autogenerate process.
These are added to a list in your metadata called `pgai_managed_tables` and you can exclude them by adding the following to your `env.py`:

```python
def include_object(object, name, type_, reflected, compare_to):
if type_ == "table" and name in target_metadata.info.get("pgai_managed_tables", set()):
return False
return True

context.configure(
connection=connection,
target_metadata=target_metadata,
include_object=include_object
)
```

This should now prevent alembic from generating tables for these models when you run `alembic revision --autogenerate`.
184 changes: 184 additions & 0 deletions projects/pgai/pgai/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import Any, Generic, TypeVar, overload

from pgvector.sqlalchemy import Vector # type: ignore
from sqlalchemy import ForeignKeyConstraint, Integer, Text, event, inspect
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Mapper,
Relationship,
RelationshipProperty,
backref,
mapped_column,
relationship,
)

# Type variable for the parent model
T = TypeVar("T", bound=DeclarativeBase)


def to_pascal_case(text: str):
# Split on any non-alphanumeric character
words = "".join(char if char.isalnum() else " " for char in text).split()
# Capitalize first letter of all words
return "".join(word.capitalize() for word in words)


class EmbeddingModel(DeclarativeBase, Generic[T]):
"""Base type for embedding models with required attributes"""

embedding_uuid: Mapped[str]
chunk: Mapped[str]
embedding: Mapped[Vector]
chunk_seq: Mapped[int]
parent: T # Type of the parent model


class _Vectorizer:
def __init__(
self,
dimensions: int,
target_schema: str | None = None,
target_table: str | None = None,
**kwargs: Any,
):
self.dimensions = dimensions
self.target_schema = target_schema
self.target_table = target_table
self.owner: type[DeclarativeBase] | None = None
self.name: str | None = None
self._embedding_class: type[EmbeddingModel[Any]] | None = None
self._relationship: RelationshipProperty[Any] | None = None
self._initialized = False
self.relationship_args = kwargs
event.listen(Mapper, "after_configured", self._initialize_all)

def _initialize_all(self):
"""Force initialization during mapper configuration"""
if not self._initialized and self.owner is not None:
self.__get__(None, self.owner)

def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None:
table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema")
self.target_schema = (
self.target_schema
or table_args_schema_name
or owner.registry.metadata.schema
or "public"
)

def create_embedding_class(
self, owner: type[DeclarativeBase]
) -> type[EmbeddingModel[Any]]:
assert self.name is not None
table_name = self.target_table or f"{owner.__tablename__}_embedding_store"
self.set_schemas_correctly(owner)
class_name = f"{to_pascal_case(self.name)}Embedding"
registry_instance = owner.registry
base: type[DeclarativeBase] = owner.__base__ # type: ignore

# Check if table already exists in metadata
# There is probably a better way to do this
# than accessing the internal _class_registry
# Not doing this ends up in a recursion because
# creating the new class reconfigures tha parent mapper
# again triggering the after_configured event
key = f"{self.target_schema}.{table_name}"
if key in owner.metadata.tables:
# Find the mapped class in the registry
for cls in owner.registry._class_registry.values(): # type: ignore
if hasattr(cls, "__table__") and cls.__table__.fullname == key: # type: ignore
return cls # type: ignore

# Get primary key information from the fully initialized model
mapper = inspect(owner)
pk_cols = mapper.primary_key

# Create the complete class dictionary
class_dict: dict[str, Any] = {
"__tablename__": table_name,
"registry": registry_instance,
# Add all standard columns
"embedding_uuid": mapped_column(Text, primary_key=True),
"chunk": mapped_column(Text, nullable=False),
"embedding": mapped_column(Vector(self.dimensions), nullable=False),
"chunk_seq": mapped_column(Integer, nullable=False),
}

# Add primary key columns to the dictionary
for col in pk_cols:
class_dict[col.name] = mapped_column(col.type, nullable=False)

# Create the table args with foreign key constraint
table_args_dict: dict[str, Any] = dict()
if self.target_schema and self.target_schema != owner.registry.metadata.schema:
table_args_dict["schema"] = self.target_schema

# Create the composite foreign key constraint
fk_constraint = ForeignKeyConstraint(
[col.name for col in pk_cols], # Local columns
[
f"{owner.__tablename__}.{col.name}" for col in pk_cols
], # Referenced columns
ondelete="CASCADE",
)

# Add table args to class dictionary
class_dict["__table_args__"] = (fk_constraint, table_args_dict)

# Create the class using type()
Embedding = type(class_name, (base,), class_dict)

return Embedding # type: ignore

@overload
def __get__(
self, obj: None, objtype: type[DeclarativeBase]
) -> type[EmbeddingModel[Any]]: ...

@overload
def __get__(
self, obj: DeclarativeBase, objtype: type[DeclarativeBase] | None = None
) -> Relationship[EmbeddingModel[Any]]: ...

def __get__(
self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None
) -> Relationship[EmbeddingModel[Any]] | type[EmbeddingModel[Any]]:
assert self.name is not None
relationship_name = f"_{self.name}_relationship"
if not self._initialized and objtype is not None:
self._embedding_class = self.create_embedding_class(objtype)

mapper = inspect(objtype)
assert mapper is not None
pk_cols = mapper.primary_key
if not hasattr(objtype, relationship_name):
self.relationship_instance = relationship(
self._embedding_class,
foreign_keys=[
getattr(self._embedding_class, col.name) for col in pk_cols
],
backref=self.relationship_args.pop(
"backref", backref("parent", lazy="select")
),
**self.relationship_args,
)
setattr(objtype, f"{self.name}_model", self._embedding_class)
setattr(objtype, relationship_name, self.relationship_instance)
self._initialized = True
if obj is None and self._initialized:
return self._embedding_class # type: ignore

return getattr(obj, relationship_name)

def __set_name__(self, owner: type[DeclarativeBase], name: str):
self.owner = owner
self.name = name

metadata = owner.registry.metadata
if not hasattr(metadata, "info"):
metadata.info = {}
metadata.info.setdefault("pgai_managed_tables", set()).add(self.target_table)


vectorizer_relationship = _Vectorizer
6 changes: 6 additions & 0 deletions projects/pgai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ classifiers = [
"Operating System :: POSIX :: Linux",
]

[project.optional-dependencies]
sqlalchemy=[
"sqlalchemy>=2.0.36",
]

[project.urls]
Homepage = "https://github.com/timescale/pgai"
Repository = "https://github.com/timescale/pgai"
Expand Down Expand Up @@ -110,4 +115,5 @@ dev-dependencies = [
"testcontainers==4.8.1",
"build==1.2.2.post1",
"twine==5.1.1",
"psycopg2==2.9.10",
]
Loading

0 comments on commit 0230509

Please sign in to comment.