Skip to content

Commit

Permalink
Fix connection URI
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jan 27, 2025
1 parent 9da92af commit c079d77
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
7 changes: 4 additions & 3 deletions plugins/snowflake/superduper_snowflake/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self, *args, **kwargs):

@staticmethod
def _get_snowpark_session(uri):
logging.info('Creating Snowpark session')
logging.info('Creating Snowpark session for'
' snowflake vector-search implementation')
if uri == 'snowflake://':
connection_parameters = dict(
host=os.environ['SNOWFLAKE_HOST'],
Expand All @@ -43,9 +44,9 @@ def _get_snowpark_session(uri):
warehouse = None
else:
match = re.match(
'snowflake://(.*):(.*)@(.*)/(.*)/(.*)?warehouse=(.*)^'
'^snowflake://(.*):(.*)@(.*)/(.*)/(.*)?warehouse=(.*)$', uri
)
password, user, account, database, schema, warehouse = match.groups()
user, password, account, database, schema, warehouse = match.groups()

connection_parameters = dict(
user=user,
Expand Down
29 changes: 17 additions & 12 deletions plugins/snowflake/superduper_snowflake/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,26 @@ def create_session(cls, vector_search_uri):
"host": host,
}
else:
pattern = r"snowflake://(?P<user>[^:]+):(?P<password>[^@]+)@(?P<account>[^/]+)/(?P<database>[^/]+)/(?P<schema>[^/]+)"
match = re.match(pattern, vector_search_uri)
schema = match.group("schema")
database = match.group("database")
if '?warehouse=' not in vector_search_uri:
match = re.match(
'^snowflake:\/\/(.*):(.*)\@(.*)\/(.*)\/(.*)$', vector_search_uri
)
user, password, account, database, schema = match.groups()
warehouse = None
else:
match = re.match(
'^snowflake://(.*):(.*)@(.*)/(.*)/(.*)?warehouse=(.*)$', vector_search_uri
)
user, password, account, database, schema, warehouse = match.groups()
if match:
connection_parameters = {
"user": match.group("user"),
"password": match.group("password"),
"account": match.group("account"),
"database": match.group("database"),
"schema": match.group("schema"),
# TODO: check warehouse
"warehouse": "base",
"user": user,
"password": password,
"account": account,
"database": database,
"schema": schema,
"warehouse": warehouse,
}

else:
raise ValueError(f"URI `{vector_search_uri}` is invalid!")

Expand Down

0 comments on commit c079d77

Please sign in to comment.