Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add alembic operations for vectorizer #266

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion docs/python-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ for post, embedding in results:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to add docs to adding-embedding-integration.md

## Working with alembic


### Excluding managed tables
The `vectorizer_relationship` generates a new SQLAlchemy model, that is available under the attribute that you specify. If you are using alembic's autogenerate functionality to generate migrations, you will need to exclude these models from the autogenerate process.
These are added to a list in your metadata called `pgai_managed_tables` and you can exclude them by adding the following to your `env.py`:

Expand All @@ -182,3 +182,49 @@ context.configure(
```

This should now prevent alembic from generating tables for these models when you run `alembic revision --autogenerate`.


### Creating vectorizers
pgai provides native Alembic operations for managing vectorizers. For them to work you need to run `register_operations` in your env.py file. Which registers the pgai operations under the global op context:

```python
from pgai.alembic import register_operations

register_operations()
```

Then you can use the `create_vectorizer` operation to create a vectorizer for your model. As well as the `drop_vectorizer` operation to remove it.

```python
from alembic import op
from pgai.alembic.configuration import (
OpenAIConfig,
CharacterTextSplitterConfig,
PythonTemplateConfig
)


def upgrade() -> None:
op.create_vectorizer(
source_table="blog",
target_table='blog_embeddings',
embedding=OpenAIConfig(
model='text-embedding-3-small',
dimensions=768
),
chunking=CharacterTextSplitterConfig(
chunk_column='content',
chunk_size=800,
chunk_overlap=400,
separator='.',
is_separator_regex=False
),
formatting=PythonTemplateConfig(template='$title - $chunk')
)


def downgrade() -> None:
op.drop_vectorizer(vectorizer_id=1, drop_all=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be better for this to take the target_table name and not the vectorizer_id (which would probably not be known when writing the migration). The target table should be unique and so we should be able to look up the id from that

```

The `create_vectorizer` operation supports all configuration options available in the [SQL API](vectorizer-api-reference.md).
4 changes: 4 additions & 0 deletions projects/extension/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ WORKDIR /pgai
COPY . .
RUN just build install

RUN mkdir -p /docker-entrypoint-initdb.d && \
echo "#!/bin/bash" > /docker-entrypoint-initdb.d/configure-timescaledb.sh && \
echo "echo \"shared_preload_libraries = 'timescaledb'\" >> \${PGDATA}/postgresql.conf" >> /docker-entrypoint-initdb.d/configure-timescaledb.sh && \
chmod +x /docker-entrypoint-initdb.d/configure-timescaledb.sh
Comment on lines +55 to +58
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add this to be able to run create extension if not exists timescaledb I'm not sure this is correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you need timescaledb for this pr? This is a dev image so this is fine I'm just curious

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I am actually running the migrations in tests and creating a vectorizer with scheduling config is not allowed if timescaledb is not installed.


###############################################################################
# image for use in extension development
Expand Down
7 changes: 7 additions & 0 deletions projects/pgai/pgai/alembic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pgai.alembic.operations import (
CreateVectorizerOp,
DropVectorizerOp,
register_operations,
)

__all__ = ["CreateVectorizerOp", "DropVectorizerOp", "register_operations"]
227 changes: 227 additions & 0 deletions projects/pgai/pgai/alembic/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import re
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my eye this is a huge improvement from before

from dataclasses import dataclass, fields
from datetime import timedelta
from typing import ClassVar, Literal, Protocol, runtime_checkable

from pydantic import BaseModel

from pgai.vectorizer.base import (
BaseOllamaConfig,
BaseOpenAIConfig,
BaseProcessing,
BasePythonTemplate,
BaseVoyageAIConfig,
ChunkingCharacterTextSplitter,
ChunkingRecursiveCharacterTextSplitter,
)


@runtime_checkable
class SQLArgumentProvider(Protocol):
def to_sql_argument(self) -> str: ...


@runtime_checkable
class PythonArgProvider(Protocol):
def to_python_arg(self) -> str: ...


def format_sql_params(params: dict[str, str | None | bool | list[str]]) -> str:
"""Format dictionary of parameters into SQL argument string without any quoting."""
formatted: list[str] = []
for key, value in params.items():
if value is None:
continue
elif isinstance(value, bool):
formatted.append(f"{key}=>{str(value).lower()}")
elif isinstance(value, list):
array_list = ",".join(f"E'{v}'" for v in value)
formatted.append(f"{key}=>ARRAY[{array_list}]")
elif isinstance(value, timedelta):
formatted.append(f"{key}=>'{value.seconds} seconds'")
else:
formatted.append(f"{key}=> '{value}'")
return ", ".join(formatted)


class SQLArgumentMixin:
arg_type: ClassVar[str]
function_name: ClassVar[str | None] = None

def to_sql_argument(self) -> str:
# Get all fields including from parent classes
params = {}
for field_name, _field in self.model_fields.items(): # type: ignore
if field_name != "arg_type":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the function_name field included then, how does that work?

value = getattr(self, field_name) # type: ignore
if value is not None:
params[field_name] = value

if self.function_name:
fn_name = self.function_name
else:
base_name = self.__class__.__name__
# Convert camelCase to snake_case
base_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", base_name).lower()
# Remove 'config' and clean up any double underscores
base_name = base_name.replace("config", "").strip("_")
fn_name = f"{self.arg_type}_{base_name}"

return f", {self.arg_type} => ai.{fn_name}({format_sql_params(params)})" # type: ignore


class OpenAIConfig(BaseOpenAIConfig, SQLArgumentMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming (and I know naming discussions are always annoying) but why not stick to the sql convention we established and name this EmbeddingOpenAIConfig or EmbeddingConfigOpenAI? (and similar for others). The pro is that the name translation from sql->python is super easy and I think would be easier to understand. The con is that it's long.

Otherwise the translation seems a bit ad-hoc. e.g. Indexing configs have "indexing" in the name but in a different spot than the sql. Let's think about this some more

arg_type: ClassVar[str] = "embedding"
function_name: ClassVar[str] = "embedding_openai" # type: ignore
chat_user: str | None = None
api_key_name: str | None = None


class VoyageAIConfig(BaseVoyageAIConfig, SQLArgumentMixin):
arg_type: ClassVar[str] = "embedding"
function_name: ClassVar[str] = "embedding_voyageai" # type: ignore
api_key_name: str | None = None


class OllamaConfig(BaseOllamaConfig, SQLArgumentMixin):
arg_type: ClassVar[str] = "embedding"


class CharacterTextSplitterConfig(ChunkingCharacterTextSplitter, SQLArgumentMixin):
arg_type: ClassVar[str] = "chunking"


class RecursiveCharacterTextSplitterConfig(
ChunkingRecursiveCharacterTextSplitter, SQLArgumentMixin
):
arg_type: ClassVar[str] = "chunking"


class ChunkValueConfig:
arg_type: ClassVar[str] = "formatting"

def to_sql_argument(self) -> str:
return f", {self.arg_type: ClassVar[str]} => ai.formatting_chunk_value()"


class PythonTemplateConfig(BasePythonTemplate, SQLArgumentMixin):
arg_type: ClassVar[str] = "formatting"


class NoIndexingConfig:
arg_type: ClassVar[str] = "indexing"

def to_sql_argument(self) -> str:
return f", {self.arg_type: ClassVar[str]} => ai.indexing_none()"


class DiskANNIndexingConfig(BaseModel, SQLArgumentMixin):
arg_type: ClassVar[str] = "indexing"
function_name: ClassVar[str] = "indexing_diskann" # type: ignore
min_rows: int
storage_layout: Literal["memory_optimized", "plain"] | None = None
num_neighbors: int | None = None
search_list_size: int | None = None
max_alpha: float | None = None
num_dimensions: int | None = None
num_bits_per_dimension: int | None = None
create_when_queue_empty: bool | None = None


class HNSWIndexingConfig(BaseModel, SQLArgumentMixin):
arg_type: ClassVar[str] = "indexing"
function_name: ClassVar[str] = "indexing_hnsw" # type: ignore
min_rows: int | None = None
opclass: Literal["vector_cosine_ops", "vector_l1_ops", "vector_ip_ops"] | None = (
None
)
m: int | None = None
ef_construction: int | None = None
create_when_queue_empty: bool | None = None


class NoSchedulingConfig:
arg_type: ClassVar[str] = "scheduling"

def to_sql_argument(self) -> str:
return f", {self.arg_type: ClassVar[str]} => ai.scheduling_none()"


class TimescaleSchedulingConfig(BaseModel, SQLArgumentMixin):
arg_type: ClassVar[str] = "scheduling"
function_name: ClassVar[str] = "scheduling_timescaledb" # type: ignore

schedule_interval: timedelta | None = None
initial_start: str | None = None
job_id: int | None = None
fixed_schedule: bool | None = None
timezone: str | None = None


class ProcessingConfig(BaseProcessing, SQLArgumentMixin):
arg_type: ClassVar[str] = "processing"
function_name: ClassVar[str] = "processing_default" # type: ignore


def format_string_param(name: str, value: str) -> str:
return f", {name} => '{value}'"


def format_bool_param(name: str, value: bool) -> str:
return f", {name} => {str(value).lower()}"


@dataclass
class CreateVectorizerParams:
source_table: str | None
embedding: OpenAIConfig | OllamaConfig | VoyageAIConfig | None = None
chunking: (
CharacterTextSplitterConfig | RecursiveCharacterTextSplitterConfig | None
) = None
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | NoIndexingConfig | None = (
None
)
formatting: ChunkValueConfig | PythonTemplateConfig | None = None
scheduling: TimescaleSchedulingConfig | NoSchedulingConfig | None = None
processing: ProcessingConfig | None = None
target_schema: str | None = None
target_table: str | None = None
view_schema: str | None = None
view_name: str | None = None
queue_schema: str | None = None
queue_table: str | None = None
grant_to: list[str] | None = None
enqueue_existing: bool = True

def to_sql(self) -> str:
parts = ["SELECT ai.create_vectorizer(", f"'{self.source_table}'::regclass"]

# Handle all config objects that implement to_sql_argument
for field in fields(self):
value = getattr(self, field.name)
if isinstance(value, SQLArgumentProvider):
parts.append(value.to_sql_argument())

# Handle string parameters
string_fields = [
"target_schema",
"target_table",
"view_schema",
"view_name",
"queue_schema",
"queue_table",
]
for field in string_fields:
value = getattr(self, field)
if value is not None:
parts.append(format_string_param(field, value))

if self.grant_to:
grant_list = ", ".join(f"'{user}'" for user in self.grant_to)
parts.append(f", grant_to => ai.grant_to({grant_list})")

if not self.enqueue_existing:
parts.append(format_bool_param("enqueue_existing", False))

parts.append(")")
return "\n".join(parts)
Loading
Loading