Skip to content

[FSTORE-1202] Support similarity search in external fg #1210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 113 additions & 9 deletions python/hsfs/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3277,6 +3277,7 @@ def __init__(
notification_topic_name=None,
spine=False,
deprecated=False,
embedding_index=None,
**kwargs,
):
super().__init__(
Expand All @@ -3287,6 +3288,7 @@ def __init__(
event_time=event_time,
online_enabled=online_enabled,
id=id,
embedding_index=embedding_index,
expectation_suite=expectation_suite,
online_topic_name=online_topic_name,
topic_name=topic_name,
Expand Down Expand Up @@ -3347,7 +3349,7 @@ def __init__(
)
else:
self._storage_connector = storage_connector

self._vector_db_client = None
self._href = href

def save(self):
Expand Down Expand Up @@ -3424,7 +3426,10 @@ def insert(
)

def read(
self, dataframe_type: Optional[str] = "default", online: Optional[bool] = False
self,
dataframe_type: Optional[str] = "default",
online: Optional[bool] = False,
read_options: Optional[dict] = None,
):
"""Get the feature group as a DataFrame.

Expand Down Expand Up @@ -3452,7 +3457,8 @@ def read(
`"pandas"`, `"numpy"` or `"python"`, defaults to `"default"`.
online: bool, optional. If `True` read from online feature store, defaults
to `False`.

read_options: Additional options as key/value pairs to pass to the spark engine.
Defaults to `None`.
# Returns
`DataFrame`: The spark dataframe containing the feature data.
`pyspark.DataFrame`. A Spark DataFrame.
Expand Down Expand Up @@ -3485,10 +3491,14 @@ def read(
self._name, self._feature_store_name
),
)
return self.select_all().read(dataframe_type=dataframe_type, online=online)
return self.select_all().read(
dataframe_type=dataframe_type,
online=online,
read_options=read_options or {},
)

def show(self, n):
"""Show the first n rows of the feature group.
def show(self, n: int, online: Optional[bool] = False):
"""Show the first `n` rows of the feature group.

!!! example
```python
Expand All @@ -3498,39 +3508,130 @@ def show(self, n):
# get the Feature Group instance
fg = fs.get_or_create_feature_group(...)

fg.show(5)
# make a query and show top 5 rows
fg.select(['date','weekly_sales','is_holiday']).show(5)
```

# Arguments
n: int. Number of rows to show.
online: bool, optional. If `True` read from online feature store, defaults
to `False`.
"""
engine.get_instance().set_job_group(
"Fetching Feature group",
"Getting feature group: {} from the featurestore {}".format(
self._name, self._feature_store_name
),
)
return self.select_all().show(n)
if online and self.embedding_index:
if self._vector_db_client is None:
self._vector_db_client = VectorDbClient(self.select_all())
results = self._vector_db_client.read(
self.id,
{},
pk=self.embedding_index.col_prefix + self.primary_key[0],
index_name=self.embedding_index.index_name,
n=n,
)
return [[result[f.name] for f in self.features] for result in results]
return self.select_all().show(n, online)

def find_neighbors(
self,
embedding: List[Union[int, float]],
col: Optional[str] = None,
k: Optional[int] = 10,
filter: Optional[Union[Filter, Logic]] = None,
min_score: Optional[float] = 0,
) -> List[Tuple[float, List[Any]]]:
"""
Finds the nearest neighbors for a given embedding in the vector database.

# Arguments
embedding: The target embedding for which neighbors are to be found.
col: The column name used to compute similarity score. Required only if there
are multiple embeddings (optional).
k: The number of nearest neighbors to retrieve (default is 10).
filter: A filter expression to restrict the search space (optional).
min_score: The minimum similarity score for neighbors to be considered (default is 0).

# Returns
A list of tuples representing the nearest neighbors.
Each tuple contains: `(The similarity score, A list of feature values)`

!!! Example
```
embedding_index = EmbeddingIndex()
embedding_index.add_embedding(name="user_vector", dimension=3)
fg = fs.create_feature_group(
name='air_quality',
embedding_index = embedding_index,
version=1,
primary_key=['id1'],
online_enabled=True,
)
fg.insert(data)
fg.find_neighbors(
[0.1, 0.2, 0.3],
k=5,
)

# apply filter
fg.find_neighbors(
[0.1, 0.2, 0.3],
k=5,
filter=(fg.id1 > 10) & (fg.id1 < 30)
)
```
"""
if self._vector_db_client is None and self._embedding_index:
self._vector_db_client = VectorDbClient(self.select_all())
results = self._vector_db_client.find_neighbors(
embedding,
feature=(self.__getattr__(col) if col else None),
k=k,
filter=filter,
min_score=min_score,
)
return [
(result[0], [result[1][f.name] for f in self.features])
for result in results
]

@classmethod
def from_response_json(cls, json_dict):
json_decamelized = humps.decamelize(json_dict)
if isinstance(json_decamelized, dict):
_ = json_decamelized.pop("type", None)
if "embedding_index" in json_decamelized:
json_decamelized["embedding_index"] = EmbeddingIndex.from_json_response(
json_decamelized["embedding_index"]
)
return cls(**json_decamelized)
for fg in json_decamelized:
_ = fg.pop("type", None)
if "embedding_index" in fg:
fg["embedding_index"] = EmbeddingIndex.from_json_response(
fg["embedding_index"]
)
return [cls(**fg) for fg in json_decamelized]

def update_from_response_json(self, json_dict):
json_decamelized = humps.decamelize(json_dict)
if "type" in json_decamelized:
_ = json_decamelized.pop("type")
if "embedding_index" in json_decamelized:
json_decamelized["embedding_index"] = EmbeddingIndex.from_json_response(
json_decamelized["embedding_index"]
)
self.__init__(**json_decamelized)
return self

def json(self):
return json.dumps(self, cls=util.FeatureStoreEncoder)

def to_dict(self):
return {
fg_meta_dict = {
"id": self._id,
"name": self._name,
"description": self._description,
Expand All @@ -3554,6 +3655,9 @@ def to_dict(self):
"notificationTopicName": self.notification_topic_name,
"deprecated": self.deprecated,
}
if self.embedding_index:
fg_meta_dict["embeddingIndex"] = self.embedding_index
return fg_meta_dict

@property
def id(self):
Expand Down
2 changes: 2 additions & 0 deletions python/hsfs/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ def create_external_feature_group(
version: Optional[int] = None,
description: Optional[str] = "",
primary_key: Optional[List[str]] = [],
embedding_index: Optional[EmbeddingIndex] = None,
features: Optional[List[feature.Feature]] = [],
statistics_config: Optional[Union[StatisticsConfig, bool, dict]] = None,
event_time: Optional[str] = None,
Expand Down Expand Up @@ -968,6 +969,7 @@ def create_external_feature_group(
version=version,
description=description,
primary_key=primary_key,
embedding_index=embedding_index,
featurestore_id=self._id,
featurestore_name=self._name,
features=features,
Expand Down