diff --git a/projects/pgai/pgai/sqlalchemy/__init__.py b/projects/pgai/pgai/sqlalchemy/__init__.py index bc7e4949a..dafa1a156 100644 --- a/projects/pgai/pgai/sqlalchemy/__init__.py +++ b/projects/pgai/pgai/sqlalchemy/__init__.py @@ -1,105 +1,138 @@ from typing import Any, Generic, TypeVar from pgvector.sqlalchemy import Vector # type: ignore -from sqlalchemy import ForeignKey, Integer, Text +from sqlalchemy import ForeignKeyConstraint, Integer, Text, inspect from sqlalchemy.orm import DeclarativeBase, Mapped, backref, mapped_column, relationship # Type variable for the parent model T = TypeVar("T", bound=DeclarativeBase) - def to_pascal_case(text: str): # Split on any non-alphanumeric character words = "".join(char if char.isalnum() else " " for char in text).split() - # Capitalize first letter of all words return "".join(word.capitalize() for word in words) - class EmbeddingModel(DeclarativeBase, Generic[T]): """Base type for embedding models with required attributes""" - embedding_uuid: Mapped[str] - id: Mapped[int] chunk: Mapped[str] embedding: Mapped[Vector] chunk_seq: Mapped[int] parent: T # Type of the parent model - class VectorizerField: def __init__( - self, - dimensions: int, - target_schema: str | None = None, - target_table: str | None = None, - add_relationship: bool = False, + self, + dimensions: int, + target_schema: str | None = None, + target_table: str | None = None, + add_relationship: bool = False, ): self.add_relationship = add_relationship - - # Store table/view configuration self.dimensions = dimensions self.target_schema = target_schema self.target_table = target_table self.owner: type[DeclarativeBase] | None = None self.name: str | None = None + self._embedding_class: type[EmbeddingModel[Any]] | None = None + self._initialized = False + + def _relationship_property(self, obj: Any = None) -> Mapped[list[EmbeddingModel[Any]]]: + # Force initialization if not done yet + if not self._initialized: + _ = self.__get__(obj, self.owner) + # Return the actual relationship + return getattr(obj, f"_{self.name}_relation") - def set_schemas_correctly(self, owner: type[T]) -> None: + def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None: table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema") self.target_schema = ( - self.target_schema - or table_args_schema_name - or owner.registry.metadata.schema - or "public" + self.target_schema + or table_args_schema_name + or owner.registry.metadata.schema + or "public" ) - def create_embedding_class( - self, owner: type[T], name: str - ) -> type[EmbeddingModel[T]]: - table_name = self.target_table or f"{owner.__tablename__}_{name}_store" + def create_embedding_class(self, owner: type[DeclarativeBase]) -> type[EmbeddingModel[Any]]: + assert self.name is not None + table_name = self.target_table or f"{owner.__tablename__}_{self.name}_store" self.set_schemas_correctly(owner) - class_name = f"{to_pascal_case(name)}Embedding" + class_name = f"{to_pascal_case(self.name)}Embedding" registry_instance = owner.registry base: type[DeclarativeBase] = owner.__base__ # type: ignore - class Embedding(base): - __tablename__ = table_name - __table_args__ = ( - {"info": {"pgai_managed": True}, "schema": self.target_schema} - if self.target_schema - and self.target_schema != owner.registry.metadata.schema - else {"info": {"pgai_managed": True}} + # Get primary key information from the fully initialized model + mapper = inspect(owner) + pk_cols = mapper.primary_key + + # Create the complete class dictionary + class_dict: dict[str,Any] = { + "__tablename__": table_name, + "registry": registry_instance, + # Add all standard columns + "embedding_uuid": mapped_column(Text, primary_key=True), + "chunk": mapped_column(Text, nullable=False), + "embedding": mapped_column(Vector(self.dimensions), nullable=False), + "chunk_seq": mapped_column(Integer, nullable=False), + } + + # Add primary key columns to the dictionary + for col in pk_cols: + class_dict[col.name] = mapped_column( + col.type, + nullable=False ) - registry = registry_instance - embedding_uuid = mapped_column(Text, primary_key=True) - id = mapped_column( - Integer, ForeignKey(f"{owner.__tablename__}.id", ondelete="CASCADE") - ) - chunk = mapped_column(Text, nullable=False) - embedding = mapped_column( - Vector(self.dimensions), nullable=False - ) - chunk_seq = mapped_column(Integer, nullable=False) + # Create the table args with foreign key constraint + table_args_dict: dict[str, Any] = {"info": {"pgai_managed": True}} + if self.target_schema and self.target_schema != owner.registry.metadata.schema: + table_args_dict["schema"] = self.target_schema + + # Create the composite foreign key constraint + fk_constraint = ForeignKeyConstraint( + [col.name for col in pk_cols], # Local columns + [f"{owner.__tablename__}.{col.name}" for col in pk_cols], # Referenced columns + ondelete="CASCADE" + ) + + # Add table args to class dictionary + class_dict["__table_args__"] = (fk_constraint, table_args_dict) + + # Create the class using type() + Embedding = type(class_name, (base,), class_dict) - Embedding.__name__ = class_name return Embedding # type: ignore def __get__( - self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None + self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None ) -> type[EmbeddingModel[Any]]: + if not self._initialized and objtype is not None: + self._embedding_class = self.create_embedding_class(objtype) + + # Set up relationship if requested + if self.add_relationship: + mapper = inspect(objtype) + pk_cols = mapper.primary_key + + relationship_instance = relationship( + self._embedding_class, + foreign_keys=[getattr(self._embedding_class, col.name) for col in pk_cols], + backref=backref("parent", lazy="select"), + ) + # Store actual relationship under a private name + setattr(objtype, f"_{self.name}_relation", relationship_instance) + + self._initialized = True + + if self._embedding_class is None: + raise RuntimeError("Embedding class not properly initialized") + return self._embedding_class def __set_name__(self, owner: type[DeclarativeBase], name: str): self.owner = owner self.name = name - self._embedding_class = self.create_embedding_class(owner, name) - - # Set up relationship if self.add_relationship: - relationship_instance = relationship( - self._embedding_class, - foreign_keys=[self._embedding_class.id], - backref=backref("parent", lazy="select"), - ) - setattr(owner, f"{name}_relation", relationship_instance) + # Add the property that ensures initialization + setattr(owner, f"{name}_relation", property(self._relationship_property)) \ No newline at end of file diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py index e69de29bb..006599952 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py @@ -0,0 +1,137 @@ +import numpy as np +from click.testing import CliRunner +from sqlalchemy import Column, Engine, Text +from sqlalchemy.orm import DeclarativeBase, Mapped, Session +from sqlalchemy.sql import text +from testcontainers.postgres import PostgresContainer # type: ignore + +from pgai.cli import vectorizer_worker +from pgai.sqlalchemy import EmbeddingModel, VectorizerField + + +class Base(DeclarativeBase): + pass + + +class Author(Base): + __tablename__ = "authors" + first_name = Column(Text, primary_key=True) + last_name = Column(Text, primary_key=True) + bio = Column(Text, nullable=False) + bio_embeddings = VectorizerField( + dimensions=768, + add_relationship=True, + ) + + bio_embeddings_relation: Mapped[list[EmbeddingModel["Author"]]] + + +def run_vectorizer_worker(db_url: str, vectorizer_id: int) -> None: + CliRunner().invoke( + vectorizer_worker, + [ + "--db-url", + db_url, + "--once", + "--vectorizer-id", + str(vectorizer_id), + "--concurrency", + "1", + ], + catch_exceptions=False, + ) + + +def test_vectorizer_composite_key( + postgres_container: PostgresContainer, initialized_engine: Engine +): + """Test vectorizer with a composite primary key.""" + db_url = postgres_container.get_connection_url() + + # Create tables + metadata = Author.metadata + metadata.create_all(initialized_engine, tables=[metadata.sorted_tables[0]]) + + # Create vectorizer + with initialized_engine.connect() as conn: + conn.execute( + text(""" + SELECT ai.create_vectorizer( + 'authors'::regclass, + target_table => 'authors_bio_embeddings_store', + embedding => ai.embedding_openai('text-embedding-3-small', 768), + chunking => ai.chunking_recursive_character_text_splitter('bio', 50, 10) + ); + """) + ) + conn.commit() + + # Insert test data + with Session(initialized_engine) as session: + author = Author( + first_name="Jane", + last_name="Doe", + bio="Jane is an accomplished researcher in artificial intelligence and machine learning. She has published numerous papers on neural networks." + ) + session.add(author) + session.commit() + + # Run vectorizer worker + run_vectorizer_worker(db_url, 1) + + # Verify embeddings were created + with Session(initialized_engine) as session: + # Verify embedding class was created correctly + assert Author.bio_embeddings.__name__ == "BioEmbeddingsEmbedding" + + # Check embeddings exist and have correct properties + embedding = session.query(Author.bio_embeddings).first() + assert embedding is not None + assert isinstance(embedding.embedding, np.ndarray) + assert len(embedding.embedding) == 768 + assert embedding.chunk is not None + assert isinstance(embedding.chunk, str) + + # Check composite key fields were created + assert hasattr(embedding, "first_name") + assert hasattr(embedding, "last_name") + assert embedding.first_name == "Jane" # type: ignore + assert embedding.last_name == "Doe" # type: ignore + + # Verify relationship works + author = session.query(Author).first() + assert author is not None + assert hasattr(author, "bio_embeddings_relation") + assert author.bio_embeddings_relation is not None + assert len(author.bio_embeddings_relation) > 0 + assert author.bio_embeddings_relation[0].chunk in author.bio + + # Test that parent relationship works + embedding_entity = session.query(Author.bio_embeddings).first() + assert embedding_entity is not None + assert embedding_entity.chunk in author.bio + assert embedding_entity.parent is not None + assert embedding_entity.parent.first_name == "Jane" + assert embedding_entity.parent.last_name == "Doe" + + # Test semantic search with composite keys + from sqlalchemy import func + + # Search for content similar to "machine learning" + similar_embeddings = ( + session.query(Author.bio_embeddings) + .order_by( + Author.bio_embeddings.embedding.cosine_distance( + func.ai.openai_embed( + "text-embedding-3-small", + "machine learning", + text("dimensions => 768"), + ) + ) + ) + .all() + ) + + assert len(similar_embeddings) > 0 + # The bio should contain machine learning related content + assert "machine learning" in similar_embeddings[0].parent.bio \ No newline at end of file diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py index 62fe0a42a..78f19c293 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py @@ -80,6 +80,9 @@ def test_vectorizer_embedding_creation( # Verify embeddings were created with Session(initialized_engine) as session: # Verify embedding class was created correctly + + blog_post = session.query(BlogPost).first() + assert blog_post.content_embeddings_relation is not None assert BlogPost.content_embeddings.__name__ == "ContentEmbeddingsEmbedding" # Check embeddings exist and have correct properties