Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Fix missing unique key when searching embeddings #226

Merged
merged 5 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 49 additions & 28 deletions greenplumpython/experimental/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import greenplumpython as gp
from greenplumpython.row import Row
from greenplumpython.type import TypeCast
from greenplumpython.type import TypeCast, _serialize_to_expr


@gp.create_function
Expand Down Expand Up @@ -39,7 +39,7 @@ class ObjectAddress(ctypes.Structure):


@gp.create_function
def _generate_embedding(content: str, model_name: str) -> gp.type_("vector"): # type: ignore reportUnknownParameterType
def create_embedding(content: str, model_name: str) -> gp.type_("vector"): # type: ignore reportUnknownParameterType
import sys

import sentence_transformers # type: ignore reportMissingImports
Expand Down Expand Up @@ -134,7 +134,7 @@ def create_index(
Callable[[gp.DataFrame], TypeCast],
# FIXME: Modifier must be adapted to all types of model.
# Can this be done with transformers.AutoConfig?
lambda t: gp.type_("vector", modifier=embedding_dimension)(_generate_embedding(t[column], model_name)), # type: ignore reportUnknownLambdaType
lambda t: gp.type_("vector", modifier=embedding_dimension)(create_embedding(t[column], model_name)), # type: ignore reportUnknownLambdaType
)
},
)[embedding_df_cols]
Expand All @@ -152,22 +152,36 @@ def create_index(
)
assert self._dataframe._db is not None
_record_dependency._create_in_db(self._dataframe._db)
query_col_names = _serialize_to_expr(
list(self._dataframe.unique_key) + [column], self._dataframe._db
)
sql_add_relationship = f"""
DO $$
BEGIN
SET LOCAL allow_system_table_mods TO ON;

WITH embedding_info AS (
SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS embedding_relid, attnum, '{model_name}' AS model
FROM pg_attribute

WITH attnum_map AS (
SELECT attname, attnum FROM pg_attribute
WHERE
attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND
attname = '{column}'
EXISTS (
SELECT FROM unnest({query_col_names}) AS query
WHERE attname = query
)
), embedding_info AS (
SELECT
'{embedding_df._qualified_table_name}'::regclass::oid AS embedding_relid,
attnum AS content_attnum,
{len(self._dataframe._unique_key) + 1} AS embedding_attnum,
'{model_name}' AS model,
ARRAY(SELECT attnum FROM attnum_map WHERE attname != '{column}') AS unique_key
FROM attnum_map
WHERE attname = '{column}'
)
UPDATE pg_class
SET reloptions = array_append(
reloptions,
format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info))
reloptions,
format('_pygp_emb_%s=%s', content_attnum::text, to_json(embedding_info))
)
FROM embedding_info
WHERE oid = '{self._dataframe._qualified_table_name}'::regclass::oid;
Expand All @@ -177,6 +191,7 @@ def create_index(
'{self._dataframe._qualified_table_name}'::regclass::oid,
'{embedding_df._qualified_table_name}'::regclass::oid
);

IF version() LIKE '%Greenplum%' THEN
PERFORM
{_record_dependency._qualified_name_str}(
Expand Down Expand Up @@ -210,9 +225,9 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame:
embdedding_info = self._dataframe._db._execute(
f"""
WITH indexed_col_info AS (
SELECT attrelid, attnum
SELECT attrelid, attnum AS content_attnum
FROM pg_attribute
WHERE
WHERE
attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND
attname = '{column}'
), reloptions AS (
Expand All @@ -222,39 +237,45 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame:
), embedding_info_json AS (
SELECT split_part(option, '=', 2)::json AS val
FROM reloptions, indexed_col_info
WHERE option LIKE format('_pygp_emb_%s=%%', attnum)
WHERE option LIKE format('_pygp_emb_%s=%%', content_attnum)
), embedding_info AS (
SELECT *
FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text)
FROM embedding_info_json, json_to_record(val) AS (
embedding_attnum int4, embedding_relid oid, model text, unique_key int[]
)
), unique_key_names AS (
SELECT ARRAY(
SELECT attname FROM pg_attribute
WHERE attrelid = embedding_relid AND attnum = ANY(unique_key)
) AS val
FROM embedding_info
)
SELECT nspname, relname, attname, model
FROM embedding_info, pg_class, pg_namespace, pg_attribute
SELECT nspname, relname, attname, model, unique_key_names.val AS unique_key
FROM embedding_info, pg_class, pg_namespace, pg_attribute, unique_key_names
WHERE
pg_class.oid = embedding_relid AND
relnamespace = pg_namespace.oid AND
embedding_relid = attrelid AND
pg_attribute.attnum = 2;
embedding_attnum = attnum;
"""
)
row: Row = embdedding_info[0]
schema: str = row["nspname"]
embedding_table_name: str = row["relname"]
model = row["model"]
embedding_col_name = row["attname"]
row: Row = embdedding_info[0] # type: ignore reportUnknownVariableType
schema: str = row["nspname"] # type: ignore reportUnknownVariableType
embedding_table_name: str = row["relname"] # type: ignore reportUnknownVariableType
model = row["model"] # type: ignore reportUnknownVariableType
embedding_col_name = row["attname"] # type: ignore reportUnknownVariableType
embedding_df = self._dataframe._db.create_dataframe(
table_name=embedding_table_name, schema=schema
table_name=embedding_table_name, schema=schema # type: ignore reportUnknownArgumentType
)
unique_key: list[str] = row["unique_key"] # type: ignore reportUnknownVariableType
assert embedding_df is not None
assert self._dataframe.unique_key is not None
distance = gp.operator("<->") # L2 distance is the default operator class in pgvector
return self._dataframe.join(
embedding_df.assign(
distance=lambda t: distance(
embedding_df[embedding_col_name], _generate_embedding(query, model)
)
distance=lambda t: distance(t[embedding_col_name], create_embedding(query, model))
).order_by("distance")[:top_k],
how="inner",
on=self._dataframe.unique_key,
on=unique_key, # type: ignore reportUnknownArgumentType
self_columns={"*"},
other_columns={},
)
Expand Down
48 changes: 42 additions & 6 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import pytest

import greenplumpython as gp
import greenplumpython.experimental.embedding as _
beeender marked this conversation as resolved.
Show resolved Hide resolved
from tests import db


def search_embeddings(t: gp.DataFrame):
results = t.embedding().search(column="content", query="apple", top_k=1)
assert len(list(results)) == 1
for row in results:
assert row["content"] == "I like eating apples."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since only one item here no need to for loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Changed.



@pytest.mark.requires_pgvector
def test_embedding_query_string(db: gp.Database):
def test_embedding_query_text(db: gp.Database):
content = ["I have a dog.", "I like eating apples."]
t = (
db.create_dataframe(columns={"id": range(len(content)), "content": content})
.save_as(
table_name="doc",
temp=True,
column_names=["id", "content"],
distribution_key={"id"},
Expand All @@ -19,8 +28,35 @@ def test_embedding_query_string(db: gp.Database):
)
.check_unique(columns={"id"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. check_unique actually does more than *check", it creates indexes which is not obvious from the function name. Shall we consider to rename the function?
  2. Need tests multi columns for check_unique() and search(). Another PR will be fine since it is not relevant to this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word "check" comes from SQL https://www.postgresql.org/docs/current/ddl-constraints.html, like in

CREATE TABLE products (
    product_no integer,
    name text,
    price numeric CHECK (price > 0)
);

AFAIK, creating an index is the only way for database to ensure that a set of columns contains only unique values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a test case for multi-column unique key.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[5.4.1. Check Constraints](https://www.postgresql.org/docs/current/ddl-constraints.html#DDL-CONSTRAINTS-CHECK-CONSTRAINTS)
[5.4.2. Not-Null Constraints](https://www.postgresql.org/docs/current/ddl-constraints.html#DDL-CONSTRAINTS-NOT-NULL)
[5.4.3. Unique Constraints](https://www.postgresql.org/docs/current/ddl-constraints.html#DDL-CONSTRAINTS-UNIQUE-CONSTRAINTS)

5.4.1. Check Constraints

A check constraint is the most generic constraint type. It allows you to specify that the value in a certain column must satisfy a Boolean (truth-value) expression. For instance, to require positive product prices, you could use:

CREATE TABLE products (
    product_no integer,
    name text,
    price numeric CHECK (price > 0)
);

Doesn't this mean CHECK is one kind of constrains, and UNIQUE is another kind of constrain?

Is check in check_unique a verb? Or do I miss anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "constraints" is the object of "check".

That is, "check" is used for "check constraints". In this example price > 0 is the constraint, and uniqueness is another type of constraints.

Therefore, I think it makes sense to call this function check_unique.

)
t = t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2")
df = t.embedding().search(column="content", query="apple", top_k=1)
assert len(list(df)) == 1
for row in df:
assert row["content"] == "I like eating apples."
t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2")
Copy link
Contributor

@ruxuez ruxuez Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to assign result to t anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let me fix it.

search_embeddings(t)

# Ensure that a new DataFrame created from table in database can also be
# searched.
search_embeddings(db.create_dataframe("doc", schema="pg_temp"))


@pytest.mark.requires_pgvector
def test_embedding_multi_col_unique(db: gp.Database):
content = ["I have a dog.", "I like eating apples."]
columns = {"id": range(len(content)), "id2": [1] * len(content), "content": content}
t = (
db.create_dataframe(columns=columns)
.save_as(
temp=True,
column_names=list(columns.keys()),
distribution_key={"id"},
distribution_type="hash",
drop_if_exists=True,
drop_cascade=True,
)
.check_unique(columns={"id", "id2"})
)
t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2")
print(
"reloptions =",
db._execute(
f"SELECT reloptions FROM pg_class WHERE oid = '{t._qualified_table_name}'::regclass"
),
)
search_embeddings(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test always pass? seems we have no assert here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asserts are in search_embeddings().

Loading