From 3644a979113ef2f8b6035c43407827e26c3832c9 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Wed, 21 Apr 2021 16:33:00 +0100 Subject: [PATCH 01/17] feat: add HTTP and HTTPS to hive (#385) * feat: add https protocol * support HTTP --- pyhive/hive.py | 79 +++++++++++++++++++++++++++++++++++++-- pyhive/sqlalchemy_hive.py | 26 +++++++++++++ setup.py | 2 + 3 files changed, 104 insertions(+), 3 deletions(-) diff --git a/pyhive/hive.py b/pyhive/hive.py index a8635bac..66569406 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -8,9 +8,12 @@ from __future__ import absolute_import from __future__ import unicode_literals +import base64 import datetime import re from decimal import Decimal +from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context + from TCLIService import TCLIService from TCLIService import constants @@ -25,6 +28,7 @@ import getpass import logging import sys +import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket import thrift.transport.TTransport @@ -38,6 +42,12 @@ _TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)') +ssl_cert_parameter_map = { + "none": CERT_NONE, + "optional": CERT_OPTIONAL, + "required": CERT_REQUIRED, +} + def _parse_timestamp(value): if value: @@ -97,9 +107,21 @@ def connect(*args, **kwargs): class Connection(object): """Wraps a Thrift session""" - def __init__(self, host=None, port=None, username=None, database='default', auth=None, - configuration=None, kerberos_service_name=None, password=None, - thrift_transport=None): + def __init__( + self, + host=None, + port=None, + scheme=None, + username=None, + database='default', + auth=None, + configuration=None, + kerberos_service_name=None, + password=None, + check_hostname=None, + ssl_cert=None, + thrift_transport=None + ): """Connect to HiveServer2 :param host: What host HiveServer2 runs on @@ -116,6 +138,32 @@ def __init__(self, host=None, port=None, username=None, database='default', auth https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62 /impala/_thrift_api.py#L152-L160 """ + if scheme in ("https", "http") and thrift_transport is None: + ssl_context = None + if scheme == "https": + ssl_context = create_default_context() + ssl_context.check_hostname = check_hostname == "true" + ssl_cert = ssl_cert or "none" + ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE) + thrift_transport = thrift.transport.THttpClient.THttpClient( + uri_or_host=f"{scheme}://{host}:{port}/cliservice/", + ssl_context=ssl_context, + ) + + if auth in ("BASIC", "NOSASL", "NONE", None): + # Always needs the Authorization header + self._set_authorization_header(thrift_transport, username, password) + elif auth == "KERBEROS" and kerberos_service_name: + self._set_kerberos_header(thrift_transport, kerberos_service_name, host) + else: + raise ValueError( + "Authentication is not valid use one of:" + "BASIC, NOSASL, KERBEROS, NONE" + ) + host, port, auth, kerberos_service_name, password = ( + None, None, None, None, None + ) + username = username or getpass.getuser() configuration = configuration or {} @@ -207,6 +255,31 @@ def sasl_factory(): self._transport.close() raise + @staticmethod + def _set_authorization_header(transport, username=None, password=None): + username = username or "user" + password = password or "pass" + auth_credentials = f"{username}:{password}".encode("UTF-8") + auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode( + "UTF-8" + ) + transport.setCustomHeaders( + {"Authorization": f"Basic {auth_credentials_base64}"} + ) + + @staticmethod + def _set_kerberos_header(transport, kerberos_service_name, host) -> None: + import kerberos + + __, krb_context = kerberos.authGSSClientInit( + service=f"{kerberos_service_name}@{host}" + ) + kerberos.authGSSClientClean(krb_context, "") + kerberos.authGSSClientStep(krb_context, "") + auth_header = kerberos.authGSSClientResponse(krb_context) + + transport.setCustomHeaders({"Authorization": f"Negotiate {auth_header}"}) + def __enter__(self): """Transport should already be opened by __init__""" return self diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index 59e0c0ee..2ef49652 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -374,3 +374,29 @@ def _check_unicode_returns(self, connection, additional_tests=None): def _check_unicode_description(self, connection): # We decode everything as UTF-8 return True + + +class HiveHTTPDialect(HiveDialect): + + name = "hive" + scheme = "http" + driver = "rest" + + def create_connect_args(self, url): + kwargs = { + "host": url.host, + "port": url.port or 10000, + "scheme": self.scheme, + "username": url.username or None, + "password": url.password or None, + } + if url.query: + kwargs.update(url.query) + return [], kwargs + return ([], kwargs) + + +class HiveHTTPSDialect(HiveHTTPDialect): + + name = "hive" + scheme = "https" diff --git a/setup.py b/setup.py index df410dbc..ad34a38b 100755 --- a/setup.py +++ b/setup.py @@ -66,6 +66,8 @@ def run_tests(self): entry_points={ 'sqlalchemy.dialects': [ 'hive = pyhive.sqlalchemy_hive:HiveDialect', + "hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect", + "hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect", 'presto = pyhive.sqlalchemy_presto:PrestoDialect', 'trino = pyhive.sqlalchemy_trino:TrinoDialect', ], From b21c507a24ed2f2b0cf15b0b6abb1c43f31d3ee0 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Tue, 27 Apr 2021 17:37:29 +0100 Subject: [PATCH 02/17] fix: make hive https py2 compat (#389) * fix: make hive https py2 compat * fix lint --- README.rst | 5 +++++ pyhive/hive.py | 29 +++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index 8903ce78..a6f854a8 100644 --- a/README.rst +++ b/README.rst @@ -70,6 +70,11 @@ First install this package to register it with SQLAlchemy (see ``setup.py``). logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) print select([func.count('*')], from_obj=logs).scalar() + # Hive + HTTPS + LDAP or basic Auth + engine = create_engine('hive+https://username:password@localhost:10000/') + logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) + print select([func.count('*')], from_obj=logs).scalar() + Note: query generation functionality is not exhaustive or fully tested, but there should be no problem with raw SQL. diff --git a/pyhive/hive.py b/pyhive/hive.py index 66569406..3f71df33 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -139,6 +139,7 @@ def __init__( /impala/_thrift_api.py#L152-L160 """ if scheme in ("https", "http") and thrift_transport is None: + port = port or 1000 ssl_context = None if scheme == "https": ssl_context = create_default_context() @@ -146,7 +147,9 @@ def __init__( ssl_cert = ssl_cert or "none" ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE) thrift_transport = thrift.transport.THttpClient.THttpClient( - uri_or_host=f"{scheme}://{host}:{port}/cliservice/", + uri_or_host="{scheme}://{host}:{port}/cliservice/".format( + scheme=scheme, host=host, port=port + ), ssl_context=ssl_context, ) @@ -259,26 +262,40 @@ def sasl_factory(): def _set_authorization_header(transport, username=None, password=None): username = username or "user" password = password or "pass" - auth_credentials = f"{username}:{password}".encode("UTF-8") + auth_credentials = "{username}:{password}".format( + username=username, password=password + ).encode("UTF-8") auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode( "UTF-8" ) transport.setCustomHeaders( - {"Authorization": f"Basic {auth_credentials_base64}"} + { + "Authorization": "Basic {auth_credentials_base64}".format( + auth_credentials_base64=auth_credentials_base64 + ) + } ) @staticmethod - def _set_kerberos_header(transport, kerberos_service_name, host) -> None: + def _set_kerberos_header(transport, kerberos_service_name, host): import kerberos __, krb_context = kerberos.authGSSClientInit( - service=f"{kerberos_service_name}@{host}" + service="{kerberos_service_name}@{host}".format( + kerberos_service_name=kerberos_service_name, host=host + ) ) kerberos.authGSSClientClean(krb_context, "") kerberos.authGSSClientStep(krb_context, "") auth_header = kerberos.authGSSClientResponse(krb_context) - transport.setCustomHeaders({"Authorization": f"Negotiate {auth_header}"}) + transport.setCustomHeaders( + { + "Authorization": "Negotiate {auth_header}".format( + auth_header=auth_header + ) + } + ) def __enter__(self): """Transport should already be opened by __init__""" From d199a1bd55c656b5c28d0d62f2d3f2e6c9a82a54 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Thu, 20 Jan 2022 09:16:14 -0800 Subject: [PATCH 03/17] Update README.rst (#423) --- README.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.rst b/README.rst index a6f854a8..89c54532 100644 --- a/README.rst +++ b/README.rst @@ -1,3 +1,10 @@ +================================ +Project is currently unsupported +================================ + + + + .. image:: https://travis-ci.org/dropbox/PyHive.svg?branch=master :target: https://travis-ci.org/dropbox/PyHive .. image:: https://img.shields.io/codecov/c/github/dropbox/PyHive.svg From 8df7254c4016cbcb8a630166fdab9073955b0e48 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 14 Feb 2022 09:53:54 -0800 Subject: [PATCH 04/17] chore: rename Trino entry point (#428) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ad34a38b..be593fc0 100755 --- a/setup.py +++ b/setup.py @@ -69,7 +69,7 @@ def run_tests(self): "hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect", "hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect", 'presto = pyhive.sqlalchemy_presto:PrestoDialect', - 'trino = pyhive.sqlalchemy_trino:TrinoDialect', + 'trino.pyhive = pyhive.sqlalchemy_trino:TrinoDialect', ], } ) From 3547bd6cccf963a033928b73c5ed498684335c39 Mon Sep 17 00:00:00 2001 From: serenajiang Date: Mon, 7 Mar 2022 13:43:09 -0800 Subject: [PATCH 05/17] Support for Presto decimals (#430) * Support for Presto decimals * lower --- pyhive/presto.py | 18 ++++++++++++------ pyhive/tests/test_presto.py | 4 +++- pyhive/tests/test_trino.py | 4 +++- pyhive/trino.py | 2 +- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pyhive/presto.py b/pyhive/presto.py index a38cd891..3217f4c2 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -9,6 +9,8 @@ from __future__ import unicode_literals from builtins import object +from decimal import Decimal + from pyhive import common from pyhive.common import DBAPITypeObject # Make all exceptions visible in this module per DB-API @@ -34,6 +36,11 @@ _logger = logging.getLogger(__name__) +TYPES_CONVERTER = { + "decimal": Decimal, + # As of Presto 0.69, binary data is returned as the varbinary type in base64 format + "varbinary": base64.b64decode +} class PrestoParamEscaper(common.ParamEscaper): def escape_datetime(self, item, format): @@ -307,14 +314,13 @@ def _fetch_more(self): """Fetch the next URI and update state""" self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs)) - def _decode_binary(self, rows): - # As of Presto 0.69, binary data is returned as the varbinary type in base64 format - # This function decodes base64 data in place + def _process_data(self, rows): for i, col in enumerate(self.description): - if col[1] == 'varbinary': + col_type = col[1].split("(")[0].lower() + if col_type in TYPES_CONVERTER: for row in rows: if row[i] is not None: - row[i] = base64.b64decode(row[i]) + row[i] = TYPES_CONVERTER[col_type](row[i]) def _process_response(self, response): """Given the JSON response from Presto's REST API, update the internal state with the next @@ -341,7 +347,7 @@ def _process_response(self, response): if 'data' in response_json: assert self._columns new_data = response_json['data'] - self._decode_binary(new_data) + self._process_data(new_data) self._data += map(tuple, new_data) if 'nextUri' not in response_json: self._state = self._STATE_FINISHED diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py index 7c74f057..187b1c21 100644 --- a/pyhive/tests/test_presto.py +++ b/pyhive/tests/test_presto.py @@ -9,6 +9,8 @@ import contextlib import os +from decimal import Decimal + import requests from pyhive import exc @@ -93,7 +95,7 @@ def test_complex(self, cursor): {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), )] self.assertEqual(rows, expected) # catch unicode/str diff --git a/pyhive/tests/test_trino.py b/pyhive/tests/test_trino.py index cdc8bb43..41bb489b 100644 --- a/pyhive/tests/test_trino.py +++ b/pyhive/tests/test_trino.py @@ -9,6 +9,8 @@ import contextlib import os +from decimal import Decimal + import requests from pyhive import exc @@ -89,7 +91,7 @@ def test_complex(self, cursor): {"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), )] self.assertEqual(rows, expected) # catch unicode/str diff --git a/pyhive/trino.py b/pyhive/trino.py index e8a1aabd..658457a3 100644 --- a/pyhive/trino.py +++ b/pyhive/trino.py @@ -124,7 +124,7 @@ def _process_response(self, response): if 'data' in response_json: assert self._columns new_data = response_json['data'] - self._decode_binary(new_data) + self._process_data(new_data) self._data += map(tuple, new_data) if 'nextUri' not in response_json: self._state = self._STATE_FINISHED From 1f99552303626cce9eb6867fb7401fc810637fd6 Mon Sep 17 00:00:00 2001 From: Usiel Riedl Date: Tue, 9 May 2023 10:05:04 +0200 Subject: [PATCH 06/17] Use str type for driver and name in HiveDialect (#450) PyHive's HiveDialect usage of bytes for the name and driver fields is not the norm is causing issues upstream: https://github.com/apache/superset/issues/22316 Even other dialects within PyHive use strings. SQLAlchemy does not strictly require a string, but all the stock dialects return a string, so I figure it is heavily implied. I think the risk of breaking something upstream with this change is low (but it is there ofc). I figure in most cases we just make someone's `str(dialect.driver)` expression redundant. Examples for some of the other stock sqlalchemy dialects (name and driver fields using str): https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/sqlite/pysqlite.py#L501 https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/sqlite/base.py#L1891 https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/mysql/base.py#L2383 https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/mysql/mysqldb.py#L113 https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/mysql/pymysql.py#L59 --- pyhive/sqlalchemy_hive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index 2ef49652..f39f1793 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -228,8 +228,8 @@ def _translate_colname(self, colname): class HiveDialect(default.DefaultDialect): - name = b'hive' - driver = b'thrift' + name = 'hive' + driver = 'thrift' execution_ctx_cls = HiveExecutionContext preparer = HiveIdentifierPreparer statement_compiler = HiveCompiler From 1c1da8b17bdf0e7e881e15bb731119558bd5440f Mon Sep 17 00:00:00 2001 From: Multazim Deshmukh <57723564+mdeshmu@users.noreply.github.com> Date: Wed, 17 May 2023 20:49:50 +0530 Subject: [PATCH 07/17] Correcting Iterable import for python 3.10 (#451) --- pyhive/common.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyhive/common.py b/pyhive/common.py index 298633a1..51692b97 100644 --- a/pyhive/common.py +++ b/pyhive/common.py @@ -18,6 +18,11 @@ from future.utils import with_metaclass from itertools import islice +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable + class DBAPICursor(with_metaclass(abc.ABCMeta, object)): """Base class for some common DB-API logic""" @@ -245,7 +250,7 @@ def escape_item(self, item): return self.escape_number(item) elif isinstance(item, basestring): return self.escape_string(item) - elif isinstance(item, collections.Iterable): + elif isinstance(item, Iterable): return self.escape_sequence(item) elif isinstance(item, datetime.datetime): return self.escape_datetime(item, self._DATETIME_FORMAT) From b0206d3cb8a9f9a95a36feeae311f6b0141c6675 Mon Sep 17 00:00:00 2001 From: nicholas-miles Date: Wed, 17 May 2023 08:21:07 -0700 Subject: [PATCH 08/17] changing drivers to support hive, presto and trino with sqlalchemy>=2.0 (#448) --- pyhive/sqlalchemy_hive.py | 14 +++++++++++--- pyhive/sqlalchemy_presto.py | 8 ++++++-- pyhive/sqlalchemy_trino.py | 8 ++++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index f39f1793..34fdb648 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -13,11 +13,19 @@ import re from sqlalchemy import exc -from sqlalchemy import processors +try: + from sqlalchemy import processors +except ImportError: + # Newer versions of sqlalchemy require: + from sqlalchemy.engine import processors from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases.mysql import MSTinyInteger +except ImportError: + # Newer versions of sqlalchemy require: + from sqlalchemy.dialects.mysql import MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -121,7 +129,7 @@ def __init__(self, dialect): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': MSTinyInteger, 'smallint': types.SmallInteger, 'int': types.Integer, 'bigint': types.BigInteger, diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index a199ebe1..94d06412 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -13,7 +13,11 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases.mysql import MSTinyInteger +except ImportError: + # Newer versions of sqlalchemy require: + from sqlalchemy.dialects.mysql import MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -29,7 +33,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': MSTinyInteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py index 4b2b3698..686a42c7 100644 --- a/pyhive/sqlalchemy_trino.py +++ b/pyhive/sqlalchemy_trino.py @@ -13,7 +13,11 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases.mysql import MSTinyInteger +except ImportError: + # Newer versions of sqlalchemy require: + from sqlalchemy.dialects.mysql import MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -28,7 +32,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': MSTinyInteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, From df03bef66500541fa921ec3614ec06a15ca17615 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Wed, 17 May 2023 09:32:32 -0700 Subject: [PATCH 09/17] Revert "changing drivers to support hive, presto and trino with sqlalchemy>=2.0 (#448)" (#452) This reverts commit b0206d3cb8a9f9a95a36feeae311f6b0141c6675. --- pyhive/sqlalchemy_hive.py | 14 +++----------- pyhive/sqlalchemy_presto.py | 8 ++------ pyhive/sqlalchemy_trino.py | 8 ++------ 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index 34fdb648..f39f1793 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -13,19 +13,11 @@ import re from sqlalchemy import exc -try: - from sqlalchemy import processors -except ImportError: - # Newer versions of sqlalchemy require: - from sqlalchemy.engine import processors +from sqlalchemy import processors from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -try: - from sqlalchemy.databases.mysql import MSTinyInteger -except ImportError: - # Newer versions of sqlalchemy require: - from sqlalchemy.dialects.mysql import MSTinyInteger +from sqlalchemy.databases import mysql from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -129,7 +121,7 @@ def __init__(self, dialect): _type_map = { 'boolean': types.Boolean, - 'tinyint': MSTinyInteger, + 'tinyint': mysql.MSTinyInteger, 'smallint': types.SmallInteger, 'int': types.Integer, 'bigint': types.BigInteger, diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index 94d06412..a199ebe1 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -13,11 +13,7 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -try: - from sqlalchemy.databases.mysql import MSTinyInteger -except ImportError: - # Newer versions of sqlalchemy require: - from sqlalchemy.dialects.mysql import MSTinyInteger +from sqlalchemy.databases import mysql from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -33,7 +29,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': MSTinyInteger, + 'tinyint': mysql.MSTinyInteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py index 686a42c7..4b2b3698 100644 --- a/pyhive/sqlalchemy_trino.py +++ b/pyhive/sqlalchemy_trino.py @@ -13,11 +13,7 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -try: - from sqlalchemy.databases.mysql import MSTinyInteger -except ImportError: - # Newer versions of sqlalchemy require: - from sqlalchemy.dialects.mysql import MSTinyInteger +from sqlalchemy.databases import mysql from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -32,7 +28,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': MSTinyInteger, + 'tinyint': mysql.MSTinyInteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, From 0bd6f5b5f76f759cd01b83287cec15da9789753e Mon Sep 17 00:00:00 2001 From: Bogdan Date: Wed, 17 May 2023 10:02:09 -0700 Subject: [PATCH 10/17] Update __init__.py (#453) https://github.com/dropbox/PyHive/commit/1c1da8b17bdf0e7e881e15bb731119558bd5440f https://github.com/dropbox/PyHive/commit/1f99552303626cce9eb6867fb7401fc810637fd6 --- pyhive/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhive/__init__.py b/pyhive/__init__.py index 8ede6abb..0a6bb1f6 100644 --- a/pyhive/__init__.py +++ b/pyhive/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import from __future__ import unicode_literals -__version__ = '0.6.3' +__version__ = '0.7.0' From 4367cc550252f9e6f85782bee7f8694325a742a6 Mon Sep 17 00:00:00 2001 From: Multazim Deshmukh <57723564+mdeshmu@users.noreply.github.com> Date: Tue, 20 Jun 2023 16:37:18 +0530 Subject: [PATCH 11/17] use pure-sasl with python 3.11 (#454) --- dev_requirements.txt | 2 + pyhive/hive.py | 56 ++++-- pyhive/sasl_compat.py | 56 ++++++ pyhive/tests/test_hive.py | 11 +- pyhive/tests/test_sasl_compat.py | 333 +++++++++++++++++++++++++++++++ setup.py | 3 + 6 files changed, 436 insertions(+), 25 deletions(-) create mode 100644 pyhive/sasl_compat.py create mode 100644 pyhive/tests/test_sasl_compat.py diff --git a/dev_requirements.txt b/dev_requirements.txt index 0bf6d8a7..40bb605a 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -12,6 +12,8 @@ pytest-timeout==1.2.0 requests>=1.0.0 requests_kerberos>=0.12.0 sasl>=0.2.1 +pure-sasl>=0.6.2 +kerberos>=1.3.0 thrift>=0.10.0 #thrift_sasl>=0.1.0 git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches diff --git a/pyhive/hive.py b/pyhive/hive.py index 3f71df33..c1287488 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -49,6 +49,45 @@ } +def get_sasl_client(host, sasl_auth, service=None, username=None, password=None): + import sasl + sasl_client = sasl.Client() + sasl_client.setAttr('host', host) + + if sasl_auth == 'GSSAPI': + sasl_client.setAttr('service', service) + elif sasl_auth == 'PLAIN': + sasl_client.setAttr('username', username) + sasl_client.setAttr('password', password) + else: + raise ValueError("sasl_auth only supports GSSAPI and PLAIN") + + sasl_client.init() + return sasl_client + + +def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None): + from pyhive.sasl_compat import PureSASLClient + + if sasl_auth == 'GSSAPI': + sasl_kwargs = {'service': service} + elif sasl_auth == 'PLAIN': + sasl_kwargs = {'username': username, 'password': password} + else: + raise ValueError("sasl_auth only supports GSSAPI and PLAIN") + + return PureSASLClient(host=host, **sasl_kwargs) + + +def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None): + try: + return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + # The sasl library is available + except ImportError: + # Fallback to pure-sasl library + return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + + def _parse_timestamp(value): if value: match = _TIMESTAMP_PATTERN.match(value) @@ -200,7 +239,6 @@ def __init__( self._transport = thrift.transport.TTransport.TBufferedTransport(socket) elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): # Defer import so package dependency is optional - import sasl import thrift_sasl if auth == 'KERBEROS': @@ -211,20 +249,8 @@ def __init__( if password is None: # Password doesn't matter in NONE mode, just needs to be nonempty. password = 'x' - - def sasl_factory(): - sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', kerberos_service_name) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) - else: - raise AssertionError - sasl_client.init() - return sasl_client - self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + + self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket) else: # All HS2 config options: # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration diff --git a/pyhive/sasl_compat.py b/pyhive/sasl_compat.py new file mode 100644 index 00000000..dc65abe9 --- /dev/null +++ b/pyhive/sasl_compat.py @@ -0,0 +1,56 @@ +# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py +# which uses Apache-2.0 license as of 21 May 2023. +# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl +# via PR https://github.com/cloudera/impyla/pull/179 +# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34 +# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82 +# Hence this code is required for the fallback to work. + + +from puresasl.client import SASLClient, SASLError +from contextlib import contextmanager + +@contextmanager +def error_catcher(self, Exc = Exception): + try: + self.error = None + yield + except Exc as e: + self.error = str(e) + + +class PureSASLClient(SASLClient): + def __init__(self, *args, **kwargs): + self.error = None + super(PureSASLClient, self).__init__(*args, **kwargs) + + def start(self, mechanism): + with error_catcher(self, SASLError): + if isinstance(mechanism, list): + self.choose_mechanism(mechanism) + else: + self.choose_mechanism([mechanism]) + return True, self.mechanism, self.process() + # else + return False, mechanism, None + + def encode(self, incoming): + with error_catcher(self): + return True, self.unwrap(incoming) + # else + return False, None + + def decode(self, outgoing): + with error_catcher(self): + return True, self.wrap(outgoing) + # else + return False, None + + def step(self, challenge=None): + with error_catcher(self): + return True, self.process(challenge) + # else + return False, None + + def getError(self): + return self.error diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py index c70ed962..b49fc190 100644 --- a/pyhive/tests/test_hive.py +++ b/pyhive/tests/test_hive.py @@ -17,7 +17,6 @@ from decimal import Decimal import mock -import sasl import thrift.transport.TSocket import thrift.transport.TTransport import thrift_sasl @@ -204,15 +203,7 @@ def test_custom_transport(self): socket = thrift.transport.TSocket.TSocket('localhost', 10000) sasl_auth = 'PLAIN' - def sasl_factory(): - sasl_client = sasl.Client() - sasl_client.setAttr('host', 'localhost') - sasl_client.setAttr('username', 'test_username') - sasl_client.setAttr('password', 'x') - sasl_client.init() - return sasl_client - - transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket) conn = hive.connect(thrift_transport=transport) with contextlib.closing(conn): with contextlib.closing(conn.cursor()) as cursor: diff --git a/pyhive/tests/test_sasl_compat.py b/pyhive/tests/test_sasl_compat.py new file mode 100644 index 00000000..49516249 --- /dev/null +++ b/pyhive/tests/test_sasl_compat.py @@ -0,0 +1,333 @@ +''' +http://www.opensource.org/licenses/mit-license.php + +Copyright 2007-2011 David Alan Cridland +Copyright 2011 Lance Stout +Copyright 2012 Tyler L Hobbs + +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software and associated documentation files (the "Software"), to deal in the Software +without restriction, including without limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or +substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +''' +# This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit +# and by refactoring them to cover wrapper functions in sasl_compat.py along with added coverage for functions exclusive to sasl_compat.py. + +import unittest +import base64 +import hashlib +import hmac +import kerberos +from mock import patch +import six +import struct +from puresasl import SASLProtocolException, QOP +from puresasl.client import SASLError +from pyhive.sasl_compat import PureSASLClient, error_catcher + + +class TestPureSASLClient(unittest.TestCase): + """Test cases for initialization of SASL client using PureSASLClient class""" + + def setUp(self): + self.sasl_kwargs = {} + self.sasl = PureSASLClient('localhost', **self.sasl_kwargs) + + def test_start_no_mechanism(self): + """Test starting SASL authentication with no mechanism.""" + success, mechanism, response = self.sasl.start(mechanism=None) + self.assertFalse(success) + self.assertIsNone(mechanism) + self.assertIsNone(response) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_wrong_mechanism(self): + """Test starting SASL authentication with a single unsupported mechanism.""" + success, mechanism, response = self.sasl.start(mechanism='WRONG') + self.assertFalse(success) + self.assertEqual(mechanism, 'WRONG') + self.assertIsNone(response) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_list_of_invalid_mechanisms(self): + """Test starting SASL authentication with a list of unsupported mechanisms.""" + self.sasl.start(['invalid1', 'invalid2']) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_list_of_valid_mechanisms(self): + """Test starting SASL authentication with a list of supported mechanisms.""" + self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5']) + # Validate right mechanism is chosen based on score. + self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5') + + def test_error_catcher_no_error(self): + """Test the error_catcher with no error.""" + with error_catcher(self.sasl): + result, _, _ = self.sasl.start(mechanism='ANONYMOUS') + + self.assertEqual(self.sasl.getError(), None) + self.assertEqual(result, True) + + def test_error_catcher_with_error(self): + """Test the error_catcher with an error.""" + with error_catcher(self.sasl): + result, _, _ = self.sasl.start(mechanism='WRONG') + + self.assertEqual(result, False) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + +"""Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims""" + +class _BaseMechanismTests(unittest.TestCase): + """Base test case for SASL mechanisms.""" + + mechanism = 'ANONYMOUS' + sasl_kwargs = {} + + def setUp(self): + self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + self.mechanism_class = self.sasl._chosen_mech + + def test_init_basic(self, *args): + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + mech = sasl._chosen_mech + self.assertIs(mech.sasl, sasl) + + def test_step_basic(self, *args): + success, response = self.sasl.step(six.b('string')) + self.assertTrue(success) + self.assertIsInstance(response, six.binary_type) + + def test_decode_encode(self, *args): + self.assertEqual(self.sasl.encode('msg'), (False, None)) + self.assertEqual(self.sasl.getError(), '') + self.assertEqual(self.sasl.decode('msg'), (False, None)) + self.assertEqual(self.sasl.getError(), '') + + +class AnonymousMechanismTest(_BaseMechanismTests): + """Test case for the Anonymous SASL mechanism.""" + + mechanism = 'ANONYMOUS' + + +class PlainTextMechanismTest(_BaseMechanismTests): + """Test case for the PlainText SASL mechanism.""" + + mechanism = 'PLAIN' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_step(self): + for challenge in (None, '', b'asdf', u"\U0001F44D"): + success, response = self.sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + + def test_step_with_authorization_id_or_identity(self): + challenge = u"\U0001F44D" + identity = 'user2' + + # Test that we can pass an identity + sasl_kwargs = self.sasl_kwargs.copy() + sasl_kwargs.update({'identity': identity}) + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + success, response = sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(sasl.complete) + + # Test that the sasl authorization_id has priority over identity + auth_id = 'user3' + sasl_kwargs.update({'authorization_id': auth_id}) + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + success, response = sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(sasl.complete) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +class ExternalMechanismTest(_BaseMechanismTests): + """Test case for the External SASL mechanisms""" + + mechanism = 'EXTERNAL' + + def test_step(self): + self.assertEqual(self.sasl.step(), (True, b'')) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +@patch('puresasl.mechanisms.kerberos.authGSSClientStep') +@patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response'))) +class GSSAPIMechanismTest(_BaseMechanismTests): + """Test case for the GSSAPI SASL mechanism.""" + + mechanism = 'GSSAPI' + service = 'GSSAPI' + sasl_kwargs = {'service': service} + + @patch('puresasl.mechanisms.kerberos.authGSSClientWrap') + @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args): + # bypassing step setup by setting qop directly + self.mechanism_class.qop = QOP.AUTH + msg = b'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication + for qop in (QOP.AUTH_INT, QOP.AUTH_CONF): + self.mechanism_class.qop = qop + with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1): + self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + if qop == QOP.AUTH_CONF: + with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0): + self.assertEqual(self.sasl.encode(msg), (False, None)) + self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.') + + def test_step_no_user(self, authGSSClientResponse, *args): + msg = six.b('whatever') + + # no user + self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''): + self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + + username = 'username' + # with user; this has to be last because it sets mechanism.user + with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE): + with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)): + self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + self.assertEqual(self.mechanism_class.user, six.b(username)) + + @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + def test_step_qop(self, *args): + self.mechanism_class._have_negotiated_details = True + self.mechanism_class.user = 'user' + msg = six.b('msg') + self.assertEqual(self.sasl.step(msg), (False, None)) + self.assertEqual(self.sasl.getError(), 'Bad response from server') + + max_len = 100 + self.assertLess(max_len, self.sasl.max_buffer) + for i, qop in QOP.bit_map.items(): + qop_size = struct.pack('!i', i << 24 | max_len) + response = base64.b64encode(qop_size) + with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response): + with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap: + self.mechanism_class.complete = False + self.assertEqual(self.sasl.step(msg), (True, qop_size)) + self.assertTrue(self.mechanism_class.complete) + self.assertEqual(self.mechanism_class.qop, qop) + self.assertEqual(self.mechanism_class.max_buffer, max_len) + + args = authGSSClientWrap.call_args[0] + out_data = args[1] + out = base64.b64decode(out_data) + self.assertEqual(out[:4], qop_size) + self.assertEqual(out[4:], six.b(self.mechanism_class.user)) + + +class CramMD5MechanismTest(_BaseMechanismTests): + """Test case for the CRAM-MD5 SASL mechanism.""" + + mechanism = 'CRAM-MD5' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_step(self): + success, response = self.sasl.step(None) + self.assertTrue(success) + self.assertIsNone(response) + challenge = six.b('msg') + hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5) + hash.update(challenge) + success, response = self.sasl.step(challenge) + self.assertTrue(success) + self.assertIn(six.b(self.username), response) + self.assertIn(six.b(hash.hexdigest()), response) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(self.sasl.complete) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +class DigestMD5MechanismTest(_BaseMechanismTests): + """Test case for the DIGEST-MD5 SASL mechanism.""" + + mechanism = 'DIGEST-MD5' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + def test_step_basic(self, *args): + pass + + def test_step(self): + """Test a SASL step with dummy challenge for DIGEST-MD5 mechanism.""" + testChallenge = ( + b'nonce="rmD6R8aMYVWH+/ih9HGBr3xNGAR6o2DUxpKlgDz6gUQ=",r' + b'ealm="example.org",qop="auth,auth-int,auth-conf",cipher="rc4-40,rc' + b'4-56,rc4,des,3des",maxbuf=65536,charset=utf-8,algorithm=md5-sess' + ) + result, response = self.sasl.step(testChallenge) + self.assertTrue(result) + self.assertIsNotNone(response) + + def test_step_server_answer(self): + """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism.""" + sasl_kwargs = {'username': "chris", 'password': "secret"} + sasl = PureSASLClient('elwood.innosoft.com', + service="imap", + mechanism=self.mechanism, + mutual_auth=True, + **sasl_kwargs) + testChallenge = ( + b'utf-8,username="chris",realm="elwood.innosoft.com",' + b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",' + b'digest-uri="imap/elwood.innosoft.com",' + b'response=d388dad90d4bbd760a152321f2143af7,qop=auth' + ) + sasl.step(testChallenge) + sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk" + + serverResponse = ( + b'rspauth=ea40f60335c427b5527b84dbabcdfffd' + ) + sasl.step(serverResponse) + # assert that step choses the only supported QOP for for DIGEST-MD5 + self.assertEqual(self.sasl.qop, QOP.AUTH) diff --git a/setup.py b/setup.py index be593fc0..d141ea1b 100755 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def run_tests(self): 'presto': ['requests>=1.0.0'], 'trino': ['requests>=1.0.0'], 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], + 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], 'sqlalchemy': ['sqlalchemy>=1.3.0'], 'kerberos': ['requests_kerberos>=0.12.0'], }, @@ -56,6 +57,8 @@ def run_tests(self): 'requests>=1.0.0', 'requests_kerberos>=0.12.0', 'sasl>=0.2.1', + 'pure-sasl>=0.6.2', + 'kerberos>=1.3.0', 'sqlalchemy>=1.3.0', 'thrift>=0.10.0', ], From 5f0ee1f2ad558e120474b31ef065bb42457d1208 Mon Sep 17 00:00:00 2001 From: Multazim Deshmukh <57723564+mdeshmu@users.noreply.github.com> Date: Sat, 8 Jul 2023 15:17:51 +0530 Subject: [PATCH 12/17] minimal changes for sqlalchemy 2.0 support (#457) --- README.rst | 18 ++++- pyhive/sqlalchemy_hive.py | 30 ++++++-- pyhive/sqlalchemy_presto.py | 28 ++++++-- pyhive/sqlalchemy_trino.py | 15 +++- pyhive/tests/sqlalchemy_test_case.py | 88 +++++++++++++++-------- pyhive/tests/test_sqlalchemy_hive.py | 98 ++++++++++++++++---------- pyhive/tests/test_sqlalchemy_presto.py | 14 ++-- pyhive/tests/test_sqlalchemy_trino.py | 93 ++++++++++++++++++++++++ 8 files changed, 293 insertions(+), 91 deletions(-) create mode 100644 pyhive/tests/test_sqlalchemy_trino.py diff --git a/README.rst b/README.rst index 89c54532..5afd746c 100644 --- a/README.rst +++ b/README.rst @@ -71,9 +71,11 @@ First install this package to register it with SQLAlchemy (see ``setup.py``). # Presto engine = create_engine('presto://localhost:8080/hive/default') # Trino - engine = create_engine('trino://localhost:8080/hive/default') + engine = create_engine('trino+pyhive://localhost:8080/hive/default') # Hive engine = create_engine('hive://localhost:10000/default') + + # SQLAlchemy < 2.0 logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) print select([func.count('*')], from_obj=logs).scalar() @@ -82,6 +84,20 @@ First install this package to register it with SQLAlchemy (see ``setup.py``). logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) print select([func.count('*')], from_obj=logs).scalar() + # SQLAlchemy >= 2.0 + metadata_obj = MetaData() + books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String)) + metadata_obj.create_all(engine) + inspector = inspect(engine) + inspector.get_columns('books') + + with engine.connect() as con: + data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" }, + { "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }] + con.execute(books.insert(), data[0]) + result = con.execute(text("select * from books")) + print(result.fetchall()) + Note: query generation functionality is not exhaustive or fully tested, but there should be no problem with raw SQL. diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index f39f1793..e2244525 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -13,11 +13,22 @@ import re from sqlalchemy import exc -from sqlalchemy import processors +from sqlalchemy.sql import text +try: + from sqlalchemy import processors +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.engine import processors from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -121,7 +132,7 @@ def __init__(self, dialect): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'int': types.Integer, 'bigint': types.BigInteger, @@ -247,10 +258,15 @@ class HiveDialect(default.DefaultDialect): supports_multivalues_insert = True type_compiler = HiveTypeCompiler supports_sane_rowcount = False + supports_statement_cache = False @classmethod def dbapi(cls): return hive + + @classmethod + def import_dbapi(cls): + return hive def create_connect_args(self, url): kwargs = { @@ -265,7 +281,7 @@ def create_connect_args(self, url): def get_schema_names(self, connection, **kw): # Equivalent to SHOW DATABASES - return [row[0] for row in connection.execute('SHOW SCHEMAS')] + return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))] def get_view_names(self, connection, schema=None, **kw): # Hive does not provide functionality to query tableType @@ -280,7 +296,7 @@ def _get_table_columns(self, connection, table_name, schema): # Using DESCRIBE works but is uglier. try: # This needs the table name to be unescaped (no backticks). - rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall() + rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall() except exc.OperationalError as e: # Does the table exist? regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}' @@ -296,7 +312,7 @@ def _get_table_columns(self, connection, table_name, schema): raise exc.NoSuchTableError(full_table) return rows - def has_table(self, connection, table_name, schema=None): + def has_table(self, connection, table_name, schema=None, **kw): try: self._get_table_columns(connection, table_name, schema) return True @@ -361,7 +377,7 @@ def get_table_names(self, connection, schema=None, **kw): query = 'SHOW TABLES' if schema: query += ' IN ' + self.identifier_preparer.quote_identifier(schema) - return [row[0] for row in connection.execute(query)] + return [row[0] for row in connection.execute(text(query))] def do_rollback(self, dbapi_connection): # No transactions for Hive diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index a199ebe1..bfe1ba04 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -9,11 +9,19 @@ from __future__ import unicode_literals import re +import sqlalchemy from sqlalchemy import exc from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +from sqlalchemy.sql import text +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -21,6 +29,7 @@ from pyhive import presto from pyhive.common import UniversalSet +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) class PrestoIdentifierPreparer(compiler.IdentifierPreparer): # Just quote everything to make things simpler / easier to upgrade @@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, @@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect): supports_multivalues_insert = True supports_unicode_statements = True supports_unicode_binds = True + supports_statement_cache = False returns_unicode_strings = True description_encoding = None supports_native_boolean = True @@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect): @classmethod def dbapi(cls): return presto + + @classmethod + def import_dbapi(cls): + return presto def create_connect_args(self, url): db_parts = (url.database or 'hive').split('/') @@ -108,14 +122,14 @@ def create_connect_args(self, url): return [], kwargs def get_schema_names(self, connection, **kw): - return [row.Schema for row in connection.execute('SHOW SCHEMAS')] + return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))] def _get_table_columns(self, connection, table_name, schema): full_table = self.identifier_preparer.quote_identifier(table_name) if schema: full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table try: - return connection.execute('SHOW COLUMNS FROM {}'.format(full_table)) + return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table))) except (presto.DatabaseError, exc.DatabaseError) as e: # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which # it successfully does in the Hive version. The difference with Presto is that this @@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema): else: raise - def has_table(self, connection, table_name, schema=None): + def has_table(self, connection, table_name, schema=None, **kw): try: self._get_table_columns(connection, table_name, schema) return True @@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw): # - a boolean column named "Partition Key" # - a string in the "Comment" column # - a string in the "Extra" column + if sqlalchemy_version >= 1.4: + row = row._mapping is_partition_key = ( (part_key in row and row[part_key]) or row['Comment'].startswith(part_key) @@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw): query = 'SHOW TABLES' if schema: query += ' FROM ' + self.identifier_preparer.quote_identifier(schema) - return [row.Table for row in connection.execute(query)] + return [row.Table for row in connection.execute(text(query))] def do_rollback(self, dbapi_connection): # No transactions for Presto diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py index 4b2b3698..11be2a6c 100644 --- a/pyhive/sqlalchemy_trino.py +++ b/pyhive/sqlalchemy_trino.py @@ -13,7 +13,13 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -28,7 +34,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, @@ -67,7 +73,12 @@ def visit_TEXT(self, type_, **kw): class TrinoDialect(PrestoDialect): name = 'trino' + supports_statement_cache = False @classmethod def dbapi(cls): return trino + + @classmethod + def import_dbapi(cls): + return trino diff --git a/pyhive/tests/sqlalchemy_test_case.py b/pyhive/tests/sqlalchemy_test_case.py index 652e05f4..db89d57b 100644 --- a/pyhive/tests/sqlalchemy_test_case.py +++ b/pyhive/tests/sqlalchemy_test_case.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals import abc +import re import contextlib import functools @@ -14,8 +15,10 @@ from sqlalchemy.schema import Index from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table -from sqlalchemy.sql import expression +from sqlalchemy.sql import expression, text +from sqlalchemy import String +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) def with_engine_connection(fn): """Pass a connection to the given function and handle cleanup. @@ -32,19 +35,33 @@ def wrapped_fn(self, *args, **kwargs): engine.dispose() return wrapped_fn +def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks): + if sqlalchemy_version >= 1.4: + insp = sqlalchemy.inspect(engine) + insp.reflect_table( + table, + include_columns=include_columns, + exclude_columns=exclude_columns, + resolve_fks=resolve_fks, + ) + else: + engine.dialect.reflecttable( + connection, table, include_columns=include_columns, + exclude_columns=exclude_columns, resolve_fks=resolve_fks) + class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)): @with_engine_connection def test_basic_query(self, engine, connection): - rows = connection.execute('SELECT * FROM one_row').fetchall() + rows = connection.execute(text('SELECT * FROM one_row')).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name self.assertEqual(len(rows[0]), 1) @with_engine_connection def test_one_row_complex_null(self, engine, connection): - one_row_complex_null = Table('one_row_complex_null', MetaData(bind=engine), autoload=True) - rows = one_row_complex_null.select().execute().fetchall() + one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine) + rows = connection.execute(one_row_complex_null.select()).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(list(rows[0]), [None] * len(rows[0])) @@ -53,27 +70,26 @@ def test_reflect_no_such_table(self, engine, connection): """reflecttable should throw an exception on an invalid table""" self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(bind=engine), autoload=True)) + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(bind=engine), - schema='also_does_not_exist', autoload=True)) + lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine)) @with_engine_connection def test_reflect_include_columns(self, engine, connection): """When passed include_columns, reflecttable should filter out other columns""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, one_row_complex, include_columns=['int'], + + one_row_complex = Table('one_row_complex', MetaData()) + reflect_table(engine, connection, one_row_complex, include_columns=['int'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(one_row_complex.c), 1) self.assertIsNotNone(one_row_complex.c.int) self.assertRaises(AttributeError, lambda: one_row_complex.c.tinyint) @with_engine_connection def test_reflect_with_schema(self, engine, connection): - dummy = Table('dummy_table', MetaData(bind=engine), schema='pyhive_test_database', - autoload=True) + dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine) self.assertEqual(len(dummy.c), 1) self.assertIsNotNone(dummy.c.a) @@ -81,22 +97,22 @@ def test_reflect_with_schema(self, engine, connection): @with_engine_connection def test_reflect_partitions(self, engine, connection): """reflecttable should get the partition column as an index""" - many_rows = Table('many_rows', MetaData(bind=engine), autoload=True) + many_rows = Table('many_rows', MetaData(), autoload_with=engine) self.assertEqual(len(many_rows.c), 2) self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) - many_rows = Table('many_rows', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, many_rows, include_columns=['a'], + many_rows = Table('many_rows', MetaData()) + reflect_table(engine, connection, many_rows, include_columns=['a'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(many_rows.c), 1) self.assertFalse(many_rows.c.a.index) self.assertFalse(many_rows.indexes) - many_rows = Table('many_rows', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, many_rows, include_columns=['b'], + many_rows = Table('many_rows', MetaData()) + reflect_table(engine, connection, many_rows, include_columns=['b'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(many_rows.c), 1) self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) @@ -104,11 +120,15 @@ def test_reflect_partitions(self, engine, connection): def test_unicode(self, engine, connection): """Verify that unicode strings make it through SQLAlchemy and the backend""" unicode_str = "中文" - one_row = Table('one_row', MetaData(bind=engine)) - returned_str = sqlalchemy.select( - [expression.bindparam("好", unicode_str)], - from_obj=one_row, - ).scalar() + one_row = Table('one_row', MetaData()) + + if sqlalchemy_version >= 1.4: + returned_str = connection.execute(sqlalchemy.select( + expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar() + else: + returned_str = connection.execute(sqlalchemy.select([ + expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar() + self.assertEqual(returned_str, unicode_str) @with_engine_connection @@ -133,13 +153,21 @@ def test_get_table_names(self, engine, connection): @with_engine_connection def test_has_table(self, engine, connection): - self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) - self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) + if sqlalchemy_version >= 1.4: + insp = sqlalchemy.inspect(engine) + self.assertTrue(insp.has_table("one_row")) + self.assertFalse(insp.has_table("this_table_does_not_exist")) + else: + self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) + self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) @with_engine_connection def test_char_length(self, engine, connection): - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) - result = sqlalchemy.select([ - sqlalchemy.func.char_length(one_row_complex.c.string) - ]).execute().scalar() + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + + if sqlalchemy_version >= 1.4: + result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar() + else: + result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar() + self.assertEqual(result, len('a string')) diff --git a/pyhive/tests/test_sqlalchemy_hive.py b/pyhive/tests/test_sqlalchemy_hive.py index 1ff0e817..790bec4c 100644 --- a/pyhive/tests/test_sqlalchemy_hive.py +++ b/pyhive/tests/test_sqlalchemy_hive.py @@ -4,6 +4,7 @@ from pyhive.sqlalchemy_hive import HiveDate from pyhive.sqlalchemy_hive import HiveDecimal from pyhive.sqlalchemy_hive import HiveTimestamp +from sqlalchemy.exc import NoSuchTableError, OperationalError from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase from pyhive.tests.sqlalchemy_test_case import with_engine_connection from sqlalchemy import types @@ -11,11 +12,15 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table +from sqlalchemy.sql import text import contextlib import datetime import decimal import sqlalchemy.types import unittest +import re + +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) _ONE_ROW_COMPLEX_CONTENTS = [ True, @@ -64,7 +69,11 @@ def test_dotted_column_names(self, engine, connection): """When Hive returns a dotted column name, both the non-dotted version should be available as an attribute, and the dotted version should remain available as a key. """ - row = connection.execute('SELECT * FROM one_row').fetchone() + row = connection.execute(text('SELECT * FROM one_row')).fetchone() + + if sqlalchemy_version >= 1.4: + row = row._mapping + assert row.keys() == ['number_of_rows'] assert 'number_of_rows' in row assert row.number_of_rows == 1 @@ -76,20 +85,33 @@ def test_dotted_column_names(self, engine, connection): def test_dotted_column_names_raw(self, engine, connection): """When Hive returns a dotted column name, and raw mode is on, nothing should be modified. """ - row = connection.execution_options(hive_raw_colnames=True) \ - .execute('SELECT * FROM one_row').fetchone() + row = connection.execution_options(hive_raw_colnames=True).execute(text('SELECT * FROM one_row')).fetchone() + + if sqlalchemy_version >= 1.4: + row = row._mapping + assert row.keys() == ['one_row.number_of_rows'] assert 'number_of_rows' not in row assert getattr(row, 'one_row.number_of_rows') == 1 assert row['one_row.number_of_rows'] == 1 + @with_engine_connection + def test_reflect_no_such_table(self, engine, connection): + """reflecttable should throw an exception on an invalid table""" + self.assertRaises( + NoSuchTableError, + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + self.assertRaises( + OperationalError, + lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) + @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) self.assertEqual(len(one_row_complex.c), 15) self.assertIsInstance(one_row_complex.c.string, Column) - row = one_row_complex.select().execute().fetchone() + row = connection.execute(one_row_complex.select()).fetchone() self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) # TODO some of these types could be filled in better @@ -112,15 +134,15 @@ def test_reflect_select(self, engine, connection): @with_engine_connection def test_type_map(self, engine, connection): """sqlalchemy should use the dbapi_type_map to infer types from raw queries""" - row = connection.execute('SELECT * FROM one_row_complex').fetchone() + row = connection.execute(text('SELECT * FROM one_row_complex')).fetchone() self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) @with_engine_connection def test_reserved_words(self, engine, connection): """Hive uses backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(bind=engine), Column('map', sqlalchemy.types.String)) - query = str(fake_table.select(fake_table.c.map == 'a')) + fake_table = Table('select', MetaData(), Column('map', sqlalchemy.types.String)) + query = str(fake_table.select().where(fake_table.c.map == 'a').compile(engine)) self.assertIn('`select`', query) self.assertIn('`map`', query) self.assertNotIn('"select"', query) @@ -132,12 +154,12 @@ def test_switch_database(self): with contextlib.closing(engine.connect()) as connection: self.assertIn( ('dummy_table',), - connection.execute('SHOW TABLES').fetchall() + connection.execute(text('SHOW TABLES')).fetchall() ) - connection.execute('USE default') + connection.execute(text('USE default')) self.assertIn( ('one_row',), - connection.execute('SHOW TABLES').fetchall() + connection.execute(text('SHOW TABLES')).fetchall() ) finally: engine.dispose() @@ -160,13 +182,13 @@ def test_lots_of_types(self, engine, connection): cols.append(Column('hive_date', HiveDate)) cols.append(Column('hive_decimal', HiveDecimal)) cols.append(Column('hive_timestamp', HiveTimestamp)) - table = Table('test_table', MetaData(bind=engine), *cols, schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute('SET mapred.job.tracker=local') - connection.execute('USE pyhive_test_database') + table = Table('test_table', MetaData(schema='pyhive_test_database'), *cols,) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(text('SET mapred.job.tracker=local')) + connection.execute(text('USE pyhive_test_database')) big_number = 10 ** 10 - 1 - connection.execute(""" + connection.execute(text(""" INSERT OVERWRITE TABLE test_table SELECT 1, "a", "a", "a", "a", "a", 0.1, @@ -175,41 +197,39 @@ def test_lots_of_types(self, engine, connection): "a", 1, 1, 0.1, 0.1, 0, 0, 0, "a", false, "a", "a", - 0, %d, 123 + 2000 + 0, :big_number, 123 + 2000 FROM default.one_row - """, big_number) - row = connection.execute(table.select()).fetchone() - self.assertEqual(row.hive_date, datetime.date(1970, 1, 1)) + """), {"big_number": big_number}) + row = connection.execute(text("select * from test_table")).fetchone() + self.assertEqual(row.hive_date, datetime.datetime(1970, 1, 1, 0, 0)) self.assertEqual(row.hive_decimal, decimal.Decimal(big_number)) self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000)) - table.drop() + table.drop(bind=connection) @with_engine_connection def test_insert_select(self, engine, connection): - one_row = Table('one_row', MetaData(bind=engine), autoload=True) - table = Table('insert_test', MetaData(bind=engine), - Column('a', sqlalchemy.types.Integer), - schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute('SET mapred.job.tracker=local') + one_row = Table('one_row', MetaData(), autoload_with=engine) + table = Table('insert_test', MetaData(schema='pyhive_test_database'), + Column('a', sqlalchemy.types.Integer)) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(text('SET mapred.job.tracker=local')) # NOTE(jing) I'm stuck on a version of Hive without INSERT ... VALUES connection.execute(table.insert().from_select(['a'], one_row.select())) - - result = table.select().execute().fetchall() + + result = connection.execute(table.select()).fetchall() expected = [(1,)] self.assertEqual(result, expected) @with_engine_connection def test_insert_values(self, engine, connection): - table = Table('insert_test', MetaData(bind=engine), - Column('a', sqlalchemy.types.Integer), - schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute(table.insert([{'a': 1}, {'a': 2}])) - - result = table.select().execute().fetchall() + table = Table('insert_test', MetaData(schema='pyhive_test_database'), + Column('a', sqlalchemy.types.Integer),) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(table.insert().values([{'a': 1}, {'a': 2}])) + + result = connection.execute(table.select()).fetchall() expected = [(1,), (2,)] self.assertEqual(result, expected) diff --git a/pyhive/tests/test_sqlalchemy_presto.py b/pyhive/tests/test_sqlalchemy_presto.py index a01e4a35..58a5c034 100644 --- a/pyhive/tests/test_sqlalchemy_presto.py +++ b/pyhive/tests/test_sqlalchemy_presto.py @@ -8,7 +8,9 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table +from sqlalchemy.sql import text from sqlalchemy.types import String +from decimal import Decimal import contextlib import unittest @@ -27,11 +29,11 @@ def test_bad_format(self): @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) # Presto ignores the union column self.assertEqual(len(one_row_complex.c), 15 - 1) self.assertIsInstance(one_row_complex.c.string, Column) - rows = one_row_complex.select().execute().fetchall() + rows = connection.execute(one_row_complex.select()).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(list(rows[0]), [ True, @@ -48,7 +50,7 @@ def test_reflect_select(self, engine, connection): {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), ]) # TODO some of these types could be filled in better @@ -71,7 +73,7 @@ def test_url_default(self): engine = create_engine('presto://localhost:8080/hive') try: with contextlib.closing(engine.connect()) as connection: - self.assertEqual(connection.execute('SELECT 1 AS foobar FROM one_row').scalar(), 1) + self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) finally: engine.dispose() @@ -79,8 +81,8 @@ def test_url_default(self): def test_reserved_words(self, engine, connection): """Presto uses double quotes, not backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(bind=engine), Column('current_timestamp', String)) - query = str(fake_table.select(fake_table.c.current_timestamp == 'a')) + fake_table = Table('select', MetaData(), Column('current_timestamp', String)) + query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) self.assertIn('"select"', query) self.assertIn('"current_timestamp"', query) self.assertNotIn('`select`', query) diff --git a/pyhive/tests/test_sqlalchemy_trino.py b/pyhive/tests/test_sqlalchemy_trino.py new file mode 100644 index 00000000..c929f941 --- /dev/null +++ b/pyhive/tests/test_sqlalchemy_trino.py @@ -0,0 +1,93 @@ +from sqlalchemy.engine import create_engine +from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase +from pyhive.tests.sqlalchemy_test_case import with_engine_connection +from sqlalchemy.exc import NoSuchTableError, DatabaseError +from sqlalchemy.schema import MetaData, Table, Column +from sqlalchemy.types import String +from sqlalchemy.sql import text +from sqlalchemy import types +from decimal import Decimal + +import unittest +import contextlib + + +class TestSqlAlchemyTrino(unittest.TestCase, SqlAlchemyTestCase): + def create_engine(self): + return create_engine('trino+pyhive://localhost:18080/hive/default?source={}'.format(self.id())) + + def test_bad_format(self): + self.assertRaises( + ValueError, + lambda: create_engine('trino+pyhive://localhost:18080/hive/default/what'), + ) + + @with_engine_connection + def test_reflect_select(self, engine, connection): + """reflecttable should be able to fill in a table from the name""" + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + # Presto ignores the union column + self.assertEqual(len(one_row_complex.c), 15 - 1) + self.assertIsInstance(one_row_complex.c.string, Column) + rows = connection.execute(one_row_complex.select()).fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(list(rows[0]), [ + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + 'a string', + '1970-01-01 00:00:00.000', + b'123', + [1, 2], + {"1": 2, "3": 4}, + [1, 2], + Decimal('0.1'), + ]) + + self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) + self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer) + self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer) + self.assertIsInstance(one_row_complex.c.int.type, types.Integer) + self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger) + self.assertIsInstance(one_row_complex.c.float.type, types.Float) + self.assertIsInstance(one_row_complex.c.double.type, types.Float) + self.assertIsInstance(one_row_complex.c.string.type, String) + self.assertIsInstance(one_row_complex.c.timestamp.type, types.NullType) + self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY) + self.assertIsInstance(one_row_complex.c.array.type, types.NullType) + self.assertIsInstance(one_row_complex.c.map.type, types.NullType) + self.assertIsInstance(one_row_complex.c.struct.type, types.NullType) + self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType) + + @with_engine_connection + def test_reflect_no_such_table(self, engine, connection): + """reflecttable should throw an exception on an invalid table""" + self.assertRaises( + NoSuchTableError, + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + self.assertRaises( + DatabaseError, + lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) + + def test_url_default(self): + engine = create_engine('trino+pyhive://localhost:18080/hive') + try: + with contextlib.closing(engine.connect()) as connection: + self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) + finally: + engine.dispose() + + @with_engine_connection + def test_reserved_words(self, engine, connection): + """Trino uses double quotes, not backticks""" + # Use keywords for the table/column name + fake_table = Table('select', MetaData(), Column('current_timestamp', String)) + query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) + self.assertIn('"select"', query) + self.assertIn('"current_timestamp"', query) + self.assertNotIn('`select`', query) + self.assertNotIn('`current_timestamp`', query) From d4ae481675ac5588ba9101596fa26f22ef0e77c4 Mon Sep 17 00:00:00 2001 From: Multazim Deshmukh <57723564+mdeshmu@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:09:00 +0530 Subject: [PATCH 13/17] update readme to reflect recent changes (#459) --- README.rst | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index 5afd746c..bb10e607 100644 --- a/README.rst +++ b/README.rst @@ -14,8 +14,8 @@ PyHive ====== PyHive is a collection of Python `DB-API `_ and -`SQLAlchemy `_ interfaces for `Presto `_ and -`Hive `_. +`SQLAlchemy `_ interfaces for `Presto `_ , +`Hive `_ and `Trino `_. Usage ===== @@ -25,7 +25,7 @@ DB-API .. code-block:: python from pyhive import presto # or import hive or import trino - cursor = presto.connect('localhost').cursor() + cursor = presto.connect('localhost').cursor() # or use hive.connect or use trino.connect cursor.execute('SELECT * FROM my_awesome_data LIMIT 10') print cursor.fetchone() print cursor.fetchall() @@ -61,7 +61,7 @@ In Python 3.7 `async` became a keyword; you can use `async_` instead: SQLAlchemy ---------- -First install this package to register it with SQLAlchemy (see ``setup.py``). +First install this package to register it with SQLAlchemy, see ``entry_points`` in ``setup.py``. .. code-block:: python @@ -117,7 +117,7 @@ Passing session configuration 'session_props': {'query_max_run_time': '1234m'}} ) create_engine( - 'trino://user@host:443/hive', + 'trino+pyhive://user@host:443/hive', connect_args={'protocol': 'https', 'session_props': {'query_max_run_time': '1234m'}} ) @@ -136,15 +136,18 @@ Requirements Install using -- ``pip install 'pyhive[hive]'`` for the Hive interface and -- ``pip install 'pyhive[presto]'`` for the Presto interface. +- ``pip install 'pyhive[hive]'`` or ``pip install 'pyhive[hive_pure_sasl]'`` for the Hive interface +- ``pip install 'pyhive[presto]'`` for the Presto interface - ``pip install 'pyhive[trino]'`` for the Trino interface +Note: ``'pyhive[hive]'`` extras uses `sasl `_ that doesn't support Python 3.11, See `github issue `_. +Hence PyHive also supports `pure-sasl `_ via additional extras ``'pyhive[hive_pure_sasl]'`` which support Python 3.11. + PyHive works with - Python 2.7 / Python 3 -- For Presto: Presto install -- For Trino: Trino install +- For Presto: `Presto installation `_ +- For Trino: `Trino installation `_ - For Hive: `HiveServer2 `_ daemon Changelog @@ -162,6 +165,26 @@ Contributing - We prefer having a small number of generic features over a large number of specialized, inflexible features. For example, the Presto code takes an arbitrary ``requests_session`` argument for customizing HTTP calls, as opposed to having a separate parameter/branch for each ``requests`` option. +Tips for test environment setup +================================ +You can setup test environment by following ``.travis.yaml`` in this repository. It uses `Cloudera's CDH 5 `_ which requires username and password for download. +It may not be feasible for everyone to get those credentials. Hence below are alternative instructions to setup test environment. + +You can clone `this repository `_ which has Docker Compose setup for Presto and Hive. +You can add below lines to its docker-compose.yaml to start Trino in same environment:: + + trino: + image: trinodb/trino:351 + ports: + - "18080:18080" + volumes: + - ./trino:/etc/trino + +Note: ``./trino`` for docker volume defined above is `trino config from PyHive repository `_ + +Then run:: + docker-compose up -d + Testing ======= .. image:: https://travis-ci.org/dropbox/PyHive.svg From 486eaefdea1326bd5f63a5dd4734c2646cf0bf84 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Thu, 30 May 2024 15:26:26 -0700 Subject: [PATCH 14/17] Update README.rst (#475) --- README.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index bb10e607..1f4db670 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,7 @@ ================================ -Project is currently unsupported +Pyhive project has been donated to Apache Kyuubi. + +You can follow it's development and report any issues you are experiencing here: https://github.com/apache/kyuubi/tree/master/python/pyhive ================================ From ac09074a652fd50e10b57a7f0bbc4f6410961301 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Thu, 30 May 2024 15:35:22 -0700 Subject: [PATCH 15/17] Update README.rst (#476) --- README.rst | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/README.rst b/README.rst index 1f4db670..cbdcce10 100644 --- a/README.rst +++ b/README.rst @@ -1,26 +1,24 @@ -================================ -Pyhive project has been donated to Apache Kyuubi. +======================================================== +PyHive project has been donated to Apache Kyuubi +======================================================== You can follow it's development and report any issues you are experiencing here: https://github.com/apache/kyuubi/tree/master/python/pyhive -================================ +Legacy notes / instructions +=========================== -.. image:: https://travis-ci.org/dropbox/PyHive.svg?branch=master - :target: https://travis-ci.org/dropbox/PyHive -.. image:: https://img.shields.io/codecov/c/github/dropbox/PyHive.svg - -====== PyHive -====== +********** + PyHive is a collection of Python `DB-API `_ and `SQLAlchemy `_ interfaces for `Presto `_ , `Hive `_ and `Trino `_. Usage -===== +********** DB-API ------ @@ -134,7 +132,7 @@ Passing session configuration ) Requirements -============ +************ Install using @@ -153,11 +151,11 @@ PyHive works with - For Hive: `HiveServer2 `_ daemon Changelog -========= +********* See https://github.com/dropbox/PyHive/releases. Contributing -============ +************ - Please fill out the Dropbox Contributor License Agreement at https://opensource.dropbox.com/cla/ and note this in your pull request. - Changes must come with tests, with the exception of trivial things like fixing comments. See .travis.yml for the test environment setup. - Notes on project scope: @@ -168,7 +166,7 @@ Contributing For example, the Presto code takes an arbitrary ``requests_session`` argument for customizing HTTP calls, as opposed to having a separate parameter/branch for each ``requests`` option. Tips for test environment setup -================================ +**************************************** You can setup test environment by following ``.travis.yaml`` in this repository. It uses `Cloudera's CDH 5 `_ which requires username and password for download. It may not be feasible for everyone to get those credentials. Hence below are alternative instructions to setup test environment. @@ -188,7 +186,7 @@ Then run:: docker-compose up -d Testing -======= +******* .. image:: https://travis-ci.org/dropbox/PyHive.svg :target: https://travis-ci.org/dropbox/PyHive .. image:: http://codecov.io/github/dropbox/PyHive/coverage.svg?branch=master @@ -207,7 +205,7 @@ WARNING: This drops/creates tables named ``one_row``, ``one_row_complex``, and ` database called ``pyhive_test_database``. Updating TCLIService -==================== +******************** The TCLIService module is autogenerated using a ``TCLIService.thrift`` file. To update it, the ``generate.py`` file can be used: ``python generate.py ``. When left blank, the From 9ec5bab4bc2b0e3b552703c74b9d5fd234cb31ef Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 7 Aug 2024 12:30:50 -0400 Subject: [PATCH 16/17] feat: JWT support --- pyhive/presto.py | 176 ++++++++++++------ pyhive/tests/test_presto.py | 241 ++++++++++++++++--------- pyhive/tests/test_sqlalchemy_presto.py | 98 +++++++--- 3 files changed, 340 insertions(+), 175 deletions(-) diff --git a/pyhive/presto.py b/pyhive/presto.py index 3217f4c2..227d69dc 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -13,6 +13,7 @@ from pyhive import common from pyhive.common import DBAPITypeObject + # Make all exceptions visible in this module per DB-API from pyhive.exc import * # noqa import base64 @@ -30,18 +31,19 @@ # PEP 249 module globals -apilevel = '2.0' +apilevel = "2.0" threadsafety = 2 # Threads may share the module and connections. -paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s +paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s _logger = logging.getLogger(__name__) TYPES_CONVERTER = { "decimal": Decimal, # As of Presto 0.69, binary data is returned as the varbinary type in base64 format - "varbinary": base64.b64decode + "varbinary": base64.b64decode, } + class PrestoParamEscaper(common.ParamEscaper): def escape_datetime(self, item, format): _type = "timestamp" if isinstance(item, datetime.datetime) else "date" @@ -52,6 +54,22 @@ def escape_datetime(self, item, format): _escaper = PrestoParamEscaper() +class JWTAuth(requests.auth.AuthBase): + """ + Simple authorization handler for JWT requests. + """ + + def __init__(self, token): + self.token = token + + def __call__(self, r): + r.headers["Authorization"] = f"Bearer {self.token}" + return r + + def __eq__(self, other): + return isinstance(other, JWTAuth) and self.token == other.token + + def connect(*args, **kwargs): """Constructor for creating a connection to the database. See class :py:class:`Connection` for arguments. @@ -96,12 +114,28 @@ class Cursor(common.DBAPICursor): visible by other cursors or connections. """ - def __init__(self, host, port='8080', username=None, principal_username=None, catalog='hive', - schema='default', poll_interval=1, source='pyhive', session_props=None, - protocol='http', password=None, requests_session=None, requests_kwargs=None, - KerberosRemoteServiceName=None, KerberosPrincipal=None, - KerberosConfigPath=None, KerberosKeytabPath=None, - KerberosCredentialCachePath=None, KerberosUseCanonicalHostname=None): + def __init__( + self, + host, + port="8080", + username=None, + principal_username=None, + catalog="hive", + schema="default", + poll_interval=1, + source="pyhive", + session_props=None, + protocol="http", + password=None, + requests_session=None, + requests_kwargs=None, + KerberosRemoteServiceName=None, + KerberosPrincipal=None, + KerberosConfigPath=None, + KerberosKeytabPath=None, + KerberosCredentialCachePath=None, + KerberosUseCanonicalHostname=None, + ): """ :param host: hostname to connect to, e.g. ``presto.example.com`` :param port: int -- port, defaults to 8080 @@ -164,7 +198,7 @@ class will use the default requests behavior of making a new session per HTTP re self._session_props = session_props if session_props is not None else {} self.last_query_id = None - if protocol not in ('http', 'https'): + if protocol not in ("http", "https"): raise ValueError("Protocol must be http/https, was {!r}".format(protocol)) self._protocol = protocol @@ -176,31 +210,39 @@ class will use the default requests behavior of making a new session per HTTP re from requests_kerberos import HTTPKerberosAuth, OPTIONAL hostname_override = None - if KerberosUseCanonicalHostname is not None \ - and KerberosUseCanonicalHostname.lower() == 'false': + if ( + KerberosUseCanonicalHostname is not None + and KerberosUseCanonicalHostname.lower() == "false" + ): hostname_override = host if KerberosConfigPath is not None: - os.environ['KRB5_CONFIG'] = KerberosConfigPath + os.environ["KRB5_CONFIG"] = KerberosConfigPath if KerberosKeytabPath is not None: - os.environ['KRB5_CLIENT_KTNAME'] = KerberosKeytabPath + os.environ["KRB5_CLIENT_KTNAME"] = KerberosKeytabPath if KerberosCredentialCachePath is not None: - os.environ['KRB5CCNAME'] = KerberosCredentialCachePath + os.environ["KRB5CCNAME"] = KerberosCredentialCachePath - requests_kwargs['auth'] = HTTPKerberosAuth(mutual_authentication=OPTIONAL, - principal=KerberosPrincipal, - service=KerberosRemoteServiceName, - hostname_override=hostname_override) + requests_kwargs["auth"] = HTTPKerberosAuth( + mutual_authentication=OPTIONAL, + principal=KerberosPrincipal, + service=KerberosRemoteServiceName, + hostname_override=hostname_override, + ) else: - if password is not None and 'auth' in requests_kwargs: - raise ValueError("Cannot use both password and requests_kwargs authentication") - for k in ('method', 'url', 'data', 'headers'): + if password is not None and "auth" in requests_kwargs: + raise ValueError( + "Cannot use both password and requests_kwargs authentication" + ) + for k in ("method", "url", "data", "headers"): if k in requests_kwargs: raise ValueError("Cannot override requests argument {}".format(k)) if password is not None: - requests_kwargs['auth'] = HTTPBasicAuth(username, password) - if protocol != 'https': + requests_kwargs["auth"] = HTTPBasicAuth(username, password) + if protocol != "https": raise ValueError("Protocol must be https when passing a password") + if "jwt" in requests_kwargs: + requests_kwargs["auth"] = JWTAuth(requests_kwargs.pop("jwt")) self._requests_kwargs = requests_kwargs self._reset_state() @@ -230,14 +272,14 @@ def description(self): """ # Sleep until we're done or we got the columns self._fetch_while( - lambda: self._columns is None and - self._state not in (self._STATE_NONE, self._STATE_FINISHED) + lambda: self._columns is None + and self._state not in (self._STATE_NONE, self._STATE_FINISHED) ) if self._columns is None: return None return [ # name, type_code, display_size, internal_size, precision, scale, null_ok - (col['name'], col['type'], None, None, None, None, True) + (col["name"], col["type"], None, None, None, None, True) for col in self._columns ] @@ -247,15 +289,15 @@ def execute(self, operation, parameters=None): Return values are not defined. """ headers = { - 'X-Presto-Catalog': self._catalog, - 'X-Presto-Schema': self._schema, - 'X-Presto-Source': self._source, - 'X-Presto-User': self._username, + "X-Presto-Catalog": self._catalog, + "X-Presto-Schema": self._schema, + "X-Presto-Source": self._source, + "X-Presto-User": self._username, } if self._session_props: - headers['X-Presto-Session'] = ','.join( - '{}={}'.format(propname, propval) + headers["X-Presto-Session"] = ",".join( + "{}={}".format(propname, propval) for propname, propval in self._session_props.items() ) @@ -268,20 +310,30 @@ def execute(self, operation, parameters=None): self._reset_state() self._state = self._STATE_RUNNING - url = urlparse.urlunparse(( - self._protocol, - '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) - _logger.info('%s', sql) + url = urlparse.urlunparse( + ( + self._protocol, + "{}:{}".format(self._host, self._port), + "/v1/statement", + None, + None, + None, + ) + ) + _logger.info("%s", sql) _logger.debug("Headers: %s", headers) response = self._requests_session.post( - url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) + url, data=sql.encode("utf-8"), headers=headers, **self._requests_kwargs + ) self._process_response(response) def cancel(self): if self._state == self._STATE_NONE: raise ProgrammingError("No query yet") if self._nextUri is None: - assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" + assert ( + self._state == self._STATE_FINISHED + ), "Should be finished if nextUri is None" return response = self._requests_session.delete(self._nextUri, **self._requests_kwargs) @@ -304,7 +356,9 @@ def poll(self): if self._state == self._STATE_NONE: raise ProgrammingError("No query yet") if self._nextUri is None: - assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" + assert ( + self._state == self._STATE_FINISHED + ), "Should be finished if nextUri is None" return None response = self._requests_session.get(self._nextUri, **self._requests_kwargs) self._process_response(response) @@ -312,7 +366,9 @@ def poll(self): def _fetch_more(self): """Fetch the next URI and update state""" - self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs)) + self._process_response( + self._requests_session.get(self._nextUri, **self._requests_kwargs) + ) def _process_data(self, rows): for i, col in enumerate(self.description): @@ -333,26 +389,28 @@ def _process_response(self, response): response_json = response.json() _logger.debug("Got response %s", response_json) - assert self._state == self._STATE_RUNNING, "Should be running if processing response" - self._nextUri = response_json.get('nextUri') - self._columns = response_json.get('columns') - if 'id' in response_json: - self.last_query_id = response_json['id'] - if 'X-Presto-Clear-Session' in response.headers: - propname = response.headers['X-Presto-Clear-Session'] + assert ( + self._state == self._STATE_RUNNING + ), "Should be running if processing response" + self._nextUri = response_json.get("nextUri") + self._columns = response_json.get("columns") + if "id" in response_json: + self.last_query_id = response_json["id"] + if "X-Presto-Clear-Session" in response.headers: + propname = response.headers["X-Presto-Clear-Session"] self._session_props.pop(propname, None) - if 'X-Presto-Set-Session' in response.headers: - propname, propval = response.headers['X-Presto-Set-Session'].split('=', 1) + if "X-Presto-Set-Session" in response.headers: + propname, propval = response.headers["X-Presto-Set-Session"].split("=", 1) self._session_props[propname] = propval - if 'data' in response_json: + if "data" in response_json: assert self._columns - new_data = response_json['data'] + new_data = response_json["data"] self._process_data(new_data) self._data += map(tuple, new_data) - if 'nextUri' not in response_json: + if "nextUri" not in response_json: self._state = self._STATE_FINISHED - if 'error' in response_json: - raise DatabaseError(response_json['error']) + if "error" in response_json: + raise DatabaseError(response_json["error"]) # @@ -361,7 +419,7 @@ def _process_response(self, response): # See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java -FIXED_INT_64 = DBAPITypeObject(['bigint']) -VARIABLE_BINARY = DBAPITypeObject(['varchar']) -DOUBLE = DBAPITypeObject(['double']) -BOOLEAN = DBAPITypeObject(['boolean']) +FIXED_INT_64 = DBAPITypeObject(["bigint"]) +VARIABLE_BINARY = DBAPITypeObject(["varchar"]) +DOUBLE = DBAPITypeObject(["double"]) +BOOLEAN = DBAPITypeObject(["boolean"]) diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py index 187b1c21..e09b6844 100644 --- a/pyhive/tests/test_presto.py +++ b/pyhive/tests/test_presto.py @@ -21,8 +21,8 @@ import unittest import datetime -_HOST = 'localhost' -_PORT = '8080' +_HOST = "localhost" +_PORT = "8080" class TestPresto(unittest.TestCase, DBAPITestCase): @@ -32,71 +32,87 @@ def connect(self): return presto.connect(host=_HOST, port=_PORT, source=self.id()) def test_bad_protocol(self): - self.assertRaisesRegexp(ValueError, 'Protocol must be', - lambda: presto.connect('localhost', protocol='nonsense').cursor()) + self.assertRaisesRegexp( + ValueError, + "Protocol must be", + lambda: presto.connect("localhost", protocol="nonsense").cursor(), + ) def test_escape_args(self): escaper = presto.PrestoParamEscaper() - self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), - ("date '2020-04-17'",)) - self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), - ("timestamp '2020-04-17 12:00:00.123'",)) + self.assertEqual( + escaper.escape_args((datetime.date(2020, 4, 17),)), ("date '2020-04-17'",) + ) + self.assertEqual( + escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), + ("timestamp '2020-04-17 12:00:00.123'",), + ) @with_cursor def test_description(self, cursor): - cursor.execute('SELECT 1 AS foobar FROM one_row') - self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)]) + cursor.execute("SELECT 1 AS foobar FROM one_row") + self.assertEqual( + cursor.description, [("foobar", "integer", None, None, None, None, True)] + ) self.assertIsNotNone(cursor.last_query_id) @with_cursor def test_complex(self, cursor): - cursor.execute('SELECT * FROM one_row_complex') + cursor.execute("SELECT * FROM one_row_complex") # TODO Presto drops the union field - if os.environ.get('PRESTO') == '0.147': - tinyint_type = 'integer' - smallint_type = 'integer' - float_type = 'double' + if os.environ.get("PRESTO") == "0.147": + tinyint_type = "integer" + smallint_type = "integer" + float_type = "double" else: # some later version made these map to more specific types - tinyint_type = 'tinyint' - smallint_type = 'smallint' - float_type = 'real' - self.assertEqual(cursor.description, [ - ('boolean', 'boolean', None, None, None, None, True), - ('tinyint', tinyint_type, None, None, None, None, True), - ('smallint', smallint_type, None, None, None, None, True), - ('int', 'integer', None, None, None, None, True), - ('bigint', 'bigint', None, None, None, None, True), - ('float', float_type, None, None, None, None, True), - ('double', 'double', None, None, None, None, True), - ('string', 'varchar', None, None, None, None, True), - ('timestamp', 'timestamp', None, None, None, None, True), - ('binary', 'varbinary', None, None, None, None, True), - ('array', 'array(integer)', None, None, None, None, True), - ('map', 'map(integer,integer)', None, None, None, None, True), - ('struct', 'row(a integer,b integer)', None, None, None, None, True), - # ('union', 'varchar', None, None, None, None, True), - ('decimal', 'decimal(10,1)', None, None, None, None, True), - ]) + tinyint_type = "tinyint" + smallint_type = "smallint" + float_type = "real" + self.assertEqual( + cursor.description, + [ + ("boolean", "boolean", None, None, None, None, True), + ("tinyint", tinyint_type, None, None, None, None, True), + ("smallint", smallint_type, None, None, None, None, True), + ("int", "integer", None, None, None, None, True), + ("bigint", "bigint", None, None, None, None, True), + ("float", float_type, None, None, None, None, True), + ("double", "double", None, None, None, None, True), + ("string", "varchar", None, None, None, None, True), + ("timestamp", "timestamp", None, None, None, None, True), + ("binary", "varbinary", None, None, None, None, True), + ("array", "array(integer)", None, None, None, None, True), + ("map", "map(integer,integer)", None, None, None, None, True), + ("struct", "row(a integer,b integer)", None, None, None, None, True), + # ('union', 'varchar', None, None, None, None, True), + ("decimal", "decimal(10,1)", None, None, None, None, True), + ], + ) rows = cursor.fetchall() - expected = [( - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - '1970-01-01 00:00:00.000', - b'123', - [1, 2], - {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON - [1, 2], # struct is returned as a list of elements - # '{0:1}', - Decimal('0.1'), - )] + expected = [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "1970-01-01 00:00:00.000", + b"123", + [1, 2], + { + "1": 2, + "3": 4, + }, # Presto converts all keys to strings so that they're valid JSON + [1, 2], # struct is returned as a list of elements + # '{0:1}', + Decimal("0.1"), + ) + ] self.assertEqual(rows, expected) # catch unicode/str self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) @@ -108,8 +124,10 @@ def test_cancel(self, cursor): "FROM many_rows a " "CROSS JOIN many_rows b " ) - self.assertIn(cursor.poll()['stats']['state'], ( - 'STARTING', 'PLANNING', 'RUNNING', 'WAITING_FOR_RESOURCES', 'QUEUED')) + self.assertIn( + cursor.poll()["stats"]["state"], + ("STARTING", "PLANNING", "RUNNING", "WAITING_FOR_RESOURCES", "QUEUED"), + ) cursor.cancel() self.assertIsNotNone(cursor.last_query_id) self.assertIsNone(cursor.poll()) @@ -122,31 +140,33 @@ def test_noops(self): cursor = connection.cursor() self.assertEqual(cursor.rowcount, -1) cursor.setinputsizes([]) - cursor.setoutputsize(1, 'blah') + cursor.setoutputsize(1, "blah") self.assertIsNone(cursor.last_query_id) connection.commit() - @mock.patch('requests.post') + @mock.patch("requests.post") def test_non_200(self, post): cursor = self.connect().cursor() post.return_value.status_code = 404 - self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables')) + self.assertRaises(exc.OperationalError, lambda: cursor.execute("show tables")) @with_cursor def test_poll(self, cursor): self.assertRaises(presto.ProgrammingError, cursor.poll) - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") while True: status = cursor.poll() if status is None: break - self.assertIn('stats', status) + self.assertIn("stats", status) def fail(*args, **kwargs): - self.fail("Should not need requests.get after done polling") # pragma: no cover + self.fail( + "Should not need requests.get after done polling" + ) # pragma: no cover - with mock.patch('requests.get', fail): + with mock.patch("requests.get", fail): self.assertEqual(cursor.fetchall(), [(1,)]) @with_cursor @@ -159,86 +179,100 @@ def test_set_session(self, cursor): cursor.fetchall() self.assertEqual(id, cursor.last_query_id) - cursor.execute('SHOW SESSION') + cursor.execute("SHOW SESSION") self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id - rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + rows = [r for r in cursor.fetchall() if r[0] == "query_max_run_time"] self.assertEqual(len(rows), 1) session_prop = rows[0] - self.assertEqual(session_prop[1], '1234m') + self.assertEqual(session_prop[1], "1234m") self.assertEqual(id, cursor.last_query_id) - cursor.execute('RESET SESSION query_max_run_time') + cursor.execute("RESET SESSION query_max_run_time") self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id cursor.fetchall() self.assertEqual(id, cursor.last_query_id) - cursor.execute('SHOW SESSION') + cursor.execute("SHOW SESSION") self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id - rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + rows = [r for r in cursor.fetchall() if r[0] == "query_max_run_time"] self.assertEqual(len(rows), 1) session_prop = rows[0] - self.assertNotEqual(session_prop[1], '1234m') + self.assertNotEqual(session_prop[1], "1234m") self.assertEqual(id, cursor.last_query_id) def test_set_session_in_constructor(self): conn = presto.connect( - host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'} + host=_HOST, source=self.id(), session_props={"query_max_run_time": "1234m"} ) with contextlib.closing(conn): with contextlib.closing(conn.cursor()) as cursor: - cursor.execute('SHOW SESSION') - rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + cursor.execute("SHOW SESSION") + rows = [r for r in cursor.fetchall() if r[0] == "query_max_run_time"] assert len(rows) == 1 session_prop = rows[0] - assert session_prop[1] == '1234m' + assert session_prop[1] == "1234m" - cursor.execute('RESET SESSION query_max_run_time') + cursor.execute("RESET SESSION query_max_run_time") cursor.fetchall() - cursor.execute('SHOW SESSION') - rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + cursor.execute("SHOW SESSION") + rows = [r for r in cursor.fetchall() if r[0] == "query_max_run_time"] assert len(rows) == 1 session_prop = rows[0] - assert session_prop[1] != '1234m' + assert session_prop[1] != "1234m" def test_invalid_protocol_config(self): """protocol should be https when passing password""" self.assertRaisesRegexp( - ValueError, 'Protocol.*https.*password', lambda: presto.connect( - host=_HOST, username='user', password='secret', protocol='http').cursor() + ValueError, + "Protocol.*https.*password", + lambda: presto.connect( + host=_HOST, username="user", password="secret", protocol="http" + ).cursor(), ) def test_invalid_password_and_kwargs(self): """password and requests_kwargs authentication are incompatible""" self.assertRaisesRegexp( - ValueError, 'Cannot use both', lambda: presto.connect( - host=_HOST, username='user', password='secret', protocol='https', - requests_kwargs={'auth': requests.auth.HTTPBasicAuth('user', 'secret')} - ).cursor() + ValueError, + "Cannot use both", + lambda: presto.connect( + host=_HOST, + username="user", + password="secret", + protocol="https", + requests_kwargs={"auth": requests.auth.HTTPBasicAuth("user", "secret")}, + ).cursor(), ) def test_invalid_kwargs(self): """some kwargs are reserved""" self.assertRaisesRegexp( - ValueError, 'Cannot override', lambda: presto.connect( - host=_HOST, username='user', requests_kwargs={'url': 'test'} - ).cursor() + ValueError, + "Cannot override", + lambda: presto.connect( + host=_HOST, username="user", requests_kwargs={"url": "test"} + ).cursor(), ) def test_requests_kwargs(self): connection = presto.connect( - host=_HOST, port=_PORT, source=self.id(), - requests_kwargs={'proxies': {'http': 'localhost:9999'}}, + host=_HOST, + port=_PORT, + source=self.id(), + requests_kwargs={"proxies": {"http": "localhost:9999"}}, ) cursor = connection.cursor() - self.assertRaises(requests.exceptions.ProxyError, - lambda: cursor.execute('SELECT * FROM one_row')) + self.assertRaises( + requests.exceptions.ProxyError, + lambda: cursor.execute("SELECT * FROM one_row"), + ) def test_requests_session(self): with requests.Session() as session: @@ -246,5 +280,32 @@ def test_requests_session(self): host=_HOST, port=_PORT, source=self.id(), requests_session=session ) cursor = connection.cursor() - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) + + def test_requests_jwt(self): + session = mock.MagicMock() + session.post().status_code = 200 + + connection = presto.connect( + host=_HOST, + port=_PORT, + source=self.id(), + requests_session=session, + requests_kwargs={"jwt": "my_jwt"}, + ) + with mock.patch("pyhive.presto.getpass.getuser", return_value="alice"): + cursor = connection.cursor() + cursor.execute("SELECT * FROM one_row") + + session.post.assert_called_with( + "http://localhost:8080/v1/statement", + auth=presto.JWTAuth("my_jwt"), + data=b"SELECT * FROM one_row", + headers={ + "X-Presto-Catalog": "hive", + "X-Presto-Schema": "default", + "X-Presto-Source": "pyhive.tests.test_presto.TestPresto.test_requests_jwt", + "X-Presto-User": "alice", + }, + ) diff --git a/pyhive/tests/test_sqlalchemy_presto.py b/pyhive/tests/test_sqlalchemy_presto.py index 58a5c034..46122119 100644 --- a/pyhive/tests/test_sqlalchemy_presto.py +++ b/pyhive/tests/test_sqlalchemy_presto.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import unicode_literals from builtins import str +from pyhive import presto from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase from pyhive.tests.sqlalchemy_test_case import with_engine_connection from sqlalchemy import types @@ -13,45 +14,54 @@ from decimal import Decimal import contextlib +import mock import unittest class TestSqlAlchemyPresto(unittest.TestCase, SqlAlchemyTestCase): def create_engine(self): - return create_engine('presto://localhost:8080/hive/default?source={}'.format(self.id())) + return create_engine( + "presto://localhost:8080/hive/default?source={}".format(self.id()) + ) def test_bad_format(self): self.assertRaises( ValueError, - lambda: create_engine('presto://localhost:8080/hive/default/what'), + lambda: create_engine("presto://localhost:8080/hive/default/what"), ) @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + one_row_complex = Table("one_row_complex", MetaData(), autoload_with=engine) # Presto ignores the union column self.assertEqual(len(one_row_complex.c), 15 - 1) self.assertIsInstance(one_row_complex.c.string, Column) rows = connection.execute(one_row_complex.select()).fetchall() self.assertEqual(len(rows), 1) - self.assertEqual(list(rows[0]), [ - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - '1970-01-01 00:00:00.000', - b'123', - [1, 2], - {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON - [1, 2], # struct is returned as a list of elements - # '{0:1}', - Decimal('0.1'), - ]) + self.assertEqual( + list(rows[0]), + [ + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "1970-01-01 00:00:00.000", + b"123", + [1, 2], + { + "1": 2, + "3": 4, + }, # Presto converts all keys to strings so that they're valid JSON + [1, 2], # struct is returned as a list of elements + # '{0:1}', + Decimal("0.1"), + ], + ) # TODO some of these types could be filled in better self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) @@ -70,10 +80,15 @@ def test_reflect_select(self, engine, connection): self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType) def test_url_default(self): - engine = create_engine('presto://localhost:8080/hive') + engine = create_engine("presto://localhost:8080/hive") try: with contextlib.closing(engine.connect()) as connection: - self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) + self.assertEqual( + connection.execute( + text("SELECT 1 AS foobar FROM one_row") + ).scalar(), + 1, + ) finally: engine.dispose() @@ -81,9 +96,40 @@ def test_url_default(self): def test_reserved_words(self, engine, connection): """Presto uses double quotes, not backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(), Column('current_timestamp', String)) - query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) + fake_table = Table("select", MetaData(), Column("current_timestamp", String)) + query = str( + fake_table.select() + .where(fake_table.c.current_timestamp == "a") + .compile(engine) + ) self.assertIn('"select"', query) self.assertIn('"current_timestamp"', query) - self.assertNotIn('`select`', query) - self.assertNotIn('`current_timestamp`', query) + self.assertNotIn("`select`", query) + self.assertNotIn("`current_timestamp`", query) + + def test_jwt(self): + session = mock.MagicMock() + session.post().status_code = 200 + + engine = create_engine( + "presto://localhost:8080/hive", + connect_args={ + "requests_kwargs": {"jwt": "my_jwt"}, + "requests_session": session, + }, + ) + with mock.patch("pyhive.presto.getpass.getuser", return_value="alice"): + with contextlib.closing(engine.connect()) as connection: + connection.execute(text("SELECT 1 AS foobar FROM one_row")) + + session.post.assert_called_with( + "http://localhost:8080/v1/statement", + auth=presto.JWTAuth("my_jwt"), + data=b"SELECT 1 AS foobar FROM one_row", + headers={ + "X-Presto-Catalog": "hive", + "X-Presto-Schema": "default", + "X-Presto-Source": "pyhive", + "X-Presto-User": "alice", + }, + ), From 4b75e55eb3863c3d03666e4b58528019a0516ff3 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 7 Aug 2024 15:48:34 -0400 Subject: [PATCH 17/17] Add CI to build package --- .gitignore | 1 + CHANGELOG.md | 4 ++ Jenkinsfile | 94 ++++++++++++++++++++++++++++++++++++++++++++++ pyhive/__init__.py | 2 +- 4 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 CHANGELOG.md create mode 100644 Jenkinsfile diff --git a/.gitignore b/.gitignore index 2ba823c2..e9af4e25 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ cover/ .cache/ *.iml /scripts/.thrift_gen +.python-version diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..f54471b8 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,4 @@ +0.7.0a +====== + +- Add support for JWT authentication. diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 00000000..985863cd --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,94 @@ +LIB_NAME = 'PyHive' +String currentVersion = "" + + +podTemplate( + imagePullSecrets: ['preset-pull'], + nodeUsageMode: 'NORMAL', + containers: [ + containerTemplate( + alwaysPullImage: true, + name: 'ci', + image: 'preset/ci:latest', + ttyEnabled: true, + command: 'cat', + resourceRequestCpu: '100m', + resourceLimitCpu: '200m', + resourceRequestMemory: '1000Mi', + resourceLimitMemory: '2000Mi', + ), + containerTemplate( + alwaysPullImage: true, + name: 'py-ci', + image: 'preset/python:3.8.9-ci', + ttyEnabled: true, + command: 'cat' + ) + ] +) { + node(POD_LABEL) { + container('py-ci') { + stage('Checkout') { + checkout scm + } + + stage('Tests') { + sh(script: 'pip install -e . && pip install -r requirements-dev.txt', label: 'install dependencies') + parallel( + check: { + currentVersion = sh( + script: "python setup.py --version", + returnStdout: true, + label: 'Get current version' + ).trim() + def retVal = sh( + script: "curl -I -f https://pypi.devops.preset.zone/${LIB_NAME}/${LIB_NAME}-${currentVersion}.tar.gz", + returnStatus: true, + label: 'Check for existing tarball' + ) + // If the thing exists, we should bail as we don't want to overwrite + if (retVal == 0) { + error("Version ${currentVersion} of ${LIB_NAME} already exists! Version bump required.") + } + } + ) + } + } + + container('py-ci') { + stage('Package Release') { + if (env.BRANCH_NAME.startsWith("PR-")) { + def shortGitRev = sh( + returnStdout: true, + script: 'git rev-parse --short HEAD' + ).trim() + def pullRequestVersion = "${currentVersion}+${env.BRANCH_NAME}.${shortGitRev}" + sh(script:"sed -i \'s/version = ${currentVersion}/version = ${pullRequestVersion}/g\' setup.cfg", label: 'Changing version for PR') + sh(script:"echo PR version: ${pullRequestVersion}", label: 'PR Release candidate version') + } + sh(script: 'python setup.py sdist --formats=gztar', label: 'Bundling release') + sh(script: "mkdir -p dist/${LIB_NAME} && mv dist/*.gz dist/${LIB_NAME}", label: 'Setup release folder') + } + } + + container('ci') { + stage('Upload Release') { + withCredentials([ + [ + $class : 'AmazonWebServicesCredentialsBinding', + credentialsId : 'ci-user', + accessKeyVariable: 'AWS_ACCESS_KEY_ID', + secretKeyVariable: 'AWS_SECRET_ACCESS_KEY', + ] + ]) { + if ((env.BRANCH_NAME == 'master') || (env.BRANCH_NAME.startsWith("PR-"))) { + sh(script: "aws s3 sync ./dist s3://preset-pypi", label: "Uploading to s3") + } + else { + echo "Skipping upload as this isn't master..." + } + } + } + } + } +} diff --git a/pyhive/__init__.py b/pyhive/__init__.py index 0a6bb1f6..16ce1b3c 100644 --- a/pyhive/__init__.py +++ b/pyhive/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import from __future__ import unicode_literals -__version__ = '0.7.0' +__version__ = '0.7.0a'