Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Bump psycopg2 to psycopg3 for all Postgres components #1

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ install-python:
python setup.py develop

lock-python-dependencies:
uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py$(PYTHON)-requirements.txt
uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py$(PYTHON)-requirements.txt

lock-python-dependencies-all:
pixi run --environment py39 --manifest-path infra/scripts/pixi/pixi.toml "uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py3.9-requirements.txt"
Expand Down Expand Up @@ -164,7 +164,7 @@ test-python-universal-mssql:
sdk/python/tests


# To use Athena as an offline store, you need to create an Athena database and an S3 bucket on AWS.
# To use Athena as an offline store, you need to create an Athena database and an S3 bucket on AWS.
# https://docs.aws.amazon.com/athena/latest/ug/getting-started.html
# Modify environment variables ATHENA_REGION, ATHENA_DATA_SOURCE, ATHENA_DATABASE, ATHENA_WORKGROUP or
# ATHENA_S3_BUCKET_NAME according to your needs. If tests fail with the pytest -n 8 option, change the number to 1.
Expand All @@ -191,7 +191,7 @@ test-python-universal-athena:
not s3_registry and \
not test_snowflake" \
sdk/python/tests

test-python-universal-postgres-offline:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.postgres_repo_configuration \
Expand All @@ -209,6 +209,7 @@ test-python-universal-postgres-offline:
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_snowflake and \
not test_universal_types" \
sdk/python/tests

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/using-scalable-registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ When this happens, your database is likely using what is referred to as an
in `SQLAlchemy` terminology. See your database's documentation for examples on
how to set its scheme in the Database URL.

`Psycopg2`, which is the database library leveraged by the online and offline
`Psycopg`, which is the database library leveraged by the online and offline
stores, is not impacted by the need to speak a particular dialect, and so the
following only applies to the registry.

Expand Down
10 changes: 10 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,13 @@ def __init__(self, input_dict: dict):
super().__init__(
f"Failed to serialize the provided dictionary into a pandas DataFrame: {input_dict.keys()}"
)


class ZeroRowsQueryResult(Exception):
def __init__(self, query: str):
super().__init__(f"This query returned zero rows:\n{query}")


class ZeroColumnQueryResult(Exception):
def __init__(self, query: str):
super().__init__(f"This query returned zero columns:\n{query}")
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import pandas as pd
import pyarrow as pa
from jinja2 import BaseLoader, Environment
from psycopg2 import sql
from psycopg import sql
from pytz import utc

from feast.data_source import DataSource
from feast.errors import InvalidEntityType
from feast.errors import InvalidEntityType, ZeroColumnQueryResult, ZeroRowsQueryResult
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source import (
Expand Down Expand Up @@ -274,8 +274,10 @@ def to_sql(self) -> str:
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
with self._query_generator() as query:
with _get_conn(self.config.offline_store) as conn, conn.cursor() as cur:
conn.set_session(readonly=True)
conn.read_only = True
cur.execute(query)
if not cur.description:
raise ZeroColumnQueryResult(query)
fields = [
(c.name, pg_type_code_to_arrow(c.type_code))
for c in cur.description
Expand Down Expand Up @@ -331,16 +333,19 @@ def _get_entity_df_event_timestamp_range(
entity_df_event_timestamp.max().to_pydatetime(),
)
elif isinstance(entity_df, str):
# If the entity_df is a string (SQL query), determine range
# from table
# If the entity_df is a string (SQL query), determine range from table
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
(
cur.execute(
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM ({entity_df}) as tmp_alias"
),
)
query = f"""
SELECT
MIN({entity_df_event_timestamp_col}) AS min,
MAX({entity_df_event_timestamp_col}) AS max
FROM ({entity_df}) AS tmp_alias
"""
cur.execute(query)
res = cur.fetchone()
entity_df_event_timestamp_range = (res[0], res[1])
if not res:
raise ZeroRowsQueryResult(query)
entity_df_event_timestamp_range = (res[0], res[1])
else:
raise InvalidEntityType(type(entity_df))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typeguard import typechecked

from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException
from feast.errors import DataSourceNoNameException, ZeroColumnQueryResult
from feast.infra.utils.postgres.connection_utils import _get_conn
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.SavedDataset_pb2 import (
Expand Down Expand Up @@ -111,7 +111,11 @@ def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
cur.execute(f"SELECT * FROM {self.get_table_query_string()} AS sub LIMIT 0")
query = f"SELECT * FROM {self.get_table_query_string()} AS sub LIMIT 0"
cur.execute(query)
if not cur.description:
raise ZeroColumnQueryResult(query)

return (
(c.name, pg_type_code_to_pg_type(c.type_code)) for c in cur.description
)
Expand Down
120 changes: 66 additions & 54 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
import logging
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Literal,
Optional,
Sequence,
Tuple,
)

import psycopg2
import pytz
from psycopg2 import sql
from psycopg2.extras import execute_values
from psycopg2.pool import SimpleConnectionPool
from psycopg import sql
from psycopg.connection import Connection
from psycopg_pool import ConnectionPool

from feast import Entity
from feast.feature_view import FeatureView
Expand Down Expand Up @@ -39,15 +48,17 @@ class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):


class PostgreSQLOnlineStore(OnlineStore):
_conn: Optional[psycopg2._psycopg.connection] = None
_conn_pool: Optional[SimpleConnectionPool] = None
_conn: Optional[Connection] = None
_conn_pool: Optional[ConnectionPool] = None

@contextlib.contextmanager
def _get_conn(self, config: RepoConfig):
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
assert config.online_store.type == "postgres"

if config.online_store.conn_type == ConnectionType.pool:
if not self._conn_pool:
self._conn_pool = _get_connection_pool(config.online_store)
self._conn_pool.open()
connection = self._conn_pool.getconn()
yield connection
self._conn_pool.putconn(connection)
Expand All @@ -64,57 +75,56 @@ def online_write_batch(
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
batch_size: int = 5000,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make configurable, addressing feast-dev#4036

) -> None:
project = config.project
# Format insert values
insert_values = []
for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)

