@@ -3277,6 +3277,7 @@ def __init__(
3277
3277
notification_topic_name = None ,
3278
3278
spine = False ,
3279
3279
deprecated = False ,
3280
+ embedding_index = None ,
3280
3281
** kwargs ,
3281
3282
):
3282
3283
super ().__init__ (
@@ -3287,6 +3288,7 @@ def __init__(
3287
3288
event_time = event_time ,
3288
3289
online_enabled = online_enabled ,
3289
3290
id = id ,
3291
+ embedding_index = embedding_index ,
3290
3292
expectation_suite = expectation_suite ,
3291
3293
online_topic_name = online_topic_name ,
3292
3294
topic_name = topic_name ,
@@ -3347,7 +3349,7 @@ def __init__(
3347
3349
)
3348
3350
else :
3349
3351
self ._storage_connector = storage_connector
3350
-
3352
+ self . _vector_db_client = None
3351
3353
self ._href = href
3352
3354
3353
3355
def save (self ):
@@ -3424,7 +3426,10 @@ def insert(
3424
3426
)
3425
3427
3426
3428
def read (
3427
- self , dataframe_type : Optional [str ] = "default" , online : Optional [bool ] = False
3429
+ self ,
3430
+ dataframe_type : Optional [str ] = "default" ,
3431
+ online : Optional [bool ] = False ,
3432
+ read_options : Optional [dict ] = None ,
3428
3433
):
3429
3434
"""Get the feature group as a DataFrame.
3430
3435
@@ -3452,7 +3457,8 @@ def read(
3452
3457
`"pandas"`, `"numpy"` or `"python"`, defaults to `"default"`.
3453
3458
online: bool, optional. If `True` read from online feature store, defaults
3454
3459
to `False`.
3455
-
3460
+ read_options: Additional options as key/value pairs to pass to the spark engine.
3461
+ Defaults to `None`.
3456
3462
# Returns
3457
3463
`DataFrame`: The spark dataframe containing the feature data.
3458
3464
`pyspark.DataFrame`. A Spark DataFrame.
@@ -3485,10 +3491,14 @@ def read(
3485
3491
self ._name , self ._feature_store_name
3486
3492
),
3487
3493
)
3488
- return self .select_all ().read (dataframe_type = dataframe_type , online = online )
3494
+ return self .select_all ().read (
3495
+ dataframe_type = dataframe_type ,
3496
+ online = online ,
3497
+ read_options = read_options or {},
3498
+ )
3489
3499
3490
- def show (self , n ):
3491
- """Show the first n rows of the feature group.
3500
+ def show (self , n : int , online : Optional [ bool ] = False ):
3501
+ """Show the first `n` rows of the feature group.
3492
3502
3493
3503
!!! example
3494
3504
```python
@@ -3498,39 +3508,130 @@ def show(self, n):
3498
3508
# get the Feature Group instance
3499
3509
fg = fs.get_or_create_feature_group(...)
3500
3510
3501
- fg.show(5)
3511
+ # make a query and show top 5 rows
3512
+ fg.select(['date','weekly_sales','is_holiday']).show(5)
3502
3513
```
3514
+
3515
+ # Arguments
3516
+ n: int. Number of rows to show.
3517
+ online: bool, optional. If `True` read from online feature store, defaults
3518
+ to `False`.
3503
3519
"""
3504
3520
engine .get_instance ().set_job_group (
3505
3521
"Fetching Feature group" ,
3506
3522
"Getting feature group: {} from the featurestore {}" .format (
3507
3523
self ._name , self ._feature_store_name
3508
3524
),
3509
3525
)
3510
- return self .select_all ().show (n )
3526
+ if online and self .embedding_index :
3527
+ if self ._vector_db_client is None :
3528
+ self ._vector_db_client = VectorDbClient (self .select_all ())
3529
+ results = self ._vector_db_client .read (
3530
+ self .id ,
3531
+ {},
3532
+ pk = self .embedding_index .col_prefix + self .primary_key [0 ],
3533
+ index_name = self .embedding_index .index_name ,
3534
+ n = n ,
3535
+ )
3536
+ return [[result [f .name ] for f in self .features ] for result in results ]
3537
+ return self .select_all ().show (n , online )
3538
+
3539
+ def find_neighbors (
3540
+ self ,
3541
+ embedding : List [Union [int , float ]],
3542
+ col : Optional [str ] = None ,
3543
+ k : Optional [int ] = 10 ,
3544
+ filter : Optional [Union [Filter , Logic ]] = None ,
3545
+ min_score : Optional [float ] = 0 ,
3546
+ ) -> List [Tuple [float , List [Any ]]]:
3547
+ """
3548
+ Finds the nearest neighbors for a given embedding in the vector database.
3549
+
3550
+ # Arguments
3551
+ embedding: The target embedding for which neighbors are to be found.
3552
+ col: The column name used to compute similarity score. Required only if there
3553
+ are multiple embeddings (optional).
3554
+ k: The number of nearest neighbors to retrieve (default is 10).
3555
+ filter: A filter expression to restrict the search space (optional).
3556
+ min_score: The minimum similarity score for neighbors to be considered (default is 0).
3557
+
3558
+ # Returns
3559
+ A list of tuples representing the nearest neighbors.
3560
+ Each tuple contains: `(The similarity score, A list of feature values)`
3561
+
3562
+ !!! Example
3563
+ ```
3564
+ embedding_index = EmbeddingIndex()
3565
+ embedding_index.add_embedding(name="user_vector", dimension=3)
3566
+ fg = fs.create_feature_group(
3567
+ name='air_quality',
3568
+ embedding_index = embedding_index,
3569
+ version=1,
3570
+ primary_key=['id1'],
3571
+ online_enabled=True,
3572
+ )
3573
+ fg.insert(data)
3574
+ fg.find_neighbors(
3575
+ [0.1, 0.2, 0.3],
3576
+ k=5,
3577
+ )
3578
+
3579
+ # apply filter
3580
+ fg.find_neighbors(
3581
+ [0.1, 0.2, 0.3],
3582
+ k=5,
3583
+ filter=(fg.id1 > 10) & (fg.id1 < 30)
3584
+ )
3585
+ ```
3586
+ """
3587
+ if self ._vector_db_client is None and self ._embedding_index :
3588
+ self ._vector_db_client = VectorDbClient (self .select_all ())
3589
+ results = self ._vector_db_client .find_neighbors (
3590
+ embedding ,
3591
+ feature = (self .__getattr__ (col ) if col else None ),
3592
+ k = k ,
3593
+ filter = filter ,
3594
+ min_score = min_score ,
3595
+ )
3596
+ return [
3597
+ (result [0 ], [result [1 ][f .name ] for f in self .features ])
3598
+ for result in results
3599
+ ]
3511
3600
3512
3601
@classmethod
3513
3602
def from_response_json (cls , json_dict ):
3514
3603
json_decamelized = humps .decamelize (json_dict )
3515
3604
if isinstance (json_decamelized , dict ):
3516
3605
_ = json_decamelized .pop ("type" , None )
3606
+ if "embedding_index" in json_decamelized :
3607
+ json_decamelized ["embedding_index" ] = EmbeddingIndex .from_json_response (
3608
+ json_decamelized ["embedding_index" ]
3609
+ )
3517
3610
return cls (** json_decamelized )
3518
3611
for fg in json_decamelized :
3519
3612
_ = fg .pop ("type" , None )
3613
+ if "embedding_index" in fg :
3614
+ fg ["embedding_index" ] = EmbeddingIndex .from_json_response (
3615
+ fg ["embedding_index" ]
3616
+ )
3520
3617
return [cls (** fg ) for fg in json_decamelized ]
3521
3618
3522
3619
def update_from_response_json (self , json_dict ):
3523
3620
json_decamelized = humps .decamelize (json_dict )
3524
3621
if "type" in json_decamelized :
3525
3622
_ = json_decamelized .pop ("type" )
3623
+ if "embedding_index" in json_decamelized :
3624
+ json_decamelized ["embedding_index" ] = EmbeddingIndex .from_json_response (
3625
+ json_decamelized ["embedding_index" ]
3626
+ )
3526
3627
self .__init__ (** json_decamelized )
3527
3628
return self
3528
3629
3529
3630
def json (self ):
3530
3631
return json .dumps (self , cls = util .FeatureStoreEncoder )
3531
3632
3532
3633
def to_dict (self ):
3533
- return {
3634
+ fg_meta_dict = {
3534
3635
"id" : self ._id ,
3535
3636
"name" : self ._name ,
3536
3637
"description" : self ._description ,
@@ -3554,6 +3655,9 @@ def to_dict(self):
3554
3655
"notificationTopicName" : self .notification_topic_name ,
3555
3656
"deprecated" : self .deprecated ,
3556
3657
}
3658
+ if self .embedding_index :
3659
+ fg_meta_dict ["embeddingIndex" ] = self .embedding_index
3660
+ return fg_meta_dict
3557
3661
3558
3662
@property
3559
3663
def id (self ):
0 commit comments