diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 08f3a93a9d..4c7dc21346 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -274,7 +274,12 @@ def get_historical_features( feature_views = _get_requested_feature_views(feature_refs, all_feature_views) provider = self._get_provider() job = provider.get_historical_features( - self.config, feature_views, feature_refs, entity_df + self.config, + feature_views, + feature_refs, + entity_df, + self._registry, + self.project, ) return job diff --git a/sdk/python/feast/infra/gcp.py b/sdk/python/feast/infra/gcp.py index 465667254e..d8834cc711 100644 --- a/sdk/python/feast/infra/gcp.py +++ b/sdk/python/feast/infra/gcp.py @@ -185,6 +185,8 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], + registry: Registry, + project: str, ) -> RetrievalJob: offline_store = get_offline_store_from_sources( [feature_view.input for feature_view in feature_views] @@ -194,6 +196,8 @@ def get_historical_features( feature_views=feature_views, feature_refs=feature_refs, entity_df=entity_df, + registry=registry, + project=project, ) return job diff --git a/sdk/python/feast/infra/local.py b/sdk/python/feast/infra/local.py index bf2b3aa231..9686de006a 100644 --- a/sdk/python/feast/infra/local.py +++ b/sdk/python/feast/infra/local.py @@ -196,6 +196,8 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], + registry: Registry, + project: str, ) -> RetrievalJob: offline_store = get_offline_store_from_sources( [feature_view.input for feature_view in feature_views] @@ -205,6 +207,8 @@ def get_historical_features( feature_views=feature_views, feature_refs=feature_refs, entity_df=entity_df, + registry=registry, + project=project, ) diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 499bf0ed91..7796c4edbb 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -15,6 +15,7 @@ RetrievalJob, _get_requested_feature_views_to_features_dict, ) +from feast.registry import Registry from feast.repo_config import RepoConfig @@ -70,6 +71,8 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], + registry: Registry, + project: str, ) -> RetrievalJob: # TODO: Add entity_df validation in order to fail before interacting with BigQuery @@ -85,7 +88,9 @@ def get_historical_features( ) # Build a query context containing all information required to template the BigQuery SQL query - query_context = get_feature_view_query_context(feature_refs, feature_views) + query_context = get_feature_view_query_context( + feature_refs, feature_views, registry, project + ) # TODO: Infer min_timestamp and max_timestamp from entity_df # Generate the BigQuery SQL query from the query context @@ -155,7 +160,10 @@ def _upload_entity_df_into_bigquery(project, entity_df) -> str: def get_feature_view_query_context( - feature_refs: List[str], feature_views: List[FeatureView] + feature_refs: List[str], + feature_views: List[FeatureView], + registry: Registry, + project: str, ) -> List[FeatureViewQueryContext]: """Build a query context containing all information required to template a BigQuery point-in-time SQL query""" @@ -165,7 +173,10 @@ def get_feature_view_query_context( query_context = [] for feature_view, features in feature_views_to_feature_map.items(): - entity_names = [entity for entity in feature_view.entities] + join_keys = [] + for entity_name in feature_view.entities: + entity = registry.get_entity(entity_name, project) + join_keys.append(entity.join_key) if isinstance(feature_view.ttl, timedelta): ttl_seconds = int(feature_view.ttl.total_seconds()) @@ -177,7 +188,7 @@ def get_feature_view_query_context( context = FeatureViewQueryContext( name=feature_view.name, ttl=ttl_seconds, - entities=entity_names, + entities=join_keys, features=features, table_ref=feature_view.input.table_ref, event_timestamp_column=feature_view.input.event_timestamp_column, diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 469f6dc0f8..89a9cfec34 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -12,6 +12,7 @@ ENTITY_DF_EVENT_TIMESTAMP_COL, _get_requested_feature_views_to_features_dict, ) +from feast.registry import Registry from feast.repo_config import RepoConfig @@ -35,6 +36,8 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], + registry: Registry, + project: str, ) -> FileRetrievalJob: if not isinstance(entity_df, pd.DataFrame): raise ValueError( @@ -80,7 +83,11 @@ def evaluate_historical_retrieval(): ) # Build a list of entity columns to join on (from the right table) - right_entity_columns = [entity for entity in feature_view.entities] + join_keys = [] + for entity_name in feature_view.entities: + entity = registry.get_entity(entity_name, project) + join_keys.append(entity.join_key) + right_entity_columns = join_keys right_entity_key_columns = [ event_timestamp_column ] + right_entity_columns diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 76c7301f86..c5bf84ee08 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -20,6 +20,7 @@ from feast.data_source import DataSource from feast.feature_view import FeatureView +from feast.registry import Registry from feast.repo_config import RepoConfig @@ -63,5 +64,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], + registry: Registry, + project: str, ) -> RetrievalJob: pass diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 0f2997123f..a43952c45e 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -99,6 +99,8 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], + registry: Registry, + project: str, ) -> RetrievalJob: pass diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index 03a214edff..a960b6836d 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -105,7 +105,7 @@ def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity] entities.append(Entity.from_proto(entity_proto)) return entities - def get_entity(self, name: str, project: str) -> Entity: + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: """ Retrieves an entity. @@ -117,7 +117,7 @@ def get_entity(self, name: str, project: str) -> Entity: Returns either the specified entity, or raises an exception if none is found """ - registry_proto = self._get_registry_proto() + registry_proto = self._get_registry_proto(allow_cache=allow_cache) for entity_proto in registry_proto.entities: if entity_proto.spec.name == name and entity_proto.spec.project == project: return Entity.from_proto(entity_proto) diff --git a/sdk/python/tests/test_historical_retrieval.py b/sdk/python/tests/test_historical_retrieval.py index 64735b56b9..6800b512c5 100644 --- a/sdk/python/tests/test_historical_retrieval.py +++ b/sdk/python/tests/test_historical_retrieval.py @@ -57,7 +57,7 @@ def stage_driver_hourly_stats_bigquery_source(df, table_id): def create_driver_hourly_stats_feature_view(source): driver_stats_feature_view = FeatureView( name="driver_stats", - entities=["driver_id"], + entities=["driver"], features=[ Feature(name="conv_rate", dtype=ValueType.FLOAT), Feature(name="acc_rate", dtype=ValueType.FLOAT), @@ -226,8 +226,8 @@ def test_historical_features_from_parquet_sources(): temp_dir, customer_df ) customer_fv = create_customer_daily_profile_feature_view(customer_source) - driver = Entity(name="driver", value_type=ValueType.INT64) - customer = Entity(name="customer", value_type=ValueType.INT64) + driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) + customer = Entity(name="customer_id", value_type=ValueType.INT64) store = FeatureStore( config=RepoConfig( @@ -331,8 +331,8 @@ def test_historical_features_from_bigquery_sources(provider_type): ) customer_fv = create_customer_daily_profile_feature_view(customer_source) - driver = Entity(name="driver", value_type=ValueType.INT64) - customer = Entity(name="customer", value_type=ValueType.INT64) + driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) + customer = Entity(name="customer_id", value_type=ValueType.INT64) if provider_type == "local": store = FeatureStore( diff --git a/sdk/python/tests/test_materialize_from_bigquery_to_datastore.py b/sdk/python/tests/test_materialize_from_bigquery_to_datastore.py new file mode 100644 index 0000000000..e69de29bb2