Skip to content

Commit

Permalink
fix: vectorizer_relationship for sqlalchemy models with mixins or inh…
Browse files Browse the repository at this point in the history
…eritance (#357)
  • Loading branch information
Askir authored Jan 13, 2025
1 parent 95fa797 commit cfd5f73
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 1 deletion.
19 changes: 18 additions & 1 deletion projects/pgai/pgai/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ class EmbeddingModel(DeclarativeBase, Generic[T]):
parent: T # Type of the parent model


def find_declarative_base(cls: type) -> type:
"""Find the SQLAlchemy declarative base class in the inheritance hierarchy."""
for base in cls.__mro__:
if (
(
hasattr(base, "_sa_registry") # Modern style
or hasattr(base, "__mapper__") # Mapped class
or hasattr(base, "metadata") # Legacy style
)
and
# Ensure it's the highest level base
not any(hasattr(parent, "_sa_registry") for parent in base.__bases__)
):
return base
raise ValueError("No SQLAlchemy declarative base found in class hierarchy")


class _Vectorizer:
def __init__(
self,
Expand Down Expand Up @@ -74,7 +91,7 @@ def create_embedding_class(
self.set_schemas_correctly(owner)
class_name = f"{owner.__name__}{to_pascal_case(self.name)}"
registry_instance = owner.registry
base: type[DeclarativeBase] = owner.__base__ # type: ignore
base = find_declarative_base(owner)

# Check if table already exists in metadata
# There is probably a better way to do this
Expand Down
82 changes: 82 additions & 0 deletions projects/pgai/tests/vectorizer/extensions/test_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from datetime import datetime
from typing import Any

import numpy as np
from sqlalchemy import Column, Engine, Integer, func, text
from sqlalchemy import Text as sa_Text
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from testcontainers.postgres import PostgresContainer # type: ignore

from pgai.sqlalchemy import vectorizer_relationship
from tests.vectorizer.extensions.utils import run_vectorizer_worker


class BaseModel(DeclarativeBase):
pass


class TimeStampedBase(BaseModel):
created_at: Mapped[datetime] = mapped_column(server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(), onupdate=func.now()
)
__abstract__ = True


class BlogPost(TimeStampedBase):
__tablename__ = "blog_posts"
id = Column(Integer, primary_key=True)
title = Column(sa_Text, nullable=False)
content = Column(sa_Text, nullable=False)
content_embeddings = vectorizer_relationship(dimensions=768, lazy="joined")


def test_vectorizer_embedding_creation(
postgres_container: PostgresContainer, initialized_engine: Engine, vcr_: Any
):
"""Test basic embedding creation and querying while the Model inherits from
another abstract model. This previously caused issues where the embedding model
inherited the fields as well which should not be the case."""
db_url = postgres_container.get_connection_url()
metadata = BlogPost.metadata
metadata.create_all(initialized_engine, tables=[metadata.sorted_tables[0]])
with initialized_engine.connect() as conn:
conn.execute(
text("""
SELECT ai.create_vectorizer(
'blog_posts'::regclass,
embedding =>
ai.embedding_openai('text-embedding-3-small', 768),
chunking =>
ai.chunking_recursive_character_text_splitter('content', 50, 10)
);
""")
)
conn.commit()

# Insert test data
with Session(initialized_engine) as session:
post = BlogPost(
title="Introduction to Machine Learning",
content="Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience.", # noqa
)
session.add(post)
session.commit()

# Run vectorizer worker
with vcr_.use_cassette("test_vectorizer_embedding_creation_relationship.yaml"):
run_vectorizer_worker(db_url, 1)

with Session(initialized_engine) as session:
blog_post = session.query(BlogPost).first()
assert blog_post is not None
assert blog_post.content_embeddings is not None
assert BlogPost.content_embeddings.__name__ == "BlogPostContentEmbeddings"

# Check embeddings exist and have correct properties
embedding = session.query(BlogPost.content_embeddings).first()
assert embedding is not None
assert isinstance(embedding.embedding, np.ndarray)
assert len(embedding.embedding) == 768
assert embedding.chunk is not None
assert isinstance(embedding.chunk, str)

0 comments on commit cfd5f73

Please sign in to comment.