From cea52e9fb02cb9e0b8f48206278474f5a5fa167e Mon Sep 17 00:00:00 2001 From: Tom Steenbergen <41334387+TomSteenbergen@users.noreply.github.com> Date: Mon, 8 Jul 2024 12:29:43 +0200 Subject: [PATCH] feat: Add async feature retrieval for Postgres Online Store (#4327) * Add async retrieval for postgres Signed-off-by: TomSteenbergen * Format Signed-off-by: TomSteenbergen * Update _prepare_keys method Signed-off-by: TomSteenbergen * Fix typo Signed-off-by: TomSteenbergen --------- Signed-off-by: TomSteenbergen --- .../infra/online_stores/contrib/postgres.py | 186 ++++++++++++------ .../infra/utils/postgres/connection_utils.py | 25 ++- .../online_store/test_universal_online.py | 2 +- 3 files changed, 150 insertions(+), 63 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 330b50bc78..ff73a4a347 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import ( Any, + AsyncGenerator, Callable, Dict, Generator, @@ -12,18 +13,24 @@ Optional, Sequence, Tuple, + Union, ) import pytz -from psycopg import sql +from psycopg import AsyncConnection, sql from psycopg.connection import Connection -from psycopg_pool import ConnectionPool +from psycopg_pool import AsyncConnectionPool, ConnectionPool from feast import Entity from feast.feature_view import FeatureView from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool +from feast.infra.utils.postgres.connection_utils import ( + _get_conn, + _get_conn_async, + _get_connection_pool, + _get_connection_pool_async, +) from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto @@ -51,6 +58,9 @@ class PostgreSQLOnlineStore(OnlineStore): _conn: Optional[Connection] = None _conn_pool: Optional[ConnectionPool] = None + _conn_async: Optional[AsyncConnection] = None + _conn_pool_async: Optional[AsyncConnectionPool] = None + @contextlib.contextmanager def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: assert config.online_store.type == "postgres" @@ -67,6 +77,24 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: self._conn = _get_conn(config.online_store) yield self._conn + @contextlib.asynccontextmanager + async def _get_conn_async( + self, config: RepoConfig + ) -> AsyncGenerator[AsyncConnection, Any]: + if config.online_store.conn_type == ConnectionType.pool: + if not self._conn_pool_async: + self._conn_pool_async = await _get_connection_pool_async( + config.online_store + ) + await self._conn_pool_async.open() + connection = await self._conn_pool_async.getconn() + yield connection + await self._conn_pool_async.putconn(connection) + else: + if not self._conn_async: + self._conn_async = await _get_conn_async(config.online_store) + yield self._conn_async + def online_write_batch( self, config: RepoConfig, @@ -132,69 +160,107 @@ def online_read( entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] + keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version) + query, params = self._construct_query_and_params( + config, table, keys, requested_features + ) - project = config.project with self._get_conn(config) as conn, conn.cursor() as cur: - # Collecting all the keys to a list allows us to make fewer round trips - # to PostgreSQL - keys = [] - for entity_key in entity_keys: - keys.append( - serialize_entity_key( - entity_key, - entity_key_serialization_version=config.entity_key_serialization_version, - ) - ) + cur.execute(query, params) + rows = cur.fetchall() - if not requested_features: - cur.execute( - sql.SQL( - """ - SELECT entity_key, feature_name, value, event_ts - FROM {} WHERE entity_key = ANY(%s); - """ - ).format( - sql.Identifier(_table_id(project, table)), - ), - (keys,), - ) - else: - cur.execute( - sql.SQL( - """ - SELECT entity_key, feature_name, value, event_ts - FROM {} WHERE entity_key = ANY(%s) and feature_name = ANY(%s); - """ - ).format( - sql.Identifier(_table_id(project, table)), - ), - (keys, requested_features), - ) + return self._process_rows(keys, rows) - rows = cur.fetchall() + async def online_read_async( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version) + query, params = self._construct_query_and_params( + config, table, keys, requested_features + ) - # Since we don't know the order returned from PostgreSQL we'll need - # to construct a dict to be able to quickly look up the correct row - # when we iterate through the keys since they are in the correct order - values_dict = defaultdict(list) - for row in rows if rows is not None else []: - values_dict[ - row[0] if isinstance(row[0], bytes) else row[0].tobytes() - ].append(row[1:]) - - for key in keys: - if key in values_dict: - value = values_dict[key] - res = {} - for feature_name, value_bin, event_ts in value: - val = ValueProto() - val.ParseFromString(bytes(value_bin)) - res[feature_name] = val - result.append((event_ts, res)) - else: - result.append((None, None)) + async with self._get_conn_async(config) as conn: + async with conn.cursor() as cur: + await cur.execute(query, params) + rows = await cur.fetchall() + + return self._process_rows(keys, rows) + + @staticmethod + def _construct_query_and_params( + config: RepoConfig, + table: FeatureView, + keys: List[bytes], + requested_features: Optional[List[str]] = None, + ) -> Tuple[sql.Composed, Union[Tuple[List[bytes], List[str]], Tuple[List[bytes]]]]: + """Construct the SQL query based on the given parameters.""" + if requested_features: + query = sql.SQL( + """ + SELECT entity_key, feature_name, value, event_ts + FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s); + """ + ).format( + sql.Identifier(_table_id(config.project, table)), + ) + params = (keys, requested_features) + else: + query = sql.SQL( + """ + SELECT entity_key, feature_name, value, event_ts + FROM {} WHERE entity_key = ANY(%s); + """ + ).format( + sql.Identifier(_table_id(config.project, table)), + ) + params = (keys, []) + return query, params + + @staticmethod + def _prepare_keys( + entity_keys: List[EntityKeyProto], entity_key_serialization_version: int + ) -> List[bytes]: + """Prepare all keys in a list to make fewer round trips to the database.""" + return [ + serialize_entity_key( + entity_key, + entity_key_serialization_version=entity_key_serialization_version, + ) + for entity_key in entity_keys + ] + + @staticmethod + def _process_rows( + keys: List[bytes], rows: List[Tuple] + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + """Transform the retrieved rows in the desired output. + PostgreSQL may return rows in an unpredictable order. Therefore, `values_dict` + is created to quickly look up the correct row using the keys, since these are + actually in the correct order. + """ + values_dict = defaultdict(list) + for row in rows if rows is not None else []: + values_dict[ + row[0] if isinstance(row[0], bytes) else row[0].tobytes() + ].append(row[1:]) + + result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] + for key in keys: + if key in values_dict: + value = values_dict[key] + res = {} + for feature_name, value_bin, event_ts in value: + val = ValueProto() + val.ParseFromString(bytes(value_bin)) + res[feature_name] = val + result.append((event_ts, res)) + else: + result.append((None, None)) return result def update( diff --git a/sdk/python/feast/infra/utils/postgres/connection_utils.py b/sdk/python/feast/infra/utils/postgres/connection_utils.py index e0599019b9..7b37ea981f 100644 --- a/sdk/python/feast/infra/utils/postgres/connection_utils.py +++ b/sdk/python/feast/infra/utils/postgres/connection_utils.py @@ -4,8 +4,8 @@ import pandas as pd import psycopg import pyarrow as pa -from psycopg.connection import Connection -from psycopg_pool import ConnectionPool +from psycopg import AsyncConnection, Connection +from psycopg_pool import AsyncConnectionPool, ConnectionPool from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig from feast.type_map import arrow_to_pg_type @@ -21,6 +21,16 @@ def _get_conn(config: PostgreSQLConfig) -> Connection: return conn +async def _get_conn_async(config: PostgreSQLConfig) -> AsyncConnection: + """Get a psycopg `AsyncConnection`.""" + conn = await psycopg.AsyncConnection.connect( + conninfo=_get_conninfo(config), + keepalives_idle=config.keepalives_idle, + **_get_conn_kwargs(config), + ) + return conn + + def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool: """Get a psycopg `ConnectionPool`.""" return ConnectionPool( @@ -32,6 +42,17 @@ def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool: ) +async def _get_connection_pool_async(config: PostgreSQLConfig) -> AsyncConnectionPool: + """Get a psycopg `AsyncConnectionPool`.""" + return AsyncConnectionPool( + conninfo=_get_conninfo(config), + min_size=config.min_conn, + max_size=config.max_conn, + open=False, + kwargs=_get_conn_kwargs(config), + ) + + def _get_conninfo(config: PostgreSQLConfig) -> str: """Get the `conninfo` argument required for connection objects.""" return ( diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 38656b90a9..2ffe869ef5 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -488,7 +488,7 @@ def test_online_retrieval_with_event_timestamps(environment, universal_data_sour @pytest.mark.integration -@pytest.mark.universal_online_stores(only=["redis", "dynamodb"]) +@pytest.mark.universal_online_stores(only=["redis", "dynamodb", "postgres"]) def test_async_online_retrieval_with_event_timestamps( environment, universal_data_sources ):