diff --git a/wyvern/feature_store/feature_server.py b/wyvern/feature_store/feature_server.py index c09a90b..28e7b8b 100644 --- a/wyvern/feature_store/feature_server.py +++ b/wyvern/feature_store/feature_server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import importlib import logging +import secrets import time import traceback from collections import defaultdict @@ -28,11 +29,14 @@ from feast.value_type import ValueType from google.protobuf.json_format import MessageToDict +from wyvern.clients.snowflake import generate_snowflake_ctx from wyvern.components.features.realtime_features_component import ( RealtimeFeatureComponent, ) from wyvern.config import settings from wyvern.feature_store.historical_feature_util import ( + build_and_merge_feast_tables, + build_and_merge_realtime_pivot_tables, build_historical_real_time_feature_requests, build_historical_registry_feature_requests, process_historical_real_time_features_requests, @@ -41,7 +45,9 @@ ) from wyvern.feature_store.schemas import ( GetHistoricalFeaturesRequest, + GetHistoricalFeaturesRequestV2, GetHistoricalFeaturesResponse, + GetHistoricalFeaturesResponseV2, GetOnlineFeaturesRequest, MaterializeRequest, ) @@ -456,6 +462,62 @@ async def get_historical_features( results=final_df.to_dict(orient="records"), ) + @app.post(f"{settings.WYVERN_HISTORICAL_FEATURES_PATH}_v2") + async def get_historical_features_v2( + data: GetHistoricalFeaturesRequestV2, + ) -> GetHistoricalFeaturesResponseV2: + # Generate a 10-digit hex for the request + random_id = secrets.token_hex(5) + + # convert the data input to pandas dataframe + realtime_features, feast_features = separate_real_time_features(data.features) + valid_realtime_features: List[str] = [] + composite_entities: Dict[str, List[str]] = {} + for realtime_feature in realtime_features: + entity_type_column = RealtimeFeatureComponent.get_entity_type_column( + realtime_feature, + ) + entity_names = RealtimeFeatureComponent.get_entity_names(realtime_feature) + if not entity_type_column or not entity_names: + logger.warning(f"feature={realtime_feature} is not found") + continue + + if len(entity_names) == 2: + composite_entities[entity_type_column] = entity_names + valid_realtime_features.append(realtime_feature) + + composite_columns = ",".join( + [ + " || ':' || ".join(entities) + f" AS {entity_type_column}" + for entity_type_column, entities in composite_entities.items() + ], + ) + composite_historical_feature_table = f"HISTORICAL_FEATURES_{random_id}" + + select_sql = f""" + CREATE TEMPORARY TABLE {composite_historical_feature_table} AS + SELECT *, {composite_columns}, TIMESTAMP as event_timestamp + FROM {data.table} + """ + snowflake_ctx = generate_snowflake_ctx() + snowflake_ctx.cursor().execute(select_sql) + + result_table = build_and_merge_realtime_pivot_tables( + valid_realtime_features, + data.table, + composite_historical_feature_table, + snowflake_ctx, + ) + merged_table = build_and_merge_feast_tables( + store, + feast_features, + result_table, + snowflake_ctx, + ) + return GetHistoricalFeaturesResponseV2( + result_table=merged_table, + ) + return app diff --git a/wyvern/feature_store/historical_feature_util.py b/wyvern/feature_store/historical_feature_util.py index ac7f8ed..0d1b0c2 100644 --- a/wyvern/feature_store/historical_feature_util.py +++ b/wyvern/feature_store/historical_feature_util.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import logging +import re from collections import defaultdict from datetime import datetime from typing import Any, Dict, List, Optional, Tuple @@ -117,6 +118,100 @@ def build_historical_real_time_feature_requests( return result_dict +def build_and_merge_realtime_pivot_tables( + full_feature_names: List[str], + input_table: str, + composite_table: str, + context: SnowflakeConnection, +) -> str: + """ + Build historical real-time feature requests grouped by entity types so that we can process them in parallel. + + Args: + full_feature_names: a list of full feature names. + request_ids: a list of request ids. + entities: a dictionary of entity names and their values. + + Returns: + A dictionary of entity types and their corresponding requests. + """ + features_grouped_by_entity = group_realtime_features_by_entity_type( + full_feature_names=full_feature_names, + ) + counter = 0 + + # prev_table is the previous temporary composite table + prev_table = composite_table + # next_table is the next temporary composite table joined with the next entity type + next_table = f"{composite_table}_0" + + # iterate through all the entity types. + # For each entity type, build a new temporary composite table with all the features for this entity type + for ( + entity_identifier_type, + curr_feature_names, + ) in features_grouped_by_entity.items(): + entity_list = entity_identifier_type.split(SQL_COLUMN_SEPARATOR) + + if len(entity_list) > 2: + logger.warning("Invalid entity_identifier_type={entity_identifier_type}") + continue + curr_feature_names_underscore = [ + fn.replace(":", "__", 1) for fn in curr_feature_names + ] + entity_identifier_type_val = ":".join(entity_list) + feature_names_sql_str = ",".join( + [f"'{fn}'" for fn in curr_feature_names_underscore], + ) + feature_names_with_pivot_table_str = ",".join( + [ + f"PIVOT_TABLE.{feature_name}" + for feature_name in curr_feature_names_underscore + ], + ) + feature_names_pivot_raw = ",".join( + [f"\"'{fn}'\" as {fn}" for fn in curr_feature_names_underscore], + ) + + # TODO: send this sql to snowflake + pivot_sql = f""" + CREATE TEMPORARY TABLE {next_table} AS ( + WITH PIVOT_DATA AS ( + SELECT F.REQUEST_ID AS REQUEST, + F.API_SOURCE, + F.EVENT_TYPE, + F.FEATURE_IDENTIFIER, + F.FEATURE_IDENTIFIER_TYPE, + REPLACE(F.FEATURE_NAME, ':', '__') AS FEATURE_NAME, + F.FEATURE_VALUE + FROM FEATURE_LOGS F + INNER JOIN (SELECT DISTINCT REQUEST FROM {input_table}) T + ON F.REQUEST_ID = T.REQUEST + WHERE F.FEATURE_IDENTIFIER_TYPE = '{entity_identifier_type_val}' + ), PIVOT_TABLE_RAW AS ( + SELECT * + FROM PIVOT_DATA + PIVOT(MAX(FEATURE_VALUE) FOR FEATURE_NAME IN ({feature_names_sql_str})) + ), PIVOT_TABLE AS ( + SELECT REQUEST, FEATURE_IDENTIFIER, FEATURE_IDENTIFIER_TYPE, {feature_names_pivot_raw} + FROM PIVOT_TABLE_RAW + ) + SELECT + {prev_table}.*,{feature_names_with_pivot_table_str} + FROM + {prev_table} + LEFT JOIN PIVOT_TABLE ON + {prev_table}.REQUEST = PIVOT_TABLE.REQUEST AND + {prev_table}.{entity_identifier_type} = PIVOT_TABLE.FEATURE_IDENTIFIER + ) + """ + context.cursor().execute(pivot_sql) + counter += 1 + prev_table = next_table + next_table = f"{composite_table}_{counter}" + return prev_table + + def process_historical_real_time_features_requests( requests: Dict[str, RequestEntityIdentifierObjects], ) -> Dict[str, pd.DataFrame]: @@ -316,6 +411,109 @@ def build_historical_registry_feature_requests( return requests +def build_and_merge_feast_tables( + store: FeatureStore, + feature_names: List[str], + composite_table: str, + context: SnowflakeConnection, +) -> str: + features_grouped_by_entities = group_registry_features_by_entities( + feature_names, + store=store, + ) + counter = 0 + prev_table = composite_table + next_table = f"{composite_table}_0" + for entity_name, feature_names in features_grouped_by_entities.items(): + if not feature_names: + continue + + if FULL_FEATURE_NAME_SEPARATOR in entity_name: + entities = entity_name.split(FULL_FEATURE_NAME_SEPARATOR) + else: + entities = [entity_name] + + if len(entities) > 2: + raise ValueError( + f"Entity name should be singular or composite: {entity_name}", + ) + + feature_columns = [fn.replace(":", "__") for fn in feature_names] + + # TODO: validate that all entities are in the entity_df_table + # for entity in entities: + # if entity not in entity_values: + # raise ValueError( + # f"{feature_names} depends on {entity}. Could not find entity values: {entity}", + # ) + identifier_column = SQL_COLUMN_SEPARATOR.join(entities) + identifier_table_sql_dupe = f""" + SELECT + {identifier_column} AS IDENTIFIER, + event_timestamp, + ROW_NUMBER() OVER (PARTITION BY IDENTIFIER, event_timestamp ORDER BY (SELECT NULL)) as rn + FROM {composite_table} + WHERE {identifier_column} is NOT NULL + """ + + # dedupe (IDENTIFIER, event_timestamp) + identifier_table_sql = f""" + WITH identifier_table_sql_dupe AS ({identifier_table_sql_dupe}) + SELECT IDENTIFIER, event_timestamp + FROM identifier_table_sql_dupe + WHERE rn = 1 + """ + result = store.get_historical_features( + entity_df=identifier_table_sql, + features=feature_names or [], + full_feature_names=True, + ) + result_sql = result.to_sql() + # Strip the leading "WITH " (WITH plus an empty space) + result_sql = result_sql.replace("WITH ", "") + # Replace the table name with 'identifier_tbl', assuming the table name is always + # in the format "feast_entity_df_" followed by a hex string (UUID without dashes) + result_sql = re.sub( + r'"feast_entity_df_[0-9a-f]{32}"', + "identifier_tbl", + result_sql, + flags=re.IGNORECASE, + ) + new_feast_table_sql = f""" + CREATE TEMPORARY TABLE {next_table}_feast AS ( + WITH identifier_tbl_dupe AS ({identifier_table_sql_dupe}), + identifier_tbl AS ( + SELECT IDENTIFIER, event_timestamp + FROM identifier_tbl_dupe + WHERE rn = 1 + ), + {result_sql} + ) + """ + context.cursor().execute(new_feast_table_sql) + + # left join to the previous composite table + picked_feature_columns_str = ", ".join( + [f'{next_table}_feast."{c}"' for c in feature_columns], + ) + new_composite_table_sql = f""" + CREATE TABLE {next_table} AS ( + SELECT {prev_table}.*, {picked_feature_columns_str} + FROM {prev_table} + LEFT JOIN {next_table}_feast + ON {prev_table}.{identifier_column} = {next_table}_feast.IDENTIFIER and + {prev_table}.event_timestamp = {next_table}_feast.event_timestamp + ) + """ + context.cursor().execute(new_composite_table_sql) + + counter += 1 + prev_table = next_table + next_table = f"{composite_table}_{counter}" + + return prev_table + + def process_historical_registry_features_requests( store: FeatureStore, requests: List[GetFeastHistoricalFeaturesRequest], @@ -354,12 +552,18 @@ def process_historical_registry_features_request( entity_df = pd.DataFrame(request.entities) # no timezone is allowed in the timestamp entity_df["event_timestamp"] = entity_df["event_timestamp"].dt.tz_localize(None) + # TODO: use sql to get the result. + # example: + # https://docs.feast.dev/getting-started/concepts/feature-retrieval + # #example-entity-sql-query-for-generating-training-data result = store.get_historical_features( entity_df=entity_df, features=request.features or [], full_feature_names=request.full_feature_names, ) + # TODO: to_sql(); replace IDENTIFIER by the original identifier_type result_df = result.to_df() + # TODO: group IDENTIFIER and event_timestamp result_df.drop_duplicates(subset=["IDENTIFIER", "event_timestamp"], inplace=True) return entity_df.merge( result_df, diff --git a/wyvern/feature_store/schemas.py b/wyvern/feature_store/schemas.py index 7055dee..9d0f4fd 100644 --- a/wyvern/feature_store/schemas.py +++ b/wyvern/feature_store/schemas.py @@ -38,6 +38,21 @@ class GetHistoricalFeaturesRequest(BaseModel): features: List[str] = [] +class GetHistoricalFeaturesRequestV2(BaseModel): + """ + Request object for getting historical features. + + Attributes: + entities: A dictionary of entity name to entity value. + timestamps: A list of timestamps. Used to retrieve historical features at specific timestamps. If not provided, + the latest feature values will be returned. + features: A list of feature names. + """ + + table: str + features: List[str] = [] + + class GetFeastHistoricalFeaturesRequest(BaseModel): """ Request object for getting historical features from Feast. @@ -66,6 +81,17 @@ class GetHistoricalFeaturesResponse(BaseModel): results: List[Dict[str, Any]] = [] +class GetHistoricalFeaturesResponseV2(BaseModel): + """ + Response object for getting historical features. + + Attributes: + result_table: the name of the transient table that holds the result. + """ + + result_table: str + + class MaterializeRequest(BaseModel): """ Request object for materializing feature views.