From eaf2a1cd118dbad307014da05e2756836dda9c61 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Tue, 5 Dec 2023 00:22:22 -0500 Subject: [PATCH 1/5] Fix missing unique key when searching embeddings Fix missing unique key when searching embeddings Unique key is required when searching embeddings to join the embedding table and the original data table before returing the results. Previously, search on a dataframe that created from an existing table in database failed due to lacking of unique key in the dataframe. This patch fixes the issue by recoding the unique key when `create_index()` in `pg_class` so that the info can be read when `search()`. --- greenplumpython/experimental/embedding.py | 64 +++++++++++++++-------- tests/test_embedding.py | 11 +++- 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index a319e5d2..c7e0a3cc 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -3,7 +3,7 @@ import greenplumpython as gp from greenplumpython.row import Row -from greenplumpython.type import TypeCast +from greenplumpython.type import TypeCast, _serialize_to_expr @gp.create_function @@ -39,7 +39,7 @@ class ObjectAddress(ctypes.Structure): @gp.create_function -def _generate_embedding(content: str, model_name: str) -> gp.type_("vector"): # type: ignore reportUnknownParameterType +def create_embedding(content: str, model_name: str) -> gp.type_("vector"): # type: ignore reportUnknownParameterType import sys import sentence_transformers # type: ignore reportMissingImports @@ -134,7 +134,7 @@ def create_index( Callable[[gp.DataFrame], TypeCast], # FIXME: Modifier must be adapted to all types of model. # Can this be done with transformers.AutoConfig? - lambda t: gp.type_("vector", modifier=embedding_dimension)(_generate_embedding(t[column], model_name)), # type: ignore reportUnknownLambdaType + lambda t: gp.type_("vector", modifier=embedding_dimension)(create_embedding(t[column], model_name)), # type: ignore reportUnknownLambdaType ) }, )[embedding_df_cols] @@ -152,21 +152,34 @@ def create_index( ) assert self._dataframe._db is not None _record_dependency._create_in_db(self._dataframe._db) + query_col_names = _serialize_to_expr( + list(self._dataframe.unique_key) + [column], self._dataframe._db + ) sql_add_relationship = f""" DO $$ BEGIN SET LOCAL allow_system_table_mods TO ON; - - WITH embedding_info AS ( - SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS embedding_relid, attnum, '{model_name}' AS model - FROM pg_attribute + + WITH attnum_map AS ( + SELECT attname, attnum FROM pg_attribute WHERE attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND - attname = '{column}' + EXISTS ( + SELECT FROM unnest({query_col_names}) AS query + WHERE attname = query + ) + ), embedding_info AS ( + SELECT + '{embedding_df._qualified_table_name}'::regclass::oid AS embedding_relid, + attnum, + '{model_name}' AS model, + ARRAY(SELECT attnum FROM attnum_map WHERE attname != '{column}') AS unique_key + FROM attnum_map + WHERE attname = '{column}' ) UPDATE pg_class SET reloptions = array_append( - reloptions, + reloptions, format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info)) ) FROM embedding_info @@ -177,6 +190,7 @@ def create_index( '{self._dataframe._qualified_table_name}'::regclass::oid, '{embedding_df._qualified_table_name}'::regclass::oid ); + IF version() LIKE '%Greenplum%' THEN PERFORM {_record_dependency._qualified_name_str}( @@ -212,7 +226,7 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: WITH indexed_col_info AS ( SELECT attrelid, attnum FROM pg_attribute - WHERE + WHERE attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND attname = '{column}' ), reloptions AS ( @@ -225,10 +239,16 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: WHERE option LIKE format('_pygp_emb_%s=%%', attnum) ), embedding_info AS ( SELECT * - FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text) + FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text, unique_key int[]) + ), unique_key_names AS ( + SELECT ARRAY( + SELECT attname FROM pg_attribute + WHERE attrelid = embedding_relid AND attnum = ANY(unique_key) + ) AS val + FROM embedding_info ) - SELECT nspname, relname, attname, model - FROM embedding_info, pg_class, pg_namespace, pg_attribute + SELECT nspname, relname, attname, model, unique_key_names.val AS unique_key + FROM embedding_info, pg_class, pg_namespace, pg_attribute, unique_key_names WHERE pg_class.oid = embedding_relid AND relnamespace = pg_namespace.oid AND @@ -236,25 +256,25 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: pg_attribute.attnum = 2; """ ) - row: Row = embdedding_info[0] - schema: str = row["nspname"] - embedding_table_name: str = row["relname"] - model = row["model"] - embedding_col_name = row["attname"] + row: Row = embdedding_info[0] # type: ignore reportUnknownVariableType + schema: str = row["nspname"] # type: ignore reportUnknownVariableType + embedding_table_name: str = row["relname"] # type: ignore reportUnknownVariableType + model = row["model"] # type: ignore reportUnknownVariableType + embedding_col_name = row["attname"] # type: ignore reportUnknownVariableType embedding_df = self._dataframe._db.create_dataframe( - table_name=embedding_table_name, schema=schema + table_name=embedding_table_name, schema=schema # type: ignore reportUnknownArgumentType ) + unique_key: list[str] = row["unique_key"] # type: ignore reportUnknownVariableType assert embedding_df is not None - assert self._dataframe.unique_key is not None distance = gp.operator("<->") # L2 distance is the default operator class in pgvector return self._dataframe.join( embedding_df.assign( distance=lambda t: distance( - embedding_df[embedding_col_name], _generate_embedding(query, model) + embedding_df[embedding_col_name], create_embedding(query, model) ) ).order_by("distance")[:top_k], how="inner", - on=self._dataframe.unique_key, + on=unique_key, # type: ignore reportUnknownArgumentType self_columns={"*"}, other_columns={}, ) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 99f397e3..adfb5b84 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -1,6 +1,7 @@ import pytest import greenplumpython as gp +import greenplumpython.experimental.embedding as _ from tests import db @@ -10,6 +11,7 @@ def test_embedding_query_string(db: gp.Database): t = ( db.create_dataframe(columns={"id": range(len(content)), "content": content}) .save_as( + table_name="doc", temp=True, column_names=["id", "content"], distribution_key={"id"}, @@ -19,8 +21,13 @@ def test_embedding_query_string(db: gp.Database): ) .check_unique(columns={"id"}) ) - t = t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2") - df = t.embedding().search(column="content", query="apple", top_k=1) + t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2") + + # For search, don't use the DataFrame returned by create_index(), + # but get a new clean DataFrame from table in database. + tp = db.create_dataframe("doc", schema="pg_temp") + assert t != tp + df = tp.embedding().search(column="content", query="apple", top_k=1) assert len(list(df)) == 1 for row in df: assert row["content"] == "I like eating apples." From 64daab44e784b0272622ff2de1e4d77b0933320b Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Tue, 5 Dec 2023 02:08:36 -0500 Subject: [PATCH 2/5] Add multi-column unique key case --- greenplumpython/experimental/embedding.py | 19 +++++----- tests/test_embedding.py | 45 +++++++++++++++++++---- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index c7e0a3cc..c30ed399 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -171,7 +171,8 @@ def create_index( ), embedding_info AS ( SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS embedding_relid, - attnum, + attnum AS content_attnum, + {len(self._dataframe._unique_key) + 1} AS embedding_attnum, '{model_name}' AS model, ARRAY(SELECT attnum FROM attnum_map WHERE attname != '{column}') AS unique_key FROM attnum_map @@ -180,7 +181,7 @@ def create_index( UPDATE pg_class SET reloptions = array_append( reloptions, - format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info)) + format('_pygp_emb_%s=%s', content_attnum::text, to_json(embedding_info)) ) FROM embedding_info WHERE oid = '{self._dataframe._qualified_table_name}'::regclass::oid; @@ -224,7 +225,7 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: embdedding_info = self._dataframe._db._execute( f""" WITH indexed_col_info AS ( - SELECT attrelid, attnum + SELECT attrelid, attnum AS content_attnum FROM pg_attribute WHERE attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND @@ -236,10 +237,12 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: ), embedding_info_json AS ( SELECT split_part(option, '=', 2)::json AS val FROM reloptions, indexed_col_info - WHERE option LIKE format('_pygp_emb_%s=%%', attnum) + WHERE option LIKE format('_pygp_emb_%s=%%', content_attnum) ), embedding_info AS ( SELECT * - FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text, unique_key int[]) + FROM embedding_info_json, json_to_record(val) AS ( + embedding_attnum int4, embedding_relid oid, model text, unique_key int[] + ) ), unique_key_names AS ( SELECT ARRAY( SELECT attname FROM pg_attribute @@ -253,7 +256,7 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: pg_class.oid = embedding_relid AND relnamespace = pg_namespace.oid AND embedding_relid = attrelid AND - pg_attribute.attnum = 2; + embedding_attnum = attnum; """ ) row: Row = embdedding_info[0] # type: ignore reportUnknownVariableType @@ -269,9 +272,7 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: distance = gp.operator("<->") # L2 distance is the default operator class in pgvector return self._dataframe.join( embedding_df.assign( - distance=lambda t: distance( - embedding_df[embedding_col_name], create_embedding(query, model) - ) + distance=lambda t: distance(t[embedding_col_name], create_embedding(query, model)) ).order_by("distance")[:top_k], how="inner", on=unique_key, # type: ignore reportUnknownArgumentType diff --git a/tests/test_embedding.py b/tests/test_embedding.py index adfb5b84..53b456b2 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -5,6 +5,13 @@ from tests import db +def search_embeddings(t: gp.DataFrame): + results = t.embedding().search(column="content", query="apple", top_k=1) + assert len(list(results)) == 1 + for row in results: + assert row["content"] == "I like eating apples." + + @pytest.mark.requires_pgvector def test_embedding_query_string(db: gp.Database): content = ["I have a dog.", "I like eating apples."] @@ -22,12 +29,34 @@ def test_embedding_query_string(db: gp.Database): .check_unique(columns={"id"}) ) t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2") + search_embeddings(t) - # For search, don't use the DataFrame returned by create_index(), - # but get a new clean DataFrame from table in database. - tp = db.create_dataframe("doc", schema="pg_temp") - assert t != tp - df = tp.embedding().search(column="content", query="apple", top_k=1) - assert len(list(df)) == 1 - for row in df: - assert row["content"] == "I like eating apples." + # Ensure that a new DataFrame created from table in database can also be + # searched. + search_embeddings(db.create_dataframe("doc", schema="pg_temp")) + + +@pytest.mark.requires_pgvector +def test_embedding_multi_col_unique(db: gp.Database): + content = ["I have a dog.", "I like eating apples."] + columns = {"id": range(len(content)), "id2": [1] * len(content), "content": content} + t = ( + db.create_dataframe(columns=columns) + .save_as( + temp=True, + column_names=list(columns.keys()), + distribution_key={"id"}, + distribution_type="hash", + drop_if_exists=True, + drop_cascade=True, + ) + .check_unique(columns={"id", "id2"}) + ) + t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2") + print( + "reloptions =", + db._execute( + f"SELECT reloptions FROM pg_class WHERE oid = '{t._qualified_table_name}'::regclass" + ), + ) + search_embeddings(t) From 8604cb0c5b2eb4d5a16d277bc7e3879b3059aa90 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Tue, 5 Dec 2023 02:27:02 -0500 Subject: [PATCH 3/5] Rephrase --- tests/test_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 53b456b2..4b127c94 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -13,7 +13,7 @@ def search_embeddings(t: gp.DataFrame): @pytest.mark.requires_pgvector -def test_embedding_query_string(db: gp.Database): +def test_embedding_query_text(db: gp.Database): content = ["I have a dog.", "I like eating apples."] t = ( db.create_dataframe(columns={"id": range(len(content)), "content": content}) From c58dd98965a7dde8fb1f7ca9d87109e085aa197c Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Wed, 6 Dec 2023 21:16:25 -0500 Subject: [PATCH 4/5] Make test more rigorous --- tests/test_embedding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 4b127c94..c2e3a4f6 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -8,8 +8,8 @@ def search_embeddings(t: gp.DataFrame): results = t.embedding().search(column="content", query="apple", top_k=1) assert len(list(results)) == 1 - for row in results: - assert row["content"] == "I like eating apples." + row = next(iter(results)) + assert row["content"] == "I like eating apples." @pytest.mark.requires_pgvector @@ -28,7 +28,7 @@ def test_embedding_query_text(db: gp.Database): ) .check_unique(columns={"id"}) ) - t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2") + t = t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2") search_embeddings(t) # Ensure that a new DataFrame created from table in database can also be From 89175b8ab192bd39aa46dfd075e2900d68a8a5bc Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Wed, 6 Dec 2023 21:30:48 -0500 Subject: [PATCH 5/5] Fix test again --- tests/test_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index c2e3a4f6..d95c0aaa 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -52,7 +52,7 @@ def test_embedding_multi_col_unique(db: gp.Database): ) .check_unique(columns={"id", "id2"}) ) - t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2") + t = t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2") print( "reloptions =", db._execute(