forked from letsql/letsql
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
snowflake setup session hotfix (letsql#57)
* test: add test of issues with ibis' snowflake connection as of 9.0.0 * fix(snowflake): monkeypatch snowflake setup * add a hotfix for proper snowflake connection initialization (pending next release that includes fix(snowflake): properly pass schema and database for sqlglot generation ibis-project/ibis#9221) * add a test for the expectation of initial snowflake connection state re catalog/database based on create_object_udfs
- Loading branch information
Showing
3 changed files
with
137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import contextlib | ||
import functools | ||
import itertools | ||
import warnings | ||
|
||
import dask | ||
import ibis | ||
import sqlglot as sg | ||
import sqlglot.expressions as sge | ||
from ibis.backends.snowflake import _SNOWFLAKE_MAP_UDFS | ||
|
||
|
||
def _setup_session(self, *, session_parameters, create_object_udfs: bool): | ||
con = self.con | ||
|
||
# enable multiple SQL statements by default | ||
session_parameters.setdefault("MULTI_STATEMENT_COUNT", 0) | ||
# don't format JSON output by default | ||
session_parameters.setdefault("JSON_INDENT", 0) | ||
|
||
# overwrite session parameters that are required for ibis + snowflake | ||
# to work | ||
session_parameters.update( | ||
dict( | ||
# Use Arrow for query results | ||
PYTHON_CONNECTOR_QUERY_RESULT_FORMAT="arrow_force", | ||
# JSON output must be strict for null versus undefined | ||
STRICT_JSON_OUTPUT=True, | ||
# Timezone must be UTC | ||
TIMEZONE="UTC", | ||
), | ||
) | ||
|
||
with contextlib.closing(con.cursor()) as cur: | ||
cur.execute( | ||
"ALTER SESSION SET {}".format( | ||
" ".join(f"{k} = {v!r}" for k, v in session_parameters.items()) | ||
) | ||
) | ||
|
||
if create_object_udfs: | ||
dialect = self.name | ||
create_stmt = sge.Create(kind="DATABASE", this="ibis_udfs", exists=True).sql( | ||
dialect | ||
) | ||
if "/" in con.database: | ||
(catalog, db) = con.database.split("/") | ||
use_stmt = sge.Use( | ||
kind="SCHEMA", | ||
this=sg.table(db, catalog=catalog, quoted=self.compiler.quoted), | ||
).sql(dialect) | ||
else: | ||
use_stmt = "" | ||
|
||
stmts = [ | ||
create_stmt, | ||
# snowflake activates a database on creation, so reset it back | ||
# to the original database and schema | ||
use_stmt, | ||
*itertools.starmap(self._make_udf, _SNOWFLAKE_MAP_UDFS.items()), | ||
] | ||
|
||
stmt = ";\n".join(stmts) | ||
with contextlib.closing(con.cursor()) as cur: | ||
try: | ||
cur.execute(stmt) | ||
except Exception as e: # noqa: BLE001 | ||
warnings.warn( | ||
f"Unable to create Ibis UDFs, some functionality will not work: {e}" | ||
) | ||
|
||
|
||
@functools.cache | ||
def monkeypatch_setup_session(): | ||
attrname = "_setup_session" | ||
_setup_session.original = getattr(ibis.backends.snowflake.Backend, attrname) | ||
setattr(ibis.backends.snowflake.Backend, attrname, _setup_session) | ||
|
||
|
||
def maybe_monkeypatch_setup_session(): | ||
tokenized = dask.base.tokenize(ibis.backends.snowflake.Backend._setup_session) | ||
if tokenized == '8c96093dd6f2f759ff96fd41199f06f5': | ||
monkeypatch_setup_session() | ||
|
||
|
||
maybe_monkeypatch_setup_session() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
|
||
SU = pytest.importorskip("letsql.common.utils.snowflake_utils") | ||
|
||
|
||
@pytest.mark.snowflake | ||
def test_setup_session(): | ||
(database, schema) = ("SNOWFLAKE_SAMPLE_DATA", "TPCH_SF1") | ||
con = SU.make_ibis_connection( | ||
database=database, | ||
schema=schema, | ||
create_object_udfs=False, | ||
) | ||
dct = ( | ||
con.raw_sql("SELECT CURRENT_WAREHOUSE(), CURRENT_DATABASE(), CURRENT_SCHEMA();") | ||
.fetch_pandas_all() | ||
.iloc[0] | ||
.to_dict() | ||
) | ||
assert con.current_catalog == f"{database}/{schema}" | ||
assert con.current_database is None | ||
assert con.con.database == f"{database}/{schema}" | ||
assert con.con.schema is None | ||
assert dct == { | ||
"CURRENT_WAREHOUSE()": "COMPUTE_WH", | ||
"CURRENT_DATABASE()": None, | ||
"CURRENT_SCHEMA()": None, | ||
} | ||
|
||
con = SU.make_ibis_connection( | ||
database=database, | ||
schema=schema, | ||
create_object_udfs=True, | ||
) | ||
dct = ( | ||
con.raw_sql("SELECT CURRENT_WAREHOUSE(), CURRENT_DATABASE(), CURRENT_SCHEMA();") | ||
.fetch_pandas_all() | ||
.iloc[0] | ||
.to_dict() | ||
) | ||
assert con.current_catalog == database | ||
assert con.current_database == schema | ||
assert con.con.database == database | ||
assert con.con.schema == schema | ||
assert dct == { | ||
"CURRENT_WAREHOUSE()": "COMPUTE_WH", | ||
"CURRENT_DATABASE()": database, | ||
"CURRENT_SCHEMA()": schema, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters