-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: allow arbitrary primary keys on parent model
- Loading branch information
Showing
3 changed files
with
225 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,105 +1,138 @@ | ||
from typing import Any, Generic, TypeVar | ||
|
||
from pgvector.sqlalchemy import Vector # type: ignore | ||
from sqlalchemy import ForeignKey, Integer, Text | ||
from sqlalchemy import ForeignKeyConstraint, Integer, Text, inspect | ||
from sqlalchemy.orm import DeclarativeBase, Mapped, backref, mapped_column, relationship | ||
|
||
# Type variable for the parent model | ||
T = TypeVar("T", bound=DeclarativeBase) | ||
|
||
|
||
def to_pascal_case(text: str): | ||
# Split on any non-alphanumeric character | ||
words = "".join(char if char.isalnum() else " " for char in text).split() | ||
|
||
# Capitalize first letter of all words | ||
return "".join(word.capitalize() for word in words) | ||
|
||
|
||
class EmbeddingModel(DeclarativeBase, Generic[T]): | ||
"""Base type for embedding models with required attributes""" | ||
|
||
embedding_uuid: Mapped[str] | ||
id: Mapped[int] | ||
chunk: Mapped[str] | ||
embedding: Mapped[Vector] | ||
chunk_seq: Mapped[int] | ||
parent: T # Type of the parent model | ||
|
||
|
||
class VectorizerField: | ||
def __init__( | ||
self, | ||
dimensions: int, | ||
target_schema: str | None = None, | ||
target_table: str | None = None, | ||
add_relationship: bool = False, | ||
self, | ||
dimensions: int, | ||
target_schema: str | None = None, | ||
target_table: str | None = None, | ||
add_relationship: bool = False, | ||
): | ||
self.add_relationship = add_relationship | ||
|
||
# Store table/view configuration | ||
self.dimensions = dimensions | ||
self.target_schema = target_schema | ||
self.target_table = target_table | ||
self.owner: type[DeclarativeBase] | None = None | ||
self.name: str | None = None | ||
self._embedding_class: type[EmbeddingModel[Any]] | None = None | ||
self._initialized = False | ||
|
||
def _relationship_property(self, obj: Any = None) -> Mapped[list[EmbeddingModel[Any]]]: | ||
# Force initialization if not done yet | ||
if not self._initialized: | ||
_ = self.__get__(obj, self.owner) | ||
# Return the actual relationship | ||
return getattr(obj, f"_{self.name}_relation") | ||
|
||
def set_schemas_correctly(self, owner: type[T]) -> None: | ||
def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None: | ||
table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema") | ||
self.target_schema = ( | ||
self.target_schema | ||
or table_args_schema_name | ||
or owner.registry.metadata.schema | ||
or "public" | ||
self.target_schema | ||
or table_args_schema_name | ||
or owner.registry.metadata.schema | ||
or "public" | ||
) | ||
|
||
def create_embedding_class( | ||
self, owner: type[T], name: str | ||
) -> type[EmbeddingModel[T]]: | ||
table_name = self.target_table or f"{owner.__tablename__}_{name}_store" | ||
def create_embedding_class(self, owner: type[DeclarativeBase]) -> type[EmbeddingModel[Any]]: | ||
assert self.name is not None | ||
table_name = self.target_table or f"{owner.__tablename__}_{self.name}_store" | ||
self.set_schemas_correctly(owner) | ||
class_name = f"{to_pascal_case(name)}Embedding" | ||
class_name = f"{to_pascal_case(self.name)}Embedding" | ||
registry_instance = owner.registry | ||
base: type[DeclarativeBase] = owner.__base__ # type: ignore | ||
|
||
class Embedding(base): | ||
__tablename__ = table_name | ||
__table_args__ = ( | ||
{"info": {"pgai_managed": True}, "schema": self.target_schema} | ||
if self.target_schema | ||
and self.target_schema != owner.registry.metadata.schema | ||
else {"info": {"pgai_managed": True}} | ||
# Get primary key information from the fully initialized model | ||
mapper = inspect(owner) | ||
pk_cols = mapper.primary_key | ||
|
||
# Create the complete class dictionary | ||
class_dict: dict[str,Any] = { | ||
"__tablename__": table_name, | ||
"registry": registry_instance, | ||
# Add all standard columns | ||
"embedding_uuid": mapped_column(Text, primary_key=True), | ||
"chunk": mapped_column(Text, nullable=False), | ||
"embedding": mapped_column(Vector(self.dimensions), nullable=False), | ||
"chunk_seq": mapped_column(Integer, nullable=False), | ||
} | ||
|
||
# Add primary key columns to the dictionary | ||
for col in pk_cols: | ||
class_dict[col.name] = mapped_column( | ||
col.type, | ||
nullable=False | ||
) | ||
registry = registry_instance | ||
|
||
embedding_uuid = mapped_column(Text, primary_key=True) | ||
id = mapped_column( | ||
Integer, ForeignKey(f"{owner.__tablename__}.id", ondelete="CASCADE") | ||
) | ||
chunk = mapped_column(Text, nullable=False) | ||
embedding = mapped_column( | ||
Vector(self.dimensions), nullable=False | ||
) | ||
chunk_seq = mapped_column(Integer, nullable=False) | ||
# Create the table args with foreign key constraint | ||
table_args_dict: dict[str, Any] = {"info": {"pgai_managed": True}} | ||
if self.target_schema and self.target_schema != owner.registry.metadata.schema: | ||
table_args_dict["schema"] = self.target_schema | ||
|
||
# Create the composite foreign key constraint | ||
fk_constraint = ForeignKeyConstraint( | ||
[col.name for col in pk_cols], # Local columns | ||
[f"{owner.__tablename__}.{col.name}" for col in pk_cols], # Referenced columns | ||
ondelete="CASCADE" | ||
) | ||
|
||
# Add table args to class dictionary | ||
class_dict["__table_args__"] = (fk_constraint, table_args_dict) | ||
|
||
# Create the class using type() | ||
Embedding = type(class_name, (base,), class_dict) | ||
|
||
Embedding.__name__ = class_name | ||
return Embedding # type: ignore | ||
|
||
def __get__( | ||
self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None | ||
self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None | ||
) -> type[EmbeddingModel[Any]]: | ||
if not self._initialized and objtype is not None: | ||
self._embedding_class = self.create_embedding_class(objtype) | ||
|
||
# Set up relationship if requested | ||
if self.add_relationship: | ||
mapper = inspect(objtype) | ||
pk_cols = mapper.primary_key | ||
|
||
relationship_instance = relationship( | ||
self._embedding_class, | ||
foreign_keys=[getattr(self._embedding_class, col.name) for col in pk_cols], | ||
backref=backref("parent", lazy="select"), | ||
) | ||
# Store actual relationship under a private name | ||
setattr(objtype, f"_{self.name}_relation", relationship_instance) | ||
|
||
self._initialized = True | ||
|
||
if self._embedding_class is None: | ||
raise RuntimeError("Embedding class not properly initialized") | ||
|
||
return self._embedding_class | ||
|
||
def __set_name__(self, owner: type[DeclarativeBase], name: str): | ||
self.owner = owner | ||
self.name = name | ||
self._embedding_class = self.create_embedding_class(owner, name) | ||
|
||
# Set up relationship | ||
if self.add_relationship: | ||
relationship_instance = relationship( | ||
self._embedding_class, | ||
foreign_keys=[self._embedding_class.id], | ||
backref=backref("parent", lazy="select"), | ||
) | ||
setattr(owner, f"{name}_relation", relationship_instance) | ||
# Add the property that ensures initialization | ||
setattr(owner, f"{name}_relation", property(self._relationship_property)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import numpy as np | ||
from click.testing import CliRunner | ||
from sqlalchemy import Column, Engine, Text | ||
from sqlalchemy.orm import DeclarativeBase, Mapped, Session | ||
from sqlalchemy.sql import text | ||
from testcontainers.postgres import PostgresContainer # type: ignore | ||
|
||
from pgai.cli import vectorizer_worker | ||
from pgai.sqlalchemy import EmbeddingModel, VectorizerField | ||
|
||
|
||
class Base(DeclarativeBase): | ||
pass | ||
|
||
|
||
class Author(Base): | ||
__tablename__ = "authors" | ||
first_name = Column(Text, primary_key=True) | ||
last_name = Column(Text, primary_key=True) | ||
bio = Column(Text, nullable=False) | ||
bio_embeddings = VectorizerField( | ||
dimensions=768, | ||
add_relationship=True, | ||
) | ||
|
||
bio_embeddings_relation: Mapped[list[EmbeddingModel["Author"]]] | ||
|
||
|
||
def run_vectorizer_worker(db_url: str, vectorizer_id: int) -> None: | ||
CliRunner().invoke( | ||
vectorizer_worker, | ||
[ | ||
"--db-url", | ||
db_url, | ||
"--once", | ||
"--vectorizer-id", | ||
str(vectorizer_id), | ||
"--concurrency", | ||
"1", | ||
], | ||
catch_exceptions=False, | ||
) | ||
|
||
|
||
def test_vectorizer_composite_key( | ||
postgres_container: PostgresContainer, initialized_engine: Engine | ||
): | ||
"""Test vectorizer with a composite primary key.""" | ||
db_url = postgres_container.get_connection_url() | ||
|
||
# Create tables | ||
metadata = Author.metadata | ||
metadata.create_all(initialized_engine, tables=[metadata.sorted_tables[0]]) | ||
|
||
# Create vectorizer | ||
with initialized_engine.connect() as conn: | ||
conn.execute( | ||
text(""" | ||
SELECT ai.create_vectorizer( | ||
'authors'::regclass, | ||
target_table => 'authors_bio_embeddings_store', | ||
embedding => ai.embedding_openai('text-embedding-3-small', 768), | ||
chunking => ai.chunking_recursive_character_text_splitter('bio', 50, 10) | ||
); | ||
""") | ||
) | ||
conn.commit() | ||
|
||
# Insert test data | ||
with Session(initialized_engine) as session: | ||
author = Author( | ||
first_name="Jane", | ||
last_name="Doe", | ||
bio="Jane is an accomplished researcher in artificial intelligence and machine learning. She has published numerous papers on neural networks." | ||
) | ||
session.add(author) | ||
session.commit() | ||
|
||
# Run vectorizer worker | ||
run_vectorizer_worker(db_url, 1) | ||
|
||
# Verify embeddings were created | ||
with Session(initialized_engine) as session: | ||
# Verify embedding class was created correctly | ||
assert Author.bio_embeddings.__name__ == "BioEmbeddingsEmbedding" | ||
|
||
# Check embeddings exist and have correct properties | ||
embedding = session.query(Author.bio_embeddings).first() | ||
assert embedding is not None | ||
assert isinstance(embedding.embedding, np.ndarray) | ||
assert len(embedding.embedding) == 768 | ||
assert embedding.chunk is not None | ||
assert isinstance(embedding.chunk, str) | ||
|
||
# Check composite key fields were created | ||
assert hasattr(embedding, "first_name") | ||
assert hasattr(embedding, "last_name") | ||
assert embedding.first_name == "Jane" # type: ignore | ||
assert embedding.last_name == "Doe" # type: ignore | ||
|
||
# Verify relationship works | ||
author = session.query(Author).first() | ||
assert author is not None | ||
assert hasattr(author, "bio_embeddings_relation") | ||
assert author.bio_embeddings_relation is not None | ||
assert len(author.bio_embeddings_relation) > 0 | ||
assert author.bio_embeddings_relation[0].chunk in author.bio | ||
|
||
# Test that parent relationship works | ||
embedding_entity = session.query(Author.bio_embeddings).first() | ||
assert embedding_entity is not None | ||
assert embedding_entity.chunk in author.bio | ||
assert embedding_entity.parent is not None | ||
assert embedding_entity.parent.first_name == "Jane" | ||
assert embedding_entity.parent.last_name == "Doe" | ||
|
||
# Test semantic search with composite keys | ||
from sqlalchemy import func | ||
|
||
# Search for content similar to "machine learning" | ||
similar_embeddings = ( | ||
session.query(Author.bio_embeddings) | ||
.order_by( | ||
Author.bio_embeddings.embedding.cosine_distance( | ||
func.ai.openai_embed( | ||
"text-embedding-3-small", | ||
"machine learning", | ||
text("dimensions => 768"), | ||
) | ||
) | ||
) | ||
.all() | ||
) | ||
|
||
assert len(similar_embeddings) > 0 | ||
# The bio should contain machine learning related content | ||
assert "machine learning" in similar_embeddings[0].parent.bio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters