-
Notifications
You must be signed in to change notification settings - Fork 164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add alembic operations for vectorizer #266
base: main
Are you sure you want to change the base?
Changes from all commits
522850a
0ea53e4
5675809
149cac8
f274f3d
04bfae6
cd42edb
5dff77c
ff4c8dc
7bfedb3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,7 +164,7 @@ for post, embedding in results: | |
|
||
## Working with alembic | ||
|
||
|
||
### Excluding managed tables | ||
The `vectorizer_relationship` generates a new SQLAlchemy model, that is available under the attribute that you specify. If you are using alembic's autogenerate functionality to generate migrations, you will need to exclude these models from the autogenerate process. | ||
These are added to a list in your metadata called `pgai_managed_tables` and you can exclude them by adding the following to your `env.py`: | ||
|
||
|
@@ -182,3 +182,49 @@ context.configure( | |
``` | ||
|
||
This should now prevent alembic from generating tables for these models when you run `alembic revision --autogenerate`. | ||
|
||
|
||
### Creating vectorizers | ||
pgai provides native Alembic operations for managing vectorizers. For them to work you need to run `register_operations` in your env.py file. Which registers the pgai operations under the global op context: | ||
|
||
```python | ||
from pgai.alembic import register_operations | ||
|
||
register_operations() | ||
``` | ||
|
||
Then you can use the `create_vectorizer` operation to create a vectorizer for your model. As well as the `drop_vectorizer` operation to remove it. | ||
|
||
```python | ||
from alembic import op | ||
from pgai.alembic.configuration import ( | ||
OpenAIConfig, | ||
CharacterTextSplitterConfig, | ||
PythonTemplateConfig | ||
) | ||
|
||
|
||
def upgrade() -> None: | ||
op.create_vectorizer( | ||
source_table="blog", | ||
target_table='blog_embeddings', | ||
embedding=OpenAIConfig( | ||
model='text-embedding-3-small', | ||
dimensions=768 | ||
), | ||
chunking=CharacterTextSplitterConfig( | ||
chunk_column='content', | ||
chunk_size=800, | ||
chunk_overlap=400, | ||
separator='.', | ||
is_separator_regex=False | ||
), | ||
formatting=PythonTemplateConfig(template='$title - $chunk') | ||
) | ||
|
||
|
||
def downgrade() -> None: | ||
op.drop_vectorizer(vectorizer_id=1, drop_all=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think it would be better for this to take the target_table name and not the vectorizer_id (which would probably not be known when writing the migration). The target table should be unique and so we should be able to look up the id from that |
||
``` | ||
|
||
The `create_vectorizer` operation supports all configuration options available in the [SQL API](vectorizer-api-reference.md). |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,10 @@ WORKDIR /pgai | |
COPY . . | ||
RUN just build install | ||
|
||
RUN mkdir -p /docker-entrypoint-initdb.d && \ | ||
echo "#!/bin/bash" > /docker-entrypoint-initdb.d/configure-timescaledb.sh && \ | ||
echo "echo \"shared_preload_libraries = 'timescaledb'\" >> \${PGDATA}/postgresql.conf" >> /docker-entrypoint-initdb.d/configure-timescaledb.sh && \ | ||
chmod +x /docker-entrypoint-initdb.d/configure-timescaledb.sh | ||
Comment on lines
+55
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to add this to be able to run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did you need timescaledb for this pr? This is a dev image so this is fine I'm just curious There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because I am actually running the migrations in tests and creating a vectorizer with scheduling config is not allowed if timescaledb is not installed. |
||
|
||
############################################################################### | ||
# image for use in extension development | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from pgai.alembic.operations import ( | ||
CreateVectorizerOp, | ||
DropVectorizerOp, | ||
register_operations, | ||
) | ||
|
||
__all__ = ["CreateVectorizerOp", "DropVectorizerOp", "register_operations"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
import re | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To my eye this is a huge improvement from before |
||
from dataclasses import dataclass, fields | ||
from datetime import timedelta | ||
from typing import ClassVar, Literal, Protocol, runtime_checkable | ||
|
||
from pydantic import BaseModel | ||
|
||
from pgai.vectorizer.base import ( | ||
BaseOllamaConfig, | ||
BaseOpenAIConfig, | ||
BaseProcessing, | ||
BasePythonTemplate, | ||
BaseVoyageAIConfig, | ||
ChunkingCharacterTextSplitter, | ||
ChunkingRecursiveCharacterTextSplitter, | ||
) | ||
|
||
|
||
@runtime_checkable | ||
class SQLArgumentProvider(Protocol): | ||
def to_sql_argument(self) -> str: ... | ||
|
||
|
||
@runtime_checkable | ||
class PythonArgProvider(Protocol): | ||
def to_python_arg(self) -> str: ... | ||
|
||
|
||
def format_sql_params(params: dict[str, str | None | bool | list[str]]) -> str: | ||
"""Format dictionary of parameters into SQL argument string without any quoting.""" | ||
formatted: list[str] = [] | ||
for key, value in params.items(): | ||
if value is None: | ||
continue | ||
elif isinstance(value, bool): | ||
formatted.append(f"{key}=>{str(value).lower()}") | ||
elif isinstance(value, list): | ||
array_list = ",".join(f"E'{v}'" for v in value) | ||
formatted.append(f"{key}=>ARRAY[{array_list}]") | ||
elif isinstance(value, timedelta): | ||
formatted.append(f"{key}=>'{value.seconds} seconds'") | ||
else: | ||
formatted.append(f"{key}=> '{value}'") | ||
return ", ".join(formatted) | ||
|
||
|
||
class SQLArgumentMixin: | ||
arg_type: ClassVar[str] | ||
function_name: ClassVar[str | None] = None | ||
|
||
def to_sql_argument(self) -> str: | ||
# Get all fields including from parent classes | ||
params = {} | ||
for field_name, _field in self.model_fields.items(): # type: ignore | ||
if field_name != "arg_type": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the function_name field included then, how does that work? |
||
value = getattr(self, field_name) # type: ignore | ||
if value is not None: | ||
params[field_name] = value | ||
|
||
if self.function_name: | ||
fn_name = self.function_name | ||
else: | ||
base_name = self.__class__.__name__ | ||
# Convert camelCase to snake_case | ||
base_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", base_name).lower() | ||
# Remove 'config' and clean up any double underscores | ||
base_name = base_name.replace("config", "").strip("_") | ||
fn_name = f"{self.arg_type}_{base_name}" | ||
|
||
return f", {self.arg_type} => ai.{fn_name}({format_sql_params(params)})" # type: ignore | ||
|
||
|
||
class OpenAIConfig(BaseOpenAIConfig, SQLArgumentMixin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naming (and I know naming discussions are always annoying) but why not stick to the sql convention we established and name this EmbeddingOpenAIConfig or EmbeddingConfigOpenAI? (and similar for others). The pro is that the name translation from sql->python is super easy and I think would be easier to understand. The con is that it's long. Otherwise the translation seems a bit ad-hoc. e.g. Indexing configs have "indexing" in the name but in a different spot than the sql. Let's think about this some more |
||
arg_type: ClassVar[str] = "embedding" | ||
function_name: ClassVar[str] = "embedding_openai" # type: ignore | ||
chat_user: str | None = None | ||
api_key_name: str | None = None | ||
|
||
|
||
class VoyageAIConfig(BaseVoyageAIConfig, SQLArgumentMixin): | ||
arg_type: ClassVar[str] = "embedding" | ||
function_name: ClassVar[str] = "embedding_voyageai" # type: ignore | ||
api_key_name: str | None = None | ||
|
||
|
||
class OllamaConfig(BaseOllamaConfig, SQLArgumentMixin): | ||
arg_type: ClassVar[str] = "embedding" | ||
|
||
|
||
class CharacterTextSplitterConfig(ChunkingCharacterTextSplitter, SQLArgumentMixin): | ||
arg_type: ClassVar[str] = "chunking" | ||
|
||
|
||
class RecursiveCharacterTextSplitterConfig( | ||
ChunkingRecursiveCharacterTextSplitter, SQLArgumentMixin | ||
): | ||
arg_type: ClassVar[str] = "chunking" | ||
|
||
|
||
class ChunkValueConfig: | ||
arg_type: ClassVar[str] = "formatting" | ||
|
||
def to_sql_argument(self) -> str: | ||
return f", {self.arg_type: ClassVar[str]} => ai.formatting_chunk_value()" | ||
|
||
|
||
class PythonTemplateConfig(BasePythonTemplate, SQLArgumentMixin): | ||
arg_type: ClassVar[str] = "formatting" | ||
|
||
|
||
class NoIndexingConfig: | ||
arg_type: ClassVar[str] = "indexing" | ||
|
||
def to_sql_argument(self) -> str: | ||
return f", {self.arg_type: ClassVar[str]} => ai.indexing_none()" | ||
|
||
|
||
class DiskANNIndexingConfig(BaseModel, SQLArgumentMixin): | ||
arg_type: ClassVar[str] = "indexing" | ||
function_name: ClassVar[str] = "indexing_diskann" # type: ignore | ||
min_rows: int | ||
storage_layout: Literal["memory_optimized", "plain"] | None = None | ||
num_neighbors: int | None = None | ||
search_list_size: int | None = None | ||
max_alpha: float | None = None | ||
num_dimensions: int | None = None | ||
num_bits_per_dimension: int | None = None | ||
create_when_queue_empty: bool | None = None | ||
|
||
|
||
class HNSWIndexingConfig(BaseModel, SQLArgumentMixin): | ||
arg_type: ClassVar[str] = "indexing" | ||
function_name: ClassVar[str] = "indexing_hnsw" # type: ignore | ||
min_rows: int | None = None | ||
opclass: Literal["vector_cosine_ops", "vector_l1_ops", "vector_ip_ops"] | None = ( | ||
None | ||
) | ||
m: int | None = None | ||
ef_construction: int | None = None | ||
create_when_queue_empty: bool | None = None | ||
|
||
|
||
class NoSchedulingConfig: | ||
arg_type: ClassVar[str] = "scheduling" | ||
|
||
def to_sql_argument(self) -> str: | ||
return f", {self.arg_type: ClassVar[str]} => ai.scheduling_none()" | ||
|
||
|
||
class TimescaleSchedulingConfig(BaseModel, SQLArgumentMixin): | ||
arg_type: ClassVar[str] = "scheduling" | ||
function_name: ClassVar[str] = "scheduling_timescaledb" # type: ignore | ||
|
||
schedule_interval: timedelta | None = None | ||
initial_start: str | None = None | ||
job_id: int | None = None | ||
fixed_schedule: bool | None = None | ||
timezone: str | None = None | ||
|
||
|
||
class ProcessingConfig(BaseProcessing, SQLArgumentMixin): | ||
arg_type: ClassVar[str] = "processing" | ||
function_name: ClassVar[str] = "processing_default" # type: ignore | ||
|
||
|
||
def format_string_param(name: str, value: str) -> str: | ||
return f", {name} => '{value}'" | ||
|
||
|
||
def format_bool_param(name: str, value: bool) -> str: | ||
return f", {name} => {str(value).lower()}" | ||
|
||
|
||
@dataclass | ||
class CreateVectorizerParams: | ||
source_table: str | None | ||
embedding: OpenAIConfig | OllamaConfig | VoyageAIConfig | None = None | ||
chunking: ( | ||
CharacterTextSplitterConfig | RecursiveCharacterTextSplitterConfig | None | ||
) = None | ||
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | NoIndexingConfig | None = ( | ||
None | ||
) | ||
formatting: ChunkValueConfig | PythonTemplateConfig | None = None | ||
scheduling: TimescaleSchedulingConfig | NoSchedulingConfig | None = None | ||
processing: ProcessingConfig | None = None | ||
target_schema: str | None = None | ||
target_table: str | None = None | ||
view_schema: str | None = None | ||
view_name: str | None = None | ||
queue_schema: str | None = None | ||
queue_table: str | None = None | ||
grant_to: list[str] | None = None | ||
enqueue_existing: bool = True | ||
|
||
def to_sql(self) -> str: | ||
parts = ["SELECT ai.create_vectorizer(", f"'{self.source_table}'::regclass"] | ||
|
||
# Handle all config objects that implement to_sql_argument | ||
for field in fields(self): | ||
value = getattr(self, field.name) | ||
if isinstance(value, SQLArgumentProvider): | ||
parts.append(value.to_sql_argument()) | ||
|
||
# Handle string parameters | ||
string_fields = [ | ||
"target_schema", | ||
"target_table", | ||
"view_schema", | ||
"view_name", | ||
"queue_schema", | ||
"queue_table", | ||
] | ||
for field in string_fields: | ||
value = getattr(self, field) | ||
if value is not None: | ||
parts.append(format_string_param(field, value)) | ||
|
||
if self.grant_to: | ||
grant_list = ", ".join(f"'{user}'" for user in self.grant_to) | ||
parts.append(f", grant_to => ai.grant_to({grant_list})") | ||
|
||
if not self.enqueue_existing: | ||
parts.append(format_bool_param("enqueue_existing", False)) | ||
|
||
parts.append(")") | ||
return "\n".join(parts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we also need to add docs to adding-embedding-integration.md