Skip to content

Commit

Permalink
[FSTORE-1202] Support similarity search in external fg (#1210)
Browse files Browse the repository at this point in the history
* external fg api

* add find_neighbor

* add db client

* create fg with embedding

* set embedding object

* add read option

* fix style

* address comment
  • Loading branch information
kennethmhc authored Feb 22, 2024
1 parent cc0a31c commit fc8b352
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 9 deletions.
122 changes: 113 additions & 9 deletions python/hsfs/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3277,6 +3277,7 @@ def __init__(
notification_topic_name=None,
spine=False,
deprecated=False,
embedding_index=None,
**kwargs,
):
super().__init__(
Expand All @@ -3287,6 +3288,7 @@ def __init__(
event_time=event_time,
online_enabled=online_enabled,
id=id,
embedding_index=embedding_index,
expectation_suite=expectation_suite,
online_topic_name=online_topic_name,
topic_name=topic_name,
Expand Down Expand Up @@ -3347,7 +3349,7 @@ def __init__(
)
else:
self._storage_connector = storage_connector

self._vector_db_client = None
self._href = href

def save(self):
Expand Down Expand Up @@ -3424,7 +3426,10 @@ def insert(
)

def read(
self, dataframe_type: Optional[str] = "default", online: Optional[bool] = False
self,
dataframe_type: Optional[str] = "default",
online: Optional[bool] = False,
read_options: Optional[dict] = None,
):
"""Get the feature group as a DataFrame.
Expand Down Expand Up @@ -3452,7 +3457,8 @@ def read(
`"pandas"`, `"numpy"` or `"python"`, defaults to `"default"`.
online: bool, optional. If `True` read from online feature store, defaults
to `False`.
read_options: Additional options as key/value pairs to pass to the spark engine.
Defaults to `None`.
# Returns
`DataFrame`: The spark dataframe containing the feature data.
`pyspark.DataFrame`. A Spark DataFrame.
Expand Down Expand Up @@ -3485,10 +3491,14 @@ def read(
self._name, self._feature_store_name
),
)
return self.select_all().read(dataframe_type=dataframe_type, online=online)
return self.select_all().read(
dataframe_type=dataframe_type,
online=online,
read_options=read_options or {},
)

def show(self, n):
"""Show the first n rows of the feature group.
def show(self, n: int, online: Optional[bool] = False):
"""Show the first `n` rows of the feature group.
!!! example
```python
Expand All @@ -3498,39 +3508,130 @@ def show(self, n):
# get the Feature Group instance
fg = fs.get_or_create_feature_group(...)
fg.show(5)
# make a query and show top 5 rows
fg.select(['date','weekly_sales','is_holiday']).show(5)
```
# Arguments
n: int. Number of rows to show.
online: bool, optional. If `True` read from online feature store, defaults
to `False`.
"""
engine.get_instance().set_job_group(
"Fetching Feature group",
"Getting feature group: {} from the featurestore {}".format(
self._name, self._feature_store_name
),
)
return self.select_all().show(n)
if online and self.embedding_index:
if self._vector_db_client is None:
self._vector_db_client = VectorDbClient(self.select_all())
results = self._vector_db_client.read(
self.id,
{},
pk=self.embedding_index.col_prefix + self.primary_key[0],
index_name=self.embedding_index.index_name,
n=n,
)
return [[result[f.name] for f in self.features] for result in results]
return self.select_all().show(n, online)

def find_neighbors(
self,
embedding: List[Union[int, float]],
col: Optional[str] = None,
k: Optional[int] = 10,
filter: Optional[Union[Filter, Logic]] = None,
min_score: Optional[float] = 0,
) -> List[Tuple[float, List[Any]]]:
"""
Finds the nearest neighbors for a given embedding in the vector database.
# Arguments
embedding: The target embedding for which neighbors are to be found.
col: The column name used to compute similarity score. Required only if there
are multiple embeddings (optional).
k: The number of nearest neighbors to retrieve (default is 10).
filter: A filter expression to restrict the search space (optional).
min_score: The minimum similarity score for neighbors to be considered (default is 0).
# Returns
A list of tuples representing the nearest neighbors.
Each tuple contains: `(The similarity score, A list of feature values)`
!!! Example
```
embedding_index = EmbeddingIndex()
embedding_index.add_embedding(name="user_vector", dimension=3)
fg = fs.create_feature_group(
name='air_quality',
embedding_index = embedding_index,
version=1,
primary_key=['id1'],
online_enabled=True,
)
fg.insert(data)
fg.find_neighbors(
[0.1, 0.2, 0.3],
k=5,
)
# apply filter
fg.find_neighbors(
[0.1, 0.2, 0.3],
k=5,
filter=(fg.id1 > 10) & (fg.id1 < 30)
)
```
"""
if self._vector_db_client is None and self._embedding_index:
self._vector_db_client = VectorDbClient(self.select_all())
results = self._vector_db_client.find_neighbors(
embedding,
feature=(self.__getattr__(col) if col else None),
k=k,
filter=filter,
min_score=min_score,
)
return [
(result[0], [result[1][f.name] for f in self.features])
for result in results
]

@classmethod
def from_response_json(cls, json_dict):
json_decamelized = humps.decamelize(json_dict)
if isinstance(json_decamelized, dict):
_ = json_decamelized.pop("type", None)
if "embedding_index" in json_decamelized:
json_decamelized["embedding_index"] = EmbeddingIndex.from_json_response(
json_decamelized["embedding_index"]
)
return cls(**json_decamelized)
for fg in json_decamelized:
_ = fg.pop("type", None)
if "embedding_index" in fg:
fg["embedding_index"] = EmbeddingIndex.from_json_response(
fg["embedding_index"]
)
return [cls(**fg) for fg in json_decamelized]

def update_from_response_json(self, json_dict):
json_decamelized = humps.decamelize(json_dict)
if "type" in json_decamelized:
_ = json_decamelized.pop("type")
if "embedding_index" in json_decamelized:
json_decamelized["embedding_index"] = EmbeddingIndex.from_json_response(
json_decamelized["embedding_index"]
)
self.__init__(**json_decamelized)
return self

def json(self):
return json.dumps(self, cls=util.FeatureStoreEncoder)

def to_dict(self):
return {
fg_meta_dict = {
"id": self._id,
"name": self._name,
"description": self._description,
Expand All @@ -3554,6 +3655,9 @@ def to_dict(self):
"notificationTopicName": self.notification_topic_name,
"deprecated": self.deprecated,
}
if self.embedding_index:
fg_meta_dict["embeddingIndex"] = self.embedding_index
return fg_meta_dict

@property
def id(self):
Expand Down
2 changes: 2 additions & 0 deletions python/hsfs/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ def create_external_feature_group(
version: Optional[int] = None,
description: Optional[str] = "",
primary_key: Optional[List[str]] = [],
embedding_index: Optional[EmbeddingIndex] = None,
features: Optional[List[feature.Feature]] = [],
statistics_config: Optional[Union[StatisticsConfig, bool, dict]] = None,
event_time: Optional[str] = None,
Expand Down Expand Up @@ -968,6 +969,7 @@ def create_external_feature_group(
version=version,
description=description,
primary_key=primary_key,
embedding_index=embedding_index,
featurestore_id=self._id,
featurestore_name=self._name,
features=features,
Expand Down

0 comments on commit fc8b352

Please sign in to comment.