with self._get_conn(config) as conn, conn.cursor() as cur:
insert_values = []
for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
vector_val = None
if config.online_store.pgvector_enabled:
vector_val = get_list_val_str(val)
insert_values.append(
(
entity_key_bin,
feature_name,
val.SerializeToString(),
vector_val,
timestamp,
created_ts,
)
for feature_name, val in values.items():
vector_val = None
if config.online_store.pgvector_enabled:
vector_val = get_list_val_str(val)
insert_values.append(
(
entity_key_bin,
feature_name,
val.SerializeToString(),
vector_val,
timestamp,
created_ts,
)
# Control the batch so that we can update the progress
batch_size = 5000
)

# Create insert query
sql_query = sql.SQL(
"""
INSERT INTO {}
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
VALUES (%s, %s, %s, %s, %s, %s)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 out of 2 actual changes to the function:

We need to explicitly set the number of placeholder values.

ON CONFLICT (entity_key, feature_name) DO
UPDATE SET
value = EXCLUDED.value,
vector_value = EXCLUDED.vector_value,
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
"""
).format(sql.Identifier(_table_id(config.project, table)))
Comment on lines +80 to +119
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No changes here, only moving code further up in the function to make it more readable.


# Push data in batches to online store
with self._get_conn(config) as conn, conn.cursor() as cur:
for i in range(0, len(insert_values), batch_size):
cur_batch = insert_values[i : i + batch_size]
execute_values(
cur,
sql.SQL(
"""
INSERT INTO {}
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
VALUES %s
ON CONFLICT (entity_key, feature_name) DO
UPDATE SET
value = EXCLUDED.value,
vector_value = EXCLUDED.vector_value,
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
""",
).format(sql.Identifier(_table_id(project, table))),
cur_batch,
page_size=batch_size,
)
cur.executemany(sql_query, cur_batch)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 out of 2 actual changes to the function:

The psycopg2.extras.execute_values functionality is removed in psycopg3. The maintainer of psycopg3 advices to use executemany. See psycopg/psycopg#576 and psycopg/psycopg#114

conn.commit()

if progress:
progress(len(cur_batch))

Expand Down Expand Up @@ -172,7 +182,9 @@ def online_read(
# when we iterate through the keys since they are in the correct order
values_dict = defaultdict(list)
for row in rows if rows is not None else []:
values_dict[row[0].tobytes()].append(row[1:])
values_dict[
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only call tobytes() when row[0] is not already of bytes type. Otherwise, this will result in Errors.

].append(row[1:])

for key in keys:
if key in values_dict:
Expand Down
Loading
Loading