From 753333e917c72037e212ed8aa664120bd478862b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 1 Feb 2024 09:01:24 +0000 Subject: [PATCH] update --- Makefile | 3 + fugue_snowflake/_utils.py | 123 +++++++++++- fugue_snowflake/client.py | 295 +++++++++++++++++++++------- tests/fugue_snowflake/__init__.py | 0 tests/fugue_snowflake/test_utils.py | 58 ++++++ 5 files changed, 409 insertions(+), 70 deletions(-) create mode 100644 tests/fugue_snowflake/__init__.py create mode 100644 tests/fugue_snowflake/test_utils.py diff --git a/Makefile b/Makefile index 713d149..a922d0d 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,9 @@ trinodocker: testtrino: python3 -bb -m pytest tests/fugue_trino --cov=fugue_trino +testsf: + python3 -bb -m pytest tests/fugue_snowflake --cov=fugue_snowflake + lab: mkdir -p tmp pip install . diff --git a/fugue_snowflake/_utils.py b/fugue_snowflake/_utils.py index 46192c4..46fcfad 100644 --- a/fugue_snowflake/_utils.py +++ b/fugue_snowflake/_utils.py @@ -1,7 +1,128 @@ -from fugue_ibis._utils import to_schema as _to_schema +from typing import Any, Dict, List, Optional + +import pyarrow as pa from fugue_ibis import IbisSchema +from fugue_ibis._utils import to_schema as _to_schema +from snowflake.connector.constants import FIELD_TYPES +from snowflake.connector.result_batch import ResultBatch from triad import Schema +from triad.utils.pyarrow import ( + TRIAD_DEFAULT_TIMESTAMP, + get_alter_func, + parse_json_columns, + replace_types_in_table, +) + +_PA_TYPE_TO_SF_TYPE: Dict[pa.DataType, str] = { + pa.string(): "STRING", + pa.bool_(): "BOOLEAN", + pa.int8(): "BYTEINT", + pa.int16(): "TINYINT", + pa.int32(): "SMALLINT", + pa.int64(): "BIGINT", + pa.uint8(): "INT", + pa.uint16(): "INT", + pa.uint32(): "INT", + pa.uint64(): "INT", + pa.float16(): "FLOAT", + pa.float32(): "FLOAT", + pa.float64(): "FLOAT", + pa.date32(): "DATE", + pa.binary(): "BINARY", +} + + +def quote_name(name: str) -> str: + quote = '"' + return quote + name.replace(quote, quote + quote) + quote def to_schema(schema: IbisSchema) -> Schema: return _to_schema(schema) + + +def pa_type_to_snowflake_type_str(tp: pa.DataType) -> str: + if tp in _PA_TYPE_TO_SF_TYPE: + return _PA_TYPE_TO_SF_TYPE[tp] + if pa.types.is_timestamp(tp): + if tp.tz is not None: + return "TIMESTAMP_TZ" + return "TIMESTAMP_NTZ" + if pa.types.is_decimal(tp): + return f"DECIMAL({tp.precision},{tp.scale})" + if pa.types.is_list(tp): + # itp = pa_type_to_snowflake_type_str(tp.value_type) + # return f"ARRAY({itp})" + return "ARRAY" + if pa.types.is_struct(tp): + # fields = [] + # for f in tp: + # fields.append( + # f"{quote_name(f.name)} {pa_type_to_snowflake_type_str(f.type)}" + # ) + # return f"OBJECT({', '.join(fields)})" + return "OBJECT" + if pa.types.is_map(tp): + # ktp = pa_type_to_snowflake_type_str(tp.key_type) + # vtp = pa_type_to_snowflake_type_str(tp.item_type) + # return f"MAP({ktp}, {vtp})" + return "MAP" + raise NotImplementedError(f"Unsupported type {tp}") + + +def fix_snowflake_arrow_result(result: pa.Table) -> pa.Table: + return replace_types_in_table( + result, + [ + (lambda tp: pa.types.is_date64(tp), pa.date32()), + ( + lambda tp: pa.types.is_timestamp(tp) + and tp.tz is None + and tp != TRIAD_DEFAULT_TIMESTAMP, + TRIAD_DEFAULT_TIMESTAMP, + ), + ], + ) + + +def to_snowflake_schema(schema: Any) -> str: + _s = schema if isinstance(schema, Schema) else Schema(schema) + fields = [] + for f in _s.fields: + fields.append(f"{quote_name(f.name)} {pa_type_to_snowflake_type_str(f.type)}") + return ", ".join(fields) + + +def get_arrow_from_batches( + batches: Optional[List[ResultBatch]], + schema: None = None, + infer_nested_types: bool = False, +) -> pa.Table: + if batches is None or len(batches) == 0: + if schema is not None: + return ( + schema if isinstance(schema, Schema) else Schema(schema) + ).create_empty_arrow_table() + raise ValueError("No result") + nested_cols = _get_nested_columns(batches[0]) + adf = pa.concat_tables([x.to_arrow() for x in batches]) + if adf.num_rows == 0: + return fix_snowflake_arrow_result(adf) + if schema is None: + adf = fix_snowflake_arrow_result(adf) + if infer_nested_types and len(nested_cols) > 0: + adf = parse_json_columns(adf, nested_cols) + return adf + _schema = schema if isinstance(schema, Schema) else Schema(schema) + adf = parse_json_columns(adf, nested_cols) + func = get_alter_func(adf.schema, _schema.pa_schema, safe=True) + return func(adf) + + +def _get_nested_columns(batch: ResultBatch) -> List[str]: + res: List[str] = [] + for meta in batch.schema: + f = FIELD_TYPES[meta.type_code] + if f.name in ["OBJECT", "ARRAY", "MAP", "VARIANT"]: + res.append(meta.name) + return res diff --git a/fugue_snowflake/client.py b/fugue_snowflake/client.py index e8065be..ff0d6b2 100644 --- a/fugue_snowflake/client.py +++ b/fugue_snowflake/client.py @@ -1,13 +1,14 @@ +import os +from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, List, Optional +from tempfile import TemporaryDirectory +from typing import Any, Iterator, List, Optional from uuid import uuid4 import ibis -import pyarrow import pyarrow as pa import snowflake.connector from fugue import ( - AnyDataFrame, ArrayDataFrame, ArrowDataFrame, DataFrame, @@ -16,10 +17,19 @@ PartitionSpec, ) from fugue_ibis import IbisTable +from pyarrow.parquet import write_table as write_parquet +from snowflake.connector.cursor import SnowflakeCursor +from snowflake.connector.result_batch import ResultBatch from triad import Schema, SerializableRLock, assert_or_throw from ._constants import get_client_init_params -from ._utils import to_schema +from ._utils import ( + get_arrow_from_batches, + pa_type_to_snowflake_type_str, + quote_name, + to_schema, + to_snowflake_schema, +) _FUGUE_SNOWFLAKE_CLIENT_CONTEXT = ContextVar( "_FUGUE_SNOWFLAKE_CLIENT_CONTEXT", default=None @@ -31,12 +41,12 @@ class SnowflakeClient: def __init__( self, - account: Optional[str] = None, - user: Optional[str] = None, - password: Optional[str] = None, - database: Optional[str] = None, - warehouse: Optional[str] = None, - schema: Optional[str] = None, + account: str, + user: str, + password: str, + database: str, + warehouse: str, + schema: str, role: Optional[str] = "ACCOUNTADMIN", ): self._temp_tables: List[str] = [] @@ -83,6 +93,33 @@ def get_current() -> "SnowflakeClient": def sf(self) -> snowflake.connector.SnowflakeConnection: return self._sf + def __setstate__(self, state): + for k, v in state.items(): + setattr(self, k, v) + self._ibis = ibis.snowflake.connect( + account=self._account, + user=self._user, + password=self._password, + warehouse=self._warehouse, + database=f"{self._database}/{self._schema}", + role=self._role, + ) + + con = self._ibis.con.connect() + self._sf: Any = con.connection.dbapi_connection # type: ignore + + def __getstate__(self): + state = self.__dict__.copy() + del state["_sf"] + del state["_ibis"] + return state + + @contextmanager + def cursor(self) -> Iterator[SnowflakeCursor]: + with self._ibis.con.connect() as con: + with con.connection.cursor() as cur: + yield cur + def stop(self): # for tt in self._temp_tables: # self.sf.cursor().execute(f"DROP TABLE IF EXISTS {tt}") @@ -106,90 +143,210 @@ def ibis(self) -> ibis.BaseBackend: def query_to_ibis(self, query: str) -> IbisTable: return self.ibis.sql(query) - def query_to_arrow(self, query: str) -> pa.Table: - with self.sf.cursor() as cur: - cur.execute( + def query_to_result_batches( + self, query: str, cursor: Optional[SnowflakeCursor] = None + ) -> Optional[List[ResultBatch]]: + if cursor is None: + with self.cursor() as cur: + cur.execute( + "alter session set python_connector_query_result_format='ARROW'" + ) + cur.execute(query) + batches = cur.get_result_batches() + else: + cursor.execute( "alter session set python_connector_query_result_format='ARROW'" ) - cur.execute(query) - return cur.fetch_arrow_all() + cursor.execute(query) + batches = cursor.get_result_batches() + return batches + + def query_to_arrow( + self, + query: str, + schema: Any = None, + infer_nested_types: bool = False, + cursor: Optional[SnowflakeCursor] = None, + ) -> pa.Table: + return get_arrow_from_batches( + self.query_to_result_batches(query, cursor=cursor), + schema=schema, + infer_nested_types=infer_nested_types, + ) - def query_to_engine_df(self, query: str, engine: ExecutionEngine) -> DataFrame: - tb = self.query_to_ibis(query) - schema = to_schema(tb.schema()) + def query_to_engine_df( + self, + query: str, + engine: ExecutionEngine, + schema: Any = None, + infer_nested_types: bool = False, + cursor: Optional[SnowflakeCursor] = None, + ) -> DataFrame: + if schema is not None: + _schema = schema if isinstance(schema, Schema) else Schema(schema) + else: + tb = self.query_to_ibis(query) + _schema = to_schema(tb.schema()) - with self.sf.cursor() as cur: - cur.execute( - "alter session set python_connector_query_result_format='ARROW'" - ) - cur.execute(query) - batches = cur.get_result_batches() + batches = self.query_to_result_batches(query, cursor=cursor) if batches is None or len(batches) == 0: - return ArrowDataFrame(schema=schema) + raise ValueError(f"No data returned from {query}") idx = ArrayDataFrame([[x] for x in range(len(batches))], "id:int") def _map(cursor: Any, df: LocalDataFrame) -> LocalDataFrame: - tbs: List[pa.Table] = [] - for row in df.as_dict_iterable(): - batch = batches[row["id"]] - tbs.append(batch.to_arrow()) - res = pa.concat_tables(tbs) - return ArrowDataFrame(res) + _b = [batches[row["id"]] for row in df.as_dict_iterable()] # type: ignore + adf = get_arrow_from_batches( + _b, schema=schema, infer_nested_types=infer_nested_types + ) + return ArrowDataFrame(adf) res = engine.map_engine.map_dataframe( idx, _map, - output_schema=schema, + output_schema=_schema, partition_spec=PartitionSpec("per_row"), ) return res - def load_df(self, df: DataFrame, name: str, mode: str = "overwrite") -> None: - if isinstance(df, ArrayDataFrame): - df_pandas = df.as_pandas() - else: - df_pandas = ArrowDataFrame(df).as_pandas() + def df_to_temp_table( + self, df: DataFrame, engine: ExecutionEngine, transient: bool = True + ) -> str: + with _Uploader( + self, self.sf.cursor(), self._database, self._schema + ) as uploader: + return uploader.to_temp_table(df, engine, transient=transient) - if mode == "overwrite": - snowflake.connector.pandas_tools.write_pandas( - self.sf, df_pandas, name, overwrite=True + def df_to_table( + self, + df: DataFrame, + table: str, + mode: str, + engine: ExecutionEngine, + table_type: str = "", + ) -> str: + with _Uploader( + self, self.sf.cursor(), self._database, self._schema + ) as uploader: + return uploader.to_table( + df, table, mode=mode, engine=engine, table_type=table_type ) - elif mode == "append": - snowflake.connector.pandas_tools.write_pandas(self.sf, df_pandas, name) - else: - raise ValueError(f"Unsupported mode: {mode}") - def create_temp_table(self, schema: Schema) -> str: - temp_table_name = f"_temp_{uuid4().hex}" - df = ArrayDataFrame(schema=schema) - df_pandas = df.as_pandas() - snowflake.connector.pandas_tools.write_pandas( - self.sf, df_pandas, temp_table_name, overwrite=True, table_type="temporary" +class _Uploader: + def __init__( + self, + client: SnowflakeClient, + cursor: SnowflakeCursor, + database: str, + schema: str, + ): + self._client = client + self._cursor = cursor + self._database = database + self._schema = schema + self._stage = self._get_full_rand_name() + + def _get_full_rand_name(self) -> str: + return self._database + "." + self._schema + "." + _temp_rand_str().upper() + + def __enter__(self) -> "_Uploader": + create_stage_sql = ( + f"CREATE STAGE IF NOT EXISTS {self._stage}" " FILE_FORMAT=(TYPE=PARQUET)" ) + print(create_stage_sql) + self._cursor.execute(create_stage_sql).fetchall() + return self - self._temp_tables.append(temp_table_name) + def __exit__( + self, exception_type: Any, exception_value: Any, exception_traceback: Any + ) -> None: + drop_stage_sql = f"DROP STAGE IF EXISTS {self._stage}" + print(drop_stage_sql) + self._cursor.execute(drop_stage_sql).fetchall() + + def to_temp_table( + self, df: DataFrame, engine: ExecutionEngine, transient: bool = False + ) -> str: + files = self.upload(df, engine) + table = self._create_temp_table(df.schema, transient=transient) + return self._copy_to_table(files, table) + + def to_table( + self, + df: DataFrame, + table: str, + mode: str, + engine: ExecutionEngine, + table_type: str = "", + ) -> str: + files = self.upload(df, engine) + assert_or_throw( + mode in ["overwrite", "append"], ValueError(f"Unsupported mode: {mode}") + ) + if mode == "overwrite": + self._cursor.execute(f"DROP TABLE IF EXISTS {table}").fetchall() + table = self._create_table(df.schema, table, table_type=table_type) + return self._copy_to_table(files, table) - return temp_table_name + def upload(self, df: DataFrame, engine: ExecutionEngine) -> List[str]: + stage_location = self._stage + client = self._client - def register_temp_table(self, name: str): - self._temp_tables.append(name) + def _map(cursor: Any, df: LocalDataFrame) -> LocalDataFrame: + file = _temp_rand_str() + ".parquet" + with TemporaryDirectory() as f: + path = os.path.join(f, file) + write_parquet(df.as_arrow(), path) + with client.cursor() as cur: + cur.execute(f"PUT file://{path} @{stage_location}").fetchall() + return ArrayDataFrame([[file]], "file:str") - def is_temp_table(self, name: str) -> bool: - return name in self._temp_tables + res = engine.map_engine.map_dataframe( + df, + _map, + output_schema=Schema("file:str"), + partition_spec=PartitionSpec(), + map_func_format_hint="pyarrow", + ) + return res.as_pandas().file.tolist() + + def _create_table(self, schema: Any, table: str, table_type: str) -> str: + expr = to_snowflake_schema(schema) + create_table_sql = f"CREATE {table_type.upper()} TABLE {table} ({expr})" + print(create_table_sql) + self._cursor.execute(create_table_sql) + return table + + def _create_temp_table(self, schema: Any, transient: bool = False) -> str: + table = self._get_full_rand_name() + return self._create_table(schema, table, "TRANSIENT" if transient else "TEMP") + + def _copy_to_table(self, files: List[str], table: str) -> str: + files_expr = ", ".join([f"'{x}'" for x in files]) + copy_sql = ( + f"COPY INTO {table} FROM" + f" @{self._stage}" + f" FILES = ({files_expr})" + f" FILE_FORMAT = (TYPE=PARQUET)" + f" MATCH_BY_COLUMN_NAME = CASE_SENSITIVE" + ) + print(copy_sql) + res = self._cursor.execute(copy_sql).fetchall() + print(res) + return table - def df_to_table( - self, df: AnyDataFrame, table_name: str = None, overwrite: bool = False - ) -> Any: - if table_name is None: - if isinstance(df, ArrayDataFrame): - schema = pyarrow.Table.from_pandas(df.as_pandas()).schema - else: - schema = ArrowDataFrame(df).schema - table_name = self.create_temp_table(schema) - - self.load_df(df, table_name, mode="overwrite" if overwrite else "append") - - return table_name + +def _temp_rand_str() -> str: + return "temp_" + str(uuid4()).split("-")[0] + + +def _to_snowflake_select_schema(schema: Any) -> str: + _s = schema if isinstance(schema, Schema) else Schema(schema) + fields = [] + for f in _s.fields: + fields.append( + f"$1:{quote_name(f.name)}::{pa_type_to_snowflake_type_str(f.type)}" + ) + return ", ".join(fields) diff --git a/tests/fugue_snowflake/__init__.py b/tests/fugue_snowflake/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/fugue_snowflake/test_utils.py b/tests/fugue_snowflake/test_utils.py new file mode 100644 index 0000000..acc5260 --- /dev/null +++ b/tests/fugue_snowflake/test_utils.py @@ -0,0 +1,58 @@ +from fugue_snowflake._utils import to_snowflake_schema, fix_snowflake_arrow_result +from triad import Schema +from fugue_snowflake.client import SnowflakeClient +from typing import Any +from pytest import raises + + +def test_to_snowflake_schema(): + def _assert(s: Any, t: str): + assert to_snowflake_schema(s) == t + + _assert( + "a:int8,b:int16,c:int32,d:int64", + '"a" BYTEINT, "b" TINYINT, "c" SMALLINT, "d" BIGINT', + ) + _assert("a:uint8,b:uint16,c:uint32,d:uint64", '"a" INT, "b" INT, "c" INT, "d" INT') + _assert(Schema("a:decimal(10,2)"), '"a" DECIMAL(10,2)') + _assert("a:float16,b:float32,c:float64", '"a" FLOAT, "b" FLOAT, "c" FLOAT') + + _assert("a:bool,b:str,d:bytes", '"a" BOOLEAN, "b" STRING, "d" BINARY') + + _assert("a:date,b:datetime", '"a" DATE, "b" TIMESTAMP_NTZ') + _assert("a:timestamp(ns,UTC)", '"a" TIMESTAMP_TZ') + + _assert("a:[int]", '"a" ARRAY') + _assert("a:", '"a" MAP') + _assert("a:{a:int,b:str}", '"a" OBJECT') + + _assert("大:{`a b`:int,b:str}", '"大" OBJECT') + + with raises(NotImplementedError): + to_snowflake_schema("a:null") + + +def test_to_snowflake_schema_with_temp_table(): + client = SnowflakeClient.get_or_create() + + def _assert(s: Any, t: str): + ss = to_snowflake_schema(s) + temp_table = "test1.s1.ttt" + with client.cursor() as cur: + cur.execute(f"CREATE OR REPLACE TEMP TABLE {temp_table} ({ss})") + adf = client.query_to_arrow( + f"SELECT * FROM {temp_table}", cursor=cur, infer_nested_types=False + ) + cur.execute(f"DROP TABLE {temp_table}") + assert Schema(t) == Schema(adf.schema) + + _assert("a:int8,b:int16,大:int32,D:int64", "a:long,b:long,大:long,D:long") + _assert("a:uint8,b:uint16,c:uint32,D:uint64", "a:long,b:long,c:long,D:long") + _assert("a:float16,b:float32,c:float64", "a:double,b:double,c:double") + _assert("a:bool,b:str,d:bytes", "a:bool,b:str,d:bytes") + _assert("a:date,b:datetime", "a:date,b:datetime") + _assert("a:[long]", "a:[long]") + + # not supported by snowflake + # _assert("a:decimal(10,2)", "a:decimal(10,2)") + # _assert("a:timestamp(ns,UTC)", 'a:timestamp(ns,UTC)')