From 66f6666265aab1b20bfa90717092279b383ef0f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Fri, 2 Apr 2021 15:00:42 +0700 Subject: [PATCH] feat: add trino sqlalchemy dialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Đặng Minh Dũng --- trino/dbapi.py | 40 ++++- trino/sqlalchemy/__init__.py | 14 ++ trino/sqlalchemy/compiler.py | 103 ++++++++++++ trino/sqlalchemy/datatype.py | 169 ++++++++++++++++++++ trino/sqlalchemy/dialect.py | 301 +++++++++++++++++++++++++++++++++++ trino/sqlalchemy/error.py | 24 +++ 6 files changed, 648 insertions(+), 3 deletions(-) create mode 100644 trino/sqlalchemy/__init__.py create mode 100644 trino/sqlalchemy/compiler.py create mode 100644 trino/sqlalchemy/datatype.py create mode 100644 trino/sqlalchemy/dialect.py create mode 100644 trino/sqlalchemy/error.py diff --git a/trino/dbapi.py b/trino/dbapi.py index 616ae341..e3b18ae8 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -29,10 +29,44 @@ import trino.exceptions import trino.client import trino.logging -from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION - +from trino.transaction import ( + Transaction, + IsolationLevel, + NO_TRANSACTION +) +from trino.exceptions import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) -__all__ = ["connect", "Connection", "Cursor"] +__all__ = [ + # https://www.python.org/dev/peps/pep-0249/#globals + "apilevel", + "threadsafety", + "paramstyle", + "connect", + "Connection", + "Cursor", + # https://www.python.org/dev/peps/pep-0249/#exceptions + "Warning", + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", +] apilevel = "2.0" diff --git a/trino/sqlalchemy/__init__.py b/trino/sqlalchemy/__init__.py new file mode 100644 index 00000000..de64cdcc --- /dev/null +++ b/trino/sqlalchemy/__init__.py @@ -0,0 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from sqlalchemy.dialects import registry + +registry.register("trino", "trino.sqlalchemy.dialect.TrinoDialect", "TrinoDialect") diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py new file mode 100644 index 00000000..4e0db8a4 --- /dev/null +++ b/trino/sqlalchemy/compiler.py @@ -0,0 +1,103 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from sqlalchemy.sql import compiler + +# https://trino.io/docs/current/language/reserved.html +RESERVED_WORDS = { + "alter", + "and", + "as", + "between", + "by", + "case", + "cast", + "constraint", + "create", + "cross", + "cube", + "current_date", + "current_path", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "deallocate", + "delete", + "describe", + "distinct", + "drop", + "else", + "end", + "escape", + "except", + "execute", + "exists", + "extract", + "false", + "for", + "from", + "full", + "group", + "grouping", + "having", + "in", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "left", + "like", + "localtime", + "localtimestamp", + "natural", + "normalize", + "not", + "null", + "on", + "or", + "order", + "outer", + "prepare", + "recursive", + "right", + "rollup", + "select", + "table", + "then", + "true", + "uescape", + "union", + "unnest", + "using", + "values", + "when", + "where", + "with", +} + + +class TrinoSQLCompiler(compiler.SQLCompiler): + pass + + +class TrinoDDLCompiler(compiler.DDLCompiler): + pass + + +class TrinoTypeCompiler(compiler.GenericTypeCompiler): + pass + + +class TrinoIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py new file mode 100644 index 00000000..570d15fa --- /dev/null +++ b/trino/sqlalchemy/datatype.py @@ -0,0 +1,169 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import Dict, Iterator, Type, Union + +from sqlalchemy import util +from sqlalchemy.sql import sqltypes +from sqlalchemy.sql.type_api import TypeEngine + +# https://trino.io/docs/current/language/types.html +_type_map = { + # === Boolean === + 'boolean': sqltypes.BOOLEAN, + + # === Integer === + 'tinyint': sqltypes.SMALLINT, + 'smallint': sqltypes.SMALLINT, + 'integer': sqltypes.INTEGER, + 'bigint': sqltypes.BIGINT, + + # === Floating-point === + 'real': sqltypes.FLOAT, + 'double': sqltypes.FLOAT, + + # === Fixed-precision === + 'decimal': sqltypes.DECIMAL, + + # === String === + 'varchar': sqltypes.VARCHAR, + 'char': sqltypes.CHAR, + 'varbinary': sqltypes.VARBINARY, + 'json': sqltypes.JSON, + + # === Date and time === + 'date': sqltypes.DATE, + 'time': sqltypes.TIME, + 'timestamp': sqltypes.TIMESTAMP, + + # 'interval year to month': + # 'interval day to second': + # + # === Structural === + # 'array': ARRAY, + # 'map': MAP + # 'row': ROW + # + # === Mixed === + # 'ipaddress': IPADDRESS + # 'uuid': UUID, + # 'hyperloglog': HYPERLOGLOG, + # 'p4hyperloglog': P4HYPERLOGLOG, + # 'qdigest': QDIGEST, + # 'tdigest': TDIGEST, +} + +SQLType = Union[TypeEngine, Type[TypeEngine]] + + +class MAP(TypeEngine): + __visit_name__ = "MAP" + + def __init__(self, key_type: SQLType, value_type: SQLType): + if isinstance(key_type, type): + key_type = key_type() + self.key_type: TypeEngine = key_type + + if isinstance(value_type, type): + value_type = value_type() + self.value_type: TypeEngine = value_type + + @property + def python_type(self): + return dict + + +class ROW(TypeEngine): + __visit_name__ = "ROW" + + def __init__(self, attr_types: Dict[str, SQLType]): + for name, attr_type in attr_types.items(): + if isinstance(attr_type, type): + attr_type = attr_type() + attr_types[name] = attr_type + self.attr_types: Dict[str, TypeEngine] = attr_types + + @property + def python_type(self): + return dict + + +def split(string: str, delimiter: str = ',', + quote: str = '"', escaped_quote: str = r'\"', + open_bracket: str = '(', close_bracket: str = ')') -> Iterator[str]: + """ + A split function that is aware of quotes and brackets/parentheses. + + :param string: string to split + :param delimiter: string defining where to split, usually a comma or space + :param quote: string, either a single or a double quote + :param escaped_quote: string representing an escaped quote + :param open_bracket: string, either [, {, < or ( + :param close_bracket: string, either ], }, > or ) + """ + parens = 0 + quotes = False + i = 0 + for j, character in enumerate(string): + complete = parens == 0 and not quotes + if complete and character == delimiter: + yield string[i:j] + i = j + len(delimiter) + elif character == open_bracket: + parens += 1 + elif character == close_bracket: + parens -= 1 + elif character == quote: + if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote: + quotes = False + elif not quotes: + quotes = True + yield string[i:] + + +def parse_sqltype(type_str: str) -> TypeEngine: + type_str = type_str.strip().lower() + match = re.match(r'^(?P\w+)\s*(?:\((?P.*)\))?', type_str) + if not match: + util.warn(f"Could not parse type name '{type_str}'") + return sqltypes.NULLTYPE + type_name = match.group("type") + type_opts = match.group("options") + + if type_name == "array": + item_type = parse_sqltype(type_opts) + if isinstance(item_type, sqltypes.ARRAY): + dimensions = (item_type.dimensions or 1) + 1 + return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions) + return sqltypes.ARRAY(item_type) + elif type_name == "map": + key_type_str, value_type_str = split(type_opts) + key_type = parse_sqltype(key_type_str) + value_type = parse_sqltype(value_type_str) + return MAP(key_type, value_type) + elif type_name == "row": + attr_types: Dict[str, SQLType] = {} + for attr_str in split(type_opts): + name, attr_type_str = split(attr_str.strip(), delimiter=' ') + attr_type = parse_sqltype(attr_type_str) + attr_types[name] = attr_type + return ROW(attr_types) + + if type_name not in _type_map: + util.warn(f"Did not recognize type '{type_name}'") + return sqltypes.NULLTYPE + type_class = _type_map[type_name] + type_args = [int(o.strip()) for o in type_opts.split(',')] if type_opts else [] + if type_name in ('time', 'timestamp'): + type_kwargs = dict(timezone=type_str.endswith("with time zone")) + return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision + return type_class(*type_args) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py new file mode 100644 index 00000000..7c1c92b2 --- /dev/null +++ b/trino/sqlalchemy/dialect.py @@ -0,0 +1,301 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from textwrap import dedent +from typing import Any, Dict, List, Optional, Tuple + +from sqlalchemy import exc, sql +from sqlalchemy.engine.base import Connection +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.url import URL + +from trino import dbapi as trino_dbapi +from trino.auth import BasicAuthentication + +from . import compiler, datatype, error + + +class TrinoDialect(DefaultDialect): + name = 'trino' + driver = 'rest' + + statement_compiler = compiler.TrinoSQLCompiler + ddl_compiler = compiler.TrinoDDLCompiler + type_compiler = compiler.TrinoTypeCompiler + preparer = compiler.TrinoIdentifierPreparer + + # Data Type + supports_native_enum = False + supports_native_boolean = True + supports_native_decimal = True + + # Column options + supports_sequences = False + supports_comments = True + inline_comments = True + supports_default_values = False + + # DDL + supports_alter = True + + # DML + supports_empty_insert = False + supports_multivalues_insert = True + postfetch_lastrowid = False + + @classmethod + def dbapi(cls): + """ + ref: https://www.python.org/dev/peps/pep-0249/#module-interface + """ + return trino_dbapi + + def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]: + args, kwargs = super(TrinoDialect, self).create_connect_args(url) # type: List[Any], Dict[str, Any] + + db_parts = kwargs.pop('database', 'hive').split('/') + if len(db_parts) == 1: + kwargs['catalog'] = db_parts[0] + elif len(db_parts) == 2: + kwargs['catalog'] = db_parts[0] + kwargs['schema'] = db_parts[1] + else: + raise ValueError(f'Unexpected database format {url.database}') + + username = kwargs.pop('username', 'anonymous') + kwargs['user'] = username + + password = kwargs.pop('password', None) + if password: + kwargs['http_scheme'] = 'https' + kwargs['auth'] = BasicAuthentication(username, password) + + return args, kwargs + + def get_columns(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + if not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError(f'schema={schema}, table={table_name}') + return self._get_columns(connection, table_name, schema, **kw) + + def _get_columns(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + schema = schema or self._get_default_schema_name(connection) + query = dedent(''' + SELECT + "column_name", + "column_default", + "is_nullable", + "data_type" + FROM "information_schema"."columns" + WHERE "table_schema" = :schema AND "table_name" = :table + ORDER BY "ordinal_position" ASC + ''').strip() + res = connection.execute(sql.text(query), schema=schema, table=table_name) + columns = [] + for record in res: + column = dict( + name=record.column_name, + type=datatype.parse_sqltype(record.data_type), + nullable=(record.is_nullable or '').upper() == 'YES', + default=record.column_default, + ) + columns.append(column) + return columns + + def get_pk_constraint(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + """Trino has no support for primary keys. Returns a dummy""" + return dict(name=None, constrained_columns=[]) + + def get_primary_keys(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[str]: + pk = self.get_pk_constraint(connection, table_name, schema) + return pk.get('constrained_columns') + + def get_foreign_keys(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for foreign keys. Returns an empty list.""" + return [] + + def get_schema_names(self, connection: Connection, **kw) -> List[str]: + query = 'SHOW SCHEMAS' + res = connection.execute(sql.text(query)) + return [row.Schema for row in res] + + def get_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + query = 'SHOW TABLES' + if schema: + query = f'{query} FROM {self.identifier_preparer.quote_identifier(schema)}' + res = connection.execute(sql.text(query)) + return [row.Table for row in res] + + def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + """Trino has no support for temporary tables. Returns an empty list.""" + return [] + + def get_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + schema = schema or self._get_default_schema_name(connection) + if schema is None: + raise exc.NoSuchTableError('schema is required') + query = dedent(''' + SELECT "table_name" + FROM "information_schema"."views" + WHERE "table_schema" = :schema + ''').strip() + res = connection.execute(sql.text(query), schema=schema) + return [row.table_name for row in res] + + def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + """Trino has no support for temporary views. Returns an empty list.""" + return [] + + def get_view_definition(self, connection: Connection, view_name: str, schema: str = None, **kw) -> str: + full_view = self._get_full_table(view_name, schema) + query = f'SHOW CREATE VIEW {full_view}' + try: + res = connection.execute(sql.text(query)) + return res.scalar() + except error.TrinoQueryError as e: + if e.error_name in ( + error.TABLE_NOT_FOUND, + error.SCHEMA_NOT_FOUND, + error.CATALOG_NOT_FOUND, + ): + raise exc.NoSuchTableError(full_view) from e + raise + + def get_indexes(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + if not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError(f'schema={schema}, table={table_name}') + + partitioned_columns = self._get_columns(connection, f'{table_name}$partitions', schema, **kw) + partition_index = dict( + name='partition', + column_names=[col['name'] for col in partitioned_columns], + unique=False + ) + return [partition_index, ] + + def get_unique_constraints(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for unique constraints. Returns an empty list.""" + return [] + + def get_check_constraints(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for check constraints. Returns an empty list.""" + return [] + + def get_table_comment(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + properties_table = self._get_full_table(f"{table_name}$properties", schema) + query = f'SELECT "comment" FROM {properties_table}' + try: + res = connection.execute(sql.text(query)) + return dict(text=res.scalar()) + except error.TrinoQueryError as e: + if e.error_name in ( + error.NOT_FOUND, + error.COLUMN_NOT_FOUND, + error.TABLE_NOT_FOUND, + ): + return dict(text=None) + raise + + def has_schema(self, connection: Connection, schema: str) -> bool: + query = f"SHOW SCHEMAS LIKE '{schema}'" + try: + res = connection.execute(sql.text(query)) + return res.first() is not None + except error.TrinoQueryError as e: + if e.error_name in ( + error.TABLE_NOT_FOUND, + error.SCHEMA_NOT_FOUND, + error.CATALOG_NOT_FOUND, + ): + return False + raise + + def has_table(self, connection: Connection, + table_name: str, schema: str = None) -> bool: + query = 'SHOW TABLES' + if schema: + query = f'{query} FROM {self.identifier_preparer.quote_identifier(schema)}' + query = f"{query} LIKE '{table_name}'" + try: + res = connection.execute(sql.text(query)) + return res.first() is not None + except error.TrinoQueryError as e: + if e.error_name in ( + error.TABLE_NOT_FOUND, + error.SCHEMA_NOT_FOUND, + error.CATALOG_NOT_FOUND, + error.MISSING_SCHEMA_NAME, + ): + return False + raise + + def has_sequence(self, connection: Connection, + sequence_name: str, schema: str = None) -> bool: + """Trino has no support for sequence. Returns False indicate that given sequence does not exists.""" + return False + + def _get_server_version_info(self, connection: Connection) -> Tuple[int, ...]: + query = 'SELECT version()' + res = connection.execute(sql.text(query)) + version = res.scalar() + return tuple([version]) + + def _get_default_schema_name(self, connection: Connection) -> Optional[str]: + dbapi_connection: trino_dbapi.Connection = connection.connection + return dbapi_connection.schema + + def do_rollback(self, dbapi_connection): + if dbapi_connection.transaction is not None: + dbapi_connection.rollback() + + def do_begin_twophase(self, connection: Connection, xid): + pass + + def do_prepare_twophase(self, connection: Connection, xid): + pass + + def do_rollback_twophase(self, connection: Connection, xid, + is_prepared: bool = True, recover: bool = False) -> None: + pass + + def do_commit_twophase(self, connection: Connection, xid, + is_prepared: bool = True, recover: bool = False) -> None: + pass + + def do_recover_twophase(self, connection: Connection) -> None: + pass + + def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level) -> None: + dbapi_conn._isolation_level = getattr(trino_dbapi.IsolationLevel, level) + + def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str: + level_names = ['AUTOCOMMIT', + 'READ_UNCOMMITTED', + 'READ_COMMITTED', + 'REPEATABLE_READ', + 'SERIALIZABLE'] + return level_names[dbapi_conn.isolation_level] + + def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str: + table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name + if schema: + schema_part = self.identifier_preparer.quote_identifier(schema) if quote else schema + return f'{schema_part}.{table_part}' + + return table_part diff --git a/trino/sqlalchemy/error.py b/trino/sqlalchemy/error.py new file mode 100644 index 00000000..3079d6eb --- /dev/null +++ b/trino/sqlalchemy/error.py @@ -0,0 +1,24 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from trino.exceptions import TrinoQueryError # noqa + +# ref: https://github.com/trinodb/trino/blob/master/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +NOT_FOUND = 'NOT_FOUND' +COLUMN_NOT_FOUND = 'COLUMN_NOT_FOUND' +TABLE_NOT_FOUND = 'TABLE_NOT_FOUND' +SCHEMA_NOT_FOUND = 'SCHEMA_NOT_FOUND' +CATALOG_NOT_FOUND = 'CATALOG_NOT_FOUND' + +MISSING_TABLE = 'MISSING_TABLE' +MISSING_COLUMN_NAME = 'MISSING_COLUMN_NAME' +MISSING_SCHEMA_NAME = 'MISSING_SCHEMA_NAME' +MISSING_CATALOG_NAME = 'MISSING_CATALOG_NAME'