Skip to content

Commit

Permalink
[PLUGINS] Bump Version [snowflake]
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jan 30, 2025
1 parent b935697 commit 719b321
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion plugins/snowflake/superduper_snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .vector_search import SnowflakeVectorSearcher as VectorSearcher
from .data_backend import SnowflakeDataBackend as DataBackend

__version__ = "0.5.12"
__version__ = "0.5.13"

__all__ = [
"VectorSearcher",
Expand Down
4 changes: 4 additions & 0 deletions plugins/snowflake/superduper_snowflake/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def _connection_callback(self, uri):
return IbisDataBackend._connection_callback(uri)
return ibis.snowflake.from_connection(self._do_connection_callback(uri), create_object_udfs=False), 'snowflake', False

def reconnect(self):
super().reconnect()
self.snowpark = self._get_snowpark_session(self.uri)

def insert(self, table_name, raw_documents):
ibis_schema = self.conn.table(table_name).schema()
df = pandas.DataFrame(raw_documents)
Expand Down
14 changes: 14 additions & 0 deletions plugins/snowflake/superduper_snowflake/vector_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import wraps
import os
import re
import typing as t
Expand All @@ -10,6 +11,18 @@
from superduper.components.vector_index import VectorIndex


def retry(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except Exception as e:
if 'token' in str(e):
self.session = SnowflakeVectorSearcher.create_session(CFG.data_backend)
return f(self, *args, **kwargs)
return wrapper


class SnowflakeVectorSearcher(BaseVectorSearcher):
"""Vector searcher implementation of atlas vector search.
Expand Down Expand Up @@ -161,6 +174,7 @@ def find_nearest_from_id(self, id: str, n=100, within_ids=None):
).collect()
return self.find_nearest_from_array(result, n=n, within_ids=within_ids)

@retry
def find_nearest_from_array(self, h, n=100, within_ids=None):
"""Find the nearest vectors to the given vector.
Expand Down

0 comments on commit 719b321

Please sign in to comment.