-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Bump psycopg2 to psycopg3 for all Postgres components #1
Changes from all commits
c8fae69
161fd9d
df0f440
c200254
8618579
86c22e5
8fb9224
0393d6d
2a7ad9a
995fc08
6f1a436
bfb5e60
46f7c96
3d3b88c
6625a2c
d0f77ca
b7de4b2
60f5a23
f9fa158
e2d2188
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,13 +2,22 @@ | |
import logging | ||
from collections import defaultdict | ||
from datetime import datetime | ||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Dict, | ||
Generator, | ||
List, | ||
Literal, | ||
Optional, | ||
Sequence, | ||
Tuple, | ||
) | ||
|
||
import psycopg2 | ||
import pytz | ||
from psycopg2 import sql | ||
from psycopg2.extras import execute_values | ||
from psycopg2.pool import SimpleConnectionPool | ||
from psycopg import sql | ||
from psycopg.connection import Connection | ||
from psycopg_pool import ConnectionPool | ||
|
||
from feast import Entity | ||
from feast.feature_view import FeatureView | ||
|
@@ -39,15 +48,17 @@ class PostgreSQLOnlineStoreConfig(PostgreSQLConfig): | |
|
||
|
||
class PostgreSQLOnlineStore(OnlineStore): | ||
_conn: Optional[psycopg2._psycopg.connection] = None | ||
_conn_pool: Optional[SimpleConnectionPool] = None | ||
_conn: Optional[Connection] = None | ||
_conn_pool: Optional[ConnectionPool] = None | ||
|
||
@contextlib.contextmanager | ||
def _get_conn(self, config: RepoConfig): | ||
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: | ||
assert config.online_store.type == "postgres" | ||
|
||
if config.online_store.conn_type == ConnectionType.pool: | ||
if not self._conn_pool: | ||
self._conn_pool = _get_connection_pool(config.online_store) | ||
self._conn_pool.open() | ||
connection = self._conn_pool.getconn() | ||
yield connection | ||
self._conn_pool.putconn(connection) | ||
|
@@ -64,57 +75,56 @@ def online_write_batch( | |
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] | ||
], | ||
progress: Optional[Callable[[int], Any]], | ||
batch_size: int = 5000, | ||
) -> None: | ||
project = config.project | ||
# Format insert values | ||
insert_values = [] | ||
for entity_key, values, timestamp, created_ts in data: | ||
entity_key_bin = serialize_entity_key( | ||
entity_key, | ||
entity_key_serialization_version=config.entity_key_serialization_version, | ||
) | ||
timestamp = _to_naive_utc(timestamp) | ||
if created_ts is not None: | ||
created_ts = _to_naive_utc(created_ts) | ||
|
||
with self._get_conn(config) as conn, conn.cursor() as cur: | ||
insert_values = [] | ||
for entity_key, values, timestamp, created_ts in data: | ||
entity_key_bin = serialize_entity_key( | ||
entity_key, | ||
entity_key_serialization_version=config.entity_key_serialization_version, | ||
) | ||
timestamp = _to_naive_utc(timestamp) | ||
if created_ts is not None: | ||
created_ts = _to_naive_utc(created_ts) | ||
|
||
for feature_name, val in values.items(): | ||
vector_val = None | ||
if config.online_store.pgvector_enabled: | ||
vector_val = get_list_val_str(val) | ||
insert_values.append( | ||
( | ||
entity_key_bin, | ||
feature_name, | ||
val.SerializeToString(), | ||
vector_val, | ||
timestamp, | ||
created_ts, | ||
) | ||
for feature_name, val in values.items(): | ||
vector_val = None | ||
if config.online_store.pgvector_enabled: | ||
vector_val = get_list_val_str(val) | ||
insert_values.append( | ||
( | ||
entity_key_bin, | ||
feature_name, | ||
val.SerializeToString(), | ||
vector_val, | ||
timestamp, | ||
created_ts, | ||
) | ||
# Control the batch so that we can update the progress | ||
batch_size = 5000 | ||
) | ||
|
||
# Create insert query | ||
sql_query = sql.SQL( | ||
""" | ||
INSERT INTO {} | ||
(entity_key, feature_name, value, vector_value, event_ts, created_ts) | ||
VALUES (%s, %s, %s, %s, %s, %s) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1 out of 2 actual changes to the function: We need to explicitly set the number of placeholder values. |
||
ON CONFLICT (entity_key, feature_name) DO | ||
UPDATE SET | ||
value = EXCLUDED.value, | ||
vector_value = EXCLUDED.vector_value, | ||
event_ts = EXCLUDED.event_ts, | ||
created_ts = EXCLUDED.created_ts; | ||
""" | ||
).format(sql.Identifier(_table_id(config.project, table))) | ||
Comment on lines
+80
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No changes here, only moving code further up in the function to make it more readable. |
||
|
||
# Push data in batches to online store | ||
with self._get_conn(config) as conn, conn.cursor() as cur: | ||
for i in range(0, len(insert_values), batch_size): | ||
cur_batch = insert_values[i : i + batch_size] | ||
execute_values( | ||
cur, | ||
sql.SQL( | ||
""" | ||
INSERT INTO {} | ||
(entity_key, feature_name, value, vector_value, event_ts, created_ts) | ||
VALUES %s | ||
ON CONFLICT (entity_key, feature_name) DO | ||
UPDATE SET | ||
value = EXCLUDED.value, | ||
vector_value = EXCLUDED.vector_value, | ||
event_ts = EXCLUDED.event_ts, | ||
created_ts = EXCLUDED.created_ts; | ||
""", | ||
).format(sql.Identifier(_table_id(project, table))), | ||
cur_batch, | ||
page_size=batch_size, | ||
) | ||
cur.executemany(sql_query, cur_batch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 out of 2 actual changes to the function: The |
||
conn.commit() | ||
|
||
if progress: | ||
progress(len(cur_batch)) | ||
|
||
|
@@ -172,7 +182,9 @@ def online_read( | |
# when we iterate through the keys since they are in the correct order | ||
values_dict = defaultdict(list) | ||
for row in rows if rows is not None else []: | ||
values_dict[row[0].tobytes()].append(row[1:]) | ||
values_dict[ | ||
row[0] if isinstance(row[0], bytes) else row[0].tobytes() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only call |
||
].append(row[1:]) | ||
|
||
for key in keys: | ||
if key in values_dict: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make configurable, addressing feast-dev#4036