From c3a2264a0a826cce74a2ff736bd3588d509fc9cc Mon Sep 17 00:00:00 2001 From: Harry Date: Sat, 27 Apr 2024 20:39:03 +0700 Subject: [PATCH] Format all following lint rules Signed-off-by: Harry --- python/generate.py | 24 +- python/pyhive/__init__.py | 6 +- python/pyhive/common.py | 34 ++- python/pyhive/exc.py | 24 +- python/pyhive/hive.py | 242 ++++++++++------ python/pyhive/presto.py | 185 +++++++----- python/pyhive/sasl_compat.py | 14 +- python/pyhive/sqlalchemy_hive.py | 177 ++++++------ python/pyhive/sqlalchemy_presto.py | 120 ++++---- python/pyhive/sqlalchemy_trino.py | 54 ++-- python/pyhive/tests/dbapi_test_case.py | 89 +++--- python/pyhive/tests/sqlalchemy_test_case.py | 153 ++++++---- python/pyhive/tests/test_common.py | 66 +++-- python/pyhive/tests/test_hive.py | 273 +++++++++++------- python/pyhive/tests/test_presto.py | 229 ++++++++------- python/pyhive/tests/test_sasl_compat.py | 252 +++++++++------- python/pyhive/tests/test_sqlalchemy_hive.py | 200 ++++++++----- python/pyhive/tests/test_sqlalchemy_presto.py | 87 +++--- python/pyhive/tests/test_sqlalchemy_trino.py | 96 +++--- python/pyhive/tests/test_trino.py | 127 ++++---- python/pyhive/trino.py | 90 +++--- python/setup.py | 67 +++-- 22 files changed, 1550 insertions(+), 1059 deletions(-) diff --git a/python/generate.py b/python/generate.py index ee1274bd894..34678bbd2d4 100644 --- a/python/generate.py +++ b/python/generate.py @@ -12,39 +12,39 @@ python generate.py """ + import shutil +import subprocess import sys from os import path from urllib.request import urlopen -import subprocess here = path.abspath(path.dirname(__file__)) -PACKAGE = 'TCLIService' -GENERATED = 'gen-py' +PACKAGE = "TCLIService" +GENERATED = "gen-py" -HIVE_SERVER2_URL = \ - 'https://raw.githubusercontent.com/apache/hive/branch-2.3/service-rpc/if/TCLIService.thrift' +HIVE_SERVER2_URL = "https://raw.githubusercontent.com/apache/hive/branch-2.3/service-rpc/if/TCLIService.thrift" def save_url(url): data = urlopen(url).read() - file_path = path.join(here, url.rsplit('/', 1)[-1]) - with open(file_path, 'wb') as f: + file_path = path.join(here, url.rsplit("/", 1)[-1]) + with open(file_path, "wb") as f: f.write(data) def main(hive_server2_url): save_url(hive_server2_url) - hive_server2_path = path.join(here, hive_server2_url.rsplit('/', 1)[-1]) + hive_server2_path = path.join(here, hive_server2_url.rsplit("/", 1)[-1]) - subprocess.call(['thrift', '-r', '--gen', 'py', hive_server2_path]) - shutil.move(path.join(here, PACKAGE), path.join(here, PACKAGE + '.old')) + subprocess.call(["thrift", "-r", "--gen", "py", hive_server2_path]) + shutil.move(path.join(here, PACKAGE), path.join(here, PACKAGE + ".old")) shutil.move(path.join(here, GENERATED, PACKAGE), path.join(here, PACKAGE)) - shutil.rmtree(path.join(here, PACKAGE + '.old')) + shutil.rmtree(path.join(here, PACKAGE + ".old")) -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) > 1: url = sys.argv[1] else: diff --git a/python/pyhive/__init__.py b/python/pyhive/__init__.py index 0a6bb1f635b..5a5732eabd8 100644 --- a/python/pyhive/__init__.py +++ b/python/pyhive/__init__.py @@ -1,3 +1,3 @@ -from __future__ import absolute_import -from __future__ import unicode_literals -__version__ = '0.7.0' +from __future__ import absolute_import, unicode_literals + +__version__ = "0.7.0" diff --git a/python/pyhive/common.py b/python/pyhive/common.py index 51692b97ea6..9164612084c 100644 --- a/python/pyhive/common.py +++ b/python/pyhive/common.py @@ -3,21 +3,20 @@ Many docstrings in this file are based on PEP-249, which is in the public domain. """ -from __future__ import absolute_import -from __future__ import unicode_literals -from builtins import bytes -from builtins import int -from builtins import object -from builtins import str -from past.builtins import basestring -from pyhive import exc +from __future__ import absolute_import, unicode_literals + import abc import collections -import time import datetime -from future.utils import with_metaclass +import time +from builtins import bytes, int, object, str from itertools import islice +from future.utils import with_metaclass +from past.builtins import basestring + +from pyhive import exc + try: from collections.abc import Iterable except ImportError: @@ -108,7 +107,9 @@ def fetchone(self): raise exc.ProgrammingError("No query yet") # Sleep until we're done or we have some data to return - self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED) + self._fetch_while( + lambda: not self._data and self._state != self._STATE_FINISHED + ) if not self._data: return None @@ -217,7 +218,9 @@ def escape_args(self, parameters): elif isinstance(parameters, (list, tuple)): return tuple(self.escape_item(x) for x in parameters) else: - raise exc.ProgrammingError("Unsupported param format: {}".format(parameters)) + raise exc.ProgrammingError( + "Unsupported param format: {}".format(parameters) + ) def escape_number(self, item): return item @@ -228,7 +231,7 @@ def escape_string(self, item): # as byte strings. The old version always encodes Unicode as byte strings, which breaks # string formatting here. if isinstance(item, bytes): - item = item.decode('utf-8') + item = item.decode("utf-8") # This is good enough when backslashes are literal, newlines are just followed, and the way # to escape a single quote is to put two single quotes. # (i.e. only special character is single quote) @@ -236,7 +239,7 @@ def escape_string(self, item): def escape_sequence(self, item): l = map(str, map(self.escape_item, item)) - return '(' + ','.join(l) + ')' + return "(" + ",".join(l) + ")" def escape_datetime(self, item, format, cutoff=0): dt_str = item.strftime(format) @@ -245,7 +248,7 @@ def escape_datetime(self, item, format, cutoff=0): def escape_item(self, item): if item is None: - return 'NULL' + return "NULL" elif isinstance(item, (int, float)): return self.escape_number(item) elif isinstance(item, basestring): @@ -262,5 +265,6 @@ def escape_item(self, item): class UniversalSet(object): """set containing everything""" + def __contains__(self, item): return True diff --git a/python/pyhive/exc.py b/python/pyhive/exc.py index 931cf211f74..4f10dcdf10c 100644 --- a/python/pyhive/exc.py +++ b/python/pyhive/exc.py @@ -1,12 +1,19 @@ """ Package private common utilities. Do not use directly. """ -from __future__ import absolute_import -from __future__ import unicode_literals + +from __future__ import absolute_import, unicode_literals __all__ = [ - 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError', - 'ProgrammingError', 'DataError', 'NotSupportedError', + "Error", + "Warning", + "InterfaceError", + "DatabaseError", + "InternalError", + "OperationalError", + "ProgrammingError", + "DataError", + "NotSupportedError", ] @@ -15,11 +22,13 @@ class Error(Exception): You can use this to catch all errors with one single except statement. """ + pass class Warning(Exception): """Exception raised for important warnings like data truncations while inserting, etc.""" + pass @@ -27,17 +36,20 @@ class InterfaceError(Error): """Exception raised for errors that are related to the database interface rather than the database itself. """ + pass class DatabaseError(Error): """Exception raised for errors that are related to the database.""" + pass class InternalError(DatabaseError): """Exception raised when the database encounters an internal error, e.g. the cursor is not valid anymore, the transaction is out of sync, etc.""" + pass @@ -47,6 +59,7 @@ class OperationalError(DatabaseError): is not found, a transaction could not be processed, a memory allocation error occurred during processing, etc. """ + pass @@ -54,6 +67,7 @@ class ProgrammingError(DatabaseError): """Exception raised for programming errors, e.g. table not found or already exists, syntax error in the SQL statement, wrong number of parameters specified, etc. """ + pass @@ -61,6 +75,7 @@ class DataError(DatabaseError): """Exception raised for errors that are due to problems with the processed data like division by zero, numeric value out of range, etc. """ + pass @@ -69,4 +84,5 @@ class NotSupportedError(DatabaseError): database, e.g. requesting a ``.rollback()`` on a connection that does not support transaction or has transactions turned off. """ + pass diff --git a/python/pyhive/hive.py b/python/pyhive/hive.py index c1287488e9d..536cdd0446a 100644 --- a/python/pyhive/hive.py +++ b/python/pyhive/hive.py @@ -5,42 +5,45 @@ Many docstrings in this file are based on the PEP, which is in the public domain. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import base64 +import contextlib import datetime +import getpass +import logging import re +import sys +from builtins import range from decimal import Decimal from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context +import thrift.protocol.TBinaryProtocol +import thrift.transport.THttpClient +import thrift.transport.TSocket +import thrift.transport.TTransport +from future.utils import iteritems -from TCLIService import TCLIService -from TCLIService import constants -from TCLIService import ttypes from pyhive import common from pyhive.common import DBAPITypeObject + # Make all exceptions visible in this module per DB-API -from pyhive.exc import * # noqa -from builtins import range -import contextlib -from future.utils import iteritems -import getpass -import logging -import sys -import thrift.transport.THttpClient -import thrift.protocol.TBinaryProtocol -import thrift.transport.TSocket -import thrift.transport.TTransport +from pyhive.exc import ( + DataError, + NotSupportedError, + OperationalError, + ProgrammingError, +) +from TCLIService import TCLIService, constants, ttypes # PEP 249 module globals -apilevel = '2.0' +apilevel = "2.0" threadsafety = 2 # Threads may share the module and connections. -paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s +paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s _logger = logging.getLogger(__name__) -_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)') +_TIMESTAMP_PATTERN = re.compile(r"(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)") ssl_cert_parameter_map = { "none": CERT_NONE, @@ -51,14 +54,15 @@ def get_sasl_client(host, sasl_auth, service=None, username=None, password=None): import sasl + sasl_client = sasl.Client() - sasl_client.setAttr('host', host) + sasl_client.setAttr("host", host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', service) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) + if sasl_auth == "GSSAPI": + sasl_client.setAttr("service", service) + elif sasl_auth == "PLAIN": + sasl_client.setAttr("username", username) + sasl_client.setAttr("password", password) else: raise ValueError("sasl_auth only supports GSSAPI and PLAIN") @@ -69,10 +73,10 @@ def get_sasl_client(host, sasl_auth, service=None, username=None, password=None) def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None): from pyhive.sasl_compat import PureSASLClient - if sasl_auth == 'GSSAPI': - sasl_kwargs = {'service': service} - elif sasl_auth == 'PLAIN': - sasl_kwargs = {'username': username, 'password': password} + if sasl_auth == "GSSAPI": + sasl_kwargs = {"service": service} + elif sasl_auth == "PLAIN": + sasl_kwargs = {"username": username, "password": password} else: raise ValueError("sasl_auth only supports GSSAPI and PLAIN") @@ -81,34 +85,44 @@ def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password= def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None): try: - return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + return get_sasl_client( + host=host, + sasl_auth=sasl_auth, + service=service, + username=username, + password=password, + ) # The sasl library is available except ImportError: # Fallback to pure-sasl library - return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) - + return get_pure_sasl_client( + host=host, + sasl_auth=sasl_auth, + service=service, + username=username, + password=password, + ) + def _parse_timestamp(value): if value: match = _TIMESTAMP_PATTERN.match(value) if match: if match.group(2): - format = '%Y-%m-%d %H:%M:%S.%f' + format = "%Y-%m-%d %H:%M:%S.%f" # use the pattern to truncate the value value = match.group() else: - format = '%Y-%m-%d %H:%M:%S' + format = "%Y-%m-%d %H:%M:%S" value = datetime.datetime.strptime(value, format) else: - raise Exception( - 'Cannot convert "{}" into a datetime'.format(value)) + raise Exception('Cannot convert "{}" into a datetime'.format(value)) else: value = None return value -TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal, - "TIMESTAMP_TYPE": _parse_timestamp} +TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal, "TIMESTAMP_TYPE": _parse_timestamp} class HiveParamEscaper(common.ParamEscaper): @@ -120,14 +134,13 @@ def escape_string(self, item): # as byte strings. The old version always encodes Unicode as byte strings, which breaks # string formatting here. if isinstance(item, bytes): - item = item.decode('utf-8') + item = item.decode("utf-8") return "'{}'".format( - item - .replace('\\', '\\\\') + item.replace("\\", "\\\\") .replace("'", "\\'") - .replace('\r', '\\r') - .replace('\n', '\\n') - .replace('\t', '\\t') + .replace("\r", "\\r") + .replace("\n", "\\n") + .replace("\t", "\\t") ) @@ -152,14 +165,14 @@ def __init__( port=None, scheme=None, username=None, - database='default', + database="default", auth=None, configuration=None, kerberos_service_name=None, password=None, check_hostname=None, ssl_cert=None, - thrift_transport=None + thrift_transport=None, ): """Connect to HiveServer2 @@ -184,7 +197,9 @@ def __init__( ssl_context = create_default_context() ssl_context.check_hostname = check_hostname == "true" ssl_cert = ssl_cert or "none" - ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE) + ssl_context.verify_mode = ssl_cert_parameter_map.get( + ssl_cert, CERT_NONE + ) thrift_transport = thrift.transport.THttpClient.THttpClient( uri_or_host="{scheme}://{host}:{port}/cliservice/".format( scheme=scheme, host=host, port=port @@ -203,17 +218,25 @@ def __init__( "BASIC, NOSASL, KERBEROS, NONE" ) host, port, auth, kerberos_service_name, password = ( - None, None, None, None, None + None, + None, + None, + None, + None, ) username = username or getpass.getuser() configuration = configuration or {} - if (password is not None) != (auth in ('LDAP', 'CUSTOM')): - raise ValueError("Password should be set if and only if in LDAP or CUSTOM mode; " - "Remove password or use one of those modes") - if (kerberos_service_name is not None) != (auth == 'KERBEROS'): - raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode") + if (password is not None) != (auth in ("LDAP", "CUSTOM")): + raise ValueError( + "Password should be set if and only if in LDAP or CUSTOM mode; " + "Remove password or use one of those modes" + ) + if (kerberos_service_name is not None) != (auth == "KERBEROS"): + raise ValueError( + "kerberos_service_name should be set if and only if in KERBEROS mode" + ) if thrift_transport is not None: has_incompatible_arg = ( host is not None @@ -223,8 +246,10 @@ def __init__( or password is not None ) if has_incompatible_arg: - raise ValueError("thrift_transport cannot be used with " - "host/port/auth/kerberos_service_name/password") + raise ValueError( + "thrift_transport cannot be used with " + "host/port/auth/kerberos_service_name/password" + ) if thrift_transport is not None: self._transport = thrift_transport @@ -232,32 +257,43 @@ def __init__( if port is None: port = 10000 if auth is None: - auth = 'NONE' + auth = "NONE" socket = thrift.transport.TSocket.TSocket(host, port) - if auth == 'NOSASL': + if auth == "NOSASL": # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml self._transport = thrift.transport.TTransport.TBufferedTransport(socket) - elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): + elif auth in ("LDAP", "KERBEROS", "NONE", "CUSTOM"): # Defer import so package dependency is optional import thrift_sasl - if auth == 'KERBEROS': + if auth == "KERBEROS": # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library - sasl_auth = 'GSSAPI' + sasl_auth = "GSSAPI" else: - sasl_auth = 'PLAIN' + sasl_auth = "PLAIN" if password is None: # Password doesn't matter in NONE mode, just needs to be nonempty. - password = 'x' - - self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket) + password = "x" + + self._transport = thrift_sasl.TSaslClientTransport( + lambda: get_installed_sasl( + host=host, + sasl_auth=sasl_auth, + service=kerberos_service_name, + username=username, + password=password, + ), + sasl_auth, + socket, + ) else: # All HS2 config options: # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration # PAM currently left to end user via thrift_transport option. raise NotImplementedError( "Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM " - "authentication are supported, got {}".format(auth)) + "authentication are supported, got {}".format(auth) + ) protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport) self._client = TCLIService.Client(protocol) @@ -274,12 +310,17 @@ def __init__( ) response = self._client.OpenSession(open_session_req) _check_status(response) - assert response.sessionHandle is not None, "Expected a session from OpenSession" + assert ( + response.sessionHandle is not None + ), "Expected a session from OpenSession" self._sessionHandle = response.sessionHandle - assert response.serverProtocolVersion == protocol_version, \ - "Unable to handle protocol version {}".format(response.serverProtocolVersion) + assert ( + response.serverProtocolVersion == protocol_version + ), "Unable to handle protocol version {}".format( + response.serverProtocolVersion + ) with contextlib.closing(self.cursor()) as cursor: - cursor.execute('USE `{}`'.format(database)) + cursor.execute("USE `{}`".format(database)) except: self._transport.close() raise @@ -316,11 +357,7 @@ def _set_kerberos_header(transport, kerberos_service_name, host): auth_header = kerberos.authGSSClientResponse(krb_context) transport.setCustomHeaders( - { - "Authorization": "Negotiate {auth_header}".format( - auth_header=auth_header - ) - } + {"Authorization": "Negotiate {auth_header}".format(auth_header=auth_header)} ) def __enter__(self): @@ -429,15 +466,27 @@ def description(self): primary_type_entry = col.typeDesc.types[0] if primary_type_entry.primitiveEntry is None: # All fancy stuff maps to string - type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE] + type_code = ttypes.TTypeId._VALUES_TO_NAMES[ + ttypes.TTypeId.STRING_TYPE + ] else: type_id = primary_type_entry.primitiveEntry.type type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id] - self._description.append(( - col.columnName.decode('utf-8') if sys.version_info[0] == 2 else col.columnName, - type_code.decode('utf-8') if sys.version_info[0] == 2 else type_code, - None, None, None, None, True - )) + self._description.append( + ( + col.columnName.decode("utf-8") + if sys.version_info[0] == 2 + else col.columnName, + type_code.decode("utf-8") + if sys.version_info[0] == 2 + else type_code, + None, + None, + None, + None, + True, + ) + ) return self._description def __enter__(self): @@ -456,7 +505,7 @@ def execute(self, operation, parameters=None, **kwargs): Return values are not defined. """ # backward compatibility with Python < 3.7 - for kw in ['async', 'async_']: + for kw in ["async", "async_"]: if kw in kwargs: async_ = kwargs[kw] break @@ -472,10 +521,11 @@ def execute(self, operation, parameters=None, **kwargs): self._reset_state() self._state = self._STATE_RUNNING - _logger.info('%s', sql) + _logger.info("%s", sql) - req = ttypes.TExecuteStatementReq(self._connection.sessionHandle, - sql, runAsync=async_) + req = ttypes.TExecuteStatementReq( + self._connection.sessionHandle, sql, runAsync=async_ + ) _logger.debug(req) response = self._connection.client.ExecuteStatement(req) _check_status(response) @@ -490,8 +540,12 @@ def cancel(self): def _fetch_more(self): """Send another TFetchResultsReq and update state""" - assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more" - assert(self._operationHandle is not None), "Should have an op handle in _fetch_more" + assert ( + self._state == self._STATE_RUNNING + ), "Should be running when in _fetch_more" + assert ( + self._operationHandle is not None + ), "Should have an op handle in _fetch_more" if not self._operationHandle.hasResultSet: raise ProgrammingError("No result set") req = ttypes.TFetchResultsReq( @@ -502,9 +556,11 @@ def _fetch_more(self): response = self._connection.client.FetchResults(req) _check_status(response) schema = self.description - assert not response.results.rows, 'expected data in columnar format' - columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in - zip(response.results.columns, schema)] + assert not response.results.rows, "expected data in columnar format" + columns = [ + _unwrap_column(col, col_schema[1]) + for col, col_schema in zip(response.results.columns, schema) + ] new_data = list(zip(*columns)) self._data += new_data # response.hasMoreRows seems to always be False, so we instead check the number of rows @@ -546,7 +602,9 @@ def fetch_logs(self): try: # Older Hive instances require logs to be retrieved using GetLog req = ttypes.TGetLogReq(operationHandle=self._operationHandle) logs = self._connection.client.GetLog(req).log.splitlines() - except ttypes.TApplicationException as e: # Otherwise, retrieve logs using newer method + except ( + ttypes.TApplicationException + ) as e: # Otherwise, retrieve logs using newer method if e.type != ttypes.TApplicationException.UNKNOWN_METHOD: raise logs = [] @@ -555,11 +613,11 @@ def fetch_logs(self): operationHandle=self._operationHandle, orientation=ttypes.TFetchOrientation.FETCH_NEXT, maxRows=self.arraysize, - fetchType=1 # 0: results, 1: logs + fetchType=1, # 0: results, 1: logs ) response = self._connection.client.FetchResults(req) _check_status(response) - assert not response.results.rows, 'expected data in columnar format' + assert not response.results.rows, "expected data in columnar format" assert len(response.results.columns) == 1, response.results.columns new_logs = _unwrap_column(response.results.columns[0]) logs += new_logs diff --git a/python/pyhive/presto.py b/python/pyhive/presto.py index 3217f4c2a7f..ed372b26468 100644 --- a/python/pyhive/presto.py +++ b/python/pyhive/presto.py @@ -5,23 +5,29 @@ Many docstrings in this file are based on the PEP, which is in the public domain. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals +import base64 +import datetime +import getpass +import logging +import os from builtins import object from decimal import Decimal +import requests +from requests.auth import HTTPBasicAuth + from pyhive import common from pyhive.common import DBAPITypeObject + # Make all exceptions visible in this module per DB-API -from pyhive.exc import * # noqa -import base64 -import getpass -import datetime -import logging -import requests -from requests.auth import HTTPBasicAuth -import os +from pyhive.exc import ( + DatabaseError, + NotSupportedError, + OperationalError, + ProgrammingError, +) try: # Python 3 import urllib.parse as urlparse @@ -30,18 +36,19 @@ # PEP 249 module globals -apilevel = '2.0' +apilevel = "2.0" threadsafety = 2 # Threads may share the module and connections. -paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s +paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s _logger = logging.getLogger(__name__) TYPES_CONVERTER = { "decimal": Decimal, # As of Presto 0.69, binary data is returned as the varbinary type in base64 format - "varbinary": base64.b64decode + "varbinary": base64.b64decode, } + class PrestoParamEscaper(common.ParamEscaper): def escape_datetime(self, item, format): _type = "timestamp" if isinstance(item, datetime.datetime) else "date" @@ -96,12 +103,28 @@ class Cursor(common.DBAPICursor): visible by other cursors or connections. """ - def __init__(self, host, port='8080', username=None, principal_username=None, catalog='hive', - schema='default', poll_interval=1, source='pyhive', session_props=None, - protocol='http', password=None, requests_session=None, requests_kwargs=None, - KerberosRemoteServiceName=None, KerberosPrincipal=None, - KerberosConfigPath=None, KerberosKeytabPath=None, - KerberosCredentialCachePath=None, KerberosUseCanonicalHostname=None): + def __init__( + self, + host, + port="8080", + username=None, + principal_username=None, + catalog="hive", + schema="default", + poll_interval=1, + source="pyhive", + session_props=None, + protocol="http", + password=None, + requests_session=None, + requests_kwargs=None, + KerberosRemoteServiceName=None, + KerberosPrincipal=None, + KerberosConfigPath=None, + KerberosKeytabPath=None, + KerberosCredentialCachePath=None, + KerberosUseCanonicalHostname=None, + ): """ :param host: hostname to connect to, e.g. ``presto.example.com`` :param port: int -- port, defaults to 8080 @@ -164,7 +187,7 @@ class will use the default requests behavior of making a new session per HTTP re self._session_props = session_props if session_props is not None else {} self.last_query_id = None - if protocol not in ('http', 'https'): + if protocol not in ("http", "https"): raise ValueError("Protocol must be http/https, was {!r}".format(protocol)) self._protocol = protocol @@ -173,33 +196,39 @@ class will use the default requests behavior of making a new session per HTTP re requests_kwargs = dict(requests_kwargs) if requests_kwargs is not None else {} if KerberosRemoteServiceName is not None: - from requests_kerberos import HTTPKerberosAuth, OPTIONAL + from requests_kerberos import OPTIONAL, HTTPKerberosAuth hostname_override = None - if KerberosUseCanonicalHostname is not None \ - and KerberosUseCanonicalHostname.lower() == 'false': + if ( + KerberosUseCanonicalHostname is not None + and KerberosUseCanonicalHostname.lower() == "false" + ): hostname_override = host if KerberosConfigPath is not None: - os.environ['KRB5_CONFIG'] = KerberosConfigPath + os.environ["KRB5_CONFIG"] = KerberosConfigPath if KerberosKeytabPath is not None: - os.environ['KRB5_CLIENT_KTNAME'] = KerberosKeytabPath + os.environ["KRB5_CLIENT_KTNAME"] = KerberosKeytabPath if KerberosCredentialCachePath is not None: - os.environ['KRB5CCNAME'] = KerberosCredentialCachePath + os.environ["KRB5CCNAME"] = KerberosCredentialCachePath - requests_kwargs['auth'] = HTTPKerberosAuth(mutual_authentication=OPTIONAL, - principal=KerberosPrincipal, - service=KerberosRemoteServiceName, - hostname_override=hostname_override) + requests_kwargs["auth"] = HTTPKerberosAuth( + mutual_authentication=OPTIONAL, + principal=KerberosPrincipal, + service=KerberosRemoteServiceName, + hostname_override=hostname_override, + ) else: - if password is not None and 'auth' in requests_kwargs: - raise ValueError("Cannot use both password and requests_kwargs authentication") - for k in ('method', 'url', 'data', 'headers'): + if password is not None and "auth" in requests_kwargs: + raise ValueError( + "Cannot use both password and requests_kwargs authentication" + ) + for k in ("method", "url", "data", "headers"): if k in requests_kwargs: raise ValueError("Cannot override requests argument {}".format(k)) if password is not None: - requests_kwargs['auth'] = HTTPBasicAuth(username, password) - if protocol != 'https': + requests_kwargs["auth"] = HTTPBasicAuth(username, password) + if protocol != "https": raise ValueError("Protocol must be https when passing a password") self._requests_kwargs = requests_kwargs @@ -230,14 +259,14 @@ def description(self): """ # Sleep until we're done or we got the columns self._fetch_while( - lambda: self._columns is None and - self._state not in (self._STATE_NONE, self._STATE_FINISHED) + lambda: self._columns is None + and self._state not in (self._STATE_NONE, self._STATE_FINISHED) ) if self._columns is None: return None return [ # name, type_code, display_size, internal_size, precision, scale, null_ok - (col['name'], col['type'], None, None, None, None, True) + (col["name"], col["type"], None, None, None, None, True) for col in self._columns ] @@ -247,15 +276,15 @@ def execute(self, operation, parameters=None): Return values are not defined. """ headers = { - 'X-Presto-Catalog': self._catalog, - 'X-Presto-Schema': self._schema, - 'X-Presto-Source': self._source, - 'X-Presto-User': self._username, + "X-Presto-Catalog": self._catalog, + "X-Presto-Schema": self._schema, + "X-Presto-Source": self._source, + "X-Presto-User": self._username, } if self._session_props: - headers['X-Presto-Session'] = ','.join( - '{}={}'.format(propname, propval) + headers["X-Presto-Session"] = ",".join( + "{}={}".format(propname, propval) for propname, propval in self._session_props.items() ) @@ -268,20 +297,30 @@ def execute(self, operation, parameters=None): self._reset_state() self._state = self._STATE_RUNNING - url = urlparse.urlunparse(( - self._protocol, - '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) - _logger.info('%s', sql) + url = urlparse.urlunparse( + ( + self._protocol, + "{}:{}".format(self._host, self._port), + "/v1/statement", + None, + None, + None, + ) + ) + _logger.info("%s", sql) _logger.debug("Headers: %s", headers) response = self._requests_session.post( - url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) + url, data=sql.encode("utf-8"), headers=headers, **self._requests_kwargs + ) self._process_response(response) def cancel(self): if self._state == self._STATE_NONE: raise ProgrammingError("No query yet") if self._nextUri is None: - assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" + assert ( + self._state == self._STATE_FINISHED + ), "Should be finished if nextUri is None" return response = self._requests_session.delete(self._nextUri, **self._requests_kwargs) @@ -304,7 +343,9 @@ def poll(self): if self._state == self._STATE_NONE: raise ProgrammingError("No query yet") if self._nextUri is None: - assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" + assert ( + self._state == self._STATE_FINISHED + ), "Should be finished if nextUri is None" return None response = self._requests_session.get(self._nextUri, **self._requests_kwargs) self._process_response(response) @@ -312,7 +353,9 @@ def poll(self): def _fetch_more(self): """Fetch the next URI and update state""" - self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs)) + self._process_response( + self._requests_session.get(self._nextUri, **self._requests_kwargs) + ) def _process_data(self, rows): for i, col in enumerate(self.description): @@ -333,26 +376,28 @@ def _process_response(self, response): response_json = response.json() _logger.debug("Got response %s", response_json) - assert self._state == self._STATE_RUNNING, "Should be running if processing response" - self._nextUri = response_json.get('nextUri') - self._columns = response_json.get('columns') - if 'id' in response_json: - self.last_query_id = response_json['id'] - if 'X-Presto-Clear-Session' in response.headers: - propname = response.headers['X-Presto-Clear-Session'] + assert ( + self._state == self._STATE_RUNNING + ), "Should be running if processing response" + self._nextUri = response_json.get("nextUri") + self._columns = response_json.get("columns") + if "id" in response_json: + self.last_query_id = response_json["id"] + if "X-Presto-Clear-Session" in response.headers: + propname = response.headers["X-Presto-Clear-Session"] self._session_props.pop(propname, None) - if 'X-Presto-Set-Session' in response.headers: - propname, propval = response.headers['X-Presto-Set-Session'].split('=', 1) + if "X-Presto-Set-Session" in response.headers: + propname, propval = response.headers["X-Presto-Set-Session"].split("=", 1) self._session_props[propname] = propval - if 'data' in response_json: + if "data" in response_json: assert self._columns - new_data = response_json['data'] + new_data = response_json["data"] self._process_data(new_data) self._data += map(tuple, new_data) - if 'nextUri' not in response_json: + if "nextUri" not in response_json: self._state = self._STATE_FINISHED - if 'error' in response_json: - raise DatabaseError(response_json['error']) + if "error" in response_json: + raise DatabaseError(response_json["error"]) # @@ -361,7 +406,7 @@ def _process_response(self, response): # See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java -FIXED_INT_64 = DBAPITypeObject(['bigint']) -VARIABLE_BINARY = DBAPITypeObject(['varchar']) -DOUBLE = DBAPITypeObject(['double']) -BOOLEAN = DBAPITypeObject(['boolean']) +FIXED_INT_64 = DBAPITypeObject(["bigint"]) +VARIABLE_BINARY = DBAPITypeObject(["varchar"]) +DOUBLE = DBAPITypeObject(["double"]) +BOOLEAN = DBAPITypeObject(["boolean"]) diff --git a/python/pyhive/sasl_compat.py b/python/pyhive/sasl_compat.py index 19af6d229e6..61defcfc2c4 100644 --- a/python/pyhive/sasl_compat.py +++ b/python/pyhive/sasl_compat.py @@ -1,17 +1,19 @@ -# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py +# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py # which uses Apache-2.0 license as of 21 May 2023. -# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl +# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl # via PR https://github.com/cloudera/impyla/pull/179 -# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34 +# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34 # but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82 # Hence this code is required for the fallback to work. - -from puresasl.client import SASLClient, SASLError + from contextlib import contextmanager +from puresasl.client import SASLClient, SASLError + + @contextmanager -def error_catcher(self, Exc = Exception): +def error_catcher(self, Exc=Exception): try: self.error = None yield diff --git a/python/pyhive/sqlalchemy_hive.py b/python/pyhive/sqlalchemy_hive.py index e22445259df..77035096e8e 100644 --- a/python/pyhive/sqlalchemy_hive.py +++ b/python/pyhive/sqlalchemy_hive.py @@ -5,30 +5,35 @@ which is released under the MIT license. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import datetime import decimal - import re + from sqlalchemy import exc from sqlalchemy.sql import text + try: from sqlalchemy import processors except ImportError: # Required for SQLAlchemy>=2.0 from sqlalchemy.engine import processors -from sqlalchemy import types -from sqlalchemy import util +from sqlalchemy import types, util + # TODO shouldn't use mysql type try: from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger except ImportError: # Required for SQLAlchemy>2.0 from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger +from decimal import Decimal + +from dateutil.parser import parse from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -36,12 +41,10 @@ from pyhive import hive from pyhive.common import UniversalSet -from dateutil.parser import parse -from decimal import Decimal - class HiveStringTypeBase(types.TypeDecorator): """Translates strings returned by Thrift into something else""" + impl = types.String def process_bind_param(self, value, dialect): @@ -50,6 +53,7 @@ def process_bind_param(self, value, dialect): class HiveDate(HiveStringTypeBase): """Translates date strings to date objects""" + impl = types.DATE def process_result_value(self, value, dialect): @@ -74,6 +78,7 @@ def adapt(self, impltype, **kwargs): class HiveTimestamp(HiveStringTypeBase): """Translates timestamp strings to datetime objects""" + impl = types.TIMESTAMP def process_result_value(self, value, dialect): @@ -96,6 +101,7 @@ def adapt(self, impltype, **kwargs): class HiveDecimal(HiveStringTypeBase): """Translates strings to decimals""" + impl = types.DECIMAL def process_result_value(self, value, dialect): @@ -126,35 +132,38 @@ class HiveIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(HiveIdentifierPreparer, self).__init__( dialect, - initial_quote='`', + initial_quote="`", ) _type_map = { - 'boolean': types.Boolean, - 'tinyint': mysql_tinyinteger, - 'smallint': types.SmallInteger, - 'int': types.Integer, - 'bigint': types.BigInteger, - 'float': types.Float, - 'double': types.Float, - 'string': types.String, - 'varchar': types.String, - 'char': types.String, - 'date': HiveDate, - 'timestamp': HiveTimestamp, - 'binary': types.String, - 'array': types.String, - 'map': types.String, - 'struct': types.String, - 'uniontype': types.String, - 'decimal': HiveDecimal, + "boolean": types.Boolean, + "tinyint": mysql_tinyinteger, + "smallint": types.SmallInteger, + "int": types.Integer, + "bigint": types.BigInteger, + "float": types.Float, + "double": types.Float, + "string": types.String, + "varchar": types.String, + "char": types.String, + "date": HiveDate, + "timestamp": HiveTimestamp, + "binary": types.String, + "array": types.String, + "map": types.String, + "struct": types.String, + "uniontype": types.String, + "decimal": HiveDecimal, } class HiveCompiler(SQLCompiler): def visit_concat_op_binary(self, binary, operator, **kw): - return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + return "concat(%s, %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_insert(self, *args, **kwargs): result = super(HiveCompiler, self).visit_insert(*args, **kwargs) @@ -162,57 +171,61 @@ def visit_insert(self, *args, **kwargs): # INSERT INTO `pyhive_test_database`.`test_table` (`a`) SELECT ... # => # INSERT INTO TABLE `pyhive_test_database`.`test_table` SELECT ... - regex = r'^(INSERT INTO) ([^\s]+) \([^\)]*\)' - assert re.search(regex, result), "Unexpected visit_insert result: {}".format(result) - return re.sub(regex, r'\1 TABLE \2', result) + regex = r"^(INSERT INTO) ([^\s]+) \([^\)]*\)" + assert re.search(regex, result), "Unexpected visit_insert result: {}".format( + result + ) + return re.sub(regex, r"\1 TABLE \2", result) def visit_column(self, *args, **kwargs): result = super(HiveCompiler, self).visit_column(*args, **kwargs) - dot_count = result.count('.') - assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format(result) + dot_count = result.count(".") + assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format( + result + ) if dot_count == 2: # we have something of the form schema.table.column # hive doesn't like the schema in front, so chop it out - result = result[result.index('.') + 1:] + result = result[result.index(".") + 1 :] return result def visit_char_length_func(self, fn, **kw): - return 'length{}'.format(self.function_argspec(fn, **kw)) + return "length{}".format(self.function_argspec(fn, **kw)) class HiveTypeCompiler(compiler.GenericTypeCompiler): def visit_INTEGER(self, type_): - return 'INT' + return "INT" def visit_NUMERIC(self, type_): - return 'DECIMAL' + return "DECIMAL" def visit_CHAR(self, type_): - return 'STRING' + return "STRING" def visit_VARCHAR(self, type_): - return 'STRING' + return "STRING" def visit_NCHAR(self, type_): - return 'STRING' + return "STRING" def visit_TEXT(self, type_): - return 'STRING' + return "STRING" def visit_CLOB(self, type_): - return 'STRING' + return "STRING" def visit_BLOB(self, type_): - return 'BINARY' + return "BINARY" def visit_TIME(self, type_): - return 'TIMESTAMP' + return "TIMESTAMP" def visit_DATE(self, type_): - return 'TIMESTAMP' + return "TIMESTAMP" def visit_DATETIME(self, type_): - return 'TIMESTAMP' + return "TIMESTAMP" class HiveExecutionContext(default.DefaultExecutionContext): @@ -226,21 +239,21 @@ class HiveExecutionContext(default.DefaultExecutionContext): @util.memoized_property def _preserve_raw_colnames(self): # Ideally, this would also gate on hive.resultset.use.unique.column.names - return self.execution_options.get('hive_raw_colnames', False) + return self.execution_options.get("hive_raw_colnames", False) def _translate_colname(self, colname): # Adjust for dotted column names. # When hive.resultset.use.unique.column.names is true (the default), Hive returns column # names as "tablename.colname" in cursor.description. - if not self._preserve_raw_colnames and '.' in colname: - return colname.split('.')[-1], colname + if not self._preserve_raw_colnames and "." in colname: + return colname.split(".")[-1], colname else: return colname, None class HiveDialect(default.DefaultDialect): - name = 'hive' - driver = 'thrift' + name = "hive" + driver = "thrift" execution_ctx_cls = HiveExecutionContext preparer = HiveIdentifierPreparer statement_compiler = HiveCompiler @@ -263,25 +276,25 @@ class HiveDialect(default.DefaultDialect): @classmethod def dbapi(cls): return hive - + @classmethod def import_dbapi(cls): return hive def create_connect_args(self, url): kwargs = { - 'host': url.host, - 'port': url.port or 10000, - 'username': url.username, - 'password': url.password, - 'database': url.database or 'default', + "host": url.host, + "port": url.port or 10000, + "username": url.username, + "password": url.password, + "database": url.database or "default", } kwargs.update(url.query) return [], kwargs def get_schema_names(self, connection, **kw): # Equivalent to SHOW DATABASES - return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))] + return [row[0] for row in connection.execute(text("SHOW SCHEMAS"))] def get_view_names(self, connection, schema=None, **kw): # Hive does not provide functionality to query tableType @@ -291,15 +304,15 @@ def get_view_names(self, connection, schema=None, **kw): def _get_table_columns(self, connection, table_name, schema): full_table = table_name if schema: - full_table = schema + '.' + table_name + full_table = schema + "." + table_name # TODO using TGetColumnsReq hangs after sending TFetchResultsReq. # Using DESCRIBE works but is uglier. try: # This needs the table name to be unescaped (no backticks). - rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall() + rows = connection.execute(text("DESCRIBE {}".format(full_table))).fetchall() except exc.OperationalError as e: # Does the table exist? - regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}' + regex_fmt = r"TExecuteStatementResp.*SemanticException.*Table not found {}" regex = regex_fmt.format(re.escape(full_table)) if re.search(regex, e.args[0]): raise exc.NoSuchTableError(full_table) @@ -307,7 +320,7 @@ def _get_table_columns(self, connection, table_name, schema): raise else: # Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist - regex = r'Table .* does not exist' + regex = r"Table .* does not exist" if len(rows) == 1 and re.match(regex, rows[0].col_name): raise exc.NoSuchTableError(full_table) return rows @@ -324,27 +337,31 @@ def get_columns(self, connection, table_name, schema=None, **kw): # Strip whitespace rows = [[col.strip() if col else None for col in row] for row in rows] # Filter out empty rows and comment - rows = [row for row in rows if row[0] and row[0] != '# col_name'] + rows = [row for row in rows if row[0] and row[0] != "# col_name"] result = [] - for (col_name, col_type, _comment) in rows: - if col_name == '# Partition Information': + for col_name, col_type, _comment in rows: + if col_name == "# Partition Information": break # Take out the more detailed type information # e.g. 'map' -> 'map' # 'decimal(10,1)' -> decimal - col_type = re.search(r'^\w+', col_type).group(0) + col_type = re.search(r"^\w+", col_type).group(0) try: coltype = _type_map[col_type] except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name)) + util.warn( + "Did not recognize type '%s' of column '%s'" % (col_type, col_name) + ) coltype = types.NullType - result.append({ - 'name': col_name, - 'type': coltype, - 'nullable': True, - 'default': None, - }) + result.append( + { + "name": col_name, + "type": coltype, + "nullable": True, + "default": None, + } + ) return result def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -360,23 +377,23 @@ def get_indexes(self, connection, table_name, schema=None, **kw): # Strip whitespace rows = [[col.strip() if col else None for col in row] for row in rows] # Filter out empty rows and comment - rows = [row for row in rows if row[0] and row[0] != '# col_name'] + rows = [row for row in rows if row[0] and row[0] != "# col_name"] for i, (col_name, _col_type, _comment) in enumerate(rows): - if col_name == '# Partition Information': + if col_name == "# Partition Information": break # Handle partition columns col_names = [] - for col_name, _col_type, _comment in rows[i + 1:]: + for col_name, _col_type, _comment in rows[i + 1 :]: col_names.append(col_name) if col_names: - return [{'name': 'partition', 'column_names': col_names, 'unique': False}] + return [{"name": "partition", "column_names": col_names, "unique": False}] else: return [] def get_table_names(self, connection, schema=None, **kw): - query = 'SHOW TABLES' + query = "SHOW TABLES" if schema: - query += ' IN ' + self.identifier_preparer.quote_identifier(schema) + query += " IN " + self.identifier_preparer.quote_identifier(schema) return [row[0] for row in connection.execute(text(query))] def do_rollback(self, dbapi_connection): @@ -393,7 +410,6 @@ def _check_unicode_description(self, connection): class HiveHTTPDialect(HiveDialect): - name = "hive" scheme = "http" driver = "rest" @@ -413,6 +429,5 @@ def create_connect_args(self, url): class HiveHTTPSDialect(HiveHTTPDialect): - name = "hive" scheme = "https" diff --git a/python/pyhive/sqlalchemy_presto.py b/python/pyhive/sqlalchemy_presto.py index bfe1ba0478f..9c47490793f 100644 --- a/python/pyhive/sqlalchemy_presto.py +++ b/python/pyhive/sqlalchemy_presto.py @@ -5,22 +5,24 @@ which is released under the MIT license. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import re + import sqlalchemy -from sqlalchemy import exc -from sqlalchemy import types -from sqlalchemy import util +from sqlalchemy import exc, types, util + # TODO shouldn't use mysql type from sqlalchemy.sql import text + try: from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger except ImportError: # Required for SQLAlchemy>=2.0 from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler @@ -29,7 +31,10 @@ from pyhive import presto from pyhive.common import UniversalSet -sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) +sqlalchemy_version = float( + re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1) +) + class PrestoIdentifierPreparer(compiler.IdentifierPreparer): # Just quote everything to make things simpler / easier to upgrade @@ -37,23 +42,23 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer): _type_map = { - 'boolean': types.Boolean, - 'tinyint': mysql_tinyinteger, - 'smallint': types.SmallInteger, - 'integer': types.Integer, - 'bigint': types.BigInteger, - 'real': types.Float, - 'double': types.Float, - 'varchar': types.String, - 'timestamp': types.TIMESTAMP, - 'date': types.DATE, - 'varbinary': types.VARBINARY, + "boolean": types.Boolean, + "tinyint": mysql_tinyinteger, + "smallint": types.SmallInteger, + "integer": types.Integer, + "bigint": types.BigInteger, + "real": types.Float, + "double": types.Float, + "varchar": types.String, + "timestamp": types.TIMESTAMP, + "date": types.DATE, + "varbinary": types.VARBINARY, } class PrestoCompiler(SQLCompiler): def visit_char_length_func(self, fn, **kw): - return 'length{}'.format(self.function_argspec(fn, **kw)) + return "length{}".format(self.function_argspec(fn, **kw)) class PrestoTypeCompiler(compiler.GenericTypeCompiler): @@ -67,19 +72,19 @@ def visit_DATETIME(self, type_, **kw): raise ValueError("Presto does not support the DATETIME column type.") def visit_FLOAT(self, type_, **kw): - return 'DOUBLE' + return "DOUBLE" def visit_TEXT(self, type_, **kw): if type_.length: - return 'VARCHAR({:d})'.format(type_.length) + return "VARCHAR({:d})".format(type_.length) else: - return 'VARCHAR' + return "VARCHAR" class PrestoDialect(default.DefaultDialect): - name = 'presto' - driver = 'rest' - paramstyle = 'pyformat' + name = "presto" + driver = "rest" + paramstyle = "pyformat" preparer = PrestoIdentifierPreparer statement_compiler = PrestoCompiler supports_alter = False @@ -98,38 +103,40 @@ class PrestoDialect(default.DefaultDialect): @classmethod def dbapi(cls): return presto - + @classmethod def import_dbapi(cls): return presto def create_connect_args(self, url): - db_parts = (url.database or 'hive').split('/') + db_parts = (url.database or "hive").split("/") kwargs = { - 'host': url.host, - 'port': url.port or 8080, - 'username': url.username, - 'password': url.password + "host": url.host, + "port": url.port or 8080, + "username": url.username, + "password": url.password, } kwargs.update(url.query) if len(db_parts) == 1: - kwargs['catalog'] = db_parts[0] + kwargs["catalog"] = db_parts[0] elif len(db_parts) == 2: - kwargs['catalog'] = db_parts[0] - kwargs['schema'] = db_parts[1] + kwargs["catalog"] = db_parts[0] + kwargs["schema"] = db_parts[1] else: raise ValueError("Unexpected database format {}".format(url.database)) return [], kwargs def get_schema_names(self, connection, **kw): - return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))] + return [row.Schema for row in connection.execute(text("SHOW SCHEMAS"))] def _get_table_columns(self, connection, table_name, schema): full_table = self.identifier_preparer.quote_identifier(table_name) if schema: - full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table + full_table = ( + self.identifier_preparer.quote_identifier(schema) + "." + full_table + ) try: - return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table))) + return connection.execute(text("SHOW COLUMNS FROM {}".format(full_table))) except (presto.DatabaseError, exc.DatabaseError) as e: # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which # it successfully does in the Hive version. The difference with Presto is that this @@ -138,8 +145,10 @@ def _get_table_columns(self, connection, table_name, schema): # presto.DatabaseError here. # Does the table exist? msg = ( - e.args[0].get('message') if e.args and isinstance(e.args[0], dict) - else e.args[0] if e.args and isinstance(e.args[0], str) + e.args[0].get("message") + if e.args and isinstance(e.args[0], dict) + else e.args[0] + if e.args and isinstance(e.args[0], str) else None ) regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name)) @@ -162,15 +171,20 @@ def get_columns(self, connection, table_name, schema=None, **kw): try: coltype = _type_map[row.Type] except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column)) + util.warn( + "Did not recognize type '%s' of column '%s'" + % (row.Type, row.Column) + ) coltype = types.NullType - result.append({ - 'name': row.Column, - 'type': coltype, - # newer Presto no longer includes this column - 'nullable': getattr(row, 'Null', True), - 'default': None, - }) + result.append( + { + "name": row.Column, + "type": coltype, + # newer Presto no longer includes this column + "nullable": getattr(row, "Null", True), + "default": None, + } + ) return result def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -185,7 +199,7 @@ def get_indexes(self, connection, table_name, schema=None, **kw): rows = self._get_table_columns(connection, table_name, schema) col_names = [] for row in rows: - part_key = 'Partition Key' + part_key = "Partition Key" # Presto puts this information in one of 3 places depending on version # - a boolean column named "Partition Key" # - a string in the "Comment" column @@ -194,20 +208,20 @@ def get_indexes(self, connection, table_name, schema=None, **kw): row = row._mapping is_partition_key = ( (part_key in row and row[part_key]) - or row['Comment'].startswith(part_key) - or ('Extra' in row and 'partition key' in row['Extra']) + or row["Comment"].startswith(part_key) + or ("Extra" in row and "partition key" in row["Extra"]) ) if is_partition_key: - col_names.append(row['Column']) + col_names.append(row["Column"]) if col_names: - return [{'name': 'partition', 'column_names': col_names, 'unique': False}] + return [{"name": "partition", "column_names": col_names, "unique": False}] else: return [] def get_table_names(self, connection, schema=None, **kw): - query = 'SHOW TABLES' + query = "SHOW TABLES" if schema: - query += ' FROM ' + self.identifier_preparer.quote_identifier(schema) + query += " FROM " + self.identifier_preparer.quote_identifier(schema) return [row.Table for row in connection.execute(text(query))] def do_rollback(self, dbapi_connection): diff --git a/python/pyhive/sqlalchemy_trino.py b/python/pyhive/sqlalchemy_trino.py index 11be2a6ca9a..b1d199cd4b5 100644 --- a/python/pyhive/sqlalchemy_trino.py +++ b/python/pyhive/sqlalchemy_trino.py @@ -5,45 +5,45 @@ which is released under the MIT license. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals -import re -from sqlalchemy import exc from sqlalchemy import types -from sqlalchemy import util + # TODO shouldn't use mysql type try: from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger except ImportError: # Required for SQLAlchemy>=2.0 from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger -from sqlalchemy.engine import default -from sqlalchemy.sql import compiler -from sqlalchemy.sql.compiler import SQLCompiler from pyhive import trino -from pyhive.common import UniversalSet -from pyhive.sqlalchemy_presto import PrestoDialect, PrestoCompiler, PrestoIdentifierPreparer +from pyhive.sqlalchemy_presto import ( + PrestoCompiler, + PrestoDialect, + PrestoIdentifierPreparer, +) + class TrinoIdentifierPreparer(PrestoIdentifierPreparer): pass _type_map = { - 'boolean': types.Boolean, - 'tinyint': mysql_tinyinteger, - 'smallint': types.SmallInteger, - 'integer': types.Integer, - 'bigint': types.BigInteger, - 'real': types.Float, - 'double': types.Float, - 'varchar': types.String, - 'timestamp': types.TIMESTAMP, - 'date': types.DATE, - 'varbinary': types.VARBINARY, + "boolean": types.Boolean, + "tinyint": mysql_tinyinteger, + "smallint": types.SmallInteger, + "integer": types.Integer, + "bigint": types.BigInteger, + "real": types.Float, + "double": types.Float, + "varchar": types.String, + "timestamp": types.TIMESTAMP, + "date": types.DATE, + "varbinary": types.VARBINARY, } @@ -62,23 +62,23 @@ def visit_DATETIME(self, type_, **kw): raise ValueError("Trino does not support the DATETIME column type.") def visit_FLOAT(self, type_, **kw): - return 'DOUBLE' + return "DOUBLE" def visit_TEXT(self, type_, **kw): if type_.length: - return 'VARCHAR({:d})'.format(type_.length) + return "VARCHAR({:d})".format(type_.length) else: - return 'VARCHAR' + return "VARCHAR" class TrinoDialect(PrestoDialect): - name = 'trino' + name = "trino" supports_statement_cache = False @classmethod def dbapi(cls): return trino - + @classmethod def import_dbapi(cls): - return trino + return trino diff --git a/python/pyhive/tests/dbapi_test_case.py b/python/pyhive/tests/dbapi_test_case.py index eda2d10928b..fcbb13ad3b6 100644 --- a/python/pyhive/tests/dbapi_test_case.py +++ b/python/pyhive/tests/dbapi_test_case.py @@ -1,15 +1,16 @@ # encoding: utf-8 """Shared DB-API test cases""" -from __future__ import absolute_import -from __future__ import unicode_literals -from builtins import object -from builtins import range -from future.utils import with_metaclass -from pyhive import exc +from __future__ import absolute_import, unicode_literals + import abc import contextlib import functools +from builtins import object, range + +from future.utils import with_metaclass + +from pyhive import exc def with_cursor(fn): @@ -17,11 +18,13 @@ def with_cursor(fn): The cursor is taken from ``self.connect()``. """ + @functools.wraps(fn) def wrapped_fn(self, *args, **kwargs): with contextlib.closing(self.connect()) as connection: with contextlib.closing(connection.cursor()) as cursor: fn(self, cursor, *args, **kwargs) + return wrapped_fn @@ -32,7 +35,7 @@ def connect(self): @with_cursor def test_fetchone(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.rownumber, 0) self.assertEqual(cursor.fetchone(), (1,)) self.assertEqual(cursor.rownumber, 1) @@ -40,19 +43,19 @@ def test_fetchone(self, cursor): @with_cursor def test_fetchall(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) - cursor.execute('SELECT a FROM many_rows ORDER BY a') + cursor.execute("SELECT a FROM many_rows ORDER BY a") self.assertEqual(cursor.fetchall(), [(i,) for i in range(10000)]) @with_cursor def test_null_param(self, cursor): - cursor.execute('SELECT %s FROM one_row', (None,)) + cursor.execute("SELECT %s FROM one_row", (None,)) self.assertEqual(cursor.fetchall(), [(None,)]) @with_cursor def test_iterator(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(list(cursor), [(1,)]) self.assertRaises(StopIteration, cursor.__next__) @@ -63,7 +66,7 @@ def test_description_initial(self, cursor): @with_cursor def test_description_failed(self, cursor): try: - cursor.execute('blah_blah') + cursor.execute("blah_blah") self.assertIsNone(cursor.description) except exc.DatabaseError: pass @@ -71,28 +74,28 @@ def test_description_failed(self, cursor): @with_cursor def test_bad_query(self, cursor): def run(): - cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist') + cursor.execute("SELECT does_not_exist FROM this_really_does_not_exist") cursor.fetchone() + self.assertRaises(exc.DatabaseError, run) @with_cursor def test_concurrent_execution(self, cursor): - cursor.execute('SELECT * FROM one_row') - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) @with_cursor def test_executemany(self, cursor): for length in 1, 2: cursor.executemany( - 'SELECT %(x)d FROM one_row', - [{'x': i} for i in range(1, length + 1)] + "SELECT %(x)d FROM one_row", [{"x": i} for i in range(1, length + 1)] ) self.assertEqual(cursor.fetchall(), [(length,)]) @with_cursor def test_executemany_none(self, cursor): - cursor.executemany('should_never_get_used', []) + cursor.executemany("should_never_get_used", []) self.assertIsNone(cursor.description) self.assertRaises(exc.ProgrammingError, cursor.fetchone) @@ -102,7 +105,7 @@ def test_fetchone_no_data(self, cursor): @with_cursor def test_fetchmany(self, cursor): - cursor.execute('SELECT * FROM many_rows LIMIT 15') + cursor.execute("SELECT * FROM many_rows LIMIT 15") self.assertEqual(cursor.fetchmany(0), []) self.assertEqual(len(cursor.fetchmany(10)), 10) self.assertEqual(len(cursor.fetchmany(10)), 5) @@ -110,43 +113,45 @@ def test_fetchmany(self, cursor): @with_cursor def test_arraysize(self, cursor): cursor.arraysize = 5 - cursor.execute('SELECT * FROM many_rows LIMIT 20') + cursor.execute("SELECT * FROM many_rows LIMIT 20") self.assertEqual(len(cursor.fetchmany()), 5) @with_cursor def test_polling_loop(self, cursor): """Try to trigger the polling logic in fetchone()""" cursor._poll_interval = 0 - cursor.execute('SELECT COUNT(*) FROM many_rows') + cursor.execute("SELECT COUNT(*) FROM many_rows") self.assertEqual(cursor.fetchone(), (10000,)) @with_cursor def test_no_params(self, cursor): cursor.execute("SELECT '%(x)s' FROM one_row") - self.assertEqual(cursor.fetchall(), [('%(x)s',)]) + self.assertEqual(cursor.fetchall(), [("%(x)s",)]) def test_escape(self): """Verify that funny characters can be escaped as strings and SELECTed back""" - bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t ''' + bad_str = """`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t """ self.run_escape_case(bad_str) @with_cursor def run_escape_case(self, cursor, bad_str): - cursor.execute( - 'SELECT %d, %s FROM one_row', - (1, bad_str) - ) - self.assertEqual(cursor.fetchall(), [(1, bad_str,)]) - cursor.execute( - 'SELECT %(a)d, %(b)s FROM one_row', - {'a': 1, 'b': bad_str} + cursor.execute("SELECT %d, %s FROM one_row", (1, bad_str)) + self.assertEqual( + cursor.fetchall(), + [ + ( + 1, + bad_str, + ) + ], ) + cursor.execute("SELECT %(a)d, %(b)s FROM one_row", {"a": 1, "b": bad_str}) self.assertEqual(cursor.fetchall(), [(1, bad_str)]) @with_cursor def test_invalid_params(self, cursor): - self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi')) - self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [object])) + self.assertRaises(exc.ProgrammingError, lambda: cursor.execute("", "hi")) + self.assertRaises(exc.ProgrammingError, lambda: cursor.execute("", [object])) def test_open_close(self): with contextlib.closing(self.connect()): @@ -158,23 +163,21 @@ def test_open_close(self): @with_cursor def test_unicode(self, cursor): unicode_str = "王兢" - cursor.execute( - 'SELECT %s FROM one_row', - (unicode_str,) - ) + cursor.execute("SELECT %s FROM one_row", (unicode_str,)) self.assertEqual(cursor.fetchall(), [(unicode_str,)]) @with_cursor def test_null(self, cursor): - cursor.execute('SELECT null FROM many_rows') + cursor.execute("SELECT null FROM many_rows") self.assertEqual(cursor.fetchall(), [(None,)] * 10000) - cursor.execute('SELECT IF(a % 11 = 0, null, a) FROM many_rows') - self.assertEqual(cursor.fetchall(), [(None if a % 11 == 0 else a,) for a in range(10000)]) + cursor.execute("SELECT IF(a % 11 = 0, null, a) FROM many_rows") + self.assertEqual( + cursor.fetchall(), [(None if a % 11 == 0 else a,) for a in range(10000)] + ) @with_cursor def test_sql_where_in(self, cursor): - cursor.execute('SELECT * FROM many_rows where a in %s', ([1, 2, 3],)) + cursor.execute("SELECT * FROM many_rows where a in %s", ([1, 2, 3],)) self.assertEqual(len(cursor.fetchall()), 3) - cursor.execute('SELECT * FROM many_rows where b in %s limit 10', - (['blah'],)) + cursor.execute("SELECT * FROM many_rows where b in %s limit 10", (["blah"],)) self.assertEqual(len(cursor.fetchall()), 10) diff --git a/python/pyhive/tests/sqlalchemy_test_case.py b/python/pyhive/tests/sqlalchemy_test_case.py index db89d57b510..1c321b245ca 100644 --- a/python/pyhive/tests/sqlalchemy_test_case.py +++ b/python/pyhive/tests/sqlalchemy_test_case.py @@ -1,30 +1,31 @@ # coding: utf-8 -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import abc -import re import contextlib import functools +import re +from builtins import object import pytest import sqlalchemy -from builtins import object from future.utils import with_metaclass +from sqlalchemy import String from sqlalchemy.exc import NoSuchTableError -from sqlalchemy.schema import Index -from sqlalchemy.schema import MetaData -from sqlalchemy.schema import Table +from sqlalchemy.schema import Index, MetaData, Table from sqlalchemy.sql import expression, text -from sqlalchemy import String -sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) +sqlalchemy_version = float( + re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1) +) + def with_engine_connection(fn): """Pass a connection to the given function and handle cleanup. The connection is taken from ``self.create_engine()``. """ + @functools.wraps(fn) def wrapped_fn(self, *args, **kwargs): engine = self.create_engine() @@ -33,9 +34,13 @@ def wrapped_fn(self, *args, **kwargs): fn(self, engine, connection, *args, **kwargs) finally: engine.dispose() + return wrapped_fn -def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks): + +def reflect_table( + engine, connection, table, include_columns, exclude_columns, resolve_fks +): if sqlalchemy_version >= 1.4: insp = sqlalchemy.inspect(engine) insp.reflect_table( @@ -46,21 +51,27 @@ def reflect_table(engine, connection, table, include_columns, exclude_columns, r ) else: engine.dialect.reflecttable( - connection, table, include_columns=include_columns, - exclude_columns=exclude_columns, resolve_fks=resolve_fks) + connection, + table, + include_columns=include_columns, + exclude_columns=exclude_columns, + resolve_fks=resolve_fks, + ) class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)): @with_engine_connection def test_basic_query(self, engine, connection): - rows = connection.execute(text('SELECT * FROM one_row')).fetchall() + rows = connection.execute(text("SELECT * FROM one_row")).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name self.assertEqual(len(rows[0]), 1) @with_engine_connection def test_one_row_complex_null(self, engine, connection): - one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine) + one_row_complex_null = Table( + "one_row_complex_null", MetaData(), autoload_with=engine + ) rows = connection.execute(one_row_complex_null.select()).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(list(rows[0]), [None] * len(rows[0])) @@ -70,18 +81,30 @@ def test_reflect_no_such_table(self, engine, connection): """reflecttable should throw an exception on an invalid table""" self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + lambda: Table("this_does_not_exist", MetaData(), autoload_with=engine), + ) self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine)) + lambda: Table( + "this_does_not_exist", + MetaData(schema="also_does_not_exist"), + autoload_with=engine, + ), + ) @with_engine_connection def test_reflect_include_columns(self, engine, connection): """When passed include_columns, reflecttable should filter out other columns""" - one_row_complex = Table('one_row_complex', MetaData()) - reflect_table(engine, connection, one_row_complex, include_columns=['int'], - exclude_columns=[], resolve_fks=True) + one_row_complex = Table("one_row_complex", MetaData()) + reflect_table( + engine, + connection, + one_row_complex, + include_columns=["int"], + exclude_columns=[], + resolve_fks=True, + ) self.assertEqual(len(one_row_complex.c), 1) self.assertIsNotNone(one_row_complex.c.int) @@ -89,66 +112,90 @@ def test_reflect_include_columns(self, engine, connection): @with_engine_connection def test_reflect_with_schema(self, engine, connection): - dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine) + dummy = Table( + "dummy_table", MetaData(schema="pyhive_test_database"), autoload_with=engine + ) self.assertEqual(len(dummy.c), 1) self.assertIsNotNone(dummy.c.a) - @pytest.mark.filterwarnings('default:Omitting:sqlalchemy.exc.SAWarning') + @pytest.mark.filterwarnings("default:Omitting:sqlalchemy.exc.SAWarning") @with_engine_connection def test_reflect_partitions(self, engine, connection): """reflecttable should get the partition column as an index""" - many_rows = Table('many_rows', MetaData(), autoload_with=engine) + many_rows = Table("many_rows", MetaData(), autoload_with=engine) self.assertEqual(len(many_rows.c), 2) - self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) + self.assertEqual( + repr(many_rows.indexes), repr({Index("partition", many_rows.c.b)}) + ) + + many_rows = Table("many_rows", MetaData()) + reflect_table( + engine, + connection, + many_rows, + include_columns=["a"], + exclude_columns=[], + resolve_fks=True, + ) - many_rows = Table('many_rows', MetaData()) - reflect_table(engine, connection, many_rows, include_columns=['a'], - exclude_columns=[], resolve_fks=True) - self.assertEqual(len(many_rows.c), 1) self.assertFalse(many_rows.c.a.index) self.assertFalse(many_rows.indexes) - many_rows = Table('many_rows', MetaData()) - reflect_table(engine, connection, many_rows, include_columns=['b'], - exclude_columns=[], resolve_fks=True) + many_rows = Table("many_rows", MetaData()) + reflect_table( + engine, + connection, + many_rows, + include_columns=["b"], + exclude_columns=[], + resolve_fks=True, + ) self.assertEqual(len(many_rows.c), 1) - self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) + self.assertEqual( + repr(many_rows.indexes), repr({Index("partition", many_rows.c.b)}) + ) @with_engine_connection def test_unicode(self, engine, connection): """Verify that unicode strings make it through SQLAlchemy and the backend""" unicode_str = "中文" - one_row = Table('one_row', MetaData()) + one_row = Table("one_row", MetaData()) if sqlalchemy_version >= 1.4: - returned_str = connection.execute(sqlalchemy.select( - expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar() + returned_str = connection.execute( + sqlalchemy.select( + expression.bindparam("好", unicode_str, type_=String()) + ).select_from(one_row) + ).scalar() else: - returned_str = connection.execute(sqlalchemy.select([ - expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar() - + returned_str = connection.execute( + sqlalchemy.select( + [expression.bindparam("好", unicode_str, type_=String())] + ).select_from(one_row) + ).scalar() + self.assertEqual(returned_str, unicode_str) @with_engine_connection def test_reflect_schemas(self, engine, connection): insp = sqlalchemy.inspect(engine) schemas = insp.get_schema_names() - self.assertIn('pyhive_test_database', schemas) - self.assertIn('default', schemas) + self.assertIn("pyhive_test_database", schemas) + self.assertIn("default", schemas) @with_engine_connection def test_get_table_names(self, engine, connection): meta = MetaData() meta.reflect(bind=engine) - self.assertIn('one_row', meta.tables) - self.assertIn('one_row_complex', meta.tables) + self.assertIn("one_row", meta.tables) + self.assertIn("one_row_complex", meta.tables) insp = sqlalchemy.inspect(engine) self.assertIn( - 'dummy_table', - insp.get_table_names(schema='pyhive_test_database'), + "dummy_table", + insp.get_table_names(schema="pyhive_test_database"), ) @with_engine_connection @@ -158,16 +205,24 @@ def test_has_table(self, engine, connection): self.assertTrue(insp.has_table("one_row")) self.assertFalse(insp.has_table("this_table_does_not_exist")) else: - self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) - self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) + self.assertTrue(Table("one_row", MetaData(bind=engine)).exists()) + self.assertFalse( + Table("this_table_does_not_exist", MetaData(bind=engine)).exists() + ) @with_engine_connection def test_char_length(self, engine, connection): - one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + one_row_complex = Table("one_row_complex", MetaData(), autoload_with=engine) if sqlalchemy_version >= 1.4: - result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar() + result = connection.execute( + sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string)) + ).scalar() else: - result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar() + result = connection.execute( + sqlalchemy.select( + [sqlalchemy.func.char_length(one_row_complex.c.string)] + ) + ).scalar() - self.assertEqual(result, len('a string')) + self.assertEqual(result, len("a string")) diff --git a/python/pyhive/tests/test_common.py b/python/pyhive/tests/test_common.py index 715c7168cc7..2c77ab9ff30 100644 --- a/python/pyhive/tests/test_common.py +++ b/python/pyhive/tests/test_common.py @@ -1,40 +1,44 @@ # encoding: utf-8 -from __future__ import absolute_import -from __future__ import unicode_literals -from pyhive import common +from __future__ import absolute_import, unicode_literals + import datetime import unittest +from pyhive import common + class TestCommon(unittest.TestCase): def test_escape_args(self): escaper = common.ParamEscaper() - self.assertEqual(escaper.escape_args({'foo': 'bar'}), - {'foo': "'bar'"}) - self.assertEqual(escaper.escape_args({'foo': 123}), - {'foo': 123}) - self.assertEqual(escaper.escape_args({'foo': 123.456}), - {'foo': 123.456}) - self.assertEqual(escaper.escape_args({'foo': ['a', 'b', 'c']}), - {'foo': "('a','b','c')"}) - self.assertEqual(escaper.escape_args({'foo': ('a', 'b', 'c')}), - {'foo': "('a','b','c')"}) - self.assertIn(escaper.escape_args({'foo': {'a', 'b'}}), - ({'foo': "('a','b')"}, {'foo': "('b','a')"})) - self.assertIn(escaper.escape_args({'foo': frozenset(['a', 'b'])}), - ({'foo': "('a','b')"}, {'foo': "('b','a')"})) + self.assertEqual(escaper.escape_args({"foo": "bar"}), {"foo": "'bar'"}) + self.assertEqual(escaper.escape_args({"foo": 123}), {"foo": 123}) + self.assertEqual(escaper.escape_args({"foo": 123.456}), {"foo": 123.456}) + self.assertEqual( + escaper.escape_args({"foo": ["a", "b", "c"]}), {"foo": "('a','b','c')"} + ) + self.assertEqual( + escaper.escape_args({"foo": ("a", "b", "c")}), {"foo": "('a','b','c')"} + ) + self.assertIn( + escaper.escape_args({"foo": {"a", "b"}}), + ({"foo": "('a','b')"}, {"foo": "('b','a')"}), + ) + self.assertIn( + escaper.escape_args({"foo": frozenset(["a", "b"])}), + ({"foo": "('a','b')"}, {"foo": "('b','a')"}), + ) - self.assertEqual(escaper.escape_args(('bar',)), - ("'bar'",)) - self.assertEqual(escaper.escape_args([123]), - (123,)) - self.assertEqual(escaper.escape_args((123.456,)), - (123.456,)) - self.assertEqual(escaper.escape_args((['a', 'b', 'c'],)), - ("('a','b','c')",)) - self.assertEqual(escaper.escape_args((['你好', 'b', 'c'],)), - ("('你好','b','c')",)) - self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), - ("'2020-04-17'",)) - self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), - ("'2020-04-17 12:00:00.123456'",)) + self.assertEqual(escaper.escape_args(("bar",)), ("'bar'",)) + self.assertEqual(escaper.escape_args([123]), (123,)) + self.assertEqual(escaper.escape_args((123.456,)), (123.456,)) + self.assertEqual(escaper.escape_args((["a", "b", "c"],)), ("('a','b','c')",)) + self.assertEqual( + escaper.escape_args((["你好", "b", "c"],)), ("('你好','b','c')",) + ) + self.assertEqual( + escaper.escape_args((datetime.date(2020, 4, 17),)), ("'2020-04-17'",) + ) + self.assertEqual( + escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), + ("'2020-04-17 12:00:00.123456'",), + ) diff --git a/python/pyhive/tests/test_hive.py b/python/pyhive/tests/test_hive.py index b49fc1904e0..e2b17d639fe 100644 --- a/python/pyhive/tests/test_hive.py +++ b/python/pyhive/tests/test_hive.py @@ -4,8 +4,7 @@ They also require a tables created by make_test_tables.sh. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import contextlib import datetime @@ -22,72 +21,116 @@ import thrift_sasl from thrift.transport.TTransport import TTransportException -from TCLIService import ttypes from pyhive import hive -from pyhive.tests.dbapi_test_case import DBAPITestCase -from pyhive.tests.dbapi_test_case import with_cursor +from pyhive.tests.dbapi_test_case import DBAPITestCase, with_cursor +from TCLIService import ttypes -_HOST = 'localhost' +_HOST = "localhost" class TestHive(unittest.TestCase, DBAPITestCase): __test__ = True def connect(self): - return hive.connect(host=_HOST, configuration={'mapred.job.tracker': 'local'}) + return hive.connect(host=_HOST, configuration={"mapred.job.tracker": "local"}) @with_cursor def test_description(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") - desc = [('one_row.number_of_rows', 'INT_TYPE', None, None, None, None, True)] + desc = [("one_row.number_of_rows", "INT_TYPE", None, None, None, None, True)] self.assertEqual(cursor.description, desc) @with_cursor def test_complex(self, cursor): - cursor.execute('SELECT * FROM one_row_complex') - self.assertEqual(cursor.description, [ - ('one_row_complex.boolean', 'BOOLEAN_TYPE', None, None, None, None, True), - ('one_row_complex.tinyint', 'TINYINT_TYPE', None, None, None, None, True), - ('one_row_complex.smallint', 'SMALLINT_TYPE', None, None, None, None, True), - ('one_row_complex.int', 'INT_TYPE', None, None, None, None, True), - ('one_row_complex.bigint', 'BIGINT_TYPE', None, None, None, None, True), - ('one_row_complex.float', 'FLOAT_TYPE', None, None, None, None, True), - ('one_row_complex.double', 'DOUBLE_TYPE', None, None, None, None, True), - ('one_row_complex.string', 'STRING_TYPE', None, None, None, None, True), - ('one_row_complex.timestamp', 'TIMESTAMP_TYPE', None, None, None, None, True), - ('one_row_complex.binary', 'BINARY_TYPE', None, None, None, None, True), - ('one_row_complex.array', 'ARRAY_TYPE', None, None, None, None, True), - ('one_row_complex.map', 'MAP_TYPE', None, None, None, None, True), - ('one_row_complex.struct', 'STRUCT_TYPE', None, None, None, None, True), - ('one_row_complex.union', 'UNION_TYPE', None, None, None, None, True), - ('one_row_complex.decimal', 'DECIMAL_TYPE', None, None, None, None, True), - ]) + cursor.execute("SELECT * FROM one_row_complex") + self.assertEqual( + cursor.description, + [ + ( + "one_row_complex.boolean", + "BOOLEAN_TYPE", + None, + None, + None, + None, + True, + ), + ( + "one_row_complex.tinyint", + "TINYINT_TYPE", + None, + None, + None, + None, + True, + ), + ( + "one_row_complex.smallint", + "SMALLINT_TYPE", + None, + None, + None, + None, + True, + ), + ("one_row_complex.int", "INT_TYPE", None, None, None, None, True), + ("one_row_complex.bigint", "BIGINT_TYPE", None, None, None, None, True), + ("one_row_complex.float", "FLOAT_TYPE", None, None, None, None, True), + ("one_row_complex.double", "DOUBLE_TYPE", None, None, None, None, True), + ("one_row_complex.string", "STRING_TYPE", None, None, None, None, True), + ( + "one_row_complex.timestamp", + "TIMESTAMP_TYPE", + None, + None, + None, + None, + True, + ), + ("one_row_complex.binary", "BINARY_TYPE", None, None, None, None, True), + ("one_row_complex.array", "ARRAY_TYPE", None, None, None, None, True), + ("one_row_complex.map", "MAP_TYPE", None, None, None, None, True), + ("one_row_complex.struct", "STRUCT_TYPE", None, None, None, None, True), + ("one_row_complex.union", "UNION_TYPE", None, None, None, None, True), + ( + "one_row_complex.decimal", + "DECIMAL_TYPE", + None, + None, + None, + None, + True, + ), + ], + ) rows = cursor.fetchall() - expected = [( - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - datetime.datetime(1970, 1, 1, 0, 0), - b'123', - '[1,2]', - '{1:2,3:4}', - '{"a":1,"b":2}', - '{0:1}', - Decimal('0.1'), - )] + expected = [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + datetime.datetime(1970, 1, 1, 0, 0), + b"123", + "[1,2]", + "{1:2,3:4}", + '{"a":1,"b":2}', + "{0:1}", + Decimal("0.1"), + ) + ] self.assertEqual(rows, expected) # catch unicode/str self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) @with_cursor def test_async(self, cursor): - cursor.execute('SELECT * FROM one_row', async_=True) + cursor.execute("SELECT * FROM one_row", async_=True) unfinished_states = ( ttypes.TOperationState.INITIALIZED_STATE, ttypes.TOperationState.RUNNING_STATE, @@ -105,12 +148,16 @@ def test_cancel(self, cursor): cursor.execute( "SELECT reflect('java.lang.Thread', 'sleep', 1000L * 1000L * 1000L) " "FROM one_row a JOIN one_row b", - async_=True + async_=True, + ) + self.assertEqual( + cursor.poll().operationState, ttypes.TOperationState.RUNNING_STATE ) - self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.RUNNING_STATE) - assert any('Stage' in line for line in cursor.fetch_logs()) + assert any("Stage" in line for line in cursor.fetch_logs()) cursor.cancel() - self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.CANCELED_STATE) + self.assertEqual( + cursor.poll().operationState, ttypes.TOperationState.CANCELED_STATE + ) def test_noops(self): """The DB-API specification requires that certain actions exist, even though they might not @@ -120,124 +167,156 @@ def test_noops(self): with contextlib.closing(connection.cursor()) as cursor: self.assertEqual(cursor.rowcount, -1) cursor.setinputsizes([]) - cursor.setoutputsize(1, 'blah') + cursor.setoutputsize(1, "blah") connection.commit() - @mock.patch('TCLIService.TCLIService.Client.OpenSession') + @mock.patch("TCLIService.TCLIService.Client.OpenSession") def test_open_failed(self, open_session): - open_session.return_value.serverProtocolVersion = \ + open_session.return_value.serverProtocolVersion = ( ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1 + ) self.assertRaises(hive.OperationalError, self.connect) def test_escape(self): # Hive thrift translates newlines into multiple rows. WTF. - bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t ''' + bad_str = """`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t """ self.run_escape_case(bad_str) def test_newlines(self): """Verify that newlines are passed through correctly""" cursor = self.connect().cursor() - orig = ' \r\n \r \n ' - cursor.execute( - 'SELECT %s FROM one_row', - (orig,) - ) + orig = " \r\n \r \n " + cursor.execute("SELECT %s FROM one_row", (orig,)) result = cursor.fetchall() self.assertEqual(result, [(orig,)]) @with_cursor def test_no_result_set(self, cursor): - cursor.execute('USE default') + cursor.execute("USE default") self.assertIsNone(cursor.description) self.assertRaises(hive.ProgrammingError, cursor.fetchone) def test_ldap_connection(self): rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - orig_ldap = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site-ldap.xml') - orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml') - des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') + orig_ldap = os.path.join( + rootdir, "scripts", "travis-conf", "hive", "hive-site-ldap.xml" + ) + orig_none = os.path.join( + rootdir, "scripts", "travis-conf", "hive", "hive-site.xml" + ) + des = os.path.join("/", "etc", "hive", "conf", "hive-site.xml") try: - subprocess.check_call(['sudo', 'cp', orig_ldap, des]) + subprocess.check_call(["sudo", "cp", orig_ldap, des]) _restart_hs2() - with contextlib.closing(hive.connect( - host=_HOST, username='existing', auth='LDAP', password='testpw') + with contextlib.closing( + hive.connect( + host=_HOST, username="existing", auth="LDAP", password="testpw" + ) ) as connection: with contextlib.closing(connection.cursor()) as cursor: - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) self.assertRaisesRegexp( - TTransportException, 'Error validating the login', + TTransportException, + "Error validating the login", lambda: hive.connect( - host=_HOST, username='existing', auth='LDAP', password='wrong') + host=_HOST, username="existing", auth="LDAP", password="wrong" + ), ) finally: - subprocess.check_call(['sudo', 'cp', orig_none, des]) + subprocess.check_call(["sudo", "cp", orig_none, des]) _restart_hs2() def test_invalid_ldap_config(self): """password should be set if and only if using LDAP""" - self.assertRaisesRegexp(ValueError, 'Password.*LDAP', - lambda: hive.connect(_HOST, password='')) - self.assertRaisesRegexp(ValueError, 'Password.*LDAP', - lambda: hive.connect(_HOST, auth='LDAP')) + self.assertRaisesRegexp( + ValueError, "Password.*LDAP", lambda: hive.connect(_HOST, password="") + ) + self.assertRaisesRegexp( + ValueError, "Password.*LDAP", lambda: hive.connect(_HOST, auth="LDAP") + ) def test_invalid_kerberos_config(self): """kerberos_service_name should be set if and only if using KERBEROS""" - self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS', - lambda: hive.connect(_HOST, kerberos_service_name='')) - self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS', - lambda: hive.connect(_HOST, auth='KERBEROS')) + self.assertRaisesRegexp( + ValueError, + "kerberos_service_name.*KERBEROS", + lambda: hive.connect(_HOST, kerberos_service_name=""), + ) + self.assertRaisesRegexp( + ValueError, + "kerberos_service_name.*KERBEROS", + lambda: hive.connect(_HOST, auth="KERBEROS"), + ) def test_invalid_transport(self): """transport and auth are incompatible""" - socket = thrift.transport.TSocket.TSocket('localhost', 10000) + socket = thrift.transport.TSocket.TSocket("localhost", 10000) transport = thrift.transport.TTransport.TBufferedTransport(socket) self.assertRaisesRegexp( - ValueError, 'thrift_transport cannot be used with', - lambda: hive.connect(_HOST, thrift_transport=transport) + ValueError, + "thrift_transport cannot be used with", + lambda: hive.connect(_HOST, thrift_transport=transport), ) def test_custom_transport(self): - socket = thrift.transport.TSocket.TSocket('localhost', 10000) - sasl_auth = 'PLAIN' - - transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket) + socket = thrift.transport.TSocket.TSocket("localhost", 10000) + sasl_auth = "PLAIN" + + transport = thrift_sasl.TSaslClientTransport( + lambda: hive.get_installed_sasl( + host="localhost", + sasl_auth=sasl_auth, + username="test_username", + password="x", + ), + sasl_auth, + socket, + ) conn = hive.connect(thrift_transport=transport) with contextlib.closing(conn): with contextlib.closing(conn.cursor()) as cursor: - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) def test_custom_connection(self): rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - orig_ldap = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site-custom.xml') - orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml') - des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') + orig_ldap = os.path.join( + rootdir, "scripts", "travis-conf", "hive", "hive-site-custom.xml" + ) + orig_none = os.path.join( + rootdir, "scripts", "travis-conf", "hive", "hive-site.xml" + ) + des = os.path.join("/", "etc", "hive", "conf", "hive-site.xml") try: - subprocess.check_call(['sudo', 'cp', orig_ldap, des]) + subprocess.check_call(["sudo", "cp", orig_ldap, des]) _restart_hs2() - with contextlib.closing(hive.connect( - host=_HOST, username='the-user', auth='CUSTOM', password='p4ssw0rd') + with contextlib.closing( + hive.connect( + host=_HOST, username="the-user", auth="CUSTOM", password="p4ssw0rd" + ) ) as connection: with contextlib.closing(connection.cursor()) as cursor: - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) self.assertRaisesRegexp( - TTransportException, 'Error validating the login', + TTransportException, + "Error validating the login", lambda: hive.connect( - host=_HOST, username='the-user', auth='CUSTOM', password='wrong') + host=_HOST, username="the-user", auth="CUSTOM", password="wrong" + ), ) finally: - subprocess.check_call(['sudo', 'cp', orig_none, des]) + subprocess.check_call(["sudo", "cp", orig_none, des]) _restart_hs2() def _restart_hs2(): - subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart']) + subprocess.check_call(["sudo", "service", "hive-server2", "restart"]) with contextlib.closing(socket.socket()) as s: - while s.connect_ex(('localhost', 10000)) != 0: + while s.connect_ex(("localhost", 10000)) != 0: time.sleep(1) diff --git a/python/pyhive/tests/test_presto.py b/python/pyhive/tests/test_presto.py index 187b1c2140f..dea1ca4c8c7 100644 --- a/python/pyhive/tests/test_presto.py +++ b/python/pyhive/tests/test_presto.py @@ -4,25 +4,22 @@ They also require a tables created by make_test_tables.sh. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import contextlib +import datetime import os +import unittest from decimal import Decimal +import mock import requests -from pyhive import exc -from pyhive import presto -from pyhive.tests.dbapi_test_case import DBAPITestCase -from pyhive.tests.dbapi_test_case import with_cursor -import mock -import unittest -import datetime +from pyhive import exc, presto +from pyhive.tests.dbapi_test_case import DBAPITestCase, with_cursor -_HOST = 'localhost' -_PORT = '8080' +_HOST = "localhost" +_PORT = "8080" class TestPresto(unittest.TestCase, DBAPITestCase): @@ -32,71 +29,87 @@ def connect(self): return presto.connect(host=_HOST, port=_PORT, source=self.id()) def test_bad_protocol(self): - self.assertRaisesRegexp(ValueError, 'Protocol must be', - lambda: presto.connect('localhost', protocol='nonsense').cursor()) + self.assertRaisesRegexp( + ValueError, + "Protocol must be", + lambda: presto.connect("localhost", protocol="nonsense").cursor(), + ) def test_escape_args(self): escaper = presto.PrestoParamEscaper() - self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), - ("date '2020-04-17'",)) - self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), - ("timestamp '2020-04-17 12:00:00.123'",)) + self.assertEqual( + escaper.escape_args((datetime.date(2020, 4, 17),)), ("date '2020-04-17'",) + ) + self.assertEqual( + escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), + ("timestamp '2020-04-17 12:00:00.123'",), + ) @with_cursor def test_description(self, cursor): - cursor.execute('SELECT 1 AS foobar FROM one_row') - self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)]) + cursor.execute("SELECT 1 AS foobar FROM one_row") + self.assertEqual( + cursor.description, [("foobar", "integer", None, None, None, None, True)] + ) self.assertIsNotNone(cursor.last_query_id) @with_cursor def test_complex(self, cursor): - cursor.execute('SELECT * FROM one_row_complex') + cursor.execute("SELECT * FROM one_row_complex") # TODO Presto drops the union field - if os.environ.get('PRESTO') == '0.147': - tinyint_type = 'integer' - smallint_type = 'integer' - float_type = 'double' + if os.environ.get("PRESTO") == "0.147": + tinyint_type = "integer" + smallint_type = "integer" + float_type = "double" else: # some later version made these map to more specific types - tinyint_type = 'tinyint' - smallint_type = 'smallint' - float_type = 'real' - self.assertEqual(cursor.description, [ - ('boolean', 'boolean', None, None, None, None, True), - ('tinyint', tinyint_type, None, None, None, None, True), - ('smallint', smallint_type, None, None, None, None, True), - ('int', 'integer', None, None, None, None, True), - ('bigint', 'bigint', None, None, None, None, True), - ('float', float_type, None, None, None, None, True), - ('double', 'double', None, None, None, None, True), - ('string', 'varchar', None, None, None, None, True), - ('timestamp', 'timestamp', None, None, None, None, True), - ('binary', 'varbinary', None, None, None, None, True), - ('array', 'array(integer)', None, None, None, None, True), - ('map', 'map(integer,integer)', None, None, None, None, True), - ('struct', 'row(a integer,b integer)', None, None, None, None, True), - # ('union', 'varchar', None, None, None, None, True), - ('decimal', 'decimal(10,1)', None, None, None, None, True), - ]) + tinyint_type = "tinyint" + smallint_type = "smallint" + float_type = "real" + self.assertEqual( + cursor.description, + [ + ("boolean", "boolean", None, None, None, None, True), + ("tinyint", tinyint_type, None, None, None, None, True), + ("smallint", smallint_type, None, None, None, None, True), + ("int", "integer", None, None, None, None, True), + ("bigint", "bigint", None, None, None, None, True), + ("float", float_type, None, None, None, None, True), + ("double", "double", None, None, None, None, True), + ("string", "varchar", None, None, None, None, True), + ("timestamp", "timestamp", None, None, None, None, True), + ("binary", "varbinary", None, None, None, None, True), + ("array", "array(integer)", None, None, None, None, True), + ("map", "map(integer,integer)", None, None, None, None, True), + ("struct", "row(a integer,b integer)", None, None, None, None, True), + # ('union', 'varchar', None, None, None, None, True), + ("decimal", "decimal(10,1)", None, None, None, None, True), + ], + ) rows = cursor.fetchall() - expected = [( - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - '1970-01-01 00:00:00.000', - b'123', - [1, 2], - {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON - [1, 2], # struct is returned as a list of elements - # '{0:1}', - Decimal('0.1'), - )] + expected = [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "1970-01-01 00:00:00.000", + b"123", + [1, 2], + { + "1": 2, + "3": 4, + }, # Presto converts all keys to strings so that they're valid JSON + [1, 2], # struct is returned as a list of elements + # '{0:1}', + Decimal("0.1"), + ) + ] self.assertEqual(rows, expected) # catch unicode/str self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) @@ -108,8 +121,10 @@ def test_cancel(self, cursor): "FROM many_rows a " "CROSS JOIN many_rows b " ) - self.assertIn(cursor.poll()['stats']['state'], ( - 'STARTING', 'PLANNING', 'RUNNING', 'WAITING_FOR_RESOURCES', 'QUEUED')) + self.assertIn( + cursor.poll()["stats"]["state"], + ("STARTING", "PLANNING", "RUNNING", "WAITING_FOR_RESOURCES", "QUEUED"), + ) cursor.cancel() self.assertIsNotNone(cursor.last_query_id) self.assertIsNone(cursor.poll()) @@ -122,31 +137,33 @@ def test_noops(self): cursor = connection.cursor() self.assertEqual(cursor.rowcount, -1) cursor.setinputsizes([]) - cursor.setoutputsize(1, 'blah') + cursor.setoutputsize(1, "blah") self.assertIsNone(cursor.last_query_id) connection.commit() - @mock.patch('requests.post') + @mock.patch("requests.post") def test_non_200(self, post): cursor = self.connect().cursor() post.return_value.status_code = 404 - self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables')) + self.assertRaises(exc.OperationalError, lambda: cursor.execute("show tables")) @with_cursor def test_poll(self, cursor): self.assertRaises(presto.ProgrammingError, cursor.poll) - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") while True: status = cursor.poll() if status is None: break - self.assertIn('stats', status) + self.assertIn("stats", status) def fail(*args, **kwargs): - self.fail("Should not need requests.get after done polling") # pragma: no cover + self.fail( + "Should not need requests.get after done polling" + ) # pragma: no cover - with mock.patch('requests.get', fail): + with mock.patch("requests.get", fail): self.assertEqual(cursor.fetchall(), [(1,)]) @with_cursor @@ -159,86 +176,100 @@ def test_set_session(self, cursor): cursor.fetchall() self.assertEqual(id, cursor.last_query_id) - cursor.execute('SHOW SESSION') + cursor.execute("SHOW SESSION") self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id - rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + rows = [r for r in cursor.fetchall() if r[0] == "query_max_run_time"] self.assertEqual(len(rows), 1) session_prop = rows[0] - self.assertEqual(session_prop[1], '1234m') + self.assertEqual(session_prop[1], "1234m") self.assertEqual(id, cursor.last_query_id) - cursor.execute('RESET SESSION query_max_run_time') + cursor.execute("RESET SESSION query_max_run_time") self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id cursor.fetchall() self.assertEqual(id, cursor.last_query_id) - cursor.execute('SHOW SESSION') + cursor.execute("SHOW SESSION") self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id - rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + rows = [r for r in cursor.fetchall() if r[0] == "query_max_run_time"] self.assertEqual(len(rows), 1) session_prop = rows[0] - self.assertNotEqual(session_prop[1], '1234m') + self.assertNotEqual(session_prop[1], "1234m") self.assertEqual(id, cursor.last_query_id) def test_set_session_in_constructor(self): conn = presto.connect( - host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'} + host=_HOST, source=self.id(), session_props={"query_max_run_time": "1234m"} ) with contextlib.closing(conn): with contextlib.closing(conn.cursor()) as cursor: - cursor.execute('SHOW SESSION') - rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + cursor.execute("SHOW SESSION") + rows = [r for r in cursor.fetchall() if r[0] == "query_max_run_time"] assert len(rows) == 1 session_prop = rows[0] - assert session_prop[1] == '1234m' + assert session_prop[1] == "1234m" - cursor.execute('RESET SESSION query_max_run_time') + cursor.execute("RESET SESSION query_max_run_time") cursor.fetchall() - cursor.execute('SHOW SESSION') - rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + cursor.execute("SHOW SESSION") + rows = [r for r in cursor.fetchall() if r[0] == "query_max_run_time"] assert len(rows) == 1 session_prop = rows[0] - assert session_prop[1] != '1234m' + assert session_prop[1] != "1234m" def test_invalid_protocol_config(self): """protocol should be https when passing password""" self.assertRaisesRegexp( - ValueError, 'Protocol.*https.*password', lambda: presto.connect( - host=_HOST, username='user', password='secret', protocol='http').cursor() + ValueError, + "Protocol.*https.*password", + lambda: presto.connect( + host=_HOST, username="user", password="secret", protocol="http" + ).cursor(), ) def test_invalid_password_and_kwargs(self): """password and requests_kwargs authentication are incompatible""" self.assertRaisesRegexp( - ValueError, 'Cannot use both', lambda: presto.connect( - host=_HOST, username='user', password='secret', protocol='https', - requests_kwargs={'auth': requests.auth.HTTPBasicAuth('user', 'secret')} - ).cursor() + ValueError, + "Cannot use both", + lambda: presto.connect( + host=_HOST, + username="user", + password="secret", + protocol="https", + requests_kwargs={"auth": requests.auth.HTTPBasicAuth("user", "secret")}, + ).cursor(), ) def test_invalid_kwargs(self): """some kwargs are reserved""" self.assertRaisesRegexp( - ValueError, 'Cannot override', lambda: presto.connect( - host=_HOST, username='user', requests_kwargs={'url': 'test'} - ).cursor() + ValueError, + "Cannot override", + lambda: presto.connect( + host=_HOST, username="user", requests_kwargs={"url": "test"} + ).cursor(), ) def test_requests_kwargs(self): connection = presto.connect( - host=_HOST, port=_PORT, source=self.id(), - requests_kwargs={'proxies': {'http': 'localhost:9999'}}, + host=_HOST, + port=_PORT, + source=self.id(), + requests_kwargs={"proxies": {"http": "localhost:9999"}}, ) cursor = connection.cursor() - self.assertRaises(requests.exceptions.ProxyError, - lambda: cursor.execute('SELECT * FROM one_row')) + self.assertRaises( + requests.exceptions.ProxyError, + lambda: cursor.execute("SELECT * FROM one_row"), + ) def test_requests_session(self): with requests.Session() as session: @@ -246,5 +277,5 @@ def test_requests_session(self): host=_HOST, port=_PORT, source=self.id(), requests_session=session ) cursor = connection.cursor() - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) diff --git a/python/pyhive/tests/test_sasl_compat.py b/python/pyhive/tests/test_sasl_compat.py index 55e53fe91e5..731015d6a64 100644 --- a/python/pyhive/tests/test_sasl_compat.py +++ b/python/pyhive/tests/test_sasl_compat.py @@ -1,4 +1,4 @@ -''' +""" http://www.opensource.org/licenses/mit-license.php Copyright 2007-2011 David Alan Cridland @@ -20,20 +20,21 @@ FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' -# This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit +""" +# This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit # and by refactoring them to cover wrapper functions in sasl_compat.py along with added coverage for functions exclusive to sasl_compat.py. -import unittest import base64 import hashlib import hmac +import struct +import unittest + import kerberos -from mock import patch import six -import struct -from puresasl import SASLProtocolException, QOP -from puresasl.client import SASLError +from mock import patch +from puresasl import QOP + from pyhive.sasl_compat import PureSASLClient, error_catcher @@ -42,7 +43,7 @@ class TestPureSASLClient(unittest.TestCase): def setUp(self): self.sasl_kwargs = {} - self.sasl = PureSASLClient('localhost', **self.sasl_kwargs) + self.sasl = PureSASLClient("localhost", **self.sasl_kwargs) def test_start_no_mechanism(self): """Test starting SASL authentication with no mechanism.""" @@ -50,31 +51,40 @@ def test_start_no_mechanism(self): self.assertFalse(success) self.assertIsNone(mechanism) self.assertIsNone(response) - self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + self.assertEqual( + self.sasl.getError(), + "None of the mechanisms listed meet all required properties", + ) def test_start_wrong_mechanism(self): """Test starting SASL authentication with a single unsupported mechanism.""" - success, mechanism, response = self.sasl.start(mechanism='WRONG') + success, mechanism, response = self.sasl.start(mechanism="WRONG") self.assertFalse(success) - self.assertEqual(mechanism, 'WRONG') + self.assertEqual(mechanism, "WRONG") self.assertIsNone(response) - self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + self.assertEqual( + self.sasl.getError(), + "None of the mechanisms listed meet all required properties", + ) def test_start_list_of_invalid_mechanisms(self): """Test starting SASL authentication with a list of unsupported mechanisms.""" - self.sasl.start(['invalid1', 'invalid2']) - self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + self.sasl.start(["invalid1", "invalid2"]) + self.assertEqual( + self.sasl.getError(), + "None of the mechanisms listed meet all required properties", + ) def test_start_list_of_valid_mechanisms(self): """Test starting SASL authentication with a list of supported mechanisms.""" - self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5']) + self.sasl.start(["PLAIN", "DIGEST-MD5", "CRAM-MD5"]) # Validate right mechanism is chosen based on score. - self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5') + self.assertEqual(self.sasl._chosen_mech.name, "DIGEST-MD5") def test_error_catcher_no_error(self): """Test the error_catcher with no error.""" with error_catcher(self.sasl): - result, _, _ = self.sasl.start(mechanism='ANONYMOUS') + result, _, _ = self.sasl.start(mechanism="ANONYMOUS") self.assertEqual(self.sasl.getError(), None) self.assertEqual(result, True) @@ -82,87 +92,98 @@ def test_error_catcher_no_error(self): def test_error_catcher_with_error(self): """Test the error_catcher with an error.""" with error_catcher(self.sasl): - result, _, _ = self.sasl.start(mechanism='WRONG') + result, _, _ = self.sasl.start(mechanism="WRONG") self.assertEqual(result, False) - self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + self.assertEqual( + self.sasl.getError(), + "None of the mechanisms listed meet all required properties", + ) + """Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims""" + class _BaseMechanismTests(unittest.TestCase): """Base test case for SASL mechanisms.""" - mechanism = 'ANONYMOUS' + mechanism = "ANONYMOUS" sasl_kwargs = {} def setUp(self): - self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + self.sasl = PureSASLClient( + "localhost", mechanism=self.mechanism, **self.sasl_kwargs + ) self.mechanism_class = self.sasl._chosen_mech def test_init_basic(self, *args): - sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + sasl = PureSASLClient("localhost", mechanism=self.mechanism, **self.sasl_kwargs) mech = sasl._chosen_mech self.assertIs(mech.sasl, sasl) def test_step_basic(self, *args): - success, response = self.sasl.step(six.b('string')) + success, response = self.sasl.step(six.b("string")) self.assertTrue(success) self.assertIsInstance(response, six.binary_type) def test_decode_encode(self, *args): - self.assertEqual(self.sasl.encode('msg'), (False, None)) - self.assertEqual(self.sasl.getError(), '') - self.assertEqual(self.sasl.decode('msg'), (False, None)) - self.assertEqual(self.sasl.getError(), '') + self.assertEqual(self.sasl.encode("msg"), (False, None)) + self.assertEqual(self.sasl.getError(), "") + self.assertEqual(self.sasl.decode("msg"), (False, None)) + self.assertEqual(self.sasl.getError(), "") class AnonymousMechanismTest(_BaseMechanismTests): """Test case for the Anonymous SASL mechanism.""" - mechanism = 'ANONYMOUS' + mechanism = "ANONYMOUS" class PlainTextMechanismTest(_BaseMechanismTests): """Test case for the PlainText SASL mechanism.""" - mechanism = 'PLAIN' - username = 'user' - password = 'pass' - sasl_kwargs = {'username': username, 'password': password} + mechanism = "PLAIN" + username = "user" + password = "pass" + sasl_kwargs = {"username": username, "password": password} def test_step(self): - for challenge in (None, '', b'asdf', u"\U0001F44D"): + for challenge in (None, "", b"asdf", "\U0001f44d"): success, response = self.sasl.step(challenge) self.assertTrue(success) - self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}')) + self.assertEqual(response, six.b(f"\x00{self.username}\x00{self.password}")) self.assertIsInstance(response, six.binary_type) def test_step_with_authorization_id_or_identity(self): - challenge = u"\U0001F44D" - identity = 'user2' + challenge = "\U0001f44d" + identity = "user2" # Test that we can pass an identity sasl_kwargs = self.sasl_kwargs.copy() - sasl_kwargs.update({'identity': identity}) - sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + sasl_kwargs.update({"identity": identity}) + sasl = PureSASLClient("localhost", mechanism=self.mechanism, **sasl_kwargs) success, response = sasl.step(challenge) self.assertTrue(success) - self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}')) + self.assertEqual( + response, six.b(f"{identity}\x00{self.username}\x00{self.password}") + ) self.assertIsInstance(response, six.binary_type) self.assertTrue(sasl.complete) # Test that the sasl authorization_id has priority over identity - auth_id = 'user3' - sasl_kwargs.update({'authorization_id': auth_id}) - sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + auth_id = "user3" + sasl_kwargs.update({"authorization_id": auth_id}) + sasl = PureSASLClient("localhost", mechanism=self.mechanism, **sasl_kwargs) success, response = sasl.step(challenge) self.assertTrue(success) - self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}')) + self.assertEqual( + response, six.b(f"{auth_id}\x00{self.username}\x00{self.password}") + ) self.assertIsInstance(response, six.binary_type) self.assertTrue(sasl.complete) def test_decode_encode(self): - msg = 'msg' + msg = "msg" self.assertEqual(self.sasl.decode(msg), (True, msg)) self.assertEqual(self.sasl.encode(msg), (True, msg)) @@ -170,76 +191,109 @@ def test_decode_encode(self): class ExternalMechanismTest(_BaseMechanismTests): """Test case for the External SASL mechanisms""" - mechanism = 'EXTERNAL' + mechanism = "EXTERNAL" def test_step(self): - self.assertEqual(self.sasl.step(), (True, b'')) + self.assertEqual(self.sasl.step(), (True, b"")) def test_decode_encode(self): - msg = 'msg' + msg = "msg" self.assertEqual(self.sasl.decode(msg), (True, msg)) self.assertEqual(self.sasl.encode(msg), (True, msg)) -@patch('puresasl.mechanisms.kerberos.authGSSClientStep') -@patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response'))) +@patch("puresasl.mechanisms.kerberos.authGSSClientStep") +@patch( + "puresasl.mechanisms.kerberos.authGSSClientResponse", + return_value=base64.b64encode(six.b("some\x00 response")), +) class GSSAPIMechanismTest(_BaseMechanismTests): """Test case for the GSSAPI SASL mechanism.""" - mechanism = 'GSSAPI' - service = 'GSSAPI' - sasl_kwargs = {'service': service} + mechanism = "GSSAPI" + service = "GSSAPI" + sasl_kwargs = {"service": service} - @patch('puresasl.mechanisms.kerberos.authGSSClientWrap') - @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + @patch("puresasl.mechanisms.kerberos.authGSSClientWrap") + @patch("puresasl.mechanisms.kerberos.authGSSClientUnwrap") def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args): # bypassing step setup by setting qop directly self.mechanism_class.qop = QOP.AUTH - msg = b'msg' + msg = b"msg" self.assertEqual(self.sasl.decode(msg), (True, msg)) self.assertEqual(self.sasl.encode(msg), (True, msg)) - # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication + # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication for qop in (QOP.AUTH_INT, QOP.AUTH_CONF): self.mechanism_class.qop = qop - with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1): - self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) - self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + with patch( + "puresasl.mechanisms.kerberos.authGSSClientResponseConf", return_value=1 + ): + self.assertEqual( + self.sasl.decode(msg), + (True, base64.b64decode(authGSSClientResponse.return_value)), + ) + self.assertEqual( + self.sasl.encode(msg), + (True, base64.b64decode(authGSSClientResponse.return_value)), + ) if qop == QOP.AUTH_CONF: - with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0): + with patch( + "puresasl.mechanisms.kerberos.authGSSClientResponseConf", + return_value=0, + ): self.assertEqual(self.sasl.encode(msg), (False, None)) - self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.') + self.assertEqual( + self.sasl.getError(), + "Error: confidentiality requested, but not honored by the server.", + ) def test_step_no_user(self, authGSSClientResponse, *args): - msg = six.b('whatever') + msg = six.b("whatever") # no user - self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) - with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''): - self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + self.assertEqual( + self.sasl.step(msg), + (True, base64.b64decode(authGSSClientResponse.return_value)), + ) + with patch( + "puresasl.mechanisms.kerberos.authGSSClientResponse", return_value="" + ): + self.assertEqual(self.sasl.step(msg), (True, six.b(""))) - username = 'username' + username = "username" # with user; this has to be last because it sets mechanism.user - with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE): - with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)): - self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + with patch( + "puresasl.mechanisms.kerberos.authGSSClientStep", + return_value=kerberos.AUTH_GSS_COMPLETE, + ): + with patch( + "puresasl.mechanisms.kerberos.authGSSClientUserName", + return_value=six.b(username), + ): + self.assertEqual(self.sasl.step(msg), (True, six.b(""))) self.assertEqual(self.mechanism_class.user, six.b(username)) - @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + @patch("puresasl.mechanisms.kerberos.authGSSClientUnwrap") def test_step_qop(self, *args): self.mechanism_class._have_negotiated_details = True - self.mechanism_class.user = 'user' - msg = six.b('msg') + self.mechanism_class.user = "user" + msg = six.b("msg") self.assertEqual(self.sasl.step(msg), (False, None)) - self.assertEqual(self.sasl.getError(), 'Bad response from server') + self.assertEqual(self.sasl.getError(), "Bad response from server") max_len = 100 self.assertLess(max_len, self.sasl.max_buffer) for i, qop in QOP.bit_map.items(): - qop_size = struct.pack('!i', i << 24 | max_len) + qop_size = struct.pack("!i", i << 24 | max_len) response = base64.b64encode(qop_size) - with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response): - with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap: + with patch( + "puresasl.mechanisms.kerberos.authGSSClientResponse", + return_value=response, + ): + with patch( + "puresasl.mechanisms.kerberos.authGSSClientWrap" + ) as authGSSClientWrap: self.mechanism_class.complete = False self.assertEqual(self.sasl.step(msg), (True, qop_size)) self.assertTrue(self.mechanism_class.complete) @@ -256,16 +310,16 @@ def test_step_qop(self, *args): class CramMD5MechanismTest(_BaseMechanismTests): """Test case for the CRAM-MD5 SASL mechanism.""" - mechanism = 'CRAM-MD5' - username = 'user' - password = 'pass' - sasl_kwargs = {'username': username, 'password': password} + mechanism = "CRAM-MD5" + username = "user" + password = "pass" + sasl_kwargs = {"username": username, "password": password} def test_step(self): success, response = self.sasl.step(None) self.assertTrue(success) self.assertIsNone(response) - challenge = six.b('msg') + challenge = six.b("msg") hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5) hash.update(challenge) success, response = self.sasl.step(challenge) @@ -276,7 +330,7 @@ def test_step(self): self.assertTrue(self.sasl.complete) def test_decode_encode(self): - msg = 'msg' + msg = "msg" self.assertEqual(self.sasl.decode(msg), (True, msg)) self.assertEqual(self.sasl.encode(msg), (True, msg)) @@ -284,13 +338,13 @@ def test_decode_encode(self): class DigestMD5MechanismTest(_BaseMechanismTests): """Test case for the DIGEST-MD5 SASL mechanism.""" - mechanism = 'DIGEST-MD5' - username = 'user' - password = 'pass' - sasl_kwargs = {'username': username, 'password': password} + mechanism = "DIGEST-MD5" + username = "user" + password = "pass" + sasl_kwargs = {"username": username, "password": password} def test_decode_encode(self): - msg = 'msg' + msg = "msg" self.assertEqual(self.sasl.decode(msg), (True, msg)) self.assertEqual(self.sasl.encode(msg), (True, msg)) @@ -310,24 +364,24 @@ def test_step(self): def test_step_server_answer(self): """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism.""" - sasl_kwargs = {'username': "chris", 'password': "secret"} - sasl = PureSASLClient('elwood.innosoft.com', - service="imap", - mechanism=self.mechanism, - mutual_auth=True, - **sasl_kwargs) + sasl_kwargs = {"username": "chris", "password": "secret"} + sasl = PureSASLClient( + "elwood.innosoft.com", + service="imap", + mechanism=self.mechanism, + mutual_auth=True, + **sasl_kwargs, + ) testChallenge = ( b'utf-8,username="chris",realm="elwood.innosoft.com",' b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",' b'digest-uri="imap/elwood.innosoft.com",' - b'response=d388dad90d4bbd760a152321f2143af7,qop=auth' + b"response=d388dad90d4bbd760a152321f2143af7,qop=auth" ) sasl.step(testChallenge) sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk" - serverResponse = ( - b'rspauth=ea40f60335c427b5527b84dbabcdfffd' - ) + serverResponse = b"rspauth=ea40f60335c427b5527b84dbabcdfffd" sasl.step(serverResponse) - # assert that step choses the only supported QOP for for DIGEST-MD5 + # assert that step choses the only supported QOP for for DIGEST-MD5 self.assertEqual(self.sasl.qop, QOP.AUTH) diff --git a/python/pyhive/tests/test_sqlalchemy_hive.py b/python/pyhive/tests/test_sqlalchemy_hive.py index 790bec4c3be..3be0dae694b 100644 --- a/python/pyhive/tests/test_sqlalchemy_hive.py +++ b/python/pyhive/tests/test_sqlalchemy_hive.py @@ -1,26 +1,25 @@ -from __future__ import absolute_import -from __future__ import unicode_literals -from builtins import str -from pyhive.sqlalchemy_hive import HiveDate -from pyhive.sqlalchemy_hive import HiveDecimal -from pyhive.sqlalchemy_hive import HiveTimestamp -from sqlalchemy.exc import NoSuchTableError, OperationalError -from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase -from pyhive.tests.sqlalchemy_test_case import with_engine_connection -from sqlalchemy import types -from sqlalchemy.engine import create_engine -from sqlalchemy.schema import Column -from sqlalchemy.schema import MetaData -from sqlalchemy.schema import Table -from sqlalchemy.sql import text +from __future__ import absolute_import, unicode_literals + import contextlib import datetime import decimal -import sqlalchemy.types -import unittest import re +import unittest +from builtins import str -sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) +import sqlalchemy.types +from sqlalchemy import types +from sqlalchemy.engine import create_engine +from sqlalchemy.exc import NoSuchTableError, OperationalError +from sqlalchemy.schema import Column, MetaData, Table +from sqlalchemy.sql import text + +from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp +from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase, with_engine_connection + +sqlalchemy_version = float( + re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1) +) _ONE_ROW_COMPLEX_CONTENTS = [ True, @@ -30,14 +29,14 @@ 9223372036854775807, 0.5, 0.25, - 'a string', + "a string", datetime.datetime(1970, 1, 1), - b'123', - '[1,2]', - '{1:2,3:4}', + b"123", + "[1,2]", + "{1:2,3:4}", '{"a":1,"b":2}', - '{0:1}', - decimal.Decimal('0.1'), + "{0:1}", + decimal.Decimal("0.1"), ] @@ -62,53 +61,62 @@ class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase): def create_engine(self): - return create_engine('hive://localhost:10000/default') + return create_engine("hive://localhost:10000/default") @with_engine_connection def test_dotted_column_names(self, engine, connection): """When Hive returns a dotted column name, both the non-dotted version should be available as an attribute, and the dotted version should remain available as a key. """ - row = connection.execute(text('SELECT * FROM one_row')).fetchone() + row = connection.execute(text("SELECT * FROM one_row")).fetchone() if sqlalchemy_version >= 1.4: row = row._mapping - assert row.keys() == ['number_of_rows'] - assert 'number_of_rows' in row + assert row.keys() == ["number_of_rows"] + assert "number_of_rows" in row assert row.number_of_rows == 1 - assert row['number_of_rows'] == 1 - assert getattr(row, 'one_row.number_of_rows') == 1 - assert row['one_row.number_of_rows'] == 1 + assert row["number_of_rows"] == 1 + assert getattr(row, "one_row.number_of_rows") == 1 + assert row["one_row.number_of_rows"] == 1 @with_engine_connection def test_dotted_column_names_raw(self, engine, connection): - """When Hive returns a dotted column name, and raw mode is on, nothing should be modified. - """ - row = connection.execution_options(hive_raw_colnames=True).execute(text('SELECT * FROM one_row')).fetchone() - + """When Hive returns a dotted column name, and raw mode is on, nothing should be modified.""" + row = ( + connection.execution_options(hive_raw_colnames=True) + .execute(text("SELECT * FROM one_row")) + .fetchone() + ) + if sqlalchemy_version >= 1.4: row = row._mapping - assert row.keys() == ['one_row.number_of_rows'] - assert 'number_of_rows' not in row - assert getattr(row, 'one_row.number_of_rows') == 1 - assert row['one_row.number_of_rows'] == 1 + assert row.keys() == ["one_row.number_of_rows"] + assert "number_of_rows" not in row + assert getattr(row, "one_row.number_of_rows") == 1 + assert row["one_row.number_of_rows"] == 1 @with_engine_connection def test_reflect_no_such_table(self, engine, connection): """reflecttable should throw an exception on an invalid table""" self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + lambda: Table("this_does_not_exist", MetaData(), autoload_with=engine), + ) self.assertRaises( OperationalError, - lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) + lambda: Table( + "this_does_not_exist", + MetaData(schema="also_does_not_exist"), + autoload_with=engine, + ), + ) @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + one_row_complex = Table("one_row_complex", MetaData(), autoload_with=engine) self.assertEqual(len(one_row_complex.c), 15) self.assertIsInstance(one_row_complex.c.string, Column) row = connection.execute(one_row_complex.select()).fetchone() @@ -134,32 +142,30 @@ def test_reflect_select(self, engine, connection): @with_engine_connection def test_type_map(self, engine, connection): """sqlalchemy should use the dbapi_type_map to infer types from raw queries""" - row = connection.execute(text('SELECT * FROM one_row_complex')).fetchone() + row = connection.execute(text("SELECT * FROM one_row_complex")).fetchone() self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) @with_engine_connection def test_reserved_words(self, engine, connection): """Hive uses backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(), Column('map', sqlalchemy.types.String)) - query = str(fake_table.select().where(fake_table.c.map == 'a').compile(engine)) - self.assertIn('`select`', query) - self.assertIn('`map`', query) + fake_table = Table("select", MetaData(), Column("map", sqlalchemy.types.String)) + query = str(fake_table.select().where(fake_table.c.map == "a").compile(engine)) + self.assertIn("`select`", query) + self.assertIn("`map`", query) self.assertNotIn('"select"', query) self.assertNotIn('"map"', query) def test_switch_database(self): - engine = create_engine('hive://localhost:10000/pyhive_test_database') + engine = create_engine("hive://localhost:10000/pyhive_test_database") try: with contextlib.closing(engine.connect()) as connection: self.assertIn( - ('dummy_table',), - connection.execute(text('SHOW TABLES')).fetchall() + ("dummy_table",), connection.execute(text("SHOW TABLES")).fetchall() ) - connection.execute(text('USE default')) + connection.execute(text("USE default")) self.assertIn( - ('one_row',), - connection.execute(text('SHOW TABLES')).fetchall() + ("one_row",), connection.execute(text("SHOW TABLES")).fetchall() ) finally: engine.dispose() @@ -169,26 +175,54 @@ def test_lots_of_types(self, engine, connection): # Presto doesn't have raw CREATE TABLE support, so we ony test hive # take type list from sqlalchemy.types types = [ - 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT', - 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', - 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME', - 'String', 'Integer', 'SmallInteger', - 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', - 'Boolean', 'Unicode', 'UnicodeText', + "INT", + "CHAR", + "VARCHAR", + "NCHAR", + "TEXT", + "Text", + "FLOAT", + "NUMERIC", + "DECIMAL", + "TIMESTAMP", + "DATETIME", + "CLOB", + "BLOB", + "BOOLEAN", + "SMALLINT", + "DATE", + "TIME", + "String", + "Integer", + "SmallInteger", + "Numeric", + "Float", + "DateTime", + "Date", + "Time", + "LargeBinary", + "Boolean", + "Unicode", + "UnicodeText", ] cols = [] for i, t in enumerate(types): cols.append(Column(str(i), getattr(sqlalchemy.types, t))) - cols.append(Column('hive_date', HiveDate)) - cols.append(Column('hive_decimal', HiveDecimal)) - cols.append(Column('hive_timestamp', HiveTimestamp)) - table = Table('test_table', MetaData(schema='pyhive_test_database'), *cols,) + cols.append(Column("hive_date", HiveDate)) + cols.append(Column("hive_decimal", HiveDecimal)) + cols.append(Column("hive_timestamp", HiveTimestamp)) + table = Table( + "test_table", + MetaData(schema="pyhive_test_database"), + *cols, + ) table.drop(checkfirst=True, bind=connection) table.create(bind=connection) - connection.execute(text('SET mapred.job.tracker=local')) - connection.execute(text('USE pyhive_test_database')) - big_number = 10 ** 10 - 1 - connection.execute(text(""" + connection.execute(text("SET mapred.job.tracker=local")) + connection.execute(text("USE pyhive_test_database")) + big_number = 10**10 - 1 + connection.execute( + text(""" INSERT OVERWRITE TABLE test_table SELECT 1, "a", "a", "a", "a", "a", 0.1, @@ -199,35 +233,45 @@ def test_lots_of_types(self, engine, connection): false, "a", "a", 0, :big_number, 123 + 2000 FROM default.one_row - """), {"big_number": big_number}) + """), + {"big_number": big_number}, + ) row = connection.execute(text("select * from test_table")).fetchone() self.assertEqual(row.hive_date, datetime.datetime(1970, 1, 1, 0, 0)) self.assertEqual(row.hive_decimal, decimal.Decimal(big_number)) - self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000)) + self.assertEqual( + row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000) + ) table.drop(bind=connection) @with_engine_connection def test_insert_select(self, engine, connection): - one_row = Table('one_row', MetaData(), autoload_with=engine) - table = Table('insert_test', MetaData(schema='pyhive_test_database'), - Column('a', sqlalchemy.types.Integer)) + one_row = Table("one_row", MetaData(), autoload_with=engine) + table = Table( + "insert_test", + MetaData(schema="pyhive_test_database"), + Column("a", sqlalchemy.types.Integer), + ) table.drop(checkfirst=True, bind=connection) table.create(bind=connection) - connection.execute(text('SET mapred.job.tracker=local')) + connection.execute(text("SET mapred.job.tracker=local")) # NOTE(jing) I'm stuck on a version of Hive without INSERT ... VALUES - connection.execute(table.insert().from_select(['a'], one_row.select())) - + connection.execute(table.insert().from_select(["a"], one_row.select())) + result = connection.execute(table.select()).fetchall() expected = [(1,)] self.assertEqual(result, expected) @with_engine_connection def test_insert_values(self, engine, connection): - table = Table('insert_test', MetaData(schema='pyhive_test_database'), - Column('a', sqlalchemy.types.Integer),) + table = Table( + "insert_test", + MetaData(schema="pyhive_test_database"), + Column("a", sqlalchemy.types.Integer), + ) table.drop(checkfirst=True, bind=connection) table.create(bind=connection) - connection.execute(table.insert().values([{'a': 1}, {'a': 2}])) + connection.execute(table.insert().values([{"a": 1}, {"a": 2}])) result = connection.execute(table.select()).fetchall() expected = [(1,), (2,)] diff --git a/python/pyhive/tests/test_sqlalchemy_presto.py b/python/pyhive/tests/test_sqlalchemy_presto.py index 58a5c034ed8..81f910a5c7b 100644 --- a/python/pyhive/tests/test_sqlalchemy_presto.py +++ b/python/pyhive/tests/test_sqlalchemy_presto.py @@ -1,57 +1,63 @@ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals + +import contextlib +import unittest from builtins import str -from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase -from pyhive.tests.sqlalchemy_test_case import with_engine_connection +from decimal import Decimal + from sqlalchemy import types from sqlalchemy.engine import create_engine -from sqlalchemy.schema import Column -from sqlalchemy.schema import MetaData -from sqlalchemy.schema import Table +from sqlalchemy.schema import Column, MetaData, Table from sqlalchemy.sql import text from sqlalchemy.types import String -from decimal import Decimal -import contextlib -import unittest +from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase, with_engine_connection class TestSqlAlchemyPresto(unittest.TestCase, SqlAlchemyTestCase): def create_engine(self): - return create_engine('presto://localhost:8080/hive/default?source={}'.format(self.id())) + return create_engine( + "presto://localhost:8080/hive/default?source={}".format(self.id()) + ) def test_bad_format(self): self.assertRaises( ValueError, - lambda: create_engine('presto://localhost:8080/hive/default/what'), + lambda: create_engine("presto://localhost:8080/hive/default/what"), ) @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + one_row_complex = Table("one_row_complex", MetaData(), autoload_with=engine) # Presto ignores the union column self.assertEqual(len(one_row_complex.c), 15 - 1) self.assertIsInstance(one_row_complex.c.string, Column) rows = connection.execute(one_row_complex.select()).fetchall() self.assertEqual(len(rows), 1) - self.assertEqual(list(rows[0]), [ - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - '1970-01-01 00:00:00.000', - b'123', - [1, 2], - {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON - [1, 2], # struct is returned as a list of elements - # '{0:1}', - Decimal('0.1'), - ]) + self.assertEqual( + list(rows[0]), + [ + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "1970-01-01 00:00:00.000", + b"123", + [1, 2], + { + "1": 2, + "3": 4, + }, # Presto converts all keys to strings so that they're valid JSON + [1, 2], # struct is returned as a list of elements + # '{0:1}', + Decimal("0.1"), + ], + ) # TODO some of these types could be filled in better self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) @@ -70,10 +76,15 @@ def test_reflect_select(self, engine, connection): self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType) def test_url_default(self): - engine = create_engine('presto://localhost:8080/hive') + engine = create_engine("presto://localhost:8080/hive") try: with contextlib.closing(engine.connect()) as connection: - self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) + self.assertEqual( + connection.execute( + text("SELECT 1 AS foobar FROM one_row") + ).scalar(), + 1, + ) finally: engine.dispose() @@ -81,9 +92,13 @@ def test_url_default(self): def test_reserved_words(self, engine, connection): """Presto uses double quotes, not backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(), Column('current_timestamp', String)) - query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) + fake_table = Table("select", MetaData(), Column("current_timestamp", String)) + query = str( + fake_table.select() + .where(fake_table.c.current_timestamp == "a") + .compile(engine) + ) self.assertIn('"select"', query) self.assertIn('"current_timestamp"', query) - self.assertNotIn('`select`', query) - self.assertNotIn('`current_timestamp`', query) + self.assertNotIn("`select`", query) + self.assertNotIn("`current_timestamp`", query) diff --git a/python/pyhive/tests/test_sqlalchemy_trino.py b/python/pyhive/tests/test_sqlalchemy_trino.py index c929f94143a..5a68d50f581 100644 --- a/python/pyhive/tests/test_sqlalchemy_trino.py +++ b/python/pyhive/tests/test_sqlalchemy_trino.py @@ -1,52 +1,57 @@ +import contextlib +import unittest +from decimal import Decimal + +from sqlalchemy import types from sqlalchemy.engine import create_engine -from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase -from pyhive.tests.sqlalchemy_test_case import with_engine_connection -from sqlalchemy.exc import NoSuchTableError, DatabaseError -from sqlalchemy.schema import MetaData, Table, Column -from sqlalchemy.types import String +from sqlalchemy.exc import DatabaseError, NoSuchTableError +from sqlalchemy.schema import Column, MetaData, Table from sqlalchemy.sql import text -from sqlalchemy import types -from decimal import Decimal +from sqlalchemy.types import String -import unittest -import contextlib +from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase, with_engine_connection class TestSqlAlchemyTrino(unittest.TestCase, SqlAlchemyTestCase): def create_engine(self): - return create_engine('trino+pyhive://localhost:18080/hive/default?source={}'.format(self.id())) - + return create_engine( + "trino+pyhive://localhost:18080/hive/default?source={}".format(self.id()) + ) + def test_bad_format(self): self.assertRaises( ValueError, - lambda: create_engine('trino+pyhive://localhost:18080/hive/default/what'), + lambda: create_engine("trino+pyhive://localhost:18080/hive/default/what"), ) @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + one_row_complex = Table("one_row_complex", MetaData(), autoload_with=engine) # Presto ignores the union column self.assertEqual(len(one_row_complex.c), 15 - 1) self.assertIsInstance(one_row_complex.c.string, Column) rows = connection.execute(one_row_complex.select()).fetchall() self.assertEqual(len(rows), 1) - self.assertEqual(list(rows[0]), [ - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - '1970-01-01 00:00:00.000', - b'123', - [1, 2], - {"1": 2, "3": 4}, - [1, 2], - Decimal('0.1'), - ]) + self.assertEqual( + list(rows[0]), + [ + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "1970-01-01 00:00:00.000", + b"123", + [1, 2], + {"1": 2, "3": 4}, + [1, 2], + Decimal("0.1"), + ], + ) self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer) @@ -62,22 +67,33 @@ def test_reflect_select(self, engine, connection): self.assertIsInstance(one_row_complex.c.map.type, types.NullType) self.assertIsInstance(one_row_complex.c.struct.type, types.NullType) self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType) - + @with_engine_connection def test_reflect_no_such_table(self, engine, connection): """reflecttable should throw an exception on an invalid table""" self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + lambda: Table("this_does_not_exist", MetaData(), autoload_with=engine), + ) self.assertRaises( DatabaseError, - lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) + lambda: Table( + "this_does_not_exist", + MetaData(schema="also_does_not_exist"), + autoload_with=engine, + ), + ) def test_url_default(self): - engine = create_engine('trino+pyhive://localhost:18080/hive') + engine = create_engine("trino+pyhive://localhost:18080/hive") try: with contextlib.closing(engine.connect()) as connection: - self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) + self.assertEqual( + connection.execute( + text("SELECT 1 AS foobar FROM one_row") + ).scalar(), + 1, + ) finally: engine.dispose() @@ -85,9 +101,13 @@ def test_url_default(self): def test_reserved_words(self, engine, connection): """Trino uses double quotes, not backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(), Column('current_timestamp', String)) - query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) + fake_table = Table("select", MetaData(), Column("current_timestamp", String)) + query = str( + fake_table.select() + .where(fake_table.c.current_timestamp == "a") + .compile(engine) + ) self.assertIn('"select"', query) self.assertIn('"current_timestamp"', query) - self.assertNotIn('`select`', query) - self.assertNotIn('`current_timestamp`', query) + self.assertNotIn("`select`", query) + self.assertNotIn("`current_timestamp`", query) diff --git a/python/pyhive/tests/test_trino.py b/python/pyhive/tests/test_trino.py index 41bb489b649..03e2561544a 100644 --- a/python/pyhive/tests/test_trino.py +++ b/python/pyhive/tests/test_trino.py @@ -4,26 +4,17 @@ They also require a tables created by make_test_tables.sh. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals -import contextlib -import os +import datetime from decimal import Decimal -import requests - -from pyhive import exc from pyhive import trino -from pyhive.tests.dbapi_test_case import DBAPITestCase from pyhive.tests.dbapi_test_case import with_cursor from pyhive.tests.test_presto import TestPresto -import mock -import unittest -import datetime -_HOST = 'localhost' -_PORT = '18080' +_HOST = "localhost" +_PORT = "18080" class TestTrino(TestPresto): @@ -33,66 +24,82 @@ def connect(self): return trino.connect(host=_HOST, port=_PORT, source=self.id()) def test_bad_protocol(self): - self.assertRaisesRegexp(ValueError, 'Protocol must be', - lambda: trino.connect('localhost', protocol='nonsense').cursor()) + self.assertRaisesRegexp( + ValueError, + "Protocol must be", + lambda: trino.connect("localhost", protocol="nonsense").cursor(), + ) def test_escape_args(self): escaper = trino.TrinoParamEscaper() - self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), - ("date '2020-04-17'",)) - self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), - ("timestamp '2020-04-17 12:00:00.123'",)) + self.assertEqual( + escaper.escape_args((datetime.date(2020, 4, 17),)), ("date '2020-04-17'",) + ) + self.assertEqual( + escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), + ("timestamp '2020-04-17 12:00:00.123'",), + ) @with_cursor def test_description(self, cursor): - cursor.execute('SELECT 1 AS foobar FROM one_row') - self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)]) + cursor.execute("SELECT 1 AS foobar FROM one_row") + self.assertEqual( + cursor.description, [("foobar", "integer", None, None, None, None, True)] + ) self.assertIsNotNone(cursor.last_query_id) @with_cursor def test_complex(self, cursor): - cursor.execute('SELECT * FROM one_row_complex') + cursor.execute("SELECT * FROM one_row_complex") # TODO Trino drops the union field - tinyint_type = 'tinyint' - smallint_type = 'smallint' - float_type = 'real' - self.assertEqual(cursor.description, [ - ('boolean', 'boolean', None, None, None, None, True), - ('tinyint', tinyint_type, None, None, None, None, True), - ('smallint', smallint_type, None, None, None, None, True), - ('int', 'integer', None, None, None, None, True), - ('bigint', 'bigint', None, None, None, None, True), - ('float', float_type, None, None, None, None, True), - ('double', 'double', None, None, None, None, True), - ('string', 'varchar', None, None, None, None, True), - ('timestamp', 'timestamp', None, None, None, None, True), - ('binary', 'varbinary', None, None, None, None, True), - ('array', 'array(integer)', None, None, None, None, True), - ('map', 'map(integer,integer)', None, None, None, None, True), - ('struct', 'row(a integer,b integer)', None, None, None, None, True), - # ('union', 'varchar', None, None, None, None, True), - ('decimal', 'decimal(10,1)', None, None, None, None, True), - ]) + tinyint_type = "tinyint" + smallint_type = "smallint" + float_type = "real" + self.assertEqual( + cursor.description, + [ + ("boolean", "boolean", None, None, None, None, True), + ("tinyint", tinyint_type, None, None, None, None, True), + ("smallint", smallint_type, None, None, None, None, True), + ("int", "integer", None, None, None, None, True), + ("bigint", "bigint", None, None, None, None, True), + ("float", float_type, None, None, None, None, True), + ("double", "double", None, None, None, None, True), + ("string", "varchar", None, None, None, None, True), + ("timestamp", "timestamp", None, None, None, None, True), + ("binary", "varbinary", None, None, None, None, True), + ("array", "array(integer)", None, None, None, None, True), + ("map", "map(integer,integer)", None, None, None, None, True), + ("struct", "row(a integer,b integer)", None, None, None, None, True), + # ('union', 'varchar', None, None, None, None, True), + ("decimal", "decimal(10,1)", None, None, None, None, True), + ], + ) rows = cursor.fetchall() - expected = [( - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - '1970-01-01 00:00:00.000', - b'123', - [1, 2], - {"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON - [1, 2], # struct is returned as a list of elements - # '{0:1}', - Decimal('0.1'), - )] + expected = [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "1970-01-01 00:00:00.000", + b"123", + [1, 2], + { + "1": 2, + "3": 4, + }, # Trino converts all keys to strings so that they're valid JSON + [1, 2], # struct is returned as a list of elements + # '{0:1}', + Decimal("0.1"), + ) + ] self.assertEqual(rows, expected) # catch unicode/str - self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) \ No newline at end of file + self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) diff --git a/python/pyhive/trino.py b/python/pyhive/trino.py index 658457a3c58..4490aedbccf 100644 --- a/python/pyhive/trino.py +++ b/python/pyhive/trino.py @@ -5,8 +5,7 @@ Many docstrings in this file are based on the PEP, which is in the public domain. """ -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging @@ -14,8 +13,19 @@ # Make all exceptions visible in this module per DB-API from pyhive.common import DBAPITypeObject -from pyhive.exc import * # noqa -from pyhive.presto import Connection as PrestoConnection, Cursor as PrestoCursor, PrestoParamEscaper +from pyhive.exc import ( + DatabaseError, + OperationalError, +) +from pyhive.presto import ( + Connection as PrestoConnection, +) +from pyhive.presto import ( + Cursor as PrestoCursor, +) +from pyhive.presto import ( + PrestoParamEscaper, +) try: # Python 3 import urllib.parse as urlparse @@ -23,9 +33,9 @@ import urlparse # PEP 249 module globals -apilevel = '2.0' +apilevel = "2.0" threadsafety = 2 # Threads may share the module and connections. -paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s +paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s _logger = logging.getLogger(__name__) @@ -69,15 +79,15 @@ def execute(self, operation, parameters=None): Return values are not defined. """ headers = { - 'X-Trino-Catalog': self._catalog, - 'X-Trino-Schema': self._schema, - 'X-Trino-Source': self._source, - 'X-Trino-User': self._username, + "X-Trino-Catalog": self._catalog, + "X-Trino-Schema": self._schema, + "X-Trino-Source": self._source, + "X-Trino-User": self._username, } if self._session_props: - headers['X-Trino-Session'] = ','.join( - '{}={}'.format(propname, propval) + headers["X-Trino-Session"] = ",".join( + "{}={}".format(propname, propval) for propname, propval in self._session_props.items() ) @@ -90,13 +100,21 @@ def execute(self, operation, parameters=None): self._reset_state() self._state = self._STATE_RUNNING - url = urlparse.urlunparse(( - self._protocol, - '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) - _logger.info('%s', sql) + url = urlparse.urlunparse( + ( + self._protocol, + "{}:{}".format(self._host, self._port), + "/v1/statement", + None, + None, + None, + ) + ) + _logger.info("%s", sql) _logger.debug("Headers: %s", headers) response = self._requests_session.post( - url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) + url, data=sql.encode("utf-8"), headers=headers, **self._requests_kwargs + ) self._process_response(response) def _process_response(self, response): @@ -110,26 +128,28 @@ def _process_response(self, response): response_json = response.json() _logger.debug("Got response %s", response_json) - assert self._state == self._STATE_RUNNING, "Should be running if processing response" - self._nextUri = response_json.get('nextUri') - self._columns = response_json.get('columns') - if 'id' in response_json: - self.last_query_id = response_json['id'] - if 'X-Trino-Clear-Session' in response.headers: - propname = response.headers['X-Trino-Clear-Session'] + assert ( + self._state == self._STATE_RUNNING + ), "Should be running if processing response" + self._nextUri = response_json.get("nextUri") + self._columns = response_json.get("columns") + if "id" in response_json: + self.last_query_id = response_json["id"] + if "X-Trino-Clear-Session" in response.headers: + propname = response.headers["X-Trino-Clear-Session"] self._session_props.pop(propname, None) - if 'X-Trino-Set-Session' in response.headers: - propname, propval = response.headers['X-Trino-Set-Session'].split('=', 1) + if "X-Trino-Set-Session" in response.headers: + propname, propval = response.headers["X-Trino-Set-Session"].split("=", 1) self._session_props[propname] = propval - if 'data' in response_json: + if "data" in response_json: assert self._columns - new_data = response_json['data'] + new_data = response_json["data"] self._process_data(new_data) self._data += map(tuple, new_data) - if 'nextUri' not in response_json: + if "nextUri" not in response_json: self._state = self._STATE_FINISHED - if 'error' in response_json: - raise DatabaseError(response_json['error']) + if "error" in response_json: + raise DatabaseError(response_json["error"]) # @@ -138,7 +158,7 @@ def _process_response(self, response): # See types in trino-main/src/main/java/com/facebook/trino/tuple/TupleInfo.java -FIXED_INT_64 = DBAPITypeObject(['bigint']) -VARIABLE_BINARY = DBAPITypeObject(['varchar']) -DOUBLE = DBAPITypeObject(['double']) -BOOLEAN = DBAPITypeObject(['boolean']) +FIXED_INT_64 = DBAPITypeObject(["bigint"]) +VARIABLE_BINARY = DBAPITypeObject(["varchar"]) +DOUBLE = DBAPITypeObject(["double"]) +BOOLEAN = DBAPITypeObject(["boolean"]) diff --git a/python/setup.py b/python/setup.py index d141ea1b377..3d14ae3ad1f 100755 --- a/python/setup.py +++ b/python/setup.py @@ -1,9 +1,11 @@ #!/usr/bin/env python +import sys + from setuptools import setup from setuptools.command.test import test as TestCommand + import pyhive -import sys class PyTest(TestCommand): @@ -15,11 +17,12 @@ def finalize_options(self): def run_tests(self): # import here, cause outside the eggs aren't loaded import pytest + errno = pytest.main(self.test_args) sys.exit(errno) -with open('README.rst') as readme: +with open("README.rst") as readme: long_description = readme.read() setup( @@ -27,11 +30,11 @@ def run_tests(self): version=pyhive.__version__, description="Python interface to Hive", long_description=long_description, - url='https://github.com/dropbox/PyHive', - author="Jing Wang", - author_email="jing@dropbox.com", + url="https://github.com/apache/kyuubi", + author="", + author_email="", license="Apache License, Version 2.0", - packages=['pyhive', 'TCLIService'], + packages=["pyhive", "TCLIService"], classifiers=[ "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", @@ -39,40 +42,42 @@ def run_tests(self): "Topic :: Database :: Front-Ends", ], install_requires=[ - 'future', - 'python-dateutil', + "future", + "python-dateutil", ], extras_require={ - 'presto': ['requests>=1.0.0'], - 'trino': ['requests>=1.0.0'], - 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], - 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], - 'sqlalchemy': ['sqlalchemy>=1.3.0'], - 'kerberos': ['requests_kerberos>=0.12.0'], + "presto": ["requests>=1.0.0"], + "trino": ["requests>=1.0.0"], + "hive": ["sasl>=0.2.1", "thrift>=0.10.0", "thrift_sasl>=0.1.0"], + "hive-pure-sasl": ["pure-sasl>=0.6.2", "thrift>=0.10.0", "thrift_sasl>=0.1.0"], + "sqlalchemy": ["sqlalchemy>=1.3.0"], + "kerberos": ["requests_kerberos>=0.12.0"], }, tests_require=[ - 'mock>=1.0.0', - 'pytest', - 'pytest-cov', - 'requests>=1.0.0', - 'requests_kerberos>=0.12.0', - 'sasl>=0.2.1', - 'pure-sasl>=0.6.2', - 'kerberos>=1.3.0', - 'sqlalchemy>=1.3.0', - 'thrift>=0.10.0', + "ruff>=0.3.0", + "mock>=1.0.0", + "pytest", + "pytest-cov", + "requests>=1.0.0", + "requests_kerberos>=0.12.0", + "sasl>=0.2.1", + "pure-sasl>=0.6.2", + "kerberos>=1.3.0", + "sqlalchemy>=1.3.0", + "thrift>=0.10.0", ], - cmdclass={'test': PyTest}, + cmdclass={"test": PyTest}, package_data={ - '': ['*.rst'], + "": ["*.md"], }, entry_points={ - 'sqlalchemy.dialects': [ - 'hive = pyhive.sqlalchemy_hive:HiveDialect', + "sqlalchemy.dialects": [ + "hive = pyhive.sqlalchemy_hive:HiveDialect", "hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect", "hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect", - 'presto = pyhive.sqlalchemy_presto:PrestoDialect', - 'trino.pyhive = pyhive.sqlalchemy_trino:TrinoDialect', + "presto = pyhive.sqlalchemy_presto:PrestoDialect", + "trino.pyhive = pyhive.sqlalchemy_trino:TrinoDialect", + "sparksql = pyhive.sqlalchemy_sparksql:SparkSqlDialect", ], - } + }, )