Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Pgvector patch #4103

Merged
merged 4 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,25 @@ test-python-universal-postgres-online:
not test_snowflake" \
sdk/python/tests

test-python-universal-pgvector-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.pgvector_repo_configuration \
PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \
python -m pytest -n 8 --integration \
-k "not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_reorder_columns and \
not test_logged_features_validation and \
not test_lambda_materialization_consistency and \
not test_offline_write and \
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_universal_types and \
not test_snowflake" \
sdk/python/tests

test-python-universal-mysql-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.mysql_repo_configuration \
Expand Down
28 changes: 28 additions & 0 deletions docs/reference/online-stores/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ online_store:
sslkey_path: /path/to/client-key.pem
sslcert_path: /path/to/client-cert.pem
sslrootcert_path: /path/to/server-ca.pem
pgvector_enabled: false
vector_len: 512
```
{% endcode %}

Expand Down Expand Up @@ -60,3 +62,29 @@ Below is a matrix indicating which functionality is supported by the Postgres on
| collocated by entity key | no |

To compare this set of functionality against other online stores, please see the full [functionality matrix](overview.md#functionality-matrix).

## PGVector
The Postgres online store supports the use of [PGVector](https://pgvector.dev/) for storing feature values.
To enable PGVector, set `pgvector_enabled: true` in the online store configuration.
The `vector_len` parameter can be used to specify the length of the vector. The default value is 512.

Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector.

{% code title="python" %}
```python
from feast import FeatureStore
from feast.infra.online_stores.postgres import retrieve_online_documents

feature_store = FeatureStore(repo_path=".")

query_vector = [0.1, 0.2, 0.3, 0.4, 0.5]
top_k = 5

feature_values = retrieve_online_documents(
feature_store=feature_store,
feature_view_name="document_fv:embedding_float",
query_vector=query_vector,
top_k=top_k,
)
```
{% endcode %}
9 changes: 7 additions & 2 deletions sdk/python/feast/infra/key_encoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,13 @@ def serialize_entity_key(
return b"".join(output)


def get_val_str(val):
accept_value_types = ["float_list_val", "double_list_val", "int_list_val"]
def get_list_val_str(val):
accept_value_types = [
"float_list_val",
"double_list_val",
"int32_list_val",
"int64_list_val",
]
for accept_type in accept_value_types:
if val.HasField(accept_type):
return str(getattr(val, accept_type).val)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PGVectorOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
),
]
51 changes: 31 additions & 20 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple

import psycopg2
import pytz
Expand All @@ -12,7 +12,7 @@

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
Expand Down Expand Up @@ -74,19 +74,18 @@ def online_write_batch(
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
val_str: Union[str, bytes]
vector_val = None
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
"pgvector_enabled" in config.online_store
and config.online_store.pgvector_enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be some sort of a feature view-level config? This global config makes it impossible to use postgres both for traditional feature lookup and vector search at the same time, isn't that right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good point. Before the pr we were unable to do that because vector is mixed with feature value. Now it's doable as vector is a new feature now. Let me see how that works

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tokoko on a second thought, I would rather not mess up the feature view at the moment. Until we figure out a better design.
Also it still won't impact the original read write flow with or without the pgvrctor_enabled.
And Postgres user has to install the extension before it's able to use the pgvector, so it's better to have this config at high level as well.

):
val_str = get_val_str(val)
else:
val_str = val.SerializeToString()
vector_val = get_list_val_str(val)
insert_values.append(
(
entity_key_bin,
feature_name,
val_str,
val.SerializeToString(),
vector_val,
timestamp,
created_ts,
)
Expand All @@ -100,11 +99,12 @@ def online_write_batch(
sql.SQL(
"""
INSERT INTO {}
(entity_key, feature_name, value, event_ts, created_ts)
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
VALUES %s
ON CONFLICT (entity_key, feature_name) DO
UPDATE SET
value = EXCLUDED.value,
vector_value = EXCLUDED.vector_value,
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
""",
Expand Down Expand Up @@ -226,20 +226,23 @@ def update(

for table in tables_to_keep:
table_name = _table_id(project, table)
value_type = "BYTEA"
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
"pgvector_enabled" in config.online_store
and config.online_store.pgvector_enabled
):
value_type = f'vector({config.online_config["vector_len"]})'
vector_value_type = f"vector({config.online_store.vector_len})"
else:
# keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
vector_value_type = "BYTEA"
cur.execute(
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {}
(
entity_key BYTEA,
feature_name TEXT,
value {},
value BYTEA,
vector_value {} NULL,
event_ts TIMESTAMPTZ,
created_ts TIMESTAMPTZ,
PRIMARY KEY(entity_key, feature_name)
Expand All @@ -248,7 +251,7 @@ def update(
"""
).format(
sql.Identifier(table_name),
sql.SQL(value_type),
sql.SQL(vector_value_type),
sql.Identifier(f"{table_name}_ek"),
sql.Identifier(table_name),
)
Expand Down Expand Up @@ -294,6 +297,14 @@ def retrieve_online_documents(
"""
project = config.project

if (
"pgvector_enabled" not in config.online_store
or not config.online_store.pgvector_enabled
):
raise ValueError(
"pgvector is not enabled in the online store configuration"
)

# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

Expand All @@ -311,8 +322,8 @@ def retrieve_online_documents(
SELECT
entity_key,
feature_name,
value,
value <-> %s as distance,
vector_value,
vector_value <-> %s as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
Expand All @@ -327,13 +338,13 @@ def retrieve_online_documents(
)
rows = cur.fetchall()

for entity_key, feature_name, value, distance, event_ts in rows:
for entity_key, feature_name, vector_value, distance, event_ts in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

# TODO Convert to List[float] for value type proto
feature_value_proto = ValueProto(string_val=value)
feature_value_proto = ValueProto(string_val=vector_value)

distance_value_proto = ValueProto(float_val=distance)
result.append((event_ts, feature_value_proto, distance_value_proto))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PGVectorOnlineStoreCreator,
PostgresOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(
online_store="postgres", online_store_creator=PostgresOnlineStoreCreator
),
IntegrationTestRepoConfig(
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
),
]

AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE EXTENSION IF NOT EXISTS vector;
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict

from testcontainers.core.container import DockerContainer
Expand Down Expand Up @@ -37,12 +38,17 @@ def teardown(self):
class PGVectorOnlineStoreCreator(OnlineStoreCreator):
def __init__(self, project_name: str, **kwargs):
super().__init__(project_name)
script_directory = os.path.dirname(os.path.abspath(__file__))
self.container = (
DockerContainer("pgvector/pgvector:pg16")
.with_env("POSTGRES_USER", "root")
.with_env("POSTGRES_PASSWORD", "test")
.with_env("POSTGRES_DB", "test")
.with_exposed_ports(5432)
.with_volume_mapping(
os.path.join(script_directory, "init.sql"),
"/docker-entrypoint-initdb.d/init.sql",
)
)

def create_online_store(self) -> Dict[str, str]:
Expand All @@ -51,8 +57,10 @@ def create_online_store(self) -> Dict[str, str]:
wait_for_logs(
container=self.container, predicate=log_string_to_wait_for, timeout=10
)
command = "psql -h localhost -p 5432 -U root -d test -c 'CREATE EXTENSION IF NOT EXISTS vector;'"
self.container.exec(command)
init_log_string_to_wait_for = "PostgreSQL init process complete"
wait_for_logs(
container=self.container, predicate=init_log_string_to_wait_for, timeout=10
)
return {
"host": "localhost",
"type": "postgres",
Expand Down
Loading