From 15c77bb4bae0582951bcbc62729156d4e5da7e3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Y=C3=BCce=20Tekol?= Date: Mon, 27 Mar 2023 20:52:42 +0300 Subject: [PATCH] DBAPI Support [API-911] (#586) Added DBAPI support --- Makefile | 11 + hazelcast/db.py | 501 +++++++++++ hazelcast/sql.py | 4 + tests/integration/dbapi/__init__.py | 0 tests/integration/dbapi/db_test.py | 149 ++++ tests/integration/dbapi/dbapi20.py | 831 ++++++++++++++++++ .../dbapi/hazelcast_dbapi20_test.py | 61 ++ tests/unit/dbapi_test.py | 135 +++ 8 files changed, 1692 insertions(+) create mode 100644 Makefile create mode 100644 hazelcast/db.py create mode 100644 tests/integration/dbapi/__init__.py create mode 100644 tests/integration/dbapi/db_test.py create mode 100644 tests/integration/dbapi/dbapi20.py create mode 100644 tests/integration/dbapi/hazelcast_dbapi20_test.py create mode 100644 tests/unit/dbapi_test.py diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..12585223c8 --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +.PHONY: check test test-cover + +check: + mypy --show-error-codes hazelcast + black --check --config black.toml . + +test: + pytest + +test-cover: + pytest --cov=hazelcast --cov-report=xml diff --git a/hazelcast/db.py b/hazelcast/db.py new file mode 100644 index 0000000000..a26924369a --- /dev/null +++ b/hazelcast/db.py @@ -0,0 +1,501 @@ +from datetime import date, datetime, time +from time import localtime +from typing import Any, Callable, Iterator, List, Optional, Sequence, Union, Tuple, Set, NamedTuple +import enum +import itertools +import threading +import urllib.parse + +from hazelcast import HazelcastClient +from hazelcast.config import Config +from hazelcast.sql import ( + HazelcastSqlError, + SqlColumnType, + SqlResult, + SqlRow, + SqlRowMetadata, + SqlExpectedResultType, + _DEFAULT_CURSOR_BUFFER_SIZE, +) + +apilevel = "2.0" +# Threads may share the module and connections. +threadsafety = 2 +paramstyle = "qmark" + + +class Type(enum.Enum): + NULL = 0 + STRING = 1 + BOOLEAN = 2 + DATE = 3 + TIME = 4 + DATETIME = 5 + INTEGER = 6 + FLOAT = 7 + DECIMAL = 8 + JSON = 9 + OBJECT = 10 + + +ColumnDescription = NamedTuple( + "ColumnDescription", + [ + ("name", str), + ("type", Type), + ("display_size", None), + ("internal_size", None), + ("precision", None), + ("scale", None), + ("null_ok", bool), + ], +) + + +class _DBAPIType: + def __init__(self, *values: Type): + self._values = values + + def __eq__(self, other: object) -> bool: + return other in self._values + + def __ne__(self, other: object) -> bool: + return other not in self._values + + +Date = date +Time = time +Timestamp = datetime +Binary = bytes +STRING = _DBAPIType(Type.STRING) +DATETIME = _DBAPIType(Type.DATE, Type.TIME, Type.DATETIME) +BINARY = _DBAPIType() +NUMBER = _DBAPIType(Type.INTEGER, Type.FLOAT, Type.DECIMAL) +ROWID = _DBAPIType() + + +def DateFromTicks(ticks): + return date(*localtime(ticks)[:3]) + + +def TimeFromTicks(ticks): + return time(*localtime(ticks)[3:6]) + + +def TimestampFromTicks(ticks): + return datetime(*localtime(ticks)[:6]) + + +class Cursor: + def __init__(self, conn: "Connection"): + self.arraysize = 1 + self._conn = conn + self._res: Union[SqlResult, None] = None + self._description: Union[List[ColumnDescription], None] = None + self._iter: Optional[Iterator[SqlRow]] = None + self._rownumber = -1 + self._closed = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __iter__(self) -> Optional[Iterator[SqlRow]]: + return self._iter + + @property + def connection(self): + return self._conn + + @property + def description(self) -> Union[List[ColumnDescription], None]: + return self._description + + @property + def rowcount(self) -> int: + return -1 + + @property + def rownumber(self) -> Optional[int]: + if self._rownumber < 0: + return None + return self._rownumber + + def close(self): + if not self._closed: + self._closed = True + self._conn._close_cursor(self) + if self._res: + self._res.close() + self._res = None + + def execute(self, operation: str, params: Optional[Tuple] = None) -> None: + if params is not None and not isinstance(params, tuple): + raise InterfaceError("params must be a tuple or None") + params = params or () + self._ensure_open() + self._rownumber = -1 + self._iter = None + self._res = None + cbs = _DEFAULT_CURSOR_BUFFER_SIZE + if self.arraysize > 0: + cbs = self.arraysize + self._description = None + res = ( + self._conn._get_client() + .sql.execute(operation, *params, cursor_buffer_size=cbs) + .result() + ) + if res.is_row_set(): + self._rownumber = 0 + self._res = res + self._description = self._make_description(res.get_row_metadata()) + self._iter = res.__iter__() + + def executemany(self, operation: str, seq_of_params: Sequence[Tuple]) -> None: + self._ensure_open() + self._rownumber = -1 + self._iter = None + self._res = None + futures = [] + svc = self._conn._get_client().sql + for params in seq_of_params: + futures.append( + svc.execute( + operation, *params, expected_result_type=SqlExpectedResultType.UPDATE_COUNT + ) + ) + for fut in futures: + fut.result() + + def fetchone(self) -> Optional[SqlRow]: + if self._iter is None: + raise InterfaceError("fetch can only be called after row returning queries") + try: + row = next(self._iter) + self._rownumber += 1 + return row + except StopIteration: + return None + + def fetchmany(self, size: Optional[int] = None) -> List[SqlRow]: + if self._iter is None: + raise InterfaceError("fetchmany can only be called after row returning queries") + if size is None: + size = self.arraysize + rows = list(itertools.islice(self._iter, size)) + self._rownumber += len(rows) + return rows + + def fetchall(self) -> List[SqlRow]: + if self._iter is None: + raise InterfaceError("fetchall can only be called after row returning queries") + rows = list(self._iter) + self._rownumber += len(rows) + return rows + + def next(self) -> Optional[SqlRow]: + if self._iter is None: + return None + return next(self._iter) + + def setinputsizes(self, sizes): + pass + + def setoutputsize(self, size=None, column=None): + pass + + @classmethod + def _make_description(cls, metadata: SqlRowMetadata) -> List[ColumnDescription]: + r = [] + for col in metadata.columns: + r.append( + ColumnDescription( + name=col.name, + type=_map_type(col.type), + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=col.nullable, + ) + ) + return r + + def _ensure_open(self): + if self._closed: + raise self.connection.ProgrammingError("connection is closed") + + +class Connection: + def __init__(self, config: Config): + self.__mu = threading.RLock() + self.__client: Optional[HazelcastClient] = HazelcastClient(config) + self._cursors: Set[Cursor] = set() + + def __enter__(self) -> "Connection": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def close(self) -> None: + if self.__client: + with self.__mu: + if self.__client: + self.__client.shutdown() + self.__client = None + return + raise InterfaceError("connection was already closed") + + def commit(self) -> None: + # transactions are not supported + # ensure an exception is thrown if there is no client + self._get_client() + + def cursor(self) -> Cursor: + with self.__mu: + if self.__client is not None: + cursor = Cursor(self) + self._cursors.add(cursor) + return cursor + raise InterfaceError("connection is already closed") + + def _get_client(self) -> HazelcastClient: + with self.__mu: + if self.__client is not None: + return self.__client + raise InterfaceError("connection is closed") + + def _close_cursor(self, cursor: Cursor) -> None: + with self.__mu: + if cursor in self._cursors: + self._cursors.remove(cursor) + + @property + def Error(self): + return Error + + @property + def Warning(self): + return Warning + + @property + def InterfaceError(self): + return InterfaceError + + @property + def DatabaseError(self): + return DatabaseError + + @property + def InternalError(self): + return InternalError + + @property + def OperationalError(self): + return OperationalError + + @property + def ProgrammingError(self): + return ProgrammingError + + @property + def IntegrityError(self): + return IntegrityError + + @property + def DataError(self): + return DataError + + @property + def NotSupportedError(self): + return NotSupportedError + + +def connect( + config=None, + *, + dsn="", + user: str = None, + password: str = None, + host: str = None, + port: int = None, + cluster_name: str = None, +) -> Connection: + c = _make_config( + config, + dsn=dsn, + user=user, + password=password, + host=host, + port=port, + cluster_name=cluster_name, + ) + return Connection(c) + + +class Error(Exception): + pass + + +class Warning(Exception): + pass + + +class InterfaceError(Error): + pass + + +class DatabaseError(Error): + pass + + +class InternalError(DatabaseError): + pass + + +class OperationalError(DatabaseError): + pass + + +class ProgrammingError(DatabaseError): + pass + + +class IntegrityError(DatabaseError): + pass + + +class DataError(DatabaseError): + pass + + +class NotSupportedError(DatabaseError): + pass + + +def _wrap_error(f: Callable) -> Any: + try: + return f() + except HazelcastSqlError as e: + raise DatabaseError(f"{e.args}") from e + except Exception as e: + raise DatabaseError from e + + +def _map_type(code: int) -> Type: + type = _type_map.get(code) + if type is None: + raise NotSupportedError(f"Unknown type code: {code}") + return type + + +_type_map = { + SqlColumnType.VARCHAR: Type.STRING, + SqlColumnType.BOOLEAN: Type.BOOLEAN, + SqlColumnType.TINYINT: Type.INTEGER, + SqlColumnType.SMALLINT: Type.INTEGER, + SqlColumnType.INTEGER: Type.INTEGER, + SqlColumnType.BIGINT: Type.INTEGER, + SqlColumnType.DECIMAL: Type.DECIMAL, + SqlColumnType.REAL: Type.FLOAT, + SqlColumnType.DOUBLE: Type.FLOAT, + SqlColumnType.DATE: Type.DATE, + SqlColumnType.TIME: Type.TIME, + SqlColumnType.TIMESTAMP: Type.DATETIME, + SqlColumnType.TIMESTAMP_WITH_TIME_ZONE: Type.DATETIME, + SqlColumnType.OBJECT: Type.OBJECT, + SqlColumnType.NULL: Type.NULL, + SqlColumnType.JSON: Type.JSON, +} + + +def _make_config( + config: Config = None, + *, + dsn="", + user: str = None, + password: str = None, + host: str = None, + port: int = None, + cluster_name: str = None, +) -> Config: + kwargs_used = user or password or host or port or cluster_name + if config is not None: + if not isinstance(config, Config): + raise InterfaceError("config must be a hazelcast.Config object") + if dsn or kwargs_used: + raise InterfaceError("config argument cannot be used with keyword arguments") + return config + if dsn: + if kwargs_used: + raise InterfaceError("dsn argument cannot be used with other keyword arguments") + return _parse_dsn(dsn) + config = Config() + if not host: + host = "localhost" + if not port: + port = 5701 + host = f"{host}:{port}" + config.cluster_members = [host] + if user is not None: + config.creds_username = user + if password is not None: + config.creds_password = password + if cluster_name is not None: + config.cluster_name = cluster_name + return config + + +def _parse_dsn(dsn: str) -> Config: + r = urllib.parse.urlparse(dsn) + if r.scheme != "hz": + raise InterfaceError(f"Scheme must be hz, but it is: {r.scheme}") + cfg = Config() + host = "localhost" + port = 5701 + if r.hostname: + host = r.hostname + if r.port: + port = r.port + cfg.cluster_members = [f"{host}:{port}"] + if r.username: + cfg.creds_username = r.username + if r.password: + cfg.creds_password = r.password + for k, v in urllib.parse.parse_qsl(r.query): + value: Any = v + if k in _parse_dsn_map: + attr_name, transform = _parse_dsn_map[k] + if transform: + try: + value = transform(value) + except ValueError as e: + raise InterfaceError from e + setattr(cfg, attr_name, value) + else: + raise InterfaceError(f"Unknown DSN attribute: {k}") + return cfg + + +def _make_bool(v: str) -> bool: + if v == "true": + return True + if v == "false": + return False + raise ValueError(f"Invalid boolean: {v}") + + +_parse_dsn_map = { + "cluster.name": ("cluster_name", None), + "cloud.token": ("cloud_discovery_token", None), + "smart": ("smart_routing", _make_bool), + "ssl": ("ssl_enabled", _make_bool), + "ssl.ca.path": ("ssl_cafile", None), + "ssl.cert.path": ("ssl_certfile", None), + "ssl.key.path": ("ssl_keyfile", None), + "ssl.key.password": ("ssl_password", None), +} diff --git a/hazelcast/sql.py b/hazelcast/sql.py index 67ed1208cb..7d74d8eec1 100644 --- a/hazelcast/sql.py +++ b/hazelcast/sql.py @@ -636,6 +636,10 @@ def __repr__(self): for i in range(self._row_metadata.column_count) ) + def __len__(self): + """Returns number of columns of the row.""" + return self._row_metadata.column_count + class _ExecuteResponse: """Represent the response of the first execute request.""" diff --git a/tests/integration/dbapi/__init__.py b/tests/integration/dbapi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/dbapi/db_test.py b/tests/integration/dbapi/db_test.py new file mode 100644 index 0000000000..3215bd47ca --- /dev/null +++ b/tests/integration/dbapi/db_test.py @@ -0,0 +1,149 @@ +import threading + +from hazelcast import HazelcastClient +from hazelcast.config import Config +from hazelcast.db import connect, Connection, Type +from tests.integration.backward_compatible.sql_test import ( + SqlTestBase, + compare_server_version_with_rc, + compare_client_version, + SERVER_CONFIG, + JET_ENABLED_CONFIG, + Student, +) + + +class DbapiTestBase(SqlTestBase): + + rc = None + cluster = None + is_v5_or_newer_server = None + is_v5_or_newer_client = None + conn: Connection = None + + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.is_v5_or_newer_server = compare_server_version_with_rc(cls.rc, "5.0") >= 0 + cls.is_v5_or_newer_client = compare_client_version("5.0") >= 0 + # enable Jet if the server is 5.0+ + cluster_config = SERVER_CONFIG % (JET_ENABLED_CONFIG if cls.is_v5_or_newer_server else "") + cls.cluster = cls.create_cluster(cls.rc, cluster_config) + cls.member = cls.cluster.start_member() + cls.client = HazelcastClient( + cluster_name=cls.cluster.id, portable_factories={666: {6: Student}} + ) + cfg = Config() + cfg.cluster_name = cls.cluster.id + cfg.portable_factories = {666: {6: Student}} + cls.conn = connect(cfg) + + @classmethod + def tearDownClass(cls): + cls.conn.close() + cls.client.shutdown() + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + def _create_mapping(self, value_format="INTEGER"): + if not self.is_v5_or_newer_server: + # Implicit mappings are removed in 5.0 + return + q = f""" + CREATE MAPPING "{self.map_name}" ( + __key INT, + this {value_format} + ) + TYPE IMaP + OPTIONS ( + 'keyFormat' = 'int', + 'valueFormat' = '{value_format.lower()}' + ) + """ + c = self.conn.cursor() + c.execute(q) + + def _populate_map(self, entry_count=10, value_factory=lambda v: v): + entries = [(i, value_factory(i)) for i in range(entry_count)] + c = self.conn.cursor() + c.executemany(f'INSERT INTO "{self.map_name}" VALUES(?, ?)', entries) + + +class DbapiTest(DbapiTestBase): + def test_fetchone(self): + self._create_mapping() + entry_count = 11 + self._populate_map(entry_count) + c = self.conn.cursor() + c.execute(f'SELECT * FROM "{self.map_name}" where __key < ? order by __key', (5,)) + self.assertEqual(0, c.rownumber) + row = c.fetchone() + self.assertEqual((0, 0), (row.get_object("__key"), row.get_object("this"))) + self.assertEqual(1, c.rownumber) + row = c.fetchone() + self.assertEqual((1, 1), (row.get_object("__key"), row.get_object("this"))) + self.assertEqual(2, c.rownumber) + + def test_fetchmany(self): + self._create_mapping() + entry_count = 11 + self._populate_map(entry_count) + c = self.conn.cursor() + c.execute(f'SELECT * FROM "{self.map_name}" where __key < ? order by __key', (5,)) + self.assertEqual(0, c.rownumber) + result = list(c.fetchmany(3)) + self.assertCountEqual( + [(i, i) for i in range(3)], + [(row.get_object("__key"), row.get_object("this")) for row in result], + ) + self.assertEqual(3, c.rownumber) + result = list(c.fetchmany(3)) + self.assertCountEqual( + [(i, i) for i in range(3, 5)], + [(row.get_object("__key"), row.get_object("this")) for row in result], + ) + self.assertEqual(5, c.rownumber) + + def test_fetchall(self): + self._create_mapping() + entry_count = 11 + self._populate_map(entry_count) + c = self.conn.cursor() + c.execute(f'SELECT * FROM "{self.map_name}" where __key < ? order by __key', (5,)) + self.assertEqual(0, c.rownumber) + result = list(c.fetchall()) + self.assertCountEqual( + [(i, i) for i in range(5)], + [(row.get_object("__key"), row.get_object("this")) for row in result], + ) + self.assertEqual(5, c.rownumber) + + def test_cursor_connection(self): + c = self.conn.cursor() + self.assertEqual(self.conn, c.connection) + + def test_description(self): + self._create_mapping() + self._populate_map(1) + c = self.conn.cursor() + c.execute(f'SELECT * FROM "{self.map_name}"') + target = [ + ("__key", Type.INTEGER, None, None, None, None, True), + ("this", Type.INTEGER, None, None, None, None, True), + ] + self.assertEqual(target, c.description) + + def test_connection_share(self): + def f(): + c = self.conn.cursor() + c.execute("show mappings;") + c.fetchall() + + threads = [] + for i in range(100): + t = threading.Thread(target=f) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertEqual(len(threads), len(self.conn._cursors)) diff --git a/tests/integration/dbapi/dbapi20.py b/tests/integration/dbapi/dbapi20.py new file mode 100644 index 0000000000..d23d0bc70d --- /dev/null +++ b/tests/integration/dbapi/dbapi20.py @@ -0,0 +1,831 @@ +#!/usr/bin/env python +""" Python DB API 2.0 driver compliance unit test suite. + + This software is Public Domain and may be used without restrictions. + + "Now we have booze and barflies entering the discussion, plus rumours of + DBAs on drugs... and I won't tell you what flashes through my mind each + time I read the subject line with 'Anal Compliance' in it. All around + this is turning out to be a thoroughly unwholesome unit test." + + -- Ian Bicking +""" + +__version__ = "1.15.0" + +import unittest +import time +import sys + +if sys.version[0] >= "3": # python 3.x + _BaseException = Exception + + def _failUnless(self, expr, msg=None): + self.assertTrue(expr, msg) + +else: # python 2.x + from exceptions import StandardError as _BaseException + + def _failUnless(self, expr, msg=None): + self.failUnless(expr, msg) ## deprecated since Python 2.6 + + +def str2bytes(sval): + if sys.version_info < (3, 0) and isinstance(sval, str): + sval = sval.decode("latin1") + return sval.encode("latin1") # python 3 make unicode into bytes + + +class DatabaseAPI20Test(unittest.TestCase): + """Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. + + The 'Optional Extensions' are not yet being tested. + + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: + + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] + + Don't 'import DatabaseAPI20Test from dbapi20', or you will + confuse the unit tester - just 'import dbapi20'. + """ + + # The self.driver module. This should be the module where the 'connect' + # method is to be found + driver = None + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + table_prefix = "dbapi20test_" # If you need to specify a prefix for tables + + ddl1 = "create table %sbooze (name varchar(20))" % table_prefix + ddl2 = "create table %sbarflys (name varchar(20), drink varchar(30))" % table_prefix + xddl1 = "drop table %sbooze" % table_prefix + xddl2 = "drop table %sbarflys" % table_prefix + insert = "insert" + + lowerfunc = "lower" # Name of stored procedure to convert string->lowercase + + # Some drivers may need to override these helpers, for example adding + # a 'commit' after the execute. + def executeDDL1(self, cursor): + cursor.execute(self.ddl1) + + def executeDDL2(self, cursor): + cursor.execute(self.ddl2) + + def setUp(self): + """self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + """ + pass + + def tearDown(self): + """self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + """ + try: + con = self._connect() + try: + cur = con.cursor() + for ddl in (self.xddl1, self.xddl2): + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + except _BaseException: + pass + + def _connect(self): + try: + r = self.driver.connect(*self.connect_args, **self.connect_kw_args) + except AttributeError: + self.fail("No connect method found in self.driver module") + return r + + def test_connect(self): + con = self._connect() + con.close() + + def test_apilevel(self): + try: + # Must exist + apilevel = self.driver.apilevel + # Must equal 2.0 + self.assertEqual(apilevel, "2.0") + except AttributeError: + self.fail("Driver doesn't define apilevel") + + def test_threadsafety(self): + try: + # Must exist + threadsafety = self.driver.threadsafety + # Must be a valid value + _failUnless(self, threadsafety in (0, 1, 2, 3)) + except AttributeError: + self.fail("Driver doesn't define threadsafety") + + def test_paramstyle(self): + try: + # Must exist + paramstyle = self.driver.paramstyle + # Must be a valid value + _failUnless(self, paramstyle in ("qmark", "numeric", "named", "format", "pyformat")) + except AttributeError: + self.fail("Driver doesn't define paramstyle") + + def test_Exceptions(self): + # Make sure required exceptions exist, and are in the + # defined heirarchy. + if sys.version[0] == "3": # under Python 3 StardardError no longer exists + self.assertTrue(issubclass(self.driver.Warning, Exception)) + self.assertTrue(issubclass(self.driver.Error, Exception)) + else: + self.failUnless(issubclass(self.driver.Warning, StandardError)) + self.failUnless(issubclass(self.driver.Error, StandardError)) + + _failUnless(self, issubclass(self.driver.InterfaceError, self.driver.Error)) + _failUnless(self, issubclass(self.driver.DatabaseError, self.driver.Error)) + _failUnless(self, issubclass(self.driver.OperationalError, self.driver.Error)) + _failUnless(self, issubclass(self.driver.IntegrityError, self.driver.Error)) + _failUnless(self, issubclass(self.driver.InternalError, self.driver.Error)) + _failUnless(self, issubclass(self.driver.ProgrammingError, self.driver.Error)) + _failUnless(self, issubclass(self.driver.NotSupportedError, self.driver.Error)) + + def test_ExceptionsAsConnectionAttributes(self): + # OPTIONAL EXTENSION + # Test for the optional DB API 2.0 extension, where the exceptions + # are exposed as attributes on the Connection object + # I figure this optional extension will be implemented by any + # driver author who is using this test suite, so it is enabled + # by default. + drv = self.driver + con = self._connect() + try: + _failUnless(self, con.Warning is drv.Warning) + _failUnless(self, con.Error is drv.Error) + _failUnless(self, con.InterfaceError is drv.InterfaceError) + _failUnless(self, con.DatabaseError is drv.DatabaseError) + _failUnless(self, con.OperationalError is drv.OperationalError) + _failUnless(self, con.IntegrityError is drv.IntegrityError) + _failUnless(self, con.InternalError is drv.InternalError) + _failUnless(self, con.ProgrammingError is drv.ProgrammingError) + _failUnless(self, con.NotSupportedError is drv.NotSupportedError) + finally: + con.close() + + def test_commit(self): + con = self._connect() + try: + # Commit must work, even if it doesn't do anything + con.commit() + finally: + con.close() + + def test_rollback(self): + con = self._connect() + try: + # If rollback is defined, it should either work or throw + # the documented exception + if hasattr(con, "rollback"): + try: + con.rollback() + except self.driver.NotSupportedError: + pass + finally: + con.close() + + def test_cursor(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + def test_cursor_isolation(self): + con = self._connect() + try: + # Make sure cursors created from the same connection have + # the documented transaction isolation level + cur1 = con.cursor() + cur2 = con.cursor() + self.executeDDL1(cur1) + cur1.execute( + "%s into %sbooze values ('Victoria Bitter')" % (self.insert, self.table_prefix) + ) + cur2.execute("select name from %sbooze" % self.table_prefix) + booze = cur2.fetchall() + self.assertEqual(len(booze), 1) + self.assertEqual(len(booze[0]), 1) + self.assertEqual(booze[0][0], "Victoria Bitter") + finally: + con.close() + + def test_description(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual( + cur.description, + None, + "cursor.description should be none after executing a " + "statement that can return no rows (such as DDL)", + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + len(cur.description), 1, "cursor.description describes too many columns" + ) + self.assertEqual( + len(cur.description[0]), 7, "cursor.description[x] tuples must have 7 elements" + ) + self.assertEqual( + cur.description[0][0].lower(), + "name", + "cursor.description[x][0] must return column name", + ) + self.assertEqual( + cur.description[0][1], + self.driver.STRING, + "cursor.description[x][1] must return column type. Got %r" % cur.description[0][1], + ) + + # Make sure self.description gets reset + self.executeDDL2(cur) + self.assertEqual( + cur.description, + None, + "cursor.description not being set to None when executing " + "no-result statements (eg. DDL)", + ) + finally: + con.close() + + def test_rowcount(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + _failUnless( + self, + cur.rowcount in (-1, 0), # Bug #543885 + "cursor.rowcount should be -1 or 0 after executing no-result " "statements", + ) + cur.execute( + "%s into %sbooze values ('Victoria Bitter')" % (self.insert, self.table_prefix) + ) + _failUnless( + self, + cur.rowcount in (-1, 1), + "cursor.rowcount should == number or rows inserted, or " + "set to -1 after executing an insert statement", + ) + cur.execute("select name from %sbooze" % self.table_prefix) + _failUnless( + self, + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) + self.executeDDL2(cur) + _failUnless( + self, + cur.rowcount in (-1, 0), # Bug #543885 + "cursor.rowcount should be -1 or 0 after executing no-result " "statements", + ) + finally: + con.close() + + lower_func = "lower" + + def test_callproc(self): + con = self._connect() + try: + cur = con.cursor() + if self.lower_func and hasattr(cur, "callproc"): + r = cur.callproc(self.lower_func, ("FOO",)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], "FOO") + r = cur.fetchall() + self.assertEqual(len(r), 1, "callproc produced no result set") + self.assertEqual(len(r[0]), 1, "callproc produced invalid result set") + self.assertEqual(r[0][0], "foo", "callproc produced invalid results") + finally: + con.close() + + def test_close(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + # cursor.execute should raise an Error if called after connection + # closed + self.assertRaises(self.driver.Error, self.executeDDL1, cur) + + # connection.commit should raise an Error if called after connection' + # closed.' + self.assertRaises(self.driver.Error, con.commit) + + def test_non_idempotent_close(self): + con = self._connect() + con.close() + # connection.close should raise an Error if called more than once + #!!! reasonable persons differ about the usefulness of this test and this feature !!! + self.assertRaises(self.driver.Error, con.close) + + def test_execute(self): + con = self._connect() + try: + cur = con.cursor() + self._paraminsert(cur) + finally: + con.close() + + def _paraminsert(self, cur): + self.executeDDL2(cur) + cur.execute( + "%s into %sbarflys values ('Victoria Bitter', 'thi%%s :may ca%%(u)se? troub:1e')" + % (self.insert, self.table_prefix) + ) + _failUnless(self, cur.rowcount in (-1, 1)) + + if self.driver.paramstyle == "qmark": + cur.execute( + "%s into %sbarflys values (?, 'thi%%s :may ca%%(u)se? troub:1e')" + % (self.insert, self.table_prefix), + ("Cooper's",), + ) + elif self.driver.paramstyle == "numeric": + cur.execute( + "%s into %sbarflys values (:1, 'thi%%s :may ca%%(u)se? troub:1e')" + % (self.insert, self.table_prefix), + ("Cooper's",), + ) + elif self.driver.paramstyle == "named": + cur.execute( + "%s into %sbarflys values (:beer, 'thi%%s :may ca%%(u)se? troub:1e')" + % (self.insert, self.table_prefix), + {"beer": "Cooper's"}, + ) + elif self.driver.paramstyle == "format": + cur.execute( + "%s into %sbarflys values (%%s, 'thi%%%%s :may ca%%%%(u)se? troub:1e')" + % (self.insert, self.table_prefix), + ("Cooper's",), + ) + elif self.driver.paramstyle == "pyformat": + cur.execute( + "%s into %sbarflys values (%%(beer)s, 'thi%%%%s :may ca%%%%(u)se? troub:1e')" + % (self.insert, self.table_prefix), + {"beer": "Cooper's"}, + ) + else: + self.fail("Invalid paramstyle") + _failUnless(self, cur.rowcount in (-1, 1)) + + cur.execute("select name, drink from %sbarflys" % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res), 2, "cursor.fetchall returned too few rows") + beers = [res[0][0], res[1][0]] + beers.sort() + self.assertEqual( + beers[0], + "Cooper's", + "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", + ) + self.assertEqual( + beers[1], + "Victoria Bitter", + "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", + ) + trouble = "thi%s :may ca%(u)se? troub:1e" + self.assertEqual( + res[0][1], + trouble, + "cursor.fetchall retrieved incorrect data, or data inserted " + "incorrectly. Got=%s, Expected=%s" % (repr(res[0][1]), repr(trouble)), + ) + self.assertEqual( + res[1][1], + trouble, + "cursor.fetchall retrieved incorrect data, or data inserted " + "incorrectly. Got=%s, Expected=%s" % (repr(res[1][1]), repr(trouble)), + ) + + def test_executemany(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + largs = [("Cooper's",), ("Boag's",)] + margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}] + if self.driver.paramstyle == "qmark": + cur.executemany( + "%s into %sbooze values (?)" % (self.insert, self.table_prefix), largs + ) + elif self.driver.paramstyle == "numeric": + cur.executemany( + "%s into %sbooze values (:1)" % (self.insert, self.table_prefix), largs + ) + elif self.driver.paramstyle == "named": + cur.executemany( + "%s into %sbooze values (:beer)" % (self.insert, self.table_prefix), margs + ) + elif self.driver.paramstyle == "format": + cur.executemany( + "%s into %sbooze values (%%s)" % (self.insert, self.table_prefix), largs + ) + elif self.driver.paramstyle == "pyformat": + cur.executemany( + "%s into %sbooze values (%%(beer)s)" % (self.insert, self.table_prefix), margs + ) + else: + self.fail("Unknown paramstyle") + _failUnless( + self, + cur.rowcount in (-1, 2), + "insert using cursor.executemany set cursor.rowcount to " + "incorrect value %r" % cur.rowcount, + ) + cur.execute("select name from %sbooze" % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res), 2, "cursor.fetchall retrieved incorrect number of rows") + beers = [res[0][0], res[1][0]] + beers.sort() + self.assertEqual(beers[0], "Boag's", "incorrect data retrieved") + self.assertEqual(beers[1], "Cooper's", "incorrect data retrieved") + finally: + con.close() + + def test_fetchone(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchone should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error, cur.fetchone) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannnot return rows + self.executeDDL1(cur) + self.assertRaises(self.driver.Error, cur.fetchone) + + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves " "no rows", + ) + _failUnless(self, cur.rowcount in (-1, 0)) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannnot return rows + cur.execute( + "%s into %sbooze values ('Victoria Bitter')" % (self.insert, self.table_prefix) + ) + self.assertRaises(self.driver.Error, cur.fetchone) + + cur.execute("select name from %sbooze" % self.table_prefix) + r = cur.fetchone() + self.assertEqual(len(r), 1, "cursor.fetchone should have retrieved a single row") + self.assertEqual(r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data") + self.assertEqual( + cur.fetchone(), None, "cursor.fetchone should return None if no more rows available" + ) + _failUnless(self, cur.rowcount in (-1, 1)) + finally: + con.close() + + samples = [ + "Carlton Cold", + "Carlton Draft", + "Mountain Goat", + "Redback", + "Victoria Bitter", + "XXXX", + ] + + def _populate(self): + """Return a list of sql commands to setup the DB for the fetch + tests. + """ + populate = [ + "%s into %sbooze values ('%s')" % (self.insert, self.table_prefix, s) + for s in self.samples + ] + return populate + + def test_fetchmany(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchmany should raise an Error if called without + # issuing a query + self.assertRaises(self.driver.Error, cur.fetchmany, 4) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute("select name from %sbooze" % self.table_prefix) + r = cur.fetchmany() + self.assertEqual( + len(r), + 1, + "cursor.fetchmany retrieved incorrect number of rows, " + "default of arraysize is one.", + ) + cur.arraysize = 10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual(len(r), 3, "cursor.fetchmany retrieved incorrect number of rows") + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual(len(r), 2, "cursor.fetchmany retrieved incorrect number of rows") + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence after " "results are exhausted", + ) + _failUnless(self, cur.rowcount in (-1, 6)) + + # Same as above, using cursor.arraysize + cur.arraysize = 4 + cur.execute("select name from %sbooze" % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual(len(r), 4, "cursor.arraysize not being honoured by fetchmany") + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r), 2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r), 0) + _failUnless(self, cur.rowcount in (-1, 6)) + + cur.arraysize = 6 + cur.execute("select name from %sbooze" % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + _failUnless(self, cur.rowcount in (-1, 6)) + self.assertEqual(len(rows), 6) + self.assertEqual(len(rows), 6) + rows = [r[0] for r in rows] + rows.sort() + + # Make sure we get the right data back out + for i in range(0, 6): + self.assertEqual( + rows[i], self.samples[i], "incorrect data retrieved by cursor.fetchmany" + ) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual( + len(rows), + 0, + "cursor.fetchmany should return an empty sequence if " + "called after the whole result set has been fetched", + ) + _failUnless(self, cur.rowcount in (-1, 6)) + + self.executeDDL2(cur) + cur.execute("select name from %sbarflys" % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence if " "query retrieved no rows", + ) + _failUnless(self, cur.rowcount in (-1, 0)) + + finally: + con.close() + + def test_fetchall(self): + con = self._connect() + try: + cur = con.cursor() + # cursor.fetchall should raise an Error if called + # without executing a query that may return rows (such + # as a select) + self.assertRaises(self.driver.Error, cur.fetchall) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + # cursor.fetchall should raise an Error if called + # after executing a a statement that cannot return rows + self.assertRaises(self.driver.Error, cur.fetchall) + + cur.execute("select name from %sbooze" % self.table_prefix) + rows = cur.fetchall() + _failUnless(self, cur.rowcount in (-1, len(self.samples))) + self.assertEqual( + len(rows), len(self.samples), "cursor.fetchall did not retrieve all rows" + ) + rows = [r[0] for r in rows] + rows.sort() + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" + ) + rows = cur.fetchall() + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + _failUnless(self, cur.rowcount in (-1, len(self.samples))) + + self.executeDDL2(cur) + cur.execute("select name from %sbarflys" % self.table_prefix) + rows = cur.fetchall() + _failUnless(self, cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " "a select query returns no rows", + ) + + finally: + con.close() + + def test_mixedfetch(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute("select name from %sbooze" % self.table_prefix) + rows1 = cur.fetchone() + rows23 = cur.fetchmany(2) + rows4 = cur.fetchone() + rows56 = cur.fetchall() + _failUnless(self, cur.rowcount in (-1, 6)) + self.assertEqual(len(rows23), 2, "fetchmany returned incorrect number of rows") + self.assertEqual(len(rows56), 2, "fetchall returned incorrect number of rows") + + rows = [rows1[0]] + rows.extend([rows23[0][0], rows23[1][0]]) + rows.append(rows4[0]) + rows.extend([rows56[0][0], rows56[1][0]]) + rows.sort() + for i in range(0, len(self.samples)): + self.assertEqual(rows[i], self.samples[i], "incorrect data retrieved or inserted") + finally: + con.close() + + def help_nextset_setUp(self, cur): + """Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + """ + raise NotImplementedError("Helper not implemented") + # sql=""" + # create procedure deleteme as + # begin + # select count(*) from booze + # select name from booze + # end + # """ + # cur.execute(sql) + + def help_nextset_tearDown(self, cur): + "If cleaning up is needed after nextSetTest" + raise NotImplementedError("Helper not implemented") + # cur.execute("drop procedure deleteme") + + def test_nextset(self): + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur, "nextset"): + return + + try: + self.executeDDL1(cur) + sql = self._populate() + for sql in self._populate(): + cur.execute(sql) + + self.help_nextset_setUp(cur) + + cur.callproc("deleteme") + numberofrows = cur.fetchone() + assert numberofrows[0] == len(self.samples) + assert cur.nextset() + names = cur.fetchall() + assert len(names) == len(self.samples) + s = cur.nextset() + assert s == None, "No more return sets, should return None" + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + + def test_nextset(self): + raise NotImplementedError("Drivers need to override this test") + + def test_arraysize(self): + # Not much here - rest of the tests for this are in test_fetchmany + con = self._connect() + try: + cur = con.cursor() + _failUnless(self, hasattr(cur, "arraysize"), "cursor.arraysize must be defined") + finally: + con.close() + + def test_setinputsizes(self): + con = self._connect() + try: + cur = con.cursor() + cur.setinputsizes((25,)) + self._paraminsert(cur) # Make sure cursor still works + finally: + con.close() + + def test_setoutputsize_basic(self): + # Basic test is to make sure setoutputsize doesn't blow up + con = self._connect() + try: + cur = con.cursor() + cur.setoutputsize(1000) + cur.setoutputsize(2000, 0) + self._paraminsert(cur) # Make sure the cursor still works + finally: + con.close() + + def test_setoutputsize(self): + # Real test for setoutputsize is driver dependant + raise NotImplementedError("Driver needed to override this test") + + def test_None(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL2(cur) + # inserting NULL to the second column, because some drivers might + # need the first one to be primary key, which means it needs + # to have a non-NULL value + cur.execute("%s into %sbarflys values ('a', NULL)" % (self.insert, self.table_prefix)) + cur.execute("select drink from %sbarflys" % self.table_prefix) + r = cur.fetchall() + self.assertEqual(len(r), 1) + self.assertEqual(len(r[0]), 1) + self.assertEqual(r[0][0], None, "NULL value not returned as None") + finally: + con.close() + + def test_Date(self): + d1 = self.driver.Date(2002, 12, 25) + d2 = self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(d1),str(d2)) + + def test_Time(self): + t1 = self.driver.Time(13, 45, 30) + t2 = self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Timestamp(self): + t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) + t2 = self.driver.TimestampFromTicks(time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Binary(self): + b = self.driver.Binary(str2bytes("Something")) + b = self.driver.Binary(str2bytes("")) + + def test_STRING(self): + _failUnless(self, hasattr(self.driver, "STRING"), "module.STRING must be defined") + + def test_BINARY(self): + _failUnless(self, hasattr(self.driver, "BINARY"), "module.BINARY must be defined.") + + def test_NUMBER(self): + _failUnless(self, hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined.") + + def test_DATETIME(self): + _failUnless(self, hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined.") + + def test_ROWID(self): + _failUnless(self, hasattr(self.driver, "ROWID"), "module.ROWID must be defined.") diff --git a/tests/integration/dbapi/hazelcast_dbapi20_test.py b/tests/integration/dbapi/hazelcast_dbapi20_test.py new file mode 100644 index 0000000000..eb4849e8e4 --- /dev/null +++ b/tests/integration/dbapi/hazelcast_dbapi20_test.py @@ -0,0 +1,61 @@ +from hazelcast.config import Config +import tests.integration.dbapi.dbapi20 as dbapi20 +from hazelcast import db +from tests.integration.backward_compatible.sql_test import SqlTestBase + + +class HazelcastDBAPI20Test(SqlTestBase, dbapi20.DatabaseAPI20Test): + + rc = None + cluster = None + member = None + driver = db + connect_kw_args = {} + table_prefix = "dbapi20test_" + ddl1 = f""" + CREATE OR REPLACE MAPPING {table_prefix}booze ( + name varchar external name "__key.name" + ) TYPE IMAP OPTIONS ( + 'keyFormat'='json-flat', + 'valueFormat'='json-flat' + ) + """ + ddl2 = f""" + CREATE OR REPLACE MAPPING {table_prefix}barflys ( + name varchar external name "__key.name", + drink varchar external name "this.drink" + ) TYPE IMAP OPTIONS ( + 'keyFormat'='json-flat', + 'valueFormat'='json-flat' + ) + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + cfg = Config() + cfg.cluster_name = cls.cluster.id + cls.connect_kw_args = { + "config": cfg, + } + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + def setUp(self): + pass + + def tearDown(self): + for name in ["booze", "barflys"]: + m = self.client.get_map(f"{self.table_prefix}{name}").blocking() + m.destroy() + + def test_nextset(self): + # we don't support this. + pass + + def test_setoutputsize(self): + # we don't support this. + pass diff --git a/tests/unit/dbapi_test.py b/tests/unit/dbapi_test.py new file mode 100644 index 0000000000..b036e8eb60 --- /dev/null +++ b/tests/unit/dbapi_test.py @@ -0,0 +1,135 @@ +import unittest + +from hazelcast.config import Config +from hazelcast.db import _make_config, InterfaceError + + +class DbApiTest(unittest.TestCase): + def test_make_config_invalid(self): + test_cases = [ + ("Invalid dsn", Config(), {"dsn": "hz://"}), + ("Both config and kwarg", Config(), {"user": "some-user"}), + ("Both DSN and kwarg", None, {"dsn": "hz://", "password": "some-pass"}), + ] + for name, c, kwargs in test_cases: + with self.assertRaises(InterfaceError, msg=f"Test case '{name}' failed"): + _make_config(c, **kwargs) + + def test_make_config_default(self): + cfg = _make_config() + target = config_with_values(cluster_members=["localhost:5701"]) + self.assertEqualConfig(target, cfg) + + def test_make_config_config_object(self): + cfg = config_with_values(creds_username="joe") + target = config_with_values(creds_username="joe") + self.assertEqualConfig(target, cfg) + + def test_make_config_cluster(self): + cfg = _make_config(cluster_name="foo") + target = config_with_values( + cluster_members=["localhost:5701"], + cluster_name="foo", + ) + self.assertEqualConfig(target, cfg) + + def test_make_config_user_password(self): + cfg = _make_config(user="joe", password="jane") + target = config_with_values( + cluster_members=["localhost:5701"], + creds_username="joe", + creds_password="jane", + ) + self.assertEqualConfig(target, cfg) + + def test_make_config_host(self): + cfg = _make_config(host="foo.com") + target = config_with_values(cluster_members=["foo.com:5701"]) + self.assertEqualConfig(target, cfg) + + def test_make_config_port(self): + cfg = _make_config(port=1234) + target = config_with_values(cluster_members=["localhost:1234"]) + self.assertEqualConfig(target, cfg) + + def test_make_config_host_port(self): + cfg = _make_config(host="foo.com", port=1234) + target = config_with_values(cluster_members=["foo.com:1234"]) + self.assertEqualConfig(target, cfg) + + def test_make_config_dsn(self): + test_cases = [ + ("hz://", config_with_values(cluster_members=["localhost:5701"])), + ("hz://foo.com", config_with_values(cluster_members=["foo.com:5701"])), + ("hz://:1234", config_with_values(cluster_members=["localhost:1234"])), + ("hz://foo.com:1234", config_with_values(cluster_members=["foo.com:1234"])), + ( + "hz://user:pass@foo.com:1234", + config_with_values( + cluster_members=["foo.com:1234"], + creds_username="user", + creds_password="pass", + ), + ), + ( + "hz://foo.com:1234?cluster.name=prod", + config_with_values( + cluster_members=["foo.com:1234"], + cluster_name="prod", + ), + ), + ( + "hz://foo.com:1234?cluster.name=prod&cloud.token=token1", + config_with_values( + cluster_members=["foo.com:1234"], + cluster_name="prod", + cloud_discovery_token="token1", + ), + ), + ( + "hz://foo.com?smart=false", + config_with_values( + cluster_members=["foo.com:5701"], + smart_routing=False, + ), + ), + ( + "hz://foo.com?ssl=true&ssl.ca.path=ca.pem&ssl.cert.path=cert.pem&ssl.key.path=key.pem&ssl.key.password=123", + config_with_values( + cluster_members=["foo.com:5701"], + ssl_enabled=True, + ssl_cafile="ca.pem", + ssl_certfile="cert.pem", + ssl_keyfile="key.pem", + ssl_password="123", + ), + ), + ] + for dsn, target in test_cases: + cfg = _make_config(dsn=dsn) + self.assertEqualConfig(target, cfg, f"Test case with DSN '{dsn}' failed") + + def test_make_config_invalid_dsn(self): + test_cases = [ + "http://", + "://", + "hz://foo.com?smart=False", + "hz://foo.com?non.existing=value", + ] + for dsn in test_cases: + with self.assertRaises(InterfaceError, msg=f"Test case with DSN '{dsn}' failed"): + _make_config(dsn=dsn) + + def assertEqualConfig(self, a: Config, b: Config, msg=""): + self.assertEqual(config_to_dict(a), config_to_dict(b), msg) + + +def config_with_values(**kwargs) -> Config: + return Config.from_dict(kwargs) + + +def config_to_dict(cfg: Config) -> dict: + d = {} + for k in cfg.__slots__: + d[k] = getattr(cfg, k) + return d