Skip to content

Commit fc8b352

Browse files
authored
[FSTORE-1202] Support similarity search in external fg (#1210)
* external fg api * add find_neighbor * add db client * create fg with embedding * set embedding object * add read option * fix style * address comment
1 parent cc0a31c commit fc8b352

File tree

2 files changed

+115
-9
lines changed

2 files changed

+115
-9
lines changed

python/hsfs/feature_group.py

Lines changed: 113 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3277,6 +3277,7 @@ def __init__(
32773277
notification_topic_name=None,
32783278
spine=False,
32793279
deprecated=False,
3280+
embedding_index=None,
32803281
**kwargs,
32813282
):
32823283
super().__init__(
@@ -3287,6 +3288,7 @@ def __init__(
32873288
event_time=event_time,
32883289
online_enabled=online_enabled,
32893290
id=id,
3291+
embedding_index=embedding_index,
32903292
expectation_suite=expectation_suite,
32913293
online_topic_name=online_topic_name,
32923294
topic_name=topic_name,
@@ -3347,7 +3349,7 @@ def __init__(
33473349
)
33483350
else:
33493351
self._storage_connector = storage_connector
3350-
3352+
self._vector_db_client = None
33513353
self._href = href
33523354

33533355
def save(self):
@@ -3424,7 +3426,10 @@ def insert(
34243426
)
34253427

34263428
def read(
3427-
self, dataframe_type: Optional[str] = "default", online: Optional[bool] = False
3429+
self,
3430+
dataframe_type: Optional[str] = "default",
3431+
online: Optional[bool] = False,
3432+
read_options: Optional[dict] = None,
34283433
):
34293434
"""Get the feature group as a DataFrame.
34303435
@@ -3452,7 +3457,8 @@ def read(
34523457
`"pandas"`, `"numpy"` or `"python"`, defaults to `"default"`.
34533458
online: bool, optional. If `True` read from online feature store, defaults
34543459
to `False`.
3455-
3460+
read_options: Additional options as key/value pairs to pass to the spark engine.
3461+
Defaults to `None`.
34563462
# Returns
34573463
`DataFrame`: The spark dataframe containing the feature data.
34583464
`pyspark.DataFrame`. A Spark DataFrame.
@@ -3485,10 +3491,14 @@ def read(
34853491
self._name, self._feature_store_name
34863492
),
34873493
)
3488-
return self.select_all().read(dataframe_type=dataframe_type, online=online)
3494+
return self.select_all().read(
3495+
dataframe_type=dataframe_type,
3496+
online=online,
3497+
read_options=read_options or {},
3498+
)
34893499

3490-
def show(self, n):
3491-
"""Show the first n rows of the feature group.
3500+
def show(self, n: int, online: Optional[bool] = False):
3501+
"""Show the first `n` rows of the feature group.
34923502
34933503
!!! example
34943504
```python
@@ -3498,39 +3508,130 @@ def show(self, n):
34983508
# get the Feature Group instance
34993509
fg = fs.get_or_create_feature_group(...)
35003510
3501-
fg.show(5)
3511+
# make a query and show top 5 rows
3512+
fg.select(['date','weekly_sales','is_holiday']).show(5)
35023513
```
3514+
3515+
# Arguments
3516+
n: int. Number of rows to show.
3517+
online: bool, optional. If `True` read from online feature store, defaults
3518+
to `False`.
35033519
"""
35043520
engine.get_instance().set_job_group(
35053521
"Fetching Feature group",
35063522
"Getting feature group: {} from the featurestore {}".format(
35073523
self._name, self._feature_store_name
35083524
),
35093525
)
3510-
return self.select_all().show(n)
3526+
if online and self.embedding_index:
3527+
if self._vector_db_client is None:
3528+
self._vector_db_client = VectorDbClient(self.select_all())
3529+
results = self._vector_db_client.read(
3530+
self.id,
3531+
{},
3532+
pk=self.embedding_index.col_prefix + self.primary_key[0],
3533+
index_name=self.embedding_index.index_name,
3534+
n=n,
3535+
)
3536+
return [[result[f.name] for f in self.features] for result in results]
3537+
return self.select_all().show(n, online)
3538+
3539+
def find_neighbors(
3540+
self,
3541+
embedding: List[Union[int, float]],
3542+
col: Optional[str] = None,
3543+
k: Optional[int] = 10,
3544+
filter: Optional[Union[Filter, Logic]] = None,
3545+
min_score: Optional[float] = 0,
3546+
) -> List[Tuple[float, List[Any]]]:
3547+
"""
3548+
Finds the nearest neighbors for a given embedding in the vector database.
3549+
3550+
# Arguments
3551+
embedding: The target embedding for which neighbors are to be found.
3552+
col: The column name used to compute similarity score. Required only if there
3553+
are multiple embeddings (optional).
3554+
k: The number of nearest neighbors to retrieve (default is 10).
3555+
filter: A filter expression to restrict the search space (optional).
3556+
min_score: The minimum similarity score for neighbors to be considered (default is 0).
3557+
3558+
# Returns
3559+
A list of tuples representing the nearest neighbors.
3560+
Each tuple contains: `(The similarity score, A list of feature values)`
3561+
3562+
!!! Example
3563+
```
3564+
embedding_index = EmbeddingIndex()
3565+
embedding_index.add_embedding(name="user_vector", dimension=3)
3566+
fg = fs.create_feature_group(
3567+
name='air_quality',
3568+
embedding_index = embedding_index,
3569+
version=1,
3570+
primary_key=['id1'],
3571+
online_enabled=True,
3572+
)
3573+
fg.insert(data)
3574+
fg.find_neighbors(
3575+
[0.1, 0.2, 0.3],
3576+
k=5,
3577+
)
3578+
3579+
# apply filter
3580+
fg.find_neighbors(
3581+
[0.1, 0.2, 0.3],
3582+
k=5,
3583+
filter=(fg.id1 > 10) & (fg.id1 < 30)
3584+
)
3585+
```
3586+
"""
3587+
if self._vector_db_client is None and self._embedding_index:
3588+
self._vector_db_client = VectorDbClient(self.select_all())
3589+
results = self._vector_db_client.find_neighbors(
3590+
embedding,
3591+
feature=(self.__getattr__(col) if col else None),
3592+
k=k,
3593+
filter=filter,
3594+
min_score=min_score,
3595+
)
3596+
return [
3597+
(result[0], [result[1][f.name] for f in self.features])
3598+
for result in results
3599+
]
35113600

35123601
@classmethod
35133602
def from_response_json(cls, json_dict):
35143603
json_decamelized = humps.decamelize(json_dict)
35153604
if isinstance(json_decamelized, dict):
35163605
_ = json_decamelized.pop("type", None)
3606+
if "embedding_index" in json_decamelized:
3607+
json_decamelized["embedding_index"] = EmbeddingIndex.from_json_response(
3608+
json_decamelized["embedding_index"]
3609+
)
35173610
return cls(**json_decamelized)
35183611
for fg in json_decamelized:
35193612
_ = fg.pop("type", None)
3613+
if "embedding_index" in fg:
3614+
fg["embedding_index"] = EmbeddingIndex.from_json_response(
3615+
fg["embedding_index"]
3616+
)
35203617
return [cls(**fg) for fg in json_decamelized]
35213618

35223619
def update_from_response_json(self, json_dict):
35233620
json_decamelized = humps.decamelize(json_dict)
35243621
if "type" in json_decamelized:
35253622
_ = json_decamelized.pop("type")
3623+
if "embedding_index" in json_decamelized:
3624+
json_decamelized["embedding_index"] = EmbeddingIndex.from_json_response(
3625+
json_decamelized["embedding_index"]
3626+
)
35263627
self.__init__(**json_decamelized)
35273628
return self
35283629

35293630
def json(self):
35303631
return json.dumps(self, cls=util.FeatureStoreEncoder)
35313632

35323633
def to_dict(self):
3533-
return {
3634+
fg_meta_dict = {
35343635
"id": self._id,
35353636
"name": self._name,
35363637
"description": self._description,
@@ -3554,6 +3655,9 @@ def to_dict(self):
35543655
"notificationTopicName": self.notification_topic_name,
35553656
"deprecated": self.deprecated,
35563657
}
3658+
if self.embedding_index:
3659+
fg_meta_dict["embeddingIndex"] = self.embedding_index
3660+
return fg_meta_dict
35573661

35583662
@property
35593663
def id(self):

python/hsfs/feature_store.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,7 @@ def create_external_feature_group(
845845
version: Optional[int] = None,
846846
description: Optional[str] = "",
847847
primary_key: Optional[List[str]] = [],
848+
embedding_index: Optional[EmbeddingIndex] = None,
848849
features: Optional[List[feature.Feature]] = [],
849850
statistics_config: Optional[Union[StatisticsConfig, bool, dict]] = None,
850851
event_time: Optional[str] = None,
@@ -968,6 +969,7 @@ def create_external_feature_group(
968969
version=version,
969970
description=description,
970971
primary_key=primary_key,
972+
embedding_index=embedding_index,
971973
featurestore_id=self._id,
972974
featurestore_name=self._name,
973975
features=features,

0 commit comments

Comments
 (0)