Skip to content

Commit

Permalink
added auth configuration for arrow flight client
Browse files Browse the repository at this point in the history
Signed-off-by: Abdul Hameed <ahameed@redhat.com>
  • Loading branch information
redhatHameed authored and tmihalac committed Jul 12, 2024
1 parent fd6f2bc commit bc4d8d3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
50 changes: 40 additions & 10 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
RetrievalMetadata,
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.permissions.client.auth_client_manager import create_flight_call_options
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage

Expand All @@ -46,6 +47,7 @@ class RemoteRetrievalJob(RetrievalJob):
def __init__(
self,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
api: str,
api_parameters: Dict[str, Any],
entity_df: Union[pd.DataFrame, str] = None,
Expand All @@ -54,6 +56,7 @@ def __init__(
):
# Initialize the client connection
self.client = client
self.options = options
self.api = api
self.api_parameters = api_parameters
self.entity_df = entity_df
Expand All @@ -69,7 +72,12 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
# This is where do_get service is invoked
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
return _send_retrieve_remote(
self.api, self.api_parameters, self.entity_df, self.table, self.client
self.api,
self.api_parameters,
self.entity_df,
self.table,
self.client,
self.options,
)

@property
Expand Down Expand Up @@ -110,6 +118,7 @@ def persist(
api=RemoteRetrievalJob.persist.__name__,
api_parameters=api_parameters,
client=self.client,
options=self.options,
table=self.table,
entity_df=self.entity_df,
)
Expand All @@ -130,6 +139,7 @@ def get_historical_features(

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)

feature_view_names = [fv.name for fv in feature_views]
name_aliases = [fv.projection.name_alias for fv in feature_views]
Expand All @@ -144,6 +154,7 @@ def get_historical_features(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.get_historical_features.__name__,
api_parameters=api_parameters,
entity_df=entity_df,
Expand All @@ -164,6 +175,7 @@ def pull_all_from_table_or_query(

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)

api_parameters = {
"data_source_name": data_source.name,
Expand All @@ -176,6 +188,7 @@ def pull_all_from_table_or_query(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.pull_all_from_table_or_query.__name__,
api_parameters=api_parameters,
)
Expand All @@ -195,6 +208,7 @@ def pull_latest_from_table_or_query(

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)

api_parameters = {
"data_source_name": data_source.name,
Expand All @@ -208,6 +222,7 @@ def pull_latest_from_table_or_query(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.pull_latest_from_table_or_query.__name__,
api_parameters=api_parameters,
)
Expand All @@ -228,6 +243,7 @@ def write_logged_features(

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)

api_parameters = {
"feature_service_name": source._feature_service.name,
Expand All @@ -237,6 +253,7 @@ def write_logged_features(
api=OfflineStore.write_logged_features.__name__,
api_parameters=api_parameters,
client=client,
options=options,
table=data,
entity_df=None,
)
Expand All @@ -252,6 +269,7 @@ def offline_write_batch(

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)

feature_view_names = [feature_view.name]
name_aliases = [feature_view.projection.name_alias]
Expand All @@ -266,6 +284,7 @@ def offline_write_batch(
api=OfflineStore.offline_write_batch.__name__,
api_parameters=api_parameters,
client=client,
options=options,
table=table,
entity_df=None,
)
Expand Down Expand Up @@ -330,22 +349,35 @@ def _send_retrieve_remote(
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
):
command_descriptor = _call_put(api, api_parameters, client, entity_df, table)
return _call_get(client, command_descriptor)
command_descriptor = _call_put(
api,
api_parameters,
client,
options,
entity_df,
table,
)
return _call_get(client, options, command_descriptor)


def _call_get(client: fl.FlightClient, command_descriptor: fl.FlightDescriptor):
def _call_get(
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
command_descriptor: fl.FlightDescriptor,
):
flight = client.get_flight_info(command_descriptor)
ticket = flight.endpoints[0].ticket
reader = client.do_get(ticket)
reader = client.do_get(ticket, options)
return reader.read_all()


def _call_put(
api: str,
api_parameters: Dict[str, Any],
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
):
Expand All @@ -365,7 +397,7 @@ def _call_put(
)
)

_put_parameters(command_descriptor, entity_df, table, client)
_put_parameters(command_descriptor, entity_df, table, client, options)
return command_descriptor


Expand All @@ -374,6 +406,7 @@ def _put_parameters(
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
):
updatedTable: pa.Table

Expand All @@ -384,10 +417,7 @@ def _put_parameters(
else:
updatedTable = _create_empty_table()

writer, _ = client.do_put(
command_descriptor,
updatedTable.schema,
)
writer, _ = client.do_put(command_descriptor, updatedTable.schema, options)

writer.write_table(updatedTable)
writer.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pandas as pd
import pyarrow
import pyarrow.flight as fl
import pytest

from feast.infra.offline_stores.contrib.athena_offline_store.athena import (
Expand Down Expand Up @@ -226,6 +227,7 @@ def retrieval_job(request, environment):

return RemoteRetrievalJob(
client=MagicMock(),
options=fl.FlightCallOptions(),
api_parameters={
"str": "str",
},
Expand Down

0 comments on commit bc4d8d3

Please sign in to comment.