Skip to content

Commit

Permalink
snowflake setup session hotfix (letsql#57)
Browse files Browse the repository at this point in the history
* test: add test of issues with ibis' snowflake connection as of 9.0.0
* fix(snowflake): monkeypatch snowflake setup
* add a hotfix for proper snowflake connection initialization (pending next release that includes fix(snowflake): properly pass schema and database for sqlglot generation ibis-project/ibis#9221)
* add a test for the expectation of initial snowflake connection state re catalog/database based on create_object_udfs
  • Loading branch information
dlovell committed May 31, 2024
1 parent c4dbe51 commit af5881d
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 0 deletions.
86 changes: 86 additions & 0 deletions python/letsql/backends/snowflake/hotfix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import contextlib
import functools
import itertools
import warnings

import dask
import ibis
import sqlglot as sg
import sqlglot.expressions as sge
from ibis.backends.snowflake import _SNOWFLAKE_MAP_UDFS


def _setup_session(self, *, session_parameters, create_object_udfs: bool):
con = self.con

# enable multiple SQL statements by default
session_parameters.setdefault("MULTI_STATEMENT_COUNT", 0)
# don't format JSON output by default
session_parameters.setdefault("JSON_INDENT", 0)

# overwrite session parameters that are required for ibis + snowflake
# to work
session_parameters.update(
dict(
# Use Arrow for query results
PYTHON_CONNECTOR_QUERY_RESULT_FORMAT="arrow_force",
# JSON output must be strict for null versus undefined
STRICT_JSON_OUTPUT=True,
# Timezone must be UTC
TIMEZONE="UTC",
),
)

with contextlib.closing(con.cursor()) as cur:
cur.execute(
"ALTER SESSION SET {}".format(
" ".join(f"{k} = {v!r}" for k, v in session_parameters.items())
)
)

if create_object_udfs:
dialect = self.name
create_stmt = sge.Create(kind="DATABASE", this="ibis_udfs", exists=True).sql(
dialect
)
if "/" in con.database:
(catalog, db) = con.database.split("/")
use_stmt = sge.Use(
kind="SCHEMA",
this=sg.table(db, catalog=catalog, quoted=self.compiler.quoted),
).sql(dialect)
else:
use_stmt = ""

stmts = [
create_stmt,
# snowflake activates a database on creation, so reset it back
# to the original database and schema
use_stmt,
*itertools.starmap(self._make_udf, _SNOWFLAKE_MAP_UDFS.items()),
]

stmt = ";\n".join(stmts)
with contextlib.closing(con.cursor()) as cur:
try:
cur.execute(stmt)
except Exception as e: # noqa: BLE001
warnings.warn(
f"Unable to create Ibis UDFs, some functionality will not work: {e}"
)


@functools.cache
def monkeypatch_setup_session():
attrname = "_setup_session"
_setup_session.original = getattr(ibis.backends.snowflake.Backend, attrname)
setattr(ibis.backends.snowflake.Backend, attrname, _setup_session)


def maybe_monkeypatch_setup_session():
tokenized = dask.base.tokenize(ibis.backends.snowflake.Backend._setup_session)
if tokenized == '8c96093dd6f2f759ff96fd41199f06f5':
monkeypatch_setup_session()


maybe_monkeypatch_setup_session()
49 changes: 49 additions & 0 deletions python/letsql/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest

SU = pytest.importorskip("letsql.common.utils.snowflake_utils")


@pytest.mark.snowflake
def test_setup_session():
(database, schema) = ("SNOWFLAKE_SAMPLE_DATA", "TPCH_SF1")
con = SU.make_ibis_connection(
database=database,
schema=schema,
create_object_udfs=False,
)
dct = (
con.raw_sql("SELECT CURRENT_WAREHOUSE(), CURRENT_DATABASE(), CURRENT_SCHEMA();")
.fetch_pandas_all()
.iloc[0]
.to_dict()
)
assert con.current_catalog == f"{database}/{schema}"
assert con.current_database is None
assert con.con.database == f"{database}/{schema}"
assert con.con.schema is None
assert dct == {
"CURRENT_WAREHOUSE()": "COMPUTE_WH",
"CURRENT_DATABASE()": None,
"CURRENT_SCHEMA()": None,
}

con = SU.make_ibis_connection(
database=database,
schema=schema,
create_object_udfs=True,
)
dct = (
con.raw_sql("SELECT CURRENT_WAREHOUSE(), CURRENT_DATABASE(), CURRENT_SCHEMA();")
.fetch_pandas_all()
.iloc[0]
.to_dict()
)
assert con.current_catalog == database
assert con.current_database == schema
assert con.con.database == database
assert con.con.schema == schema
assert dct == {
"CURRENT_WAREHOUSE()": "COMPUTE_WH",
"CURRENT_DATABASE()": database,
"CURRENT_SCHEMA()": schema,
}
2 changes: 2 additions & 0 deletions python/letsql/common/utils/snowflake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pandas as pd
import snowflake.connector

import letsql.backends.snowflake.hotfix # noqa: F401


def make_credential_defaults():
return {
Expand Down

0 comments on commit af5881d

Please sign in to comment.