From d39346ff55c71e7a5646fffd176fb7ad8ce16650 Mon Sep 17 00:00:00 2001 From: mdumandag Date: Mon, 5 Apr 2021 19:03:04 +0300 Subject: [PATCH 1/9] Implement SQL service Added a new service called `sql` that executes SQL queries on Maps. The service returns a `SqlResult` immediately so that, the user can cancel queries immediately. Also, via `SqlResult`, the user can get blocking or non-blocking iterators to get results row-by-row. The rows contain row metadata along with the objects that can be accessed either through the column names or column indexes. The row metadata contains information about the columns returned in rows. Also, the users can get the update count, row metadata, and whether or not the result contains rows through `SqlResult`. --- docs/api/modules.rst | 1 + docs/api/sql.rst | 4 + hazelcast/client.py | 5 + hazelcast/connection.py | 32 +- hazelcast/protocol/builtin.py | 263 ++- .../codec/custom/sql_column_metadata_codec.py | 39 + .../protocol/codec/custom/sql_error_codec.py | 35 + .../codec/custom/sql_query_id_codec.py | 40 + hazelcast/protocol/codec/sql_close_codec.py | 15 + hazelcast/protocol/codec/sql_execute_codec.py | 44 + .../codec/sql_execute_reserved_codec.py | 43 + hazelcast/protocol/codec/sql_fetch_codec.py | 30 + .../codec/sql_fetch_reserved_codec.py | 33 + hazelcast/serialization/serializer.py | 10 +- hazelcast/sql.py | 1585 +++++++++++++++++ hazelcast/util.py | 139 +- start_rc.py | 6 +- .../backward_compatible/cluster_test.py | 8 +- .../backward_compatible/sql_test.py | 570 ++++++ tests/unit/load_balancer_test.py | 85 +- tests/unit/sql_test.py | 302 ++++ 21 files changed, 3252 insertions(+), 37 deletions(-) create mode 100644 docs/api/sql.rst create mode 100644 hazelcast/protocol/codec/custom/sql_column_metadata_codec.py create mode 100644 hazelcast/protocol/codec/custom/sql_error_codec.py create mode 100644 hazelcast/protocol/codec/custom/sql_query_id_codec.py create mode 100644 hazelcast/protocol/codec/sql_close_codec.py create mode 100644 hazelcast/protocol/codec/sql_execute_codec.py create mode 100644 hazelcast/protocol/codec/sql_execute_reserved_codec.py create mode 100644 hazelcast/protocol/codec/sql_fetch_codec.py create mode 100644 hazelcast/protocol/codec/sql_fetch_reserved_codec.py create mode 100644 hazelcast/sql.py create mode 100644 tests/integration/backward_compatible/sql_test.py create mode 100644 tests/unit/sql_test.py diff --git a/docs/api/modules.rst b/docs/api/modules.rst index f3ccd3dca8..88dd4f964c 100644 --- a/docs/api/modules.rst +++ b/docs/api/modules.rst @@ -14,6 +14,7 @@ API Documentation predicate proxy/modules serialization + sql transaction util diff --git a/docs/api/sql.rst b/docs/api/sql.rst new file mode 100644 index 0000000000..745bb7e25b --- /dev/null +++ b/docs/api/sql.rst @@ -0,0 +1,4 @@ +SQL +=========== + +.. automodule:: hazelcast.sql diff --git a/hazelcast/client.py b/hazelcast/client.py index 32c47eca8e..2680ca7aa9 100644 --- a/hazelcast/client.py +++ b/hazelcast/client.py @@ -36,6 +36,7 @@ ) from hazelcast.reactor import AsyncoreReactor from hazelcast.serialization import SerializationServiceV1 +from hazelcast.sql import _InternalSqlService, SqlService from hazelcast.statistics import Statistics from hazelcast.transaction import TWO_PHASE, TransactionManager from hazelcast.util import AtomicInteger, RoundRobinLB @@ -388,6 +389,10 @@ def __init__(self, **kwargs): self._invocation_service.init( self._internal_partition_service, self._connection_manager, self._listener_service ) + self._internal_sql_service = _InternalSqlService( + self._connection_manager, self._serialization_service, self._invocation_service + ) + self.sql = SqlService(self._internal_sql_service) self._init_context() self._start() diff --git a/hazelcast/connection.py b/hazelcast/connection.py index 02150e6eb1..6b7ed2ab6b 100644 --- a/hazelcast/connection.py +++ b/hazelcast/connection.py @@ -146,17 +146,20 @@ def add_listener(self, on_connection_opened=None, on_connection_closed=None): def get_connection(self, member_uuid): return self.active_connections.get(member_uuid, None) - def get_random_connection(self): + def get_random_connection(self, should_get_data_member=False): if self._smart_routing_enabled: - member = self._load_balancer.next() - if member: - connection = self.get_connection(member.uuid) - if connection: - return connection + connection = self._get_connection_from_load_balancer(should_get_data_member) + if connection: + return connection # We should not get to this point under normal circumstances. # Therefore, copying the list should be OK. - for connection in list(six.itervalues(self.active_connections)): + for member_uuid, connection in list(six.iteritems(self.active_connections)): + if should_get_data_member: + member = self._cluster_service.get_member(member_uuid) + if not member or member.lite_member: + continue + return connection return None @@ -256,6 +259,21 @@ def check_invocation_allowed(self): else: raise IOError("No connection found to cluster") + def _get_connection_from_load_balancer(self, should_get_data_member): + load_balancer = self._load_balancer + if should_get_data_member: + if load_balancer.can_get_next_data_member(): + member = load_balancer.next_data_member() + else: + member = None + else: + member = load_balancer.next() + + if not member: + return None + + return self.get_connection(member.uuid) + def _get_or_connect_to_address(self, address): for connection in list(six.itervalues(self.active_connections)): if connection.remote_address == address: diff --git a/hazelcast/protocol/builtin.py b/hazelcast/protocol/builtin.py index e28c5fb7a3..856f5978ab 100644 --- a/hazelcast/protocol/builtin.py +++ b/hazelcast/protocol/builtin.py @@ -1,4 +1,6 @@ import uuid +from datetime import date, time, datetime, timedelta +from decimal import Decimal from hazelcast import six from hazelcast.six.moves import range @@ -11,7 +13,7 @@ NULL_FINAL_FRAME_BUF, END_FINAL_FRAME_BUF, ) -from hazelcast.serialization import ( +from hazelcast.serialization.bits import ( LONG_SIZE_IN_BYTES, UUID_SIZE_IN_BYTES, LE_INT, @@ -23,8 +25,21 @@ LE_INT8, UUID_MSB_SHIFT, UUID_LSB_MASK, + BYTE_SIZE_IN_BYTES, + SHORT_SIZE_IN_BYTES, + LE_INT16, + FLOAT_SIZE_IN_BYTES, + LE_FLOAT, + LE_DOUBLE, + DOUBLE_SIZE_IN_BYTES, ) from hazelcast.serialization.data import Data +from hazelcast.util import int_from_bytes, timezone + +_LOCAL_DATE_SIZE_IN_BYTES = SHORT_SIZE_IN_BYTES + BYTE_SIZE_IN_BYTES * 2 +_LOCAL_TIME_SIZE_IN_BYTES = BYTE_SIZE_IN_BYTES * 3 + INT_SIZE_IN_BYTES +_LOCAL_DATE_TIME_SIZE_IN_BYTES = _LOCAL_DATE_SIZE_IN_BYTES + _LOCAL_TIME_SIZE_IN_BYTES +_OFFSET_DATE_TIME_SIZE_IN_BYTES = _LOCAL_DATE_TIME_SIZE_IN_BYTES + INT_SIZE_IN_BYTES class CodecUtil(object): @@ -274,6 +289,49 @@ def decode_uuid(buf, offset): ) return uuid.UUID(bytes=bytes(b)) + @staticmethod + def decode_short(buf, offset): + return LE_INT16.unpack_from(buf, offset)[0] + + @staticmethod + def decode_float(buf, offset): + return LE_FLOAT.unpack_from(buf, offset)[0] + + @staticmethod + def decode_double(buf, offset): + return LE_DOUBLE.unpack_from(buf, offset)[0] + + @staticmethod + def decode_local_date(buf, offset): + year = FixSizedTypesCodec.decode_short(buf, offset) + month = FixSizedTypesCodec.decode_byte(buf, offset + SHORT_SIZE_IN_BYTES) + day = FixSizedTypesCodec.decode_byte(buf, offset + SHORT_SIZE_IN_BYTES + BYTE_SIZE_IN_BYTES) + + return date(year, month, day) + + @staticmethod + def decode_local_time(buf, offset): + hour = FixSizedTypesCodec.decode_byte(buf, offset) + minute = FixSizedTypesCodec.decode_byte(buf, offset + BYTE_SIZE_IN_BYTES) + second = FixSizedTypesCodec.decode_byte(buf, offset + BYTE_SIZE_IN_BYTES * 2) + nano = FixSizedTypesCodec.decode_int(buf, offset + BYTE_SIZE_IN_BYTES * 3) + + return time(hour, minute, second, int(nano / 1000.0)) + + @staticmethod + def decode_local_date_time(buf, offset): + date_value = FixSizedTypesCodec.decode_local_date(buf, offset) + time_value = FixSizedTypesCodec.decode_local_time(buf, offset + _LOCAL_DATE_SIZE_IN_BYTES) + + return datetime.combine(date_value, time_value) + + @staticmethod + def decode_offset_date_time(buf, offset): + datetime_value = FixSizedTypesCodec.decode_local_date_time(buf, offset) + offset_seconds = FixSizedTypesCodec.decode_int(buf, offset + _LOCAL_DATE_TIME_SIZE_IN_BYTES) + + return datetime_value.replace(tzinfo=timezone(timedelta(seconds=offset_seconds))) + class ListIntegerCodec(object): @staticmethod @@ -496,3 +554,206 @@ def encode(buf, value, is_final=False): @staticmethod def decode(msg): return msg.next_frame().buf.decode("utf-8") + + +class ListCNFixedSizeCodec(object): + _TYPE_NULL_ONLY = 1 + _TYPE_NOT_NULL_ONLY = 2 + _TYPE_MIXED = 3 + + _ITEMS_PER_BITMASK = 8 + + _HEADER_SIZE = BYTE_SIZE_IN_BYTES + INT_SIZE_IN_BYTES + + @staticmethod + def decode(msg, item_size, decoder): + frame = msg.next_frame() + type = FixSizedTypesCodec.decode_byte(frame.buf, 0) + count = FixSizedTypesCodec.decode_int(frame.buf, 1) + + response = [] + if type == ListCNFixedSizeCodec._TYPE_NULL_ONLY: + for _ in range(count): + response.append(None) + elif type == ListCNFixedSizeCodec._TYPE_NOT_NULL_ONLY: + for i in range(count): + response.append( + decoder(frame.buf, ListCNFixedSizeCodec._HEADER_SIZE + i * item_size) + ) + else: + position = ListCNFixedSizeCodec._HEADER_SIZE + read_count = 0 + + while read_count < count: + bitmask = FixSizedTypesCodec.decode_byte(frame.buf, position) + position += 1 + + i = 0 + while i < ListCNFixedSizeCodec._ITEMS_PER_BITMASK and read_count < count: + mask = 1 << i + if (bitmask & mask) == mask: + response.append(decoder(frame.buf, position)) + position += item_size + else: + response.append(None) + read_count += 1 + + i += 1 + + return response + + +class ListCNBooleanCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, BOOLEAN_SIZE_IN_BYTES, FixSizedTypesCodec.decode_boolean + ) + + +class ListCNByteCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode(msg, BYTE_SIZE_IN_BYTES, FixSizedTypesCodec.decode_byte) + + +class ListCNShortCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, SHORT_SIZE_IN_BYTES, FixSizedTypesCodec.decode_short + ) + + +class ListCNIntegerCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode(msg, INT_SIZE_IN_BYTES, FixSizedTypesCodec.decode_int) + + +class ListCNLongCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode(msg, LONG_SIZE_IN_BYTES, FixSizedTypesCodec.decode_long) + + +class ListCNFloatCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, FLOAT_SIZE_IN_BYTES, FixSizedTypesCodec.decode_float + ) + + +class ListCNDoubleCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, DOUBLE_SIZE_IN_BYTES, FixSizedTypesCodec.decode_double + ) + + +class ListCNLocalDateCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, _LOCAL_DATE_SIZE_IN_BYTES, FixSizedTypesCodec.decode_local_date + ) + + +class ListCNLocalTimeCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, _LOCAL_TIME_SIZE_IN_BYTES, FixSizedTypesCodec.decode_local_time + ) + + +class ListCNLocalDateTimeCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, _LOCAL_DATE_TIME_SIZE_IN_BYTES, FixSizedTypesCodec.decode_local_date_time + ) + + +class ListCNOffsetDateTimeCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, _OFFSET_DATE_TIME_SIZE_IN_BYTES, FixSizedTypesCodec.decode_offset_date_time + ) + + +class BigDecimalCodec(object): + @staticmethod + def decode(msg): + buf = msg.next_frame().buf + size = FixSizedTypesCodec.decode_int(buf, 0) + unscaled_value = int_from_bytes(buf[INT_SIZE_IN_BYTES : INT_SIZE_IN_BYTES + size]) + scale = FixSizedTypesCodec.decode_int(buf, INT_SIZE_IN_BYTES + size) + sign = 0 if unscaled_value >= 0 else 1 + return Decimal((sign, tuple(map(int, str(abs(unscaled_value)))), -1 * scale)) + + +class SqlPageCodec(object): + @staticmethod + def decode(msg): + from hazelcast.sql import SqlColumnType, _SqlPage + + # begin frame + msg.next_frame() + + # read the "last" flag + is_last = LE_INT8.unpack_from(msg.next_frame().buf, 0)[0] == 1 + + # read column types + column_type_ids = ListIntegerCodec.decode(msg) + + # read columns + columns = [] + + for column_type_id in column_type_ids: + if column_type_id == SqlColumnType.VARCHAR: + columns.append( + ListMultiFrameCodec.decode_contains_nullable(msg, StringCodec.decode) + ) + elif column_type_id == SqlColumnType.BOOLEAN: + columns.append(ListCNBooleanCodec.decode(msg)) + elif column_type_id == SqlColumnType.TINYINT: + columns.append(ListCNByteCodec.decode(msg)) + elif column_type_id == SqlColumnType.SMALLINT: + columns.append(ListCNShortCodec.decode(msg)) + elif column_type_id == SqlColumnType.INTEGER: + columns.append(ListCNIntegerCodec.decode(msg)) + elif column_type_id == SqlColumnType.BIGINT: + columns.append(ListCNLongCodec.decode(msg)) + elif column_type_id == SqlColumnType.REAL: + columns.append(ListCNFloatCodec.decode(msg)) + elif column_type_id == SqlColumnType.DOUBLE: + columns.append(ListCNDoubleCodec.decode(msg)) + elif column_type_id == SqlColumnType.DATE: + columns.append(ListCNLocalDateCodec.decode(msg)) + elif column_type_id == SqlColumnType.TIME: + columns.append(ListCNLocalTimeCodec.decode(msg)) + elif column_type_id == SqlColumnType.TIMESTAMP: + columns.append(ListCNLocalDateTimeCodec.decode(msg)) + elif column_type_id == SqlColumnType.TIMESTAMP_WITH_TIME_ZONE: + columns.append(ListCNOffsetDateTimeCodec.decode(msg)) + elif column_type_id == SqlColumnType.DECIMAL: + columns.append( + ListMultiFrameCodec.decode_contains_nullable(msg, BigDecimalCodec.decode) + ) + elif column_type_id == SqlColumnType.NULL: + frame = msg.next_frame() + size = FixSizedTypesCodec.decode_int(frame.buf, 0) + column = [None for _ in range(size)] + columns.append(column) + elif column_type_id == SqlColumnType.OBJECT: + columns.append(ListMultiFrameCodec.decode_contains_nullable(msg, DataCodec.decode)) + else: + raise ValueError("Unknown type %s" % column_type_id) + + CodecUtil.fast_forward_to_end_frame(msg) + + return _SqlPage(column_type_ids, columns, is_last) diff --git a/hazelcast/protocol/codec/custom/sql_column_metadata_codec.py b/hazelcast/protocol/codec/custom/sql_column_metadata_codec.py new file mode 100644 index 0000000000..4c72e71582 --- /dev/null +++ b/hazelcast/protocol/codec/custom/sql_column_metadata_codec.py @@ -0,0 +1,39 @@ +from hazelcast.protocol.builtin import FixSizedTypesCodec, CodecUtil +from hazelcast.serialization.bits import * +from hazelcast.protocol.client_message import END_FRAME_BUF, END_FINAL_FRAME_BUF, SIZE_OF_FRAME_LENGTH_AND_FLAGS, create_initial_buffer_custom +from hazelcast.sql import SqlColumnMetadata +from hazelcast.protocol.builtin import StringCodec + +_TYPE_ENCODE_OFFSET = 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS +_TYPE_DECODE_OFFSET = 0 +_NULLABLE_ENCODE_OFFSET = _TYPE_ENCODE_OFFSET + INT_SIZE_IN_BYTES +_NULLABLE_DECODE_OFFSET = _TYPE_DECODE_OFFSET + INT_SIZE_IN_BYTES +_INITIAL_FRAME_SIZE = _NULLABLE_ENCODE_OFFSET + BOOLEAN_SIZE_IN_BYTES - 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS + + +class SqlColumnMetadataCodec(object): + @staticmethod + def encode(buf, sql_column_metadata, is_final=False): + initial_frame_buf = create_initial_buffer_custom(_INITIAL_FRAME_SIZE) + FixSizedTypesCodec.encode_int(initial_frame_buf, _TYPE_ENCODE_OFFSET, sql_column_metadata.type) + FixSizedTypesCodec.encode_boolean(initial_frame_buf, _NULLABLE_ENCODE_OFFSET, sql_column_metadata.nullable) + buf.extend(initial_frame_buf) + StringCodec.encode(buf, sql_column_metadata.name) + if is_final: + buf.extend(END_FINAL_FRAME_BUF) + else: + buf.extend(END_FRAME_BUF) + + @staticmethod + def decode(msg): + msg.next_frame() + initial_frame = msg.next_frame() + type = FixSizedTypesCodec.decode_int(initial_frame.buf, _TYPE_DECODE_OFFSET) + is_nullable_exists = False + nullable = False + if len(initial_frame.buf) >= _NULLABLE_DECODE_OFFSET + BOOLEAN_SIZE_IN_BYTES: + nullable = FixSizedTypesCodec.decode_boolean(initial_frame.buf, _NULLABLE_DECODE_OFFSET) + is_nullable_exists = True + name = StringCodec.decode(msg) + CodecUtil.fast_forward_to_end_frame(msg) + return SqlColumnMetadata(name, type, is_nullable_exists, nullable) diff --git a/hazelcast/protocol/codec/custom/sql_error_codec.py b/hazelcast/protocol/codec/custom/sql_error_codec.py new file mode 100644 index 0000000000..37bea690ef --- /dev/null +++ b/hazelcast/protocol/codec/custom/sql_error_codec.py @@ -0,0 +1,35 @@ +from hazelcast.protocol.builtin import FixSizedTypesCodec, CodecUtil +from hazelcast.serialization.bits import * +from hazelcast.protocol.client_message import END_FRAME_BUF, END_FINAL_FRAME_BUF, SIZE_OF_FRAME_LENGTH_AND_FLAGS, create_initial_buffer_custom +from hazelcast.sql import _SqlError +from hazelcast.protocol.builtin import StringCodec + +_CODE_ENCODE_OFFSET = 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS +_CODE_DECODE_OFFSET = 0 +_ORIGINATING_MEMBER_ID_ENCODE_OFFSET = _CODE_ENCODE_OFFSET + INT_SIZE_IN_BYTES +_ORIGINATING_MEMBER_ID_DECODE_OFFSET = _CODE_DECODE_OFFSET + INT_SIZE_IN_BYTES +_INITIAL_FRAME_SIZE = _ORIGINATING_MEMBER_ID_ENCODE_OFFSET + UUID_SIZE_IN_BYTES - 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS + + +class SqlErrorCodec(object): + @staticmethod + def encode(buf, sql_error, is_final=False): + initial_frame_buf = create_initial_buffer_custom(_INITIAL_FRAME_SIZE) + FixSizedTypesCodec.encode_int(initial_frame_buf, _CODE_ENCODE_OFFSET, sql_error.code) + FixSizedTypesCodec.encode_uuid(initial_frame_buf, _ORIGINATING_MEMBER_ID_ENCODE_OFFSET, sql_error.originating_member_id) + buf.extend(initial_frame_buf) + CodecUtil.encode_nullable(buf, sql_error.message, StringCodec.encode) + if is_final: + buf.extend(END_FINAL_FRAME_BUF) + else: + buf.extend(END_FRAME_BUF) + + @staticmethod + def decode(msg): + msg.next_frame() + initial_frame = msg.next_frame() + code = FixSizedTypesCodec.decode_int(initial_frame.buf, _CODE_DECODE_OFFSET) + originating_member_id = FixSizedTypesCodec.decode_uuid(initial_frame.buf, _ORIGINATING_MEMBER_ID_DECODE_OFFSET) + message = CodecUtil.decode_nullable(msg, StringCodec.decode) + CodecUtil.fast_forward_to_end_frame(msg) + return _SqlError(code, message, originating_member_id) diff --git a/hazelcast/protocol/codec/custom/sql_query_id_codec.py b/hazelcast/protocol/codec/custom/sql_query_id_codec.py new file mode 100644 index 0000000000..4d967fc72b --- /dev/null +++ b/hazelcast/protocol/codec/custom/sql_query_id_codec.py @@ -0,0 +1,40 @@ +from hazelcast.protocol.builtin import FixSizedTypesCodec, CodecUtil +from hazelcast.serialization.bits import * +from hazelcast.protocol.client_message import END_FRAME_BUF, END_FINAL_FRAME_BUF, SIZE_OF_FRAME_LENGTH_AND_FLAGS, create_initial_buffer_custom +from hazelcast.sql import _SqlQueryId + +_MEMBER_ID_HIGH_ENCODE_OFFSET = 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS +_MEMBER_ID_HIGH_DECODE_OFFSET = 0 +_MEMBER_ID_LOW_ENCODE_OFFSET = _MEMBER_ID_HIGH_ENCODE_OFFSET + LONG_SIZE_IN_BYTES +_MEMBER_ID_LOW_DECODE_OFFSET = _MEMBER_ID_HIGH_DECODE_OFFSET + LONG_SIZE_IN_BYTES +_LOCAL_ID_HIGH_ENCODE_OFFSET = _MEMBER_ID_LOW_ENCODE_OFFSET + LONG_SIZE_IN_BYTES +_LOCAL_ID_HIGH_DECODE_OFFSET = _MEMBER_ID_LOW_DECODE_OFFSET + LONG_SIZE_IN_BYTES +_LOCAL_ID_LOW_ENCODE_OFFSET = _LOCAL_ID_HIGH_ENCODE_OFFSET + LONG_SIZE_IN_BYTES +_LOCAL_ID_LOW_DECODE_OFFSET = _LOCAL_ID_HIGH_DECODE_OFFSET + LONG_SIZE_IN_BYTES +_INITIAL_FRAME_SIZE = _LOCAL_ID_LOW_ENCODE_OFFSET + LONG_SIZE_IN_BYTES - 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS + + +class SqlQueryIdCodec(object): + @staticmethod + def encode(buf, sql_query_id, is_final=False): + initial_frame_buf = create_initial_buffer_custom(_INITIAL_FRAME_SIZE) + FixSizedTypesCodec.encode_long(initial_frame_buf, _MEMBER_ID_HIGH_ENCODE_OFFSET, sql_query_id.member_id_high) + FixSizedTypesCodec.encode_long(initial_frame_buf, _MEMBER_ID_LOW_ENCODE_OFFSET, sql_query_id.member_id_low) + FixSizedTypesCodec.encode_long(initial_frame_buf, _LOCAL_ID_HIGH_ENCODE_OFFSET, sql_query_id.local_id_high) + FixSizedTypesCodec.encode_long(initial_frame_buf, _LOCAL_ID_LOW_ENCODE_OFFSET, sql_query_id.local_id_low) + buf.extend(initial_frame_buf) + if is_final: + buf.extend(END_FINAL_FRAME_BUF) + else: + buf.extend(END_FRAME_BUF) + + @staticmethod + def decode(msg): + msg.next_frame() + initial_frame = msg.next_frame() + member_id_high = FixSizedTypesCodec.decode_long(initial_frame.buf, _MEMBER_ID_HIGH_DECODE_OFFSET) + member_id_low = FixSizedTypesCodec.decode_long(initial_frame.buf, _MEMBER_ID_LOW_DECODE_OFFSET) + local_id_high = FixSizedTypesCodec.decode_long(initial_frame.buf, _LOCAL_ID_HIGH_DECODE_OFFSET) + local_id_low = FixSizedTypesCodec.decode_long(initial_frame.buf, _LOCAL_ID_LOW_DECODE_OFFSET) + CodecUtil.fast_forward_to_end_frame(msg) + return _SqlQueryId(member_id_high, member_id_low, local_id_high, local_id_low) diff --git a/hazelcast/protocol/codec/sql_close_codec.py b/hazelcast/protocol/codec/sql_close_codec.py new file mode 100644 index 0000000000..b6647befc4 --- /dev/null +++ b/hazelcast/protocol/codec/sql_close_codec.py @@ -0,0 +1,15 @@ +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer +from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec + +# hex: 0x210300 +_REQUEST_MESSAGE_TYPE = 2163456 +# hex: 0x210301 +_RESPONSE_MESSAGE_TYPE = 2163457 + +_REQUEST_INITIAL_FRAME_SIZE = REQUEST_HEADER_SIZE + + +def encode_request(query_id): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + SqlQueryIdCodec.encode(buf, query_id, True) + return OutboundMessage(buf, False) diff --git a/hazelcast/protocol/codec/sql_execute_codec.py b/hazelcast/protocol/codec/sql_execute_codec.py new file mode 100644 index 0000000000..8b93c3db73 --- /dev/null +++ b/hazelcast/protocol/codec/sql_execute_codec.py @@ -0,0 +1,44 @@ +from hazelcast.serialization.bits import * +from hazelcast.protocol.builtin import FixSizedTypesCodec +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer, RESPONSE_HEADER_SIZE +from hazelcast.protocol.builtin import StringCodec +from hazelcast.protocol.builtin import ListMultiFrameCodec +from hazelcast.protocol.builtin import DataCodec +from hazelcast.protocol.builtin import CodecUtil +from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec +from hazelcast.protocol.codec.custom.sql_column_metadata_codec import SqlColumnMetadataCodec +from hazelcast.protocol.builtin import SqlPageCodec +from hazelcast.protocol.codec.custom.sql_error_codec import SqlErrorCodec + +# hex: 0x210400 +_REQUEST_MESSAGE_TYPE = 2163712 +# hex: 0x210401 +_RESPONSE_MESSAGE_TYPE = 2163713 + +_REQUEST_TIMEOUT_MILLIS_OFFSET = REQUEST_HEADER_SIZE +_REQUEST_CURSOR_BUFFER_SIZE_OFFSET = _REQUEST_TIMEOUT_MILLIS_OFFSET + LONG_SIZE_IN_BYTES +_REQUEST_EXPECTED_RESULT_TYPE_OFFSET = _REQUEST_CURSOR_BUFFER_SIZE_OFFSET + INT_SIZE_IN_BYTES +_REQUEST_INITIAL_FRAME_SIZE = _REQUEST_EXPECTED_RESULT_TYPE_OFFSET + BYTE_SIZE_IN_BYTES +_RESPONSE_UPDATE_COUNT_OFFSET = RESPONSE_HEADER_SIZE + + +def encode_request(sql, parameters, timeout_millis, cursor_buffer_size, schema, expected_result_type, query_id): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + FixSizedTypesCodec.encode_long(buf, _REQUEST_TIMEOUT_MILLIS_OFFSET, timeout_millis) + FixSizedTypesCodec.encode_int(buf, _REQUEST_CURSOR_BUFFER_SIZE_OFFSET, cursor_buffer_size) + FixSizedTypesCodec.encode_byte(buf, _REQUEST_EXPECTED_RESULT_TYPE_OFFSET, expected_result_type) + StringCodec.encode(buf, sql) + ListMultiFrameCodec.encode_contains_nullable(buf, parameters, DataCodec.encode) + CodecUtil.encode_nullable(buf, schema, StringCodec.encode) + SqlQueryIdCodec.encode(buf, query_id, True) + return OutboundMessage(buf, False) + + +def decode_response(msg): + initial_frame = msg.next_frame() + response = dict() + response["update_count"] = FixSizedTypesCodec.decode_long(initial_frame.buf, _RESPONSE_UPDATE_COUNT_OFFSET) + response["row_metadata"] = ListMultiFrameCodec.decode_nullable(msg, SqlColumnMetadataCodec.decode) + response["row_page"] = CodecUtil.decode_nullable(msg, SqlPageCodec.decode) + response["error"] = CodecUtil.decode_nullable(msg, SqlErrorCodec.decode) + return response diff --git a/hazelcast/protocol/codec/sql_execute_reserved_codec.py b/hazelcast/protocol/codec/sql_execute_reserved_codec.py new file mode 100644 index 0000000000..ff10cccc0e --- /dev/null +++ b/hazelcast/protocol/codec/sql_execute_reserved_codec.py @@ -0,0 +1,43 @@ +from hazelcast.serialization.bits import * +from hazelcast.protocol.builtin import FixSizedTypesCodec +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer, RESPONSE_HEADER_SIZE +from hazelcast.protocol.builtin import StringCodec +from hazelcast.protocol.builtin import ListMultiFrameCodec +from hazelcast.protocol.builtin import DataCodec +from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec +from hazelcast.protocol.builtin import CodecUtil +from hazelcast.protocol.codec.custom.sql_column_metadata_codec import SqlColumnMetadataCodec +from hazelcast.protocol.builtin import ListCNDataCodec +from hazelcast.protocol.codec.custom.sql_error_codec import SqlErrorCodec + +# hex: 0x210100 +_REQUEST_MESSAGE_TYPE = 2162944 +# hex: 0x210101 +_RESPONSE_MESSAGE_TYPE = 2162945 + +_REQUEST_TIMEOUT_MILLIS_OFFSET = REQUEST_HEADER_SIZE +_REQUEST_CURSOR_BUFFER_SIZE_OFFSET = _REQUEST_TIMEOUT_MILLIS_OFFSET + LONG_SIZE_IN_BYTES +_REQUEST_INITIAL_FRAME_SIZE = _REQUEST_CURSOR_BUFFER_SIZE_OFFSET + INT_SIZE_IN_BYTES +_RESPONSE_ROW_PAGE_LAST_OFFSET = RESPONSE_HEADER_SIZE +_RESPONSE_UPDATE_COUNT_OFFSET = _RESPONSE_ROW_PAGE_LAST_OFFSET + BOOLEAN_SIZE_IN_BYTES + + +def encode_request(sql, parameters, timeout_millis, cursor_buffer_size): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + FixSizedTypesCodec.encode_long(buf, _REQUEST_TIMEOUT_MILLIS_OFFSET, timeout_millis) + FixSizedTypesCodec.encode_int(buf, _REQUEST_CURSOR_BUFFER_SIZE_OFFSET, cursor_buffer_size) + StringCodec.encode(buf, sql) + ListMultiFrameCodec.encode(buf, parameters, DataCodec.encode, True) + return OutboundMessage(buf, False) + + +def decode_response(msg): + initial_frame = msg.next_frame() + response = dict() + response["row_page_last"] = FixSizedTypesCodec.decode_boolean(initial_frame.buf, _RESPONSE_ROW_PAGE_LAST_OFFSET) + response["update_count"] = FixSizedTypesCodec.decode_long(initial_frame.buf, _RESPONSE_UPDATE_COUNT_OFFSET) + response["query_id"] = CodecUtil.decode_nullable(msg, SqlQueryIdCodec.decode) + response["row_metadata"] = ListMultiFrameCodec.decode_nullable(msg, SqlColumnMetadataCodec.decode) + response["row_page"] = ListMultiFrameCodec.decode_nullable(msg, ListCNDataCodec.decode) + response["error"] = CodecUtil.decode_nullable(msg, SqlErrorCodec.decode) + return response diff --git a/hazelcast/protocol/codec/sql_fetch_codec.py b/hazelcast/protocol/codec/sql_fetch_codec.py new file mode 100644 index 0000000000..139832bff9 --- /dev/null +++ b/hazelcast/protocol/codec/sql_fetch_codec.py @@ -0,0 +1,30 @@ +from hazelcast.serialization.bits import * +from hazelcast.protocol.builtin import FixSizedTypesCodec +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer +from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec +from hazelcast.protocol.builtin import SqlPageCodec +from hazelcast.protocol.builtin import CodecUtil +from hazelcast.protocol.codec.custom.sql_error_codec import SqlErrorCodec + +# hex: 0x210500 +_REQUEST_MESSAGE_TYPE = 2163968 +# hex: 0x210501 +_RESPONSE_MESSAGE_TYPE = 2163969 + +_REQUEST_CURSOR_BUFFER_SIZE_OFFSET = REQUEST_HEADER_SIZE +_REQUEST_INITIAL_FRAME_SIZE = _REQUEST_CURSOR_BUFFER_SIZE_OFFSET + INT_SIZE_IN_BYTES + + +def encode_request(query_id, cursor_buffer_size): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + FixSizedTypesCodec.encode_int(buf, _REQUEST_CURSOR_BUFFER_SIZE_OFFSET, cursor_buffer_size) + SqlQueryIdCodec.encode(buf, query_id, True) + return OutboundMessage(buf, False) + + +def decode_response(msg): + msg.next_frame() + response = dict() + response["row_page"] = CodecUtil.decode_nullable(msg, SqlPageCodec.decode) + response["error"] = CodecUtil.decode_nullable(msg, SqlErrorCodec.decode) + return response diff --git a/hazelcast/protocol/codec/sql_fetch_reserved_codec.py b/hazelcast/protocol/codec/sql_fetch_reserved_codec.py new file mode 100644 index 0000000000..1089ead68b --- /dev/null +++ b/hazelcast/protocol/codec/sql_fetch_reserved_codec.py @@ -0,0 +1,33 @@ +from hazelcast.serialization.bits import * +from hazelcast.protocol.builtin import FixSizedTypesCodec +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer, RESPONSE_HEADER_SIZE +from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec +from hazelcast.protocol.builtin import ListMultiFrameCodec +from hazelcast.protocol.builtin import ListCNDataCodec +from hazelcast.protocol.codec.custom.sql_error_codec import SqlErrorCodec +from hazelcast.protocol.builtin import CodecUtil + +# hex: 0x210200 +_REQUEST_MESSAGE_TYPE = 2163200 +# hex: 0x210201 +_RESPONSE_MESSAGE_TYPE = 2163201 + +_REQUEST_CURSOR_BUFFER_SIZE_OFFSET = REQUEST_HEADER_SIZE +_REQUEST_INITIAL_FRAME_SIZE = _REQUEST_CURSOR_BUFFER_SIZE_OFFSET + INT_SIZE_IN_BYTES +_RESPONSE_ROW_PAGE_LAST_OFFSET = RESPONSE_HEADER_SIZE + + +def encode_request(query_id, cursor_buffer_size): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + FixSizedTypesCodec.encode_int(buf, _REQUEST_CURSOR_BUFFER_SIZE_OFFSET, cursor_buffer_size) + SqlQueryIdCodec.encode(buf, query_id, True) + return OutboundMessage(buf, False) + + +def decode_response(msg): + initial_frame = msg.next_frame() + response = dict() + response["row_page_last"] = FixSizedTypesCodec.decode_boolean(initial_frame.buf, _RESPONSE_ROW_PAGE_LAST_OFFSET) + response["row_page"] = ListMultiFrameCodec.decode_nullable(msg, ListCNDataCodec.decode) + response["error"] = CodecUtil.decode_nullable(msg, SqlErrorCodec.decode) + return response diff --git a/hazelcast/serialization/serializer.py b/hazelcast/serialization/serializer.py index 5e50c5f33e..a2d35ee204 100644 --- a/hazelcast/serialization/serializer.py +++ b/hazelcast/serialization/serializer.py @@ -1,6 +1,5 @@ import binascii import time -import uuid from datetime import datetime from hazelcast import six @@ -10,8 +9,7 @@ from hazelcast.serialization.base import HazelcastSerializationError from hazelcast.serialization.serialization_const import * from hazelcast.six.moves import range, cPickle - -from hazelcast.util import to_signed +from hazelcast.util import UuidUtil if not six.PY2: long = int @@ -135,12 +133,10 @@ class UuidSerializer(BaseSerializer): def read(self, inp): msb = inp.read_long() lsb = inp.read_long() - return uuid.UUID(int=(((msb << UUID_MSB_SHIFT) & UUID_MSB_MASK) | (lsb & UUID_LSB_MASK))) + return UuidUtil.from_bits(msb, lsb) def write(self, out, obj): - i = obj.int - msb = to_signed(i >> UUID_MSB_SHIFT, 64) - lsb = to_signed(i & UUID_LSB_MASK, 64) + msb, lsb = UuidUtil.to_bits(obj) out.write_long(msb) out.write_long(lsb) diff --git a/hazelcast/sql.py b/hazelcast/sql.py new file mode 100644 index 0000000000..efab4a7488 --- /dev/null +++ b/hazelcast/sql.py @@ -0,0 +1,1585 @@ +import logging +import uuid +from threading import RLock + +from hazelcast.errors import HazelcastError +from hazelcast.future import Future, ImmediateFuture +from hazelcast.invocation import Invocation +from hazelcast.util import ( + UuidUtil, + check_not_none, + to_millis, + check_true, + get_attr_name, + try_to_get_error_message, +) + +_logger = logging.getLogger(__name__) + + +class SqlService(object): + """A service to execute SQL statements. + + The service allows you to query data stored in + :class:`Map `. + + Warnings: + + The service is in beta state. Behavior and API might change + in future releases. + + **Querying an IMap** + + Every Map instance is exposed as a table with the same name in the + ``partitioned`` schema. The ``partitioned`` schema is included into + a default search path, therefore a Map could be referenced in an + SQL statement with or without the schema name. + + **Column resolution** + + Every table backed by a Map has a set of columns that are resolved + automatically. Column resolution uses Map entries located on the + member that initiates the query. The engine extracts columns from a + key and a value and then merges them into a single column set. + In case the key and the value have columns with the same name, the + key takes precedence. + + Columns are extracted from objects as follows (which happens on the + server-side): + + - For non-Portable objects, public getters and fields are used to + populate the column list. For getters, the first letter is converted + to lower case. A getter takes precedence over a field in case of naming + conflict + - For :class:`Portable ` objects, + field names used in the + :func:`write_portable() ` + method are used to populate the column list + + The whole key and value objects could be accessed through a special fields + ``__key`` and ``this``, respectively. If key (value) object has fields, + then the whole key (value) field is exposed as a normal field. Otherwise the + field is hidden. Hidden fields can be accessed directly, but are not returned + by ``SELECT * FROM ...`` queries. + + Consider the following key/value model: :: + + class PersonKey(Portable): + def __init__(self, person_id=None, department_id=None): + self.person_id = person_id + self.department_id = department_id + + def write_portable(self, writer): + writer.write_long("person_id", self.person_id) + writer.write_long("department_id", self.department_id) + + ... + + class Person(Portable): + def __init__(self, name=None): + self.name = name + + def write_portable(self, writer): + writer.write_string("name", self.name) + + ... + + This model will be resolved to the following table columns: + + - ``person_id`` ``BIGINT`` + - ``department_id`` ``BIGINT`` + - ``name`` ``VARCHAR`` + - ``__key`` ``OBJECT`` (hidden) + - ``this`` ``OBJECT`` (hidden) + + **Consistency** + + Results returned from Map query are weakly consistent: + + - If an entry was not updated during iteration, it is guaranteed to be + returned exactly once + - If an entry was modified during iteration, it might be returned zero, + one or several times + + **Usage** + + When a query is executed, an :class:`SqlResult` is returned. You may get + row iterator from the result. The result must be closed at the end. The + iterator will close the result automatically when it is exhausted given + that no error is raised during the iteration. The code snippet below + demonstrates a typical usage pattern: :: + + client = hazelcast.HazelcastClient() + + result = client.sql.execute("SELECT * FROM person") + + for row in result: + print(row.get_object("person_id")) + print(row.get_object("name")) + ... + + + See the documentation of the :class:`SqlResult` for more information about + the different type of iteration methods. + + Notes: + + When an SQL statement is submitted to a member, it is parsed and + optimized by the ``hazelcast-sql`` module. The ``hazelcast-sql`` must + be in the classpath, otherwise an exception will be thrown. If you're + using the ``hazelcast-all`` or ``hazelcast-enterprise-all`` packages, the + ``hazelcast-sql`` module is included in them by default. If not, i.e., you + are using ``hazelcast`` or ``hazelcast-enterprise``, then you need to have + ``hazelcast-sql`` in the classpath. If you are using the Docker image, + the SQL module is included by default. + + """ + + def __init__(self, internal_sql_service): + self._service = internal_sql_service + + def execute(self, sql, *params): + """Convenient method to execute a distributed query with the given + parameters. + + Converts passed SQL string and parameters into an :class:`SqlStatement` + object and invokes :func:`execute_statement`. + + Args: + sql (str): SQL string. + *params: Query parameters that will be passed to + :func:`SqlStatement.add_parameter`. + + Returns: + SqlResult: The execution result. + + Raises: + HazelcastSqlError: In case of execution error. + """ + return self._service.execute(sql, *params) + + def execute_statement(self, statement): + """Executes an SQL statement. + + Args: + statement (SqlStatement): Statement to be executed + + Returns: + SqlResult: The execution result. + + Raises: + HazelcastSqlError: In case of execution error. + """ + return self._service.execute_statement(statement) + + +class _SqlQueryId(object): + """Cluster-wide unique query ID.""" + + __slots__ = ("member_id_high", "member_id_low", "local_id_high", "local_id_low") + + def __init__(self, member_id_high, member_id_low, local_id_high, local_id_low): + self.member_id_high = member_id_high + """int: Most significant bits of the UUID of the member + that the query will route to. + """ + + self.member_id_low = member_id_low + """int: Least significant bits of the UUID of the member + that the query will route to.""" + + self.local_id_high = local_id_high + """int: Most significant bits of the UUID of the local id.""" + + self.local_id_low = local_id_low + """int: Least significant bits of the UUID of the local id.""" + + @classmethod + def from_uuid(cls, member_uuid): + """Generates a local random UUID and creates a query id + out of it and the given member UUID. + + Args: + member_uuid (uuid.UUID): UUID of the member. + + Returns: + _SqlQueryId: Generated unique query id. + """ + local_id = uuid.uuid4() + + member_msb, member_lsb = UuidUtil.to_bits(member_uuid) + local_msb, local_lsb = UuidUtil.to_bits(local_id) + + return cls(member_msb, member_lsb, local_msb, local_lsb) + + +class SqlColumnMetadata(object): + """Metadata for one of the columns of the returned rows.""" + + __slots__ = ("_name", "_type", "_nullable") + + def __init__(self, name, column_type, nullable, is_nullable_exists): + self._name = name + self._type = column_type + self._nullable = nullable if is_nullable_exists else True + + @property + def name(self): + """str: Name of the column.""" + return self._name + + @property + def type(self): + """SqlColumnType: Type of the column.""" + return self._type + + @property + def nullable(self): + """bool: ``True`` if the rows in this column might be ``None``, + ``False`` otherwise. + """ + return self._nullable + + def __repr__(self): + return "%s %s" % (self.name, get_attr_name(SqlColumnType, self.type)) + + +class _SqlError(object): + """Server-side error that is propagated to the client.""" + + __slots__ = ("code", "message", "originating_member_uuid") + + def __init__(self, code, message, originating_member_uuid): + self.code = code + """_SqlErrorCode: The error code.""" + + self.message = message + """str: The error message.""" + + self.originating_member_uuid = originating_member_uuid + """uuid.UUID: UUID of the member that caused or initiated an error condition.""" + + +class _SqlPage(object): + """A finite set of rows returned to the user.""" + + __slots__ = ("_column_types", "_columns", "_is_last") + + def __init__(self, column_types, columns, last): + self._column_types = column_types + self._columns = columns + self._is_last = last + + @property + def row_count(self): + """int: Number of rows in the page.""" + # Each column should have equal number of rows. + # Just check the first one. + return len(self._columns[0]) + + @property + def column_count(self): + """int: Number of columns.""" + return len(self._column_types) + + @property + def is_last(self): + """bool: Whether this is the last page or not.""" + return self._is_last + + def get_column_value(self, column_index, row_index): + """ + Args: + column_index (int): + row_index (int): + + Returns: + The value with the given indexes. + """ + return self._columns[column_index][row_index] + + +class SqlColumnType(object): + + VARCHAR = 0 + """ + Represented by ``str``. + """ + + BOOLEAN = 1 + """ + Represented by ``bool``. + """ + + TINYINT = 2 + """ + Represented by ``int``. + """ + + SMALLINT = 3 + """ + Represented by ``int``. + """ + + INTEGER = 4 + """ + Represented by ``int``. + """ + + BIGINT = 5 + """ + Represented by ``int`` (for Python 3) or ``long`` (for Python 2). + """ + + DECIMAL = 6 + """ + Represented by ``decimal.Decimal``. + """ + + REAL = 7 + """ + Represented by ``float``. + """ + + DOUBLE = 8 + """ + Represented by ``float``. + """ + + DATE = 9 + """ + Represented by ``datetime.date``. + """ + + TIME = 10 + """ + Represented by ``datetime.time``. + """ + + TIMESTAMP = 11 + """ + Represented by ``datetime.datetime``. + """ + + TIMESTAMP_WITH_TIME_ZONE = 12 + """ + Represented by ``datetime.datetime`` with ``datetime.tzinfo``. + """ + + OBJECT = 13 + """ + Could be represented by any Python class. + """ + + NULL = 14 + """ + The type of the generic SQL ``NULL`` literal. + + The only valid value of ``NULL`` type is ``None``. + """ + + +class _SqlErrorCode(object): + + GENERIC = -1 + """ + Generic error. + """ + + CONNECTION_PROBLEM = 1001 + """ + A network connection problem between members, or between a client and a member. + """ + + CANCELLED_BY_USER = 1003 + """ + Query was cancelled due to user request. + """ + + TIMEOUT = 1004 + """ + Query was cancelled due to timeout. + """ + + PARTITION_DISTRIBUTION = 1005 + """ + A problem with partition distribution. + """ + + MAP_DESTROYED = 1006 + """ + An error caused by a concurrent destroy of a map. + """ + + MAP_LOADING_IN_PROGRESS = 1007 + """ + Map loading is not finished yet. + """ + + PARSING = 1008 + """ + Generic parsing error. + """ + + INDEX_INVALID = 1009 + """ + An error caused by an attempt to query an index that is not valid. + """ + + DATA_EXCEPTION = 2000 + """ + An error with data conversion or transformation. + """ + + +class HazelcastSqlError(HazelcastError): + """Represents an error occurred during the SQL query execution.""" + + def __init__(self, originating_member_uuid, code, message, cause): + super(HazelcastSqlError, self).__init__(message, cause) + self._originating_member_uuid = originating_member_uuid + + # TODO: This is private API, might be good to make it public or + # remove this information altogether. + self._code = code + + @property + def originating_member_uuid(self): + """uuid.UUID: UUID of the member that caused or initiated an error condition.""" + return self._originating_member_uuid + + +class SqlRowMetadata(object): + """Metadata for the returned rows.""" + + __slots__ = ("_columns", "_name_to_index") + + COLUMN_NOT_FOUND = -1 + """Constant indicating that the column is not found.""" + + def __init__(self, columns): + self._columns = columns + self._name_to_index = {column.name: index for index, column in enumerate(columns)} + + @property + def columns(self): + """list[SqlColumnMetadata]: List of column metadata.""" + return self._columns + + @property + def column_count(self): + """int: Number of column in the row.""" + return len(self._columns) + + def get_column(self, index): + """ + Args: + index (int): Zero-based column index. + + Returns: + SqlColumnMetadata: Metadata for the given column index. + """ + check_true(0 <= index < len(self._columns), "Column index is out of bounds: %s" % index) + return self._columns[index] + + def find_column(self, column_name): + """ + Args: + column_name (str): Name of the column. + + Returns: + int: Column index or :const:`COLUMN_NOT_FOUND` if a column + with the given name is not found. + """ + check_not_none(column_name, "Column name cannot be None") + return self._name_to_index.get(column_name, SqlRowMetadata.COLUMN_NOT_FOUND) + + def __repr__(self): + return "[%s]" % ", ".join( + map( + lambda column: "%s %s" % (column.name, get_attr_name(SqlColumnType, column.type)), + self._columns, + ) + ) + + +class SqlRow(object): + """One of the rows of the SQL query result.""" + + __slots__ = ("_row_metadata", "_row") + + def __init__(self, row_metadata, row): + self._row_metadata = row_metadata + self._row = row + + def get_object(self, column_name): + """Gets the value of the column by column name. + + Column name should be one of those defined in :class:`SqlRowMetadata`, + case-sensitive. You may also use :func:`SqlRowMetadata.find_column` to + test for column existence. + + The class of the returned value depends on the SQL type of the column. + No implicit conversions are performed on the value. + + Args: + column_name (str): + + Returns: + Value of the column. + + See Also: + :attr:`metadata` + + :func:`SqlRowMetadata.find_column` + + :attr:`SqlColumnMetadata.type` + + :attr:`SqlColumnMetadata.name` + """ + index = self._row_metadata.find_column(column_name) + if index == SqlRowMetadata.COLUMN_NOT_FOUND: + raise ValueError("Column '%s' doesn't exist" % column_name) + return self._row[index] + + def get_object_with_index(self, column_index): + """Gets the value of the column by index. + + The class of the returned value depends on the SQL type of the column. + No implicit conversions are performed on the value. + + Args: + column_index (int): Zero-based column index. + + Returns: + Value of the column. + + See Also: + :attr:`metadata` + + :attr:`SqlColumnMetadata.type` + """ + check_true( + 0 <= column_index < self._row_metadata.column_count, + "Column index is out of bounds: %s" % column_index, + ) + return self._row[column_index] + + @property + def metadata(self): + """SqlRowMetadata: The row metadata.""" + return self._row_metadata + + def __repr__(self): + def mapping(column_index): + metadata = self._row_metadata.get_column(column_index) + value = self._row[column_index] + return "%s %s=%s" % (metadata.name, get_attr_name(SqlColumnType, metadata.type), value) + + return "[%s]" % ", ".join( + map( + mapping, + range(self._row_metadata.column_count), + ) + ) + + +class _ExecuteResponse(object): + """Represent the response of the first execute request.""" + + __slots__ = ("row_metadata", "row_page", "update_count") + + def __init__(self, row_metadata, row_page, update_count): + self.row_metadata = row_metadata + """SqlRowMetadata: Row metadata or None, if the response only + contains update count.""" + + self.row_page = row_page + """_SqlPage: First page of the query response or None, if the + response only contains update count. + """ + + self.update_count = update_count + """int: Update count or -1 if row metadata or row page exist.""" + + +class _IteratorBase(object): + """Base class for the blocking and Future-producing + iterators to use.""" + + __slots__ = ( + "row_metadata", + "fetch_fn", + "deserialize_fn", + "page", + "row_count", + "position", + "is_last", + ) + + def __init__(self, row_metadata, fetch_fn, deserialize_fn): + self.row_metadata = row_metadata + """SqlRowMetadata: Row metadata.""" + + self.fetch_fn = fetch_fn + """function: Fetches the next page. It produces a Future[_SqlPage].""" + + self.deserialize_fn = deserialize_fn + """function: Deserializes the value.""" + + self.page = None + """_SqlPage: Current page.""" + + self.row_count = 0 + """int: Number of rows in the current page.""" + + self.position = 0 + """int: Index of the next row in the page.""" + + self.is_last = False + """bool: Whether this is the last page or not.""" + + def on_next_page(self, page): + """ + Called when a new page is fetched or on the + initialization of the iterator to update its + internal state. + + Args: + page (_SqlPage): + """ + self.page = page + self.row_count = page.row_count + self.is_last = page.is_last + self.position = 0 + + def _get_current_row(self): + """ + Returns: + list: The row pointed by the current position. + """ + values = [] + for i in range(self.page.column_count): + value = self.page.get_column_value(i, self.position) + + # The column might contain user objects so we have to deserialize it. + # This call is no-op if the value is not Data. + values.append(self.deserialize_fn(value)) + + return values + + +class _FutureProducingIterator(_IteratorBase): + """An iterator that produces infinite stream of Futures. It is the + responsibility of the user to either call them in blocking fashion, + or call ``next`` only if the current call to next did not raise + ``StopIteration`` error (possibly with callback-based code). + """ + + def __iter__(self): + return self + + def next(self): + # Defined for backward-compatibility with Python 2. + return self.__next__() + + def __next__(self): + return self._has_next().continue_with(self._has_next_continuation) + + def _has_next_continuation(self, future): + """Based on the call to :func:`_has_next`, either + raises ``StopIteration`` error or gets the current row + and returns it. + + Args: + future (hazelcast.future.Future): + + Returns: + SqlRow: + """ + has_next = future.result() + if not has_next: + # Iterator is exhausted, raise this to inform the user. + # If the user continues to call next, we will continuously + # will raise this. + raise StopIteration + + row = self._get_current_row() + self.position += 1 + return SqlRow(self.row_metadata, row) + + def _has_next(self): + """Returns a Future indicating whether there are more rows + left to iterate. + + Returns: + hazelcast.future.Future: + """ + if self.position == self.row_count: + # We exhausted the current page. + + if self.is_last: + # This was the last page, no row left + # on the server side. + return ImmediateFuture(False) + + # It seems that there are some rows left on the server. + # Fetch them, and then return. + return self.fetch_fn().continue_with(self._fetch_continuation) + + # There are some elements left in the current page. + return ImmediateFuture(True) + + def _fetch_continuation(self, future): + """After a new page is fetched, updates the internal state + of the iterator and returns whether or not there are some + rows in the fetched page. + + Args: + future (hazelcast.future.Future): + + Returns: + hazelcast.future.Future: + """ + page = future.result() + self.on_next_page(page) + return self._has_next() + + +class _BlockingIterator(_IteratorBase): + """An iterator that blocks when the current page is exhausted + and we need to fetch a new page from the server. Otherwise, + it returns immediately with an object. + + This version is more performant than the Future-producing + counterpart in a sense that, it does not box everything with + a Future object. + """ + + def __iter__(self): + return self + + def next(self): + # Defined for backward-compatibility with Python 2. + return self.__next__() + + def __next__(self): + if not self._has_next(): + # No more rows are left. + raise StopIteration + + row = self._get_current_row() + self.position += 1 + return SqlRow(self.row_metadata, row) + + def _has_next(self): + while self.position == self.row_count: + # We exhausted the current page. + + if self.is_last: + # No more rows left on the server. + return False + + # Block while waiting for the next page. + page = self.fetch_fn().result() + + # Update the internal state with the next page. + self.on_next_page(page) + + # There are some rows left in the current page. + return True + + +class SqlResult(object): + """SQL query result. + + Depending on the statement type it represents a stream of + rows or an update count. + + To iterate over the stream of rows, there are two possible options. + + The first, and the easiest one is to iterate over the rows + in a blocking fashion. :: + + result = client.sql.execute("SELECT ...") + for row in result: + # Process the row. + print(row) + + The second option is to use the non-blocking API with callbacks. :: + + result = client.sql.execute("SELECT ...") + it = result.iterator() # Future of iterator + + def on_iterator_response(iterator_future): + iterator = iterator_future.result() + + def on_next_row(row_future): + try: + row = row_future.result() + # Process the row. + print(row) + + # Iterate over the next row. + next(iterator).add_done_callback(on_next_row) + except StopIteration: + # Exhausted the iterator. No more rows are left. + pass + + next(iterator).add_done_callback(on_next_row) + + it.add_done_callback(on_iterator_response) + + When in doubt, use the blocking API shown in the first code sample. + + Also, one might call :func:`close` over the result object to + release the resources associated with the result on the server side. + It might also be used to cancel query execution on the server side + if it is still active. + + To get the update count, use the :func:`update_count`. :: + + update_count = client.sql.execute("SELECT ...").update_count().result() + + One does not have to call :func:`close` in this case. + """ + + def __init__(self, sql_service, connection, query_id, cursor_buffer_size, execute_future): + self._sql_service = sql_service + """_InternalSqlService: Reference to the SQL service.""" + + self._connection = connection + """hazelcast.connection.Connection: Reference to the connection + that the execute request is made to.""" + + self._query_id = query_id + """_SqlQueryId: Uniuqe id of the SQL query.""" + + self._cursor_buffer_size = cursor_buffer_size + """int: Size of the cursor buffer measured in the number of rows.""" + + self._lock = RLock() + """RLock: Protects the shared access to instance variables below.""" + + self._execute_response = Future() + """Future: Will be resolved with :class:`_ExecuteResponse` once the + execute request is resolved.""" + + self._iterator_requested = False + """bool: Flag that shows whether an iterator is already requested.""" + + self._closed = False + """bool: Flag that shows whether the query execution is still active + on the server side. When ``True``, there is no need to send the "close" + request to the server.""" + + self._fetch_future = None + """Future: Will be set, if there are more pages to fetch on the server + side. It should be set to ``None`` once the fetch is completed.""" + + execute_future.add_done_callback(self._handle_execute_response) + + def iterator(self): + """ + Returns the iterator over the result rows. + + The iterator may be requested only once. + + Returns: + Future[Iterator[Future[SqlRow]]]: Iterator that produces Future + of :class:`SqlRow` s. See the class documentation for the correct + way to use this. + """ + return self._get_iterator(False) + + def is_row_set(self): + """ + Returns: + Future[bool]: Whether this result has rows to iterate. + """ + + def continuation(future): + response = future.result() + # By design, if the row_metadata (or row_page) is None, + # we only got the update count. + return response.row_metadata is not None + + return self._execute_response.continue_with(continuation) + + def update_count(self): + """Returns the number of rows updated by the statement or ``-1`` if this + result is a row set. In case the result doesn't contain rows but the + update count isn't applicable or known, ``0`` is returned. + + Returns: + Future[int]: + """ + + def continuation(future): + response = future.result() + # This will be set to -1, when we got row set on the client side. + # See _on_execute_response. + return response.update_count + + return self._execute_response.continue_with(continuation) + + def get_row_metadata(self): + """Gets the row metadata. + + Returns: + Future[SqlRowMetadata]: + """ + + def continuation(future): + response = future.result() + + if not response.row_metadata: + raise ValueError("This result contains only update count") + + return response.row_metadata + + return self._execute_response.continue_with(continuation) + + def close(self): + """Release the resources associated with the query result. + + The query engine delivers the rows asynchronously. The query may + become inactive even before all rows are consumed. The invocation + of this command will cancel the execution of the query on all members + if the query is still active. Otherwise it is no-op. For a result + with an update count it is always no-op. + + Returns: + Future[None]: + """ + + with self._lock: + if self._closed: + # Do nothing if the result is already closed. + return ImmediateFuture(None) + + error = HazelcastSqlError( + self._sql_service.get_client_id(), + _SqlErrorCode.CANCELLED_BY_USER, + "Query was cancelled by the user", + None, + ) + + if not self._execute_response.done(): + # If the cancellation is initiated before the first response is + # received, then throw cancellation errors on the dependent + # methods (update count, row metadata, iterator). + self._on_execute_error(error) + + if not self._fetch_future: + # Make sure that all subsequent fetches will fail. + self._fetch_future = Future() + + self._on_fetch_error(error) + + def wrap_error_on_failure(f): + # If the close request is failed somehow, + # wrap it in a HazelcastSqlError. + try: + return f.result() + except Exception as e: + raise self._sql_service.re_raise(e, self._connection) + + self._closed = True + + # Send the close request + return self._sql_service.close(self._connection, self._query_id).continue_with( + wrap_error_on_failure + ) + + def __iter__(self): + # Get blocking iterator, and wait for the + # first page. + return self._get_iterator(True).result() + + def _get_iterator(self, should_get_blocking): + """Gets the iterator after the execute request finishes. + + Args: + should_get_blocking (bool): Whether to get a blocking iterator. + + Returns: + Future[Iterator]: + """ + + def continuation(future): + response = future.result() + + with self._lock: + if not response.row_metadata: + # Can't get an iterator when we only have update count + raise ValueError("This result contains only update count") + + if self._iterator_requested: + # Can't get an iterator when we already get one + raise ValueError("Iterator can be requested only once") + + self._iterator_requested = True + + if should_get_blocking: + iterator = _BlockingIterator( + response.row_metadata, + self._fetch_next_page, + self._sql_service.deserialize_object, + ) + else: + iterator = _FutureProducingIterator( + response.row_metadata, + self._fetch_next_page, + self._sql_service.deserialize_object, + ) + + # Pass the first page information to the iterator + iterator.on_next_page(response.row_page) + return iterator + + return self._execute_response.continue_with(continuation) + + def _fetch_next_page(self): + """Fetches the next page, if there is no fetch request + in-flight. + + Returns: + Future[_SqlPage]: + """ + with self._lock: + if self._fetch_future: + # A fetch request is already in-flight, return it. + return self._fetch_future + + future = Future() + self._fetch_future = future + + self._sql_service.fetch( + self._connection, self._query_id, self._cursor_buffer_size + ).add_done_callback(self._handle_fetch_response) + + # Need to return future, not self._fetch_future, because through + # some unlucky timing, we might call _handle_fetch_response + # before returning, which could set self._fetch_future to + # None. + return future + + def _handle_fetch_response(self, future): + """Handles the result of the fetch request, by either: + + - setting it to exception, so that the future calls to + fetch fails immediately. + - setting it to next page, and setting self._fetch_future + to None so that the next fetch request might actually + try to fetch something from the server. + + Args: + future (Future): The response from the server for + the fetch request. + """ + try: + response = future.result() + + response_error = self._handle_response_error(response["error"]) + if response_error: + # There is a server side error sent to client. + self._on_fetch_error(response_error) + return + + # The result contains the next page, as expected. + self._on_fetch_response(response["row_page"]) + except Exception as e: + # Something went bad, we couldn't get response from + # the server, invocation failed. + self._on_fetch_error(self._sql_service.re_raise(e, self._connection)) + + def _on_fetch_error(self, error): + """Sets the fetch future with exception, but not resetting it + so that the next fetch request fails immediately. + + Args: + error (Exception): The error. + """ + with self._lock: + self._fetch_future.set_exception(error) + + def _on_fetch_response(self, page): + """Sets the fetch future with the next page, + resets it, and if this is the last page, + marks the result as closed. + + Args: + page (_SqlPage): The next page. + """ + with self._lock: + self._fetch_future.set_result(page) + if page.is_last: + # This is the last page, there is nothing + # more on the server. + self._mark_closed() + + self._fetch_future = None + + def _handle_execute_response(self, future): + """Handles the result of the execute request, by either: + + - setting it to an exception so that the dependent methods + (iterator, update_count etc.) fails immediately + - setting it to an execute response + + Args: + future (Future): + """ + try: + response = future.result() + + response_error = self._handle_response_error(response["error"]) + if response_error: + # There is a server-side error sent to the client. + self._on_execute_error(response_error) + return + + row_metadata = response["row_metadata"] + if row_metadata is not None: + # The result contains some rows, not an update count. + row_metadata = SqlRowMetadata(row_metadata) + + self._on_execute_response(row_metadata, response["row_page"], response["update_count"]) + except Exception as e: + # Something went bad, we couldn't get response from + # the server, invocation failed. + self._on_execute_error(self._sql_service.re_raise(e, self._connection)) + + def _handle_response_error(self, error): + """If the error is not ``None``, return it as + :class:`HazelcastSqlError` so that we can raise + it to user. + + Args: + error (_SqlError): The error or ``None``. + + Returns: + HazelcastSqlError: If the error is not ``None``, + ``None`` otherwise. + """ + if error: + return HazelcastSqlError(error.originating_member_uuid, error.code, error.message, None) + return None + + def _on_execute_error(self, error): + """Called when the first execute request is failed. + + Args: + error (HazelcastSqlError): The wrapped error that can + be raised to the user. + """ + with self._lock: + if self._closed: + # User might be already cancelled it. + return + + self._execute_response.set_exception(error) + + def _on_execute_response(self, row_metadata, row_page, update_count): + """Called when the first execute request is succeed. + + Args: + row_metadata (SqlRowMetadata): The row metadata. Might be ``None`` + if the response only contains the update count. + row_page (_SqlPage): The first page of the result. Might be + ``None`` if the response only contains the update count. + update_count (int): The update count. + """ + with self._lock: + if self._closed: + # User might be already cancelled it. + return + + if row_metadata: + # Result contains the row set for the query. + # Set the update count to -1. + response = _ExecuteResponse(row_metadata, row_page, -1) + + if row_page.is_last: + # This is the last page, close the result. + self._mark_closed() + + self._execute_response.set_result(response) + else: + # Result only contains the update count. + response = _ExecuteResponse(None, None, update_count) + self._execute_response.set_result(response) + + # There is nothing more we can get from the server. + self._mark_closed() + + def _mark_closed(self): + """Marks the result as closed.""" + with self._lock: + self._closed = True + + +class _InternalSqlService(object): + """Internal SQL service that offers more public API + than the one exposed to the user. + """ + + def __init__(self, connection_manager, serialization_service, invocation_service): + self._connection_manager = connection_manager + self._serialization_service = serialization_service + self._invocation_service = invocation_service + + def execute(self, sql, *params): + """Constructs a statement and executes it. + + Args: + sql (str): SQL string. + *params: Query parameters. + + Returns: + SqlResult: The execution result. + """ + statement = SqlStatement(sql) + + for param in params: + statement.add_parameter(param) + + return self.execute_statement(statement) + + def execute_statement(self, statement): + """Executes the given statement. + + Args: + statement (SqlStatement): The statement to execute. + + Returns: + SqlResult: The execution result. + """ + + # Get a random Data member (non-lite member) + connection = self._connection_manager.get_random_connection(True) + if not connection: + # Either the client is not connected to the cluster, or + # there are no data members in the cluster. + raise HazelcastSqlError( + self.get_client_id(), + _SqlErrorCode.CONNECTION_PROBLEM, + "Client is not currently connected to the cluster.", + None, + ) + + # Create a new, unique query id. + query_id = _SqlQueryId.from_uuid(connection.remote_uuid) + + # Serialize the passed parameters. + serialized_params = [] + for param in statement.parameters: + serialized_params.append(self._serialization_service.to_data(param)) + + request = sql_execute_codec.encode_request( + statement.sql, + serialized_params, + # to_millis expects None to produce -1 + to_millis(None if statement.timeout == -1 else statement.timeout), + statement.cursor_buffer_size, + statement.schema, + statement.expected_result_type, + query_id, + ) + + invocation = Invocation( + request, connection=connection, response_handler=sql_execute_codec.decode_response + ) + + result = SqlResult( + self, connection, query_id, statement.cursor_buffer_size, invocation.future + ) + + self._invocation_service.invoke(invocation) + + return result + + def deserialize_object(self, obj): + return self._serialization_service.to_object(obj) + + def fetch(self, connection, query_id, cursor_buffer_size): + """Fetches the next page of the query execution. + + Args: + connection (hazelcast.connection.Connection): Connection + that the first execute request, hence the fetch request + must route to. + query_id (_SqlQueryId): Unique id of the query. + cursor_buffer_size (int): Size of cursor buffer. Same as + the one used in the first execute request. + + Returns: + Future: Decoded fetch response. + """ + request = sql_fetch_codec.encode_request(query_id, cursor_buffer_size) + invocation = Invocation( + request, connection=connection, response_handler=sql_fetch_codec.decode_response + ) + self._invocation_service.invoke(invocation) + return invocation.future + + def get_client_id(self): + """ + Returns: + uuid.UUID: Unique client UUID. + """ + return self._connection_manager.client_uuid + + def re_raise(self, error, connection): + """Returns the error wrapped as the :class:`HazelcastSqlError` + so that it can be raised to the user. + + Args: + error (Exception): The error to re raise. + connection (hazelcast.connection.Connection): Connection + that the query requests are routed to. If it is not + live, we will inform the user about the possible + cluster topology change. + + Returns: + HazelcastSqlError: The re raised error. + """ + if not connection.live: + return HazelcastSqlError( + self.get_client_id(), + _SqlErrorCode.CONNECTION_PROBLEM, + "Cluster topology changed while a query was executed: Member cannot be reached: " + + connection.remote_address, + error, + ) + + if isinstance(error, HazelcastSqlError): + return error + + return HazelcastSqlError( + self.get_client_id(), _SqlErrorCode.GENERIC, try_to_get_error_message(error), error + ) + + def close(self, connection, query_id): + """Closes the remote query cursor. + + Args: + connection (hazelcast.connection.Connection): Connection + that the first execute request, hence the close request + must route to. + query_id (_SqlQueryId): The query id to close. + + Returns: + Future: + """ + request = sql_close_codec.encode_request(query_id) + invocation = Invocation(request, connection=connection) + self._invocation_service.invoke(invocation) + return invocation.future + + +class SqlExpectedResultType(object): + """The expected statement result type.""" + + ANY = 0 + """ + The statement may produce either rows or an update count. + """ + + ROWS = 1 + """ + The statement must produce rows. An exception is thrown is the statement produces an update count. + """ + + UPDATE_COUNT = 2 + """ + The statement must produce an update count. An exception is thrown is the statement produces rows. + """ + + +class SqlStatement(object): + """Definition of an SQL statement. + + This object is mutable. Properties are read once before the execution + is started. Changes to properties do not affect the behavior of already + running statements. + """ + + TIMEOUT_NOT_SET = -1 + + TIMEOUT_DISABLED = 0 + + DEFAULT_TIMEOUT = TIMEOUT_NOT_SET + + DEFAULT_CURSOR_BUFFER_SIZE = 4096 + + def __init__(self, sql): + self.sql = sql + self._parameters = [] + self._timeout = SqlStatement.DEFAULT_TIMEOUT + self._cursor_buffer_size = SqlStatement.DEFAULT_CURSOR_BUFFER_SIZE + self._schema = None + self._expected_result_type = SqlExpectedResultType.ANY + + @property + def sql(self): + """str: The SQL string to be executed.""" + return self._sql + + @sql.setter + def sql(self, sql): + check_not_none(sql, "SQL cannot be None") + + if not sql.strip(): + raise ValueError("SQL cannot be empty") + + self._sql = sql + + @property + def schema(self): + """str: The schema name. The engine will try to resolve the + non-qualified object identifiers from the statement in the given + schema. If not found, the default search path will be used, which + looks for objects in the predefined schemas ``partitioned`` and + ``public``. + + The schema name is case sensitive. For example, ``foo`` and ``Foo`` + are different schemas. + + The default value is ``None`` meaning only the default search path is + used. + """ + return self._schema + + @schema.setter + def schema(self, schema): + self._schema = schema + + @property + def parameters(self): + """list: Sets the statement parameters. + + You may define parameter placeholders in the statement with the ``?`` + character. For every placeholder, a parameter value must be provided. + + When the setter is called, the content of the parameters list is copied. + Subsequent changes to the original list don't change the statement parameters. + """ + return self._parameters + + @parameters.setter + def parameters(self, parameters): + if not parameters: + self._parameters = [] + else: + self._parameters = list(parameters) + + @property + def timeout(self): + """float: The execution timeout in seconds. + + If the timeout is reached for a running statement, it will be + cancelled forcefully. + + Zero value means no timeout. :const:`TIMEOUT_NOT_SET` means that + the value from the server-side config will be used. Other negative + values are prohibited. + + Defaults to :const:`TIMEOUT_NOT_SET`. + """ + return self._timeout + + @timeout.setter + def timeout(self, timeout): + if timeout < 0 and timeout != SqlStatement.TIMEOUT_NOT_SET: + raise ValueError("Timeout must be non-negative or -1, not %s" % timeout) + + self._timeout = timeout + + @property + def cursor_buffer_size(self): + """int: The cursor buffer size (measured in the number of rows). + + When a statement is submitted for execution, a :class:`SqlResult` + is returned as a result. When rows are ready to be consumed, + they are put into an internal buffer of the cursor. This parameter + defines the maximum number of rows in that buffer. When the threshold + is reached, the backpressure mechanism will slow down the execution, + possibly to a complete halt, to prevent out-of-memory. + + Only positive values are allowed. + + The default value is expected to work well for most workloads. A bigger + buffer size may give you a slight performance boost for queries with + large result sets at the cost of increased memory consumption. + + Defaults to :const:`DEFAULT_CURSOR_BUFFER_SIZE`. + """ + return self._cursor_buffer_size + + @cursor_buffer_size.setter + def cursor_buffer_size(self, cursor_buffer_size): + if cursor_buffer_size <= 0: + raise ValueError("Cursor buffer size must be positive, not %s" % cursor_buffer_size) + self._cursor_buffer_size = cursor_buffer_size + + @property + def expected_result_type(self): + """SqlExpectedResultType: The expected result type.""" + return self._expected_result_type + + @expected_result_type.setter + def expected_result_type(self, expected_result_type): + check_not_none(expected_result_type, "Expected result type cannot be None") + self._expected_result_type = expected_result_type + + def add_parameter(self, parameter): + """Adds a single parameter to the end of the parameters list. + + Args: + parameter: The parameter. + + See Also: + :attr:`parameters` + + :func:`clear_parameters` + """ + self._parameters.append(parameter) + + def clear_parameters(self): + """Clears statement parameters.""" + self._parameters = [] + + def copy(self): + """Creates a copy of this instance. + + Returns: + SqlStatement: + """ + copied = SqlStatement(self.sql) + copied.parameters = list(self.parameters) + copied.timeout = self.timeout + copied.cursor_buffer_size = self.cursor_buffer_size + copied.schema = self.schema + copied.expected_result_type = self.expected_result_type + return copied + + def __repr__(self): + return ( + "SqlStatement(schema=%s, sql=%s, parameters=%s, timeout=%s," + " cursor_buffer_size=%s, expected_result_type=%s)" + % ( + self.schema, + self.sql, + self.parameters, + self.timeout, + self.cursor_buffer_size, + self._expected_result_type, + ) + ) + + +# These are imported at the bottom of the page to get rid of the +# cyclic import errors. +from hazelcast.protocol.codec import sql_execute_codec, sql_fetch_codec, sql_close_codec diff --git a/hazelcast/util.py b/hazelcast/util.py index 6accedfa06..a149341f37 100644 --- a/hazelcast/util.py +++ b/hazelcast/util.py @@ -1,6 +1,10 @@ +import binascii import random import threading import time +import uuid + +from hazelcast.serialization import UUID_MSB_SHIFT, UUID_LSB_MASK, UUID_MSB_MASK try: from collections.abc import Sequence, Iterable @@ -272,18 +276,65 @@ def next(self): """ raise NotImplementedError("next") + def next_data_member(self): + """Returns the next data member to route to. + + Returns: + hazelcast.core.MemberInfo: The next data member or + ``None`` if no data member is available. + """ + return None + + def can_get_next_data_member(self): + """Returns whether this instance supports getting data members + through a call to :func:`next_data_member`. + + Returns: + bool: ``True`` if this instance supports getting data members. + """ + return False + + +class _Members(object): + __slots__ = ("members", "data_members") + + def __init__(self, members, data_members): + self.members = members + self.data_members = data_members + class _AbstractLoadBalancer(LoadBalancer): def __init__(self): self._cluster_service = None - self._members = [] + self._members = _Members([], []) def init(self, cluster_service): self._cluster_service = cluster_service cluster_service.add_listener(self._listener, self._listener, True) + def next(self): + members = self._members.members + return self._next(members) + + def next_data_member(self): + members = self._members.data_members + return self._next(members) + + def can_get_next_data_member(self): + return True + def _listener(self, _): - self._members = self._cluster_service.get_members() + members = self._cluster_service.get_members() + data_members = [] + + for member in members: + if not member.lite_member: + data_members.append(member) + + self._members = _Members(members, data_members) + + def _next(self, members): + raise NotImplementedError("_next") class RoundRobinLB(_AbstractLoadBalancer): @@ -298,8 +349,7 @@ def __init__(self): super(RoundRobinLB, self).__init__() self._idx = 0 - def next(self): - members = self._members + def _next(self, members): if not members: return None @@ -312,15 +362,14 @@ def next(self): class RandomLB(_AbstractLoadBalancer): """A load balancer that selects a random member to route to.""" - def next(self): - members = self._members + def _next(self, members): if not members: return None idx = random.randrange(0, len(members)) return members[idx] -class IterationType: +class IterationType(object): """To differentiate users selection on result collection on map-wide operations like ``entry_set``, ``key_set``, ``values`` etc. """ @@ -333,3 +382,79 @@ class IterationType: ENTRY = 2 """Iterate over entries""" + + +class UuidUtil(object): + @staticmethod + def to_bits(value): + i = value.int + most_significant_bits = to_signed(i >> UUID_MSB_SHIFT, 64) + least_significant_bits = to_signed(i & UUID_LSB_MASK, 64) + return most_significant_bits, least_significant_bits + + @staticmethod + def from_bits(most_significant_bits, least_significant_bits): + return uuid.UUID( + int=( + ((most_significant_bits << UUID_MSB_SHIFT) & UUID_MSB_MASK) + | (least_significant_bits & UUID_LSB_MASK) + ) + ) + + +if hasattr(int, "from_bytes"): + + def int_from_bytes(buffer): + return int.from_bytes(buffer, "big", signed=True) + + +else: + # Compatibility with Python 2 + def int_from_bytes(buffer): + buffer = bytearray(buffer) + if buffer[0] & 0x80: + neg = bytearray() + for c in buffer: + neg.append(c ^ 0xFF) + return -1 * int(binascii.hexlify(neg), 16) - 1 + return int(binascii.hexlify(buffer), 16) + + +try: + from datetime import timezone +except ImportError: + from datetime import tzinfo, timedelta + + # There is no tzinfo implementation(timezone) in the + # Python 2. Here we provide the bare minimum + # to the user. + class FixedOffsetTimezone(tzinfo): + __slots__ = ("_offset",) + + def __init__(self, offset): + self._offset = offset + + def utcoffset(self, dt): + return self._offset + + def tzname(self, dt): + return None + + def dst(self, dt): + return timedelta(0) + + timezone = FixedOffsetTimezone + + +def try_to_get_error_message(error): + # If the error has a message attribute, + # return it. If not, almost all of the + # built-in errors (and Hazelcast Errors) + # set the exception message as the first + # parameter of args. If it is not there, + # then return None. + if hasattr(error, "message"): + return error.message + elif len(error.args) > 0: + return error.args[0] + return None diff --git a/start_rc.py b/start_rc.py index 60851429c0..4bc56f55e3 100644 --- a/start_rc.py +++ b/start_rc.py @@ -3,7 +3,7 @@ import sys from os.path import isfile -SERVER_VERSION = "4.1.4-SNAPSHOT" +SERVER_VERSION = "4.2.1-SNAPSHOT" RC_VERSION = "0.8-SNAPSHOT" RELEASE_REPO = "http://repo1.maven.apache.org/maven2" @@ -72,7 +72,7 @@ def start_rc(stdout=None, stderr=None): enterprise_key = os.environ.get("HAZELCAST_ENTERPRISE_KEY", None) if enterprise_key: - server = download_if_necessary(ENTERPRISE_REPO, "hazelcast-enterprise", SERVER_VERSION) + server = download_if_necessary(ENTERPRISE_REPO, "hazelcast-enterprise-all", SERVER_VERSION) ep_tests = download_if_necessary( ENTERPRISE_REPO, "hazelcast-enterprise", SERVER_VERSION, True ) @@ -80,7 +80,7 @@ def start_rc(stdout=None, stderr=None): artifacts.append(server) artifacts.append(ep_tests) else: - server = download_if_necessary(REPO, "hazelcast", SERVER_VERSION) + server = download_if_necessary(REPO, "hazelcast-all", SERVER_VERSION) artifacts.append(server) class_path = CLASS_PATH_SEPARATOR.join(artifacts) diff --git a/tests/integration/backward_compatible/cluster_test.py b/tests/integration/backward_compatible/cluster_test.py index 89343f8187..64bca4b1f3 100644 --- a/tests/integration/backward_compatible/cluster_test.py +++ b/tests/integration/backward_compatible/cluster_test.py @@ -148,7 +148,9 @@ def test_random_load_balancer(self): lb = client._load_balancer self.assertTrue(isinstance(lb, RandomLB)) - six.assertCountEqual(self, self.addresses, list(map(lambda m: m.address, lb._members))) + six.assertCountEqual( + self, self.addresses, list(map(lambda m: m.address, lb._members.members)) + ) for _ in range(10): self.assertTrue(lb.next().address in self.addresses) @@ -161,7 +163,9 @@ def test_round_robin_load_balancer(self): lb = client._load_balancer self.assertTrue(isinstance(lb, RoundRobinLB)) - six.assertCountEqual(self, self.addresses, list(map(lambda m: m.address, lb._members))) + six.assertCountEqual( + self, self.addresses, list(map(lambda m: m.address, lb._members.members)) + ) for i in range(10): self.assertEqual(self.addresses[i % len(self.addresses)], lb.next().address) diff --git a/tests/integration/backward_compatible/sql_test.py b/tests/integration/backward_compatible/sql_test.py new file mode 100644 index 0000000000..ac8a8d9b98 --- /dev/null +++ b/tests/integration/backward_compatible/sql_test.py @@ -0,0 +1,570 @@ +import datetime +import decimal +import random +import string + +from hazelcast import six +from hazelcast.serialization.api import Portable +from hazelcast.sql import HazelcastSqlError, SqlStatement, SqlExpectedResultType, SqlColumnType +from hazelcast.util import timezone +from tests.base import SingleMemberTestCase +from tests.hzrc.ttypes import Lang + +SERVER_CONFIG = """ + + + + com.hazelcast.client.test.PortableFactory + + + + +""" + + +class SqlTestBase(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + return SERVER_CONFIG + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + config["portable_factories"] = {666: {6: Student}} + return config + + def setUp(self): + self.map_name = random_string() + self.map = self.client.get_map(self.map_name).blocking() + + def tearDown(self): + self.map.clear() + + def _populate_map(self, entry_count=10, value_factory=lambda v: v): + entries = {i: value_factory(i) for i in range(entry_count)} + self.map.put_all(entries) + + +class SqlServiceTest(SqlTestBase): + def test_execute(self): + entry_count = 11 + self._populate_map(entry_count) + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + six.assertCountEqual( + self, + [(i, i) for i in range(entry_count)], + [(row.get_object("__key"), row.get_object("this")) for row in result], + ) + + def test_execute_with_params(self): + entry_count = 13 + self._populate_map(entry_count) + result = self.client.sql.execute( + "SELECT this FROM %s WHERE __key > ? AND this > ?" % self.map_name, 5, 6 + ) + six.assertCountEqual( + self, + [i for i in range(7, entry_count)], + [row.get_object("this") for row in result], + ) + + def test_execute_with_mismatched_params_when_sql_has_more(self): + self._populate_map() + result = self.client.sql.execute( + "SELECT * FROM %s WHERE __key > ? AND this > ?" % self.map_name, 5 + ) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + def test_execute_with_mismatched_params_when_params_has_more(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s WHERE this > ?" % self.map_name, 5, 6) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + def test_execute_statement(self): + entry_count = 12 + self._populate_map(entry_count, str) + statement = SqlStatement("SELECT this FROM %s" % self.map_name) + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual( + self, + [str(i) for i in range(entry_count)], + [row.get_object_with_index(0) for row in result], + ) + + def test_execute_statement_with_params(self): + entry_count = 20 + self._populate_map(entry_count, lambda v: Student(v, v)) + statement = SqlStatement( + "SELECT age FROM %s WHERE height = CAST(? AS REAL)" % self.map_name + ) + statement.add_parameter(13.0) + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual(self, [13], [row.get_object("age") for row in result]) + + def test_execute_statement_with_mismatched_params_when_sql_has_more(self): + self._populate_map() + statement = SqlStatement("SELECT * FROM %s WHERE __key > ? AND this > ?" % self.map_name) + statement.parameters = [5] + result = self.client.sql.execute_statement(statement) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + def test_execute_statement_with_mismatched_params_when_params_has_more(self): + self._populate_map() + statement = SqlStatement("SELECT * FROM %s WHERE this > ?" % self.map_name) + statement.parameters = [5, 6] + result = self.client.sql.execute_statement(statement) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + def test_execute_statement_with_timeout(self): + entry_count = 100 + self._populate_map(entry_count, lambda v: Student(v, v)) + statement = SqlStatement("SELECT age FROM %s WHERE height < 10" % self.map_name) + statement.timeout = 100 + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object("age") for row in result] + ) + + def test_execute_statement_with_cursor_buffer_size(self): + entry_count = 50 + self._populate_map(entry_count, lambda v: Student(v, v)) + statement = SqlStatement("SELECT age FROM %s" % self.map_name) + statement.cursor_buffer_size = 3 + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual( + self, [i for i in range(entry_count)], [row.get_object("age") for row in result] + ) + + def test_execute_statement_with_copy(self): + self._populate_map() + statement = SqlStatement("SELECT __key FROM %s WHERE this >= ?" % self.map_name) + statement.parameters = [9] + copy_statement = statement.copy() + statement.clear_parameters() + + result = self.client.sql.execute_statement(copy_statement) + self.assertEqual([9], [row.get_object_with_index(0) for row in result]) + + result = self.client.sql.execute_statement(statement) + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + # Can't test the case we would expect an update count, because the IMDG SQL + # engine does not support such query as of now. + def test_execute_statement_with_expected_result_type_of_rows_when_rows_are_expected(self): + entry_count = 100 + self._populate_map(entry_count, lambda v: Student(v, v)) + statement = SqlStatement("SELECT age FROM %s WHERE age < 3" % self.map_name) + statement.expected_result_type = SqlExpectedResultType.ROWS + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual(self, [i for i in range(3)], [row.get_object("age") for row in result]) + + # Can't test the case we would expect an update count, because the IMDG SQL + # engine does not support such query as of now. + def test_execute_statement_with_expected_result_type_of_update_count_when_rows_are_expected( + self, + ): + self._populate_map() + statement = SqlStatement("SELECT * FROM %s" % self.map_name) + statement.expected_result_type = SqlExpectedResultType.UPDATE_COUNT + result = self.client.sql.execute_statement(statement) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + # Can't test the schema, because the IMDG SQL engine does not support + # specifying a schema yet. + + +class SqlResultTest(SqlTestBase): + def test_blocking_iterator(self): + self._populate_map() + result = self.client.sql.execute("SELECT __key FROM %s" % self.map_name) + + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object_with_index(0) for row in result] + ) + + def test_blocking_iterator_when_iterator_requested_more_than_once(self): + self._populate_map() + result = self.client.sql.execute("SELECT this FROM %s" % self.map_name) + + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object_with_index(0) for row in result] + ) + + with self.assertRaises(ValueError): + for _ in result: + pass + + def test_blocking_iterator_with_multi_paged_result(self): + self._populate_map() + statement = SqlStatement("SELECT __key FROM %s" % self.map_name) + statement.cursor_buffer_size = 1 # Each page will contain just 1 result + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object_with_index(0) for row in result] + ) + + def test_iterator(self): + self._populate_map() + result = self.client.sql.execute("SELECT __key FROM %s" % self.map_name) + + iterator_future = result.iterator() + + rows = [] + + def cb(f): + iterator = f.result() + + def iterate(row_future): + try: + row = row_future.result() + rows.append(row.get_object_with_index(0)) + next(iterator).add_done_callback(iterate) + except StopIteration: + pass + + next(iterator).add_done_callback(iterate) + + iterator_future.add_done_callback(cb) + + def assertion(): + six.assertCountEqual( + self, + [i for i in range(10)], + rows, + ) + + self.assertTrueEventually(assertion) + + def test_iterator_when_iterator_requested_more_than_once(self): + self._populate_map() + result = self.client.sql.execute("SELECT this FROM %s" % self.map_name) + + iterator = result.iterator().result() + + rows = [] + for row_future in iterator: + try: + row = row_future.result() + rows.append(row.get_object("this")) + except StopIteration: + break + + six.assertCountEqual(self, [i for i in range(10)], rows) + + with self.assertRaises(ValueError): + result.iterator().result() + + def test_iterator_with_multi_paged_result(self): + self._populate_map() + statement = SqlStatement("SELECT __key FROM %s" % self.map_name) + statement.cursor_buffer_size = 1 # Each page will contain just 1 result + result = self.client.sql.execute_statement(statement) + + iterator = result.iterator().result() + + rows = [] + for row_future in iterator: + try: + row = row_future.result() + rows.append(row.get_object_with_index(0)) + except StopIteration: + break + + six.assertCountEqual(self, [i for i in range(10)], rows) + + def test_request_blocking_iterator_after_iterator(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + + result.iterator().result() + + with self.assertRaises(ValueError): + for _ in result: + pass + + def test_request_iterator_after_blocking_iterator(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + + for _ in result: + pass + + with self.assertRaises(ValueError): + result.iterator().result() + + # Can't test the case we would expect row to be not set, because the IMDG SQL + # engine does not support update/insert queries now. + def test_is_row_set_when_row_is_set(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + self.assertTrue(result.is_row_set().result()) + + # Can't test the case we would expect a non-negative updated count, because the IMDG SQL + # engine does not support update/insert queries now. + def test_update_count_when_there_is_no_update(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s WHERE __key > 5" % self.map_name) + self.assertEqual(-1, result.update_count().result()) + + def test_get_row_metadata(self): + self._populate_map(value_factory=str) + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + row_metadata = result.get_row_metadata().result() + self.assertEqual(2, row_metadata.column_count) + columns = row_metadata.columns + self.assertEqual(SqlColumnType.INTEGER, columns[0].type) + self.assertEqual(SqlColumnType.VARCHAR, columns[1].type) + self.assertTrue(columns[0].nullable) + self.assertTrue(columns[1].nullable) + + def test_close_after_query_execution(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + for _ in result: + pass + self.assertIsNone(result.close().result()) + + def test_close_when_query_is_active(self): + self._populate_map() + statement = SqlStatement("SELECT * FROM %s " % self.map_name) + statement.cursor_buffer_size = 1 # Each page will contain 1 row + result = self.client.sql.execute_statement(statement) + + # Fetch couple of pages + iterator = iter(result) + next(iterator) + + self.assertIsNone(result.close().result()) + + with self.assertRaises(HazelcastSqlError): + # Next fetch requests should fail + next(iterator) + + +class SqlColumnTypesReadTest(SqlTestBase): + def test_varchar(self): + def value_factory(key): + return "val-%s" % key + + self._populate_map(value_factory=value_factory) + self._validate_rows(SqlColumnType.VARCHAR, value_factory) + + def test_boolean(self): + def value_factory(key): + return key % 2 == 0 + + self._populate_map(value_factory=value_factory) + self._validate_rows(SqlColumnType.BOOLEAN, value_factory) + + def test_tiny_int(self): + self._populate_map_via_rc("new java.lang.Byte(key)") + self._validate_rows(SqlColumnType.TINYINT) + + def test_small_int(self): + self._populate_map_via_rc("new java.lang.Short(key)") + self._validate_rows(SqlColumnType.SMALLINT) + + def test_integer(self): + self._populate_map_via_rc("new java.lang.Integer(key)") + self._validate_rows(SqlColumnType.INTEGER) + + def test_big_int(self): + self._populate_map_via_rc("new java.lang.Long(key)") + self._validate_rows(SqlColumnType.BIGINT) + + def test_real(self): + self._populate_map_via_rc("new java.lang.Float(key)") + self._validate_rows(SqlColumnType.REAL, float) + + def test_double(self): + self._populate_map_via_rc("new java.lang.Double(key)") + self._validate_rows(SqlColumnType.DOUBLE, float) + + def test_date(self): + def value_factory(key): + return datetime.date(key + 2000, key + 1, key + 1) + + self._populate_map_via_rc("java.time.LocalDate.of(key + 2000, key + 1, key + 1)") + self._validate_rows(SqlColumnType.DATE, value_factory) + + def test_time(self): + def value_factory(key): + return datetime.time(key, key, key, key) + + self._populate_map_via_rc("java.time.LocalTime.of(key, key, key, key * 1000)") + self._validate_rows(SqlColumnType.TIME, value_factory) + + def test_timestamp(self): + def value_factory(key): + return datetime.datetime(key + 2000, key + 1, key + 1, key, key, key, key) + + self._populate_map_via_rc( + "java.time.LocalDateTime.of(key + 2000, key + 1, key + 1, key, key, key, key * 1000)" + ) + self._validate_rows(SqlColumnType.TIMESTAMP, value_factory) + + def test_timestamp_with_time_zone(self): + def value_factory(key): + return datetime.datetime( + key + 2000, + key + 1, + key + 1, + key, + key, + key, + key, + timezone(datetime.timedelta(hours=key)), + ) + + self._populate_map_via_rc( + "java.time.OffsetDateTime.of(key + 2000, key + 1, key + 1, key, key, key, key * 1000, " + "java.time.ZoneOffset.ofHours(key))" + ) + self._validate_rows(SqlColumnType.TIMESTAMP_WITH_TIME_ZONE, value_factory) + + def test_decimal(self): + def value_factory(key): + return decimal.Decimal((0, tuple(map(int, str(abs(key)))), -1 * key)) + + self._populate_map_via_rc( + 'new java.math.BigDecimal(new java.math.BigInteger("" + key), key)' + ) + self._validate_rows(SqlColumnType.DECIMAL, value_factory) + + def test_null(self): + self._populate_map() + result = self.client.sql.execute("SELECT __key, NULL AS this FROM %s" % self.map_name) + self._validate_result(result, SqlColumnType.NULL, lambda _: None) + + def test_object(self): + def value_factory(key): + return Student(key, key) + + self._populate_map(value_factory=value_factory) + result = self.client.sql.execute("SELECT __key, this FROM %s" % self.map_name) + self._validate_result(result, SqlColumnType.OBJECT, value_factory) + + def test_null_only_column(self): + self._populate_map() + result = self.client.sql.execute( + "SELECT __key, CAST(NULL AS INTEGER) FROM %s" % self.map_name + ) + self._validate_result(result, SqlColumnType.INTEGER, lambda _: None) + + def _validate_rows(self, expected_type, value_factory=lambda key: key): + result = self.client.sql.execute("SELECT * FROM %s " % self.map_name) + self._validate_result(result, expected_type, value_factory) + + def _validate_result(self, result, expected_type, factory): + for row in result: + key = row.get_object("__key") + expected_value = factory(key) + row_metadata = row.metadata + + self.assertEqual(2, row_metadata.column_count) + column_metadata = row_metadata.get_column(1) + self.assertEqual(expected_type, column_metadata.type) + self.assertEqual(expected_value, row.get_object_with_index(1)) + + def _populate_map_via_rc(self, new_object_literal): + script = """ + var map = instance_0.getMap("%s"); + for (var key = 0; key < 10; key++) { + map.set(new java.lang.Integer(key), %s); + } + """ % ( + self.map_name, + new_object_literal, + ) + + response = self.rc.executeOnController(self.cluster.id, script, Lang.JAVASCRIPT) + self.assertTrue(response.success) + + +LITE_MEMBER_CONFIG = """ + + + +""" + + +class SqlServiceLiteMemberClusterTest(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + return LITE_MEMBER_CONFIG + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + def test_execute(self): + with self.assertRaises(HazelcastSqlError) as cm: + self.client.sql.execute("SOME QUERY") + + # Make sure that exception is originating from the client + self.assertNotEqual(self.member.uuid, str(cm.exception.originating_member_uuid)) + + def test_execute_statement(self): + statement = SqlStatement("SOME QUERY") + with self.assertRaises(HazelcastSqlError) as cm: + self.client.sql.execute_statement(statement) + + # Make sure that exception is originating from the client + self.assertNotEqual(self.member.uuid, str(cm.exception.originating_member_uuid)) + + +class Student(Portable): + def __init__(self, age=None, height=None): + self.age = age + self.height = height + + def write_portable(self, writer): + writer.write_long("age", self.age) + writer.write_float("height", self.height) + + def read_portable(self, reader): + self.age = reader.read_long("age") + self.height = reader.read_float("height") + + def get_factory_id(self): + return 666 + + def get_class_id(self): + return 6 + + def __eq__(self, other): + return isinstance(other, Student) and self.age == other.age and self.height == other.height + + +def random_string(): + return "".join(random.choice(string.ascii_letters) for _ in range(random.randint(3, 20))) diff --git a/tests/unit/load_balancer_test.py b/tests/unit/load_balancer_test.py index 125ab16915..d2dc5ca809 100644 --- a/tests/unit/load_balancer_test.py +++ b/tests/unit/load_balancer_test.py @@ -3,9 +3,27 @@ from hazelcast.util import RandomLB, RoundRobinLB +class _MockMember(object): + def __init__(self, id, lite_member): + self.id = id + self.lite_member = lite_member + + class _MockClusterService(object): - def __init__(self, members): - self._members = members + def __init__(self, data_member_count, lite_member_count): + id = 0 + data_members = [] + lite_members = [] + + for _ in range(data_member_count): + data_members.append(_MockMember(id, False)) + id += 1 + + for _ in range(lite_member_count): + lite_members.append(_MockMember(id, True)) + id += 1 + + self._members = data_members + lite_members def add_listener(self, listener, *_): for m in self._members: @@ -17,27 +35,74 @@ def get_members(self): class LoadBalancersTest(unittest.TestCase): def test_random_lb_with_no_members(self): - cluster = _MockClusterService([]) + cluster = _MockClusterService(0, 0) lb = RandomLB() lb.init(cluster) self.assertIsNone(lb.next()) + self.assertIsNone(lb.next_data_member()) def test_round_robin_lb_with_no_members(self): - cluster = _MockClusterService([]) + cluster = _MockClusterService(0, 0) lb = RoundRobinLB() lb.init(cluster) self.assertIsNone(lb.next()) + self.assertIsNone(lb.next_data_member()) - def test_random_lb_with_members(self): - cluster = _MockClusterService([0, 1, 2]) + def test_random_lb_with_data_members(self): + cluster = _MockClusterService(3, 0) lb = RandomLB() lb.init(cluster) + self.assertTrue(lb.can_get_next_data_member()) for _ in range(10): - self.assertTrue(0 <= lb.next() <= 2) + self.assertTrue(0 <= lb.next().id <= 2) + self.assertTrue(0 <= lb.next_data_member().id <= 2) - def test_round_robin_lb_with_members(self): - cluster = _MockClusterService([0, 1, 2]) + def test_round_robin_lb_with_data_members(self): + cluster = _MockClusterService(5, 0) lb = RoundRobinLB() lb.init(cluster) + self.assertTrue(lb.can_get_next_data_member()) + + for i in range(10): + self.assertEqual(i % 5, lb.next().id) + for i in range(10): - self.assertEqual(i % 3, lb.next()) + self.assertEqual(i % 5, lb.next_data_member().id) + + def test_random_lb_with_lite_members(self): + cluster = _MockClusterService(0, 3) + lb = RandomLB() + lb.init(cluster) + + for _ in range(10): + self.assertTrue(0 <= lb.next().id <= 2) + self.assertIsNone(lb.next_data_member()) + + def test_round_robin_lb_with_lite_members(self): + cluster = _MockClusterService(0, 3) + lb = RoundRobinLB() + lb.init(cluster) + + for i in range(10): + self.assertEqual(i % 3, lb.next().id) + self.assertIsNone(lb.next_data_member()) + + def test_random_lb_with_mixed_members(self): + cluster = _MockClusterService(3, 3) + lb = RandomLB() + lb.init(cluster) + + for _ in range(20): + self.assertTrue(0 <= lb.next().id < 6) + self.assertTrue(0 <= lb.next_data_member().id < 3) + + def test_round_robin_lb_with_mixed_members(self): + cluster = _MockClusterService(3, 3) + lb = RoundRobinLB() + lb.init(cluster) + + for i in range(24): + self.assertEqual(i % 6, lb.next().id) + + for i in range(24): + self.assertEqual(i % 3, lb.next_data_member().id) diff --git a/tests/unit/sql_test.py b/tests/unit/sql_test.py new file mode 100644 index 0000000000..35a607f6a0 --- /dev/null +++ b/tests/unit/sql_test.py @@ -0,0 +1,302 @@ +import itertools +import unittest +import uuid + +from mock import MagicMock + +from hazelcast.protocol.codec import sql_execute_codec, sql_close_codec, sql_fetch_codec +from hazelcast.protocol.client_message import _OUTBOUND_MESSAGE_MESSAGE_TYPE_OFFSET +from hazelcast.serialization import LE_INT +from hazelcast.sql import ( + SqlService, + SqlColumnMetadata, + SqlColumnType, + _SqlPage, + SqlRowMetadata, + _InternalSqlService, + HazelcastSqlError, + _SqlErrorCode, + _SqlError, + SqlStatement, + SqlExpectedResultType, +) + +EXPECTED_ROWS = ["result", "result2"] +EXPECTED_UPDATE_COUNT = 42 + + +class SqlMockTest(unittest.TestCase): + def setUp(self): + + self.connection = MagicMock() + + connection_manager = MagicMock(client_uuid=uuid.uuid4()) + connection_manager.get_random_connection = MagicMock(return_value=self.connection) + + serialization_service = MagicMock() + serialization_service.to_object.side_effect = lambda arg: arg + serialization_service.to_data.side_effect = lambda arg: arg + + self.invocation_registry = {} + correlation_id_counter = itertools.count() + invocation_service = MagicMock() + + def invoke(invocation): + self.invocation_registry[next(correlation_id_counter)] = invocation + + invocation_service.invoke.side_effect = invoke + + self.internal_service = _InternalSqlService( + connection_manager, serialization_service, invocation_service + ) + self.service = SqlService(self.internal_service) + self.result = self.service.execute("SOME QUERY") + + def test_iterator_with_rows(self): + self.set_execute_response_with_rows() + self.assertEqual(-1, self.result.update_count().result()) + self.assertTrue(self.result.is_row_set().result()) + self.assertIsInstance(self.result.get_row_metadata().result(), SqlRowMetadata) + self.assertEqual(EXPECTED_ROWS, self.get_rows_from_iterator()) + + def test_blocking_iterator_with_rows(self): + self.set_execute_response_with_rows() + self.assertEqual(-1, self.result.update_count().result()) + self.assertTrue(self.result.is_row_set().result()) + self.assertIsInstance(self.result.get_row_metadata().result(), SqlRowMetadata) + self.assertEqual(EXPECTED_ROWS, self.get_rows_from_blocking_iterator()) + + def test_iterator_with_update_count(self): + self.set_execute_response_with_update_count() + self.assertEqual(EXPECTED_UPDATE_COUNT, self.result.update_count().result()) + self.assertFalse(self.result.is_row_set().result()) + + with self.assertRaises(ValueError): + self.result.get_row_metadata().result() + + with self.assertRaises(ValueError): + self.result.iterator().result() + + def test_blocking_iterator_with_update_count(self): + self.set_execute_response_with_update_count() + self.assertEqual(EXPECTED_UPDATE_COUNT, self.result.update_count().result()) + self.assertFalse(self.result.is_row_set().result()) + + with self.assertRaises(ValueError): + self.result.get_row_metadata().result() + + with self.assertRaises(ValueError): + for _ in self.result: + pass + + def test_execute_error(self): + self.set_execute_error(RuntimeError("expected")) + with self.assertRaises(HazelcastSqlError) as cm: + iter(self.result) + + self.assertEqual(_SqlErrorCode.GENERIC, cm.exception._code) + + def test_execute_error_when_connection_is_not_live(self): + self.connection.live = False + self.set_execute_error(RuntimeError("expected")) + with self.assertRaises(HazelcastSqlError) as cm: + iter(self.result) + + self.assertEqual(_SqlErrorCode.CONNECTION_PROBLEM, cm.exception._code) + + def test_close_when_execute_is_not_done(self): + future = self.result.close() + self.set_close_response() + self.assertIsNone(future.result()) + with self.assertRaises(HazelcastSqlError) as cm: + iter(self.result) + + self.assertEqual(_SqlErrorCode.CANCELLED_BY_USER, cm.exception._code) + + def test_close_when_close_request_fails(self): + future = self.result.close() + self.set_close_error(HazelcastSqlError(None, _SqlErrorCode.MAP_DESTROYED, "expected", None)) + + with self.assertRaises(HazelcastSqlError) as cm: + future.result() + + self.assertEqual(_SqlErrorCode.MAP_DESTROYED, cm.exception._code) + + def test_fetch_error(self): + self.set_execute_response_with_rows(is_last=False) + result = [] + i = self.result.iterator().result() + # First page contains two rows + result.append(next(i).result().get_object_with_index(0)) + result.append(next(i).result().get_object_with_index(0)) + + self.assertEqual(EXPECTED_ROWS, result) + + # initiate the fetch request + future = next(i) + + self.set_fetch_error(RuntimeError("expected")) + + with self.assertRaises(HazelcastSqlError) as cm: + future.result() + + self.assertEqual(_SqlErrorCode.GENERIC, cm.exception._code) + + def test_fetch_server_error(self): + self.set_execute_response_with_rows(is_last=False) + result = [] + i = self.result.iterator().result() + # First page contains two rows + result.append(next(i).result().get_object_with_index(0)) + result.append(next(i).result().get_object_with_index(0)) + + self.assertEqual(EXPECTED_ROWS, result) + + # initiate the fetch request + future = next(i) + + self.set_fetch_response_with_error() + + with self.assertRaises(HazelcastSqlError) as cm: + future.result() + + self.assertEqual(_SqlErrorCode.PARSING, cm.exception._code) + + def test_close_in_between_fetches(self): + self.set_execute_response_with_rows(is_last=False) + result = [] + i = self.result.iterator().result() + # First page contains two rows + result.append(next(i).result().get_object_with_index(0)) + result.append(next(i).result().get_object_with_index(0)) + + self.assertEqual(EXPECTED_ROWS, result) + + # initiate the fetch request + future = next(i) + + self.result.close() + + with self.assertRaises(HazelcastSqlError) as cm: + future.result() + + self.assertEqual(_SqlErrorCode.CANCELLED_BY_USER, cm.exception._code) + + def set_fetch_response_with_error(self): + response = {"row_page": None, "error": _SqlError(_SqlErrorCode.PARSING, "expected", None)} + self.set_future_result_or_exception(response, sql_fetch_codec._REQUEST_MESSAGE_TYPE) + + def set_fetch_error(self, error): + self.set_future_result_or_exception(error, sql_fetch_codec._REQUEST_MESSAGE_TYPE) + + def set_close_error(self, error): + self.set_future_result_or_exception(error, sql_close_codec._REQUEST_MESSAGE_TYPE) + + def set_close_response(self): + self.set_future_result_or_exception(None, sql_close_codec._REQUEST_MESSAGE_TYPE) + + def set_execute_response_with_update_count(self): + self.set_execute_response(EXPECTED_UPDATE_COUNT, None, None, None) + + def get_rows_from_blocking_iterator(self): + return [row.get_object_with_index(0) for row in self.result] + + def get_rows_from_iterator(self): + result = [] + for row_future in self.result.iterator().result(): + try: + row = row_future.result() + result.append(row.get_object_with_index(0)) + except StopIteration: + break + return result + + def set_execute_response_with_rows(self, is_last=True): + self.set_execute_response( + -1, + [SqlColumnMetadata("name", SqlColumnType.VARCHAR, True, True)], + _SqlPage([SqlColumnType.VARCHAR], [EXPECTED_ROWS], is_last), + None, + ) + + def set_execute_response(self, update_count, row_metadata, row_page, error): + response = { + "update_count": update_count, + "row_metadata": row_metadata, + "row_page": row_page, + "error": error, + } + + self.set_future_result_or_exception(response, sql_execute_codec._REQUEST_MESSAGE_TYPE) + + def set_execute_error(self, error): + self.set_future_result_or_exception(error, sql_execute_codec._REQUEST_MESSAGE_TYPE) + + def get_message_type(self, invocation): + return LE_INT.unpack_from(invocation.request.buf, _OUTBOUND_MESSAGE_MESSAGE_TYPE_OFFSET)[0] + + def set_future_result_or_exception(self, value, message_type): + for invocation in self.invocation_registry.values(): + if self.get_message_type(invocation) == message_type: + if isinstance(value, Exception): + invocation.future.set_exception(value) + else: + invocation.future.set_result(value) + + +class SqlInvalidInputTest(unittest.TestCase): + def test_statement_sql(self): + valid_inputs = ["a", " a", " a "] + + for valid in valid_inputs: + statement = SqlStatement(valid) + self.assertEqual(valid, statement.sql) + + invalid_inputs = ["", " ", None] + + for invalid in invalid_inputs: + expected_error = AssertionError if invalid is None else ValueError + with self.assertRaises(expected_error): + SqlStatement(invalid) + + def test_statement_timeout(self): + valid_inputs = [-1, 0, 15] + + for valid in valid_inputs: + statement = SqlStatement("sql") + statement.timeout = valid + self.assertEqual(valid, statement.timeout) + + invalid_inputs = [-10, -100] + + for invalid in invalid_inputs: + statement = SqlStatement("sql") + with self.assertRaises(ValueError): + statement.timeout = invalid + + def test_statement_cursor_buffer_size(self): + valid_inputs = [1, 10, 999999] + + for valid in valid_inputs: + statement = SqlStatement("something") + statement.cursor_buffer_size = valid + self.assertEqual(valid, statement.cursor_buffer_size) + + invalid_inputs = [0, -10, -99999] + + for invalid in invalid_inputs: + statement = SqlStatement("something") + with self.assertRaises(ValueError): + statement.cursor_buffer_size = invalid + + def test_statement_expected_result_type(self): + valid_inputs = [SqlExpectedResultType.ROWS, SqlExpectedResultType.UPDATE_COUNT] + + for valid in valid_inputs: + statement = SqlStatement("something") + statement.expected_result_type = valid + self.assertEqual(valid, statement.expected_result_type) + + with self.assertRaises(AssertionError): + statement = SqlStatement("something") + statement.expected_result_type = None From 7b8e970cfd2d4522c69683ad0d3ea40846cd9844 Mon Sep 17 00:00:00 2001 From: mdumandag Date: Thu, 6 May 2021 18:43:03 +0300 Subject: [PATCH 2/9] delete reserved codecs --- .../codec/sql_execute_reserved_codec.py | 43 ------------------- .../codec/sql_fetch_reserved_codec.py | 33 -------------- 2 files changed, 76 deletions(-) delete mode 100644 hazelcast/protocol/codec/sql_execute_reserved_codec.py delete mode 100644 hazelcast/protocol/codec/sql_fetch_reserved_codec.py diff --git a/hazelcast/protocol/codec/sql_execute_reserved_codec.py b/hazelcast/protocol/codec/sql_execute_reserved_codec.py deleted file mode 100644 index ff10cccc0e..0000000000 --- a/hazelcast/protocol/codec/sql_execute_reserved_codec.py +++ /dev/null @@ -1,43 +0,0 @@ -from hazelcast.serialization.bits import * -from hazelcast.protocol.builtin import FixSizedTypesCodec -from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer, RESPONSE_HEADER_SIZE -from hazelcast.protocol.builtin import StringCodec -from hazelcast.protocol.builtin import ListMultiFrameCodec -from hazelcast.protocol.builtin import DataCodec -from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec -from hazelcast.protocol.builtin import CodecUtil -from hazelcast.protocol.codec.custom.sql_column_metadata_codec import SqlColumnMetadataCodec -from hazelcast.protocol.builtin import ListCNDataCodec -from hazelcast.protocol.codec.custom.sql_error_codec import SqlErrorCodec - -# hex: 0x210100 -_REQUEST_MESSAGE_TYPE = 2162944 -# hex: 0x210101 -_RESPONSE_MESSAGE_TYPE = 2162945 - -_REQUEST_TIMEOUT_MILLIS_OFFSET = REQUEST_HEADER_SIZE -_REQUEST_CURSOR_BUFFER_SIZE_OFFSET = _REQUEST_TIMEOUT_MILLIS_OFFSET + LONG_SIZE_IN_BYTES -_REQUEST_INITIAL_FRAME_SIZE = _REQUEST_CURSOR_BUFFER_SIZE_OFFSET + INT_SIZE_IN_BYTES -_RESPONSE_ROW_PAGE_LAST_OFFSET = RESPONSE_HEADER_SIZE -_RESPONSE_UPDATE_COUNT_OFFSET = _RESPONSE_ROW_PAGE_LAST_OFFSET + BOOLEAN_SIZE_IN_BYTES - - -def encode_request(sql, parameters, timeout_millis, cursor_buffer_size): - buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) - FixSizedTypesCodec.encode_long(buf, _REQUEST_TIMEOUT_MILLIS_OFFSET, timeout_millis) - FixSizedTypesCodec.encode_int(buf, _REQUEST_CURSOR_BUFFER_SIZE_OFFSET, cursor_buffer_size) - StringCodec.encode(buf, sql) - ListMultiFrameCodec.encode(buf, parameters, DataCodec.encode, True) - return OutboundMessage(buf, False) - - -def decode_response(msg): - initial_frame = msg.next_frame() - response = dict() - response["row_page_last"] = FixSizedTypesCodec.decode_boolean(initial_frame.buf, _RESPONSE_ROW_PAGE_LAST_OFFSET) - response["update_count"] = FixSizedTypesCodec.decode_long(initial_frame.buf, _RESPONSE_UPDATE_COUNT_OFFSET) - response["query_id"] = CodecUtil.decode_nullable(msg, SqlQueryIdCodec.decode) - response["row_metadata"] = ListMultiFrameCodec.decode_nullable(msg, SqlColumnMetadataCodec.decode) - response["row_page"] = ListMultiFrameCodec.decode_nullable(msg, ListCNDataCodec.decode) - response["error"] = CodecUtil.decode_nullable(msg, SqlErrorCodec.decode) - return response diff --git a/hazelcast/protocol/codec/sql_fetch_reserved_codec.py b/hazelcast/protocol/codec/sql_fetch_reserved_codec.py deleted file mode 100644 index 1089ead68b..0000000000 --- a/hazelcast/protocol/codec/sql_fetch_reserved_codec.py +++ /dev/null @@ -1,33 +0,0 @@ -from hazelcast.serialization.bits import * -from hazelcast.protocol.builtin import FixSizedTypesCodec -from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer, RESPONSE_HEADER_SIZE -from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec -from hazelcast.protocol.builtin import ListMultiFrameCodec -from hazelcast.protocol.builtin import ListCNDataCodec -from hazelcast.protocol.codec.custom.sql_error_codec import SqlErrorCodec -from hazelcast.protocol.builtin import CodecUtil - -# hex: 0x210200 -_REQUEST_MESSAGE_TYPE = 2163200 -# hex: 0x210201 -_RESPONSE_MESSAGE_TYPE = 2163201 - -_REQUEST_CURSOR_BUFFER_SIZE_OFFSET = REQUEST_HEADER_SIZE -_REQUEST_INITIAL_FRAME_SIZE = _REQUEST_CURSOR_BUFFER_SIZE_OFFSET + INT_SIZE_IN_BYTES -_RESPONSE_ROW_PAGE_LAST_OFFSET = RESPONSE_HEADER_SIZE - - -def encode_request(query_id, cursor_buffer_size): - buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) - FixSizedTypesCodec.encode_int(buf, _REQUEST_CURSOR_BUFFER_SIZE_OFFSET, cursor_buffer_size) - SqlQueryIdCodec.encode(buf, query_id, True) - return OutboundMessage(buf, False) - - -def decode_response(msg): - initial_frame = msg.next_frame() - response = dict() - response["row_page_last"] = FixSizedTypesCodec.decode_boolean(initial_frame.buf, _RESPONSE_ROW_PAGE_LAST_OFFSET) - response["row_page"] = ListMultiFrameCodec.decode_nullable(msg, ListCNDataCodec.decode) - response["error"] = CodecUtil.decode_nullable(msg, SqlErrorCodec.decode) - return response From 2349a14dac62060b77091623690ee681f991ff15 Mon Sep 17 00:00:00 2001 From: mdumandag Date: Fri, 7 May 2021 12:48:24 +0300 Subject: [PATCH 3/9] add with statement support --- hazelcast/sql.py | 25 ++++++++++++++++--- .../backward_compatible/sql_test.py | 20 +++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/hazelcast/sql.py b/hazelcast/sql.py index efab4a7488..54810e6e06 100644 --- a/hazelcast/sql.py +++ b/hazelcast/sql.py @@ -837,6 +837,16 @@ def on_next_row(row_future): It might also be used to cancel query execution on the server side if it is still active. + When the blocking API is used, one might also use it with ``with`` + statement to automatically close the query even if an exception + is thrown in the iteration. :: + + with client.sql.execute("SELECT ...") as result: + for row in result: + # Process the row. + print(row) + + To get the update count, use the :func:`update_count`. :: update_count = client.sql.execute("SELECT ...").update_count().result() @@ -880,8 +890,7 @@ def __init__(self, sql_service, connection, query_id, cursor_buffer_size, execut execute_future.add_done_callback(self._handle_execute_response) def iterator(self): - """ - Returns the iterator over the result rows. + """Returns the iterator over the result rows. The iterator may be requested only once. @@ -1219,6 +1228,16 @@ def _mark_closed(self): with self._lock: self._closed = True + def __enter__(self): + # The execute request is already sent. + # There is nothing more to do. + return self + + def __exit__(self, *_): + # Ignoring the possible exception details + # since we close the query regardless of that. + self.close().result() + class _InternalSqlService(object): """Internal SQL service that offers more public API @@ -1555,7 +1574,7 @@ def copy(self): """Creates a copy of this instance. Returns: - SqlStatement: + SqlStatement: The new copy. """ copied = SqlStatement(self.sql) copied.parameters = list(self.parameters) diff --git a/tests/integration/backward_compatible/sql_test.py b/tests/integration/backward_compatible/sql_test.py index ac8a8d9b98..5ed10bddb3 100644 --- a/tests/integration/backward_compatible/sql_test.py +++ b/tests/integration/backward_compatible/sql_test.py @@ -4,6 +4,7 @@ import string from hazelcast import six +from hazelcast.future import ImmediateFuture from hazelcast.serialization.api import Portable from hazelcast.sql import HazelcastSqlError, SqlStatement, SqlExpectedResultType, SqlColumnType from hazelcast.util import timezone @@ -366,6 +367,25 @@ def test_close_when_query_is_active(self): # Next fetch requests should fail next(iterator) + def test_with_statement(self): + self._populate_map() + with self.client.sql.execute("SELECT this FROM %s" % self.map_name) as result: + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object_with_index(0) for row in result] + ) + + def test_with_statement_when_iteration_throws(self): + self._populate_map() + statement = SqlStatement("SELECT this FROM %s" % self.map_name) + statement.cursor_buffer_size = 1 # so that it doesn't close immediately + + with self.assertRaises(RuntimeError): + with self.client.sql.execute_statement(statement) as result: + for _ in result: + raise RuntimeError("expected") + + self.assertIsInstance(result.close(), ImmediateFuture) + class SqlColumnTypesReadTest(SqlTestBase): def test_varchar(self): From 3685e741f511fddd2a993e15b9df57f294e4edbe Mon Sep 17 00:00:00 2001 From: mdumandag Date: Tue, 1 Jun 2021 16:57:50 +0300 Subject: [PATCH 4/9] address review comments --- hazelcast/connection.py | 9 ++-- hazelcast/protocol/builtin.py | 23 +++++----- hazelcast/serialization/serializer.py | 6 +-- hazelcast/sql.py | 63 +++++++++++---------------- hazelcast/util.py | 2 +- 5 files changed, 45 insertions(+), 58 deletions(-) diff --git a/hazelcast/connection.py b/hazelcast/connection.py index 6b7ed2ab6b..8119c7379b 100644 --- a/hazelcast/connection.py +++ b/hazelcast/connection.py @@ -152,8 +152,10 @@ def get_random_connection(self, should_get_data_member=False): if connection: return connection - # We should not get to this point under normal circumstances. - # Therefore, copying the list should be OK. + # We should not get to this point under normal circumstances + # for the smart client. For uni-socket client, there would be + # a single connection in the dict. Therefore, copying the list + # should be acceptable. for member_uuid, connection in list(six.iteritems(self.active_connections)): if should_get_data_member: member = self._cluster_service.get_member(member_uuid) @@ -261,11 +263,10 @@ def check_invocation_allowed(self): def _get_connection_from_load_balancer(self, should_get_data_member): load_balancer = self._load_balancer + member = None if should_get_data_member: if load_balancer.can_get_next_data_member(): member = load_balancer.next_data_member() - else: - member = None else: member = load_balancer.next() diff --git a/hazelcast/protocol/builtin.py b/hazelcast/protocol/builtin.py index 856f5978ab..4effd14c5d 100644 --- a/hazelcast/protocol/builtin.py +++ b/hazelcast/protocol/builtin.py @@ -571,36 +571,33 @@ def decode(msg, item_size, decoder): type = FixSizedTypesCodec.decode_byte(frame.buf, 0) count = FixSizedTypesCodec.decode_int(frame.buf, 1) - response = [] if type == ListCNFixedSizeCodec._TYPE_NULL_ONLY: - for _ in range(count): - response.append(None) + return [None] * count elif type == ListCNFixedSizeCodec._TYPE_NOT_NULL_ONLY: - for i in range(count): - response.append( - decoder(frame.buf, ListCNFixedSizeCodec._HEADER_SIZE + i * item_size) - ) + header_size = ListCNFixedSizeCodec._HEADER_SIZE + return [decoder(frame.buf, header_size + i * item_size) for i in range(count)] else: + response = [] position = ListCNFixedSizeCodec._HEADER_SIZE read_count = 0 + items_per_bitmask = ListCNFixedSizeCodec._ITEMS_PER_BITMASK while read_count < count: bitmask = FixSizedTypesCodec.decode_byte(frame.buf, position) position += 1 - i = 0 - while i < ListCNFixedSizeCodec._ITEMS_PER_BITMASK and read_count < count: + batch_size = min(items_per_bitmask, count - read_count) + for i in range(batch_size): mask = 1 << i if (bitmask & mask) == mask: response.append(decoder(frame.buf, position)) position += item_size else: response.append(None) - read_count += 1 - i += 1 + read_count += batch_size - return response + return response class ListCNBooleanCodec(object): @@ -693,7 +690,7 @@ def decode(msg): unscaled_value = int_from_bytes(buf[INT_SIZE_IN_BYTES : INT_SIZE_IN_BYTES + size]) scale = FixSizedTypesCodec.decode_int(buf, INT_SIZE_IN_BYTES + size) sign = 0 if unscaled_value >= 0 else 1 - return Decimal((sign, tuple(map(int, str(abs(unscaled_value)))), -1 * scale)) + return Decimal((sign, tuple(int(digit) for digit in str(abs(unscaled_value))), -1 * scale)) class SqlPageCodec(object): diff --git a/hazelcast/serialization/serializer.py b/hazelcast/serialization/serializer.py index a2d35ee204..730504045d 100644 --- a/hazelcast/serialization/serializer.py +++ b/hazelcast/serialization/serializer.py @@ -9,7 +9,7 @@ from hazelcast.serialization.base import HazelcastSerializationError from hazelcast.serialization.serialization_const import * from hazelcast.six.moves import range, cPickle -from hazelcast.util import UuidUtil +from hazelcast.util import UUIDUtil if not six.PY2: long = int @@ -133,10 +133,10 @@ class UuidSerializer(BaseSerializer): def read(self, inp): msb = inp.read_long() lsb = inp.read_long() - return UuidUtil.from_bits(msb, lsb) + return UUIDUtil.from_bits(msb, lsb) def write(self, out, obj): - msb, lsb = UuidUtil.to_bits(obj) + msb, lsb = UUIDUtil.to_bits(obj) out.write_long(msb) out.write_long(lsb) diff --git a/hazelcast/sql.py b/hazelcast/sql.py index 54810e6e06..5bb188de85 100644 --- a/hazelcast/sql.py +++ b/hazelcast/sql.py @@ -6,7 +6,7 @@ from hazelcast.future import Future, ImmediateFuture from hazelcast.invocation import Invocation from hazelcast.util import ( - UuidUtil, + UUIDUtil, check_not_none, to_millis, check_true, @@ -207,8 +207,8 @@ def from_uuid(cls, member_uuid): """ local_id = uuid.uuid4() - member_msb, member_lsb = UuidUtil.to_bits(member_uuid) - local_msb, local_lsb = UuidUtil.to_bits(local_id) + member_msb, member_lsb = UUIDUtil.to_bits(member_uuid) + local_msb, local_lsb = UUIDUtil.to_bits(local_id) return cls(member_msb, member_lsb, local_msb, local_lsb) @@ -496,10 +496,8 @@ def find_column(self, column_name): def __repr__(self): return "[%s]" % ", ".join( - map( - lambda column: "%s %s" % (column.name, get_attr_name(SqlColumnType, column.type)), - self._columns, - ) + "%s %s" % (column.name, get_attr_name(SqlColumnType, column.type)) + for column in self._columns ) @@ -571,16 +569,14 @@ def metadata(self): return self._row_metadata def __repr__(self): - def mapping(column_index): - metadata = self._row_metadata.get_column(column_index) - value = self._row[column_index] - return "%s %s=%s" % (metadata.name, get_attr_name(SqlColumnType, metadata.type), value) - return "[%s]" % ", ".join( - map( - mapping, - range(self._row_metadata.column_count), + "%s %s=%s" + % ( + self._row_metadata.get_column(i).name, + get_attr_name(SqlColumnType, self._row_metadata.get_column(i).type), + self._row[i], ) + for i in range(self._row_metadata.column_count) ) @@ -658,15 +654,13 @@ def _get_current_row(self): Returns: list: The row pointed by the current position. """ - values = [] - for i in range(self.page.column_count): - value = self.page.get_column_value(i, self.position) - - # The column might contain user objects so we have to deserialize it. - # This call is no-op if the value is not Data. - values.append(self.deserialize_fn(value)) - return values + # The column might contain user objects so we have to deserialize it. + # Deserialization is no-op if the value is not Data. + return [ + self.deserialize_fn(self.page.get_column_value(i, self.position)) + for i in range(self.page.column_count) + ] class _FutureProducingIterator(_IteratorBase): @@ -1126,7 +1120,7 @@ def _on_fetch_response(self, page): if page.is_last: # This is the last page, there is nothing # more on the server. - self._mark_closed() + self._closed = True self._fetch_future = None @@ -1212,7 +1206,7 @@ def _on_execute_response(self, row_metadata, row_page, update_count): if row_page.is_last: # This is the last page, close the result. - self._mark_closed() + self._closed = True self._execute_response.set_result(response) else: @@ -1221,19 +1215,14 @@ def _on_execute_response(self, row_metadata, row_page, update_count): self._execute_response.set_result(response) # There is nothing more we can get from the server. - self._mark_closed() - - def _mark_closed(self): - """Marks the result as closed.""" - with self._lock: - self._closed = True + self._closed = True def __enter__(self): # The execute request is already sent. # There is nothing more to do. return self - def __exit__(self, *_): + def __exit__(self, exc_type, exc_value, traceback): # Ignoring the possible exception details # since we close the query regardless of that. self.close().result() @@ -1292,9 +1281,9 @@ def execute_statement(self, statement): query_id = _SqlQueryId.from_uuid(connection.remote_uuid) # Serialize the passed parameters. - serialized_params = [] - for param in statement.parameters: - serialized_params.append(self._serialization_service.to_data(param)) + serialized_params = [ + self._serialization_service.to_data(param) for param in statement.parameters + ] request = sql_execute_codec.encode_request( statement.sql, @@ -1368,8 +1357,8 @@ def re_raise(self, error, connection): return HazelcastSqlError( self.get_client_id(), _SqlErrorCode.CONNECTION_PROBLEM, - "Cluster topology changed while a query was executed: Member cannot be reached: " - + connection.remote_address, + "Cluster topology changed while a query was executed: Member cannot be reached: %s" + % connection.remote_address, error, ) diff --git a/hazelcast/util.py b/hazelcast/util.py index a149341f37..4f7fd6a556 100644 --- a/hazelcast/util.py +++ b/hazelcast/util.py @@ -384,7 +384,7 @@ class IterationType(object): """Iterate over entries""" -class UuidUtil(object): +class UUIDUtil(object): @staticmethod def to_bits(value): i = value.int From 372cf6c176cb230982f4e64c9b5f8e637a59e4e3 Mon Sep 17 00:00:00 2001 From: mdumandag Date: Tue, 8 Jun 2021 13:37:01 +0300 Subject: [PATCH 5/9] adress review comments --- hazelcast/protocol/builtin.py | 49 +++++++++++++-------------- hazelcast/sql.py | 63 ++++++++++++++++++----------------- hazelcast/util.py | 8 ++--- 3 files changed, 59 insertions(+), 61 deletions(-) diff --git a/hazelcast/protocol/builtin.py b/hazelcast/protocol/builtin.py index 4effd14c5d..34550db8fa 100644 --- a/hazelcast/protocol/builtin.py +++ b/hazelcast/protocol/builtin.py @@ -577,7 +577,7 @@ def decode(msg, item_size, decoder): header_size = ListCNFixedSizeCodec._HEADER_SIZE return [decoder(frame.buf, header_size + i * item_size) for i in range(count)] else: - response = [] + response = [None] * count position = ListCNFixedSizeCodec._HEADER_SIZE read_count = 0 items_per_bitmask = ListCNFixedSizeCodec._ITEMS_PER_BITMASK @@ -590,12 +590,10 @@ def decode(msg, item_size, decoder): for i in range(batch_size): mask = 1 << i if (bitmask & mask) == mask: - response.append(decoder(frame.buf, position)) + response[read_count] = decoder(frame.buf, position) position += item_size - else: - response.append(None) - read_count += batch_size + read_count += 1 return response @@ -706,48 +704,49 @@ def decode(msg): # read column types column_type_ids = ListIntegerCodec.decode(msg) + column_count = len(column_type_ids) # read columns - columns = [] + columns = [None] * column_count + + for i in range(column_count): + column_type_id = column_type_ids[i] - for column_type_id in column_type_ids: if column_type_id == SqlColumnType.VARCHAR: - columns.append( - ListMultiFrameCodec.decode_contains_nullable(msg, StringCodec.decode) - ) + columns[i] = ListMultiFrameCodec.decode_contains_nullable(msg, StringCodec.decode) elif column_type_id == SqlColumnType.BOOLEAN: - columns.append(ListCNBooleanCodec.decode(msg)) + columns[i] = ListCNBooleanCodec.decode(msg) elif column_type_id == SqlColumnType.TINYINT: - columns.append(ListCNByteCodec.decode(msg)) + columns[i] = ListCNByteCodec.decode(msg) elif column_type_id == SqlColumnType.SMALLINT: - columns.append(ListCNShortCodec.decode(msg)) + columns[i] = ListCNShortCodec.decode(msg) elif column_type_id == SqlColumnType.INTEGER: - columns.append(ListCNIntegerCodec.decode(msg)) + columns[i] = ListCNIntegerCodec.decode(msg) elif column_type_id == SqlColumnType.BIGINT: - columns.append(ListCNLongCodec.decode(msg)) + columns[i] = ListCNLongCodec.decode(msg) elif column_type_id == SqlColumnType.REAL: - columns.append(ListCNFloatCodec.decode(msg)) + columns[i] = ListCNFloatCodec.decode(msg) elif column_type_id == SqlColumnType.DOUBLE: - columns.append(ListCNDoubleCodec.decode(msg)) + columns[i] = ListCNDoubleCodec.decode(msg) elif column_type_id == SqlColumnType.DATE: - columns.append(ListCNLocalDateCodec.decode(msg)) + columns[i] = ListCNLocalDateCodec.decode(msg) elif column_type_id == SqlColumnType.TIME: - columns.append(ListCNLocalTimeCodec.decode(msg)) + columns[i] = ListCNLocalTimeCodec.decode(msg) elif column_type_id == SqlColumnType.TIMESTAMP: - columns.append(ListCNLocalDateTimeCodec.decode(msg)) + columns[i] = ListCNLocalDateTimeCodec.decode(msg) elif column_type_id == SqlColumnType.TIMESTAMP_WITH_TIME_ZONE: - columns.append(ListCNOffsetDateTimeCodec.decode(msg)) + columns[i] = ListCNOffsetDateTimeCodec.decode(msg) elif column_type_id == SqlColumnType.DECIMAL: - columns.append( - ListMultiFrameCodec.decode_contains_nullable(msg, BigDecimalCodec.decode) + columns[i] = ListMultiFrameCodec.decode_contains_nullable( + msg, BigDecimalCodec.decode ) elif column_type_id == SqlColumnType.NULL: frame = msg.next_frame() size = FixSizedTypesCodec.decode_int(frame.buf, 0) column = [None for _ in range(size)] - columns.append(column) + columns[i] = column elif column_type_id == SqlColumnType.OBJECT: - columns.append(ListMultiFrameCodec.decode_contains_nullable(msg, DataCodec.decode)) + columns[i] = ListMultiFrameCodec.decode_contains_nullable(msg, DataCodec.decode) else: raise ValueError("Unknown type %s" % column_type_id) diff --git a/hazelcast/sql.py b/hazelcast/sql.py index 5bb188de85..058f7cdbe5 100644 --- a/hazelcast/sql.py +++ b/hazelcast/sql.py @@ -20,7 +20,7 @@ class SqlService(object): """A service to execute SQL statements. - The service allows you to query data stored in + The service allows you to query data stored in a :class:`Map `. Warnings: @@ -50,13 +50,13 @@ class SqlService(object): - For non-Portable objects, public getters and fields are used to populate the column list. For getters, the first letter is converted to lower case. A getter takes precedence over a field in case of naming - conflict + conflict. - For :class:`Portable ` objects, field names used in the :func:`write_portable() ` - method are used to populate the column list + method are used to populate the column list. - The whole key and value objects could be accessed through a special fields + The whole key and value objects could be accessed through special fields ``__key`` and ``this``, respectively. If key (value) object has fields, then the whole key (value) field is exposed as a normal field. Otherwise the field is hidden. Hidden fields can be accessed directly, but are not returned @@ -86,11 +86,11 @@ def write_portable(self, writer): This model will be resolved to the following table columns: - - ``person_id`` ``BIGINT`` - - ``department_id`` ``BIGINT`` - - ``name`` ``VARCHAR`` - - ``__key`` ``OBJECT`` (hidden) - - ``this`` ``OBJECT`` (hidden) + - person_id ``BIGINT`` + - department_id ``BIGINT`` + - name ``VARCHAR`` + - __key ``OBJECT`` (hidden) + - this ``OBJECT`` (hidden) **Consistency** @@ -120,7 +120,7 @@ def write_portable(self, writer): See the documentation of the :class:`SqlResult` for more information about - the different type of iteration methods. + different iteration methods. Notes: @@ -214,7 +214,7 @@ def from_uuid(cls, member_uuid): class SqlColumnMetadata(object): - """Metadata for one of the columns of the returned rows.""" + """Metadata of a column in an SQL row.""" __slots__ = ("_name", "_type", "_nullable") @@ -235,8 +235,8 @@ def type(self): @property def nullable(self): - """bool: ``True`` if the rows in this column might be ``None``, - ``False`` otherwise. + """bool: ``True`` if this column values can be ``None``, ``False`` + otherwise. """ return self._nullable @@ -287,7 +287,7 @@ def is_last(self): """bool: Whether this is the last page or not.""" return self._is_last - def get_column_value(self, column_index, row_index): + def get_value(self, column_index, row_index): """ Args: column_index (int): @@ -468,7 +468,7 @@ def columns(self): @property def column_count(self): - """int: Number of column in the row.""" + """int: Number of columns in the row.""" return len(self._columns) def get_column(self, index): @@ -502,7 +502,7 @@ def __repr__(self): class SqlRow(object): - """One of the rows of the SQL query result.""" + """One of the rows of an SQL query result.""" __slots__ = ("_row_metadata", "_row") @@ -511,13 +511,13 @@ def __init__(self, row_metadata, row): self._row = row def get_object(self, column_name): - """Gets the value of the column by column name. + """Gets the value in the column indicated by the column name. Column name should be one of those defined in :class:`SqlRowMetadata`, case-sensitive. You may also use :func:`SqlRowMetadata.find_column` to test for column existence. - The class of the returned value depends on the SQL type of the column. + The type of the returned value depends on the SQL type of the column. No implicit conversions are performed on the value. Args: @@ -596,7 +596,7 @@ def __init__(self, row_metadata, row_page, update_count): """ self.update_count = update_count - """int: Update count or -1 if row metadata or row page exist.""" + """int: Update count or -1 if the result is a rowset.""" class _IteratorBase(object): @@ -658,7 +658,7 @@ def _get_current_row(self): # The column might contain user objects so we have to deserialize it. # Deserialization is no-op if the value is not Data. return [ - self.deserialize_fn(self.page.get_column_value(i, self.position)) + self.deserialize_fn(self.page.get_value(i, self.position)) for i in range(self.page.column_count) ] @@ -695,7 +695,7 @@ def _has_next_continuation(self, future): if not has_next: # Iterator is exhausted, raise this to inform the user. # If the user continues to call next, we will continuously - # will raise this. + # raise this. raise StopIteration row = self._get_current_row() @@ -826,12 +826,12 @@ def on_next_row(row_future): When in doubt, use the blocking API shown in the first code sample. - Also, one might call :func:`close` over the result object to + One can call :func:`close` method of a result object to release the resources associated with the result on the server side. It might also be used to cancel query execution on the server side if it is still active. - When the blocking API is used, one might also use it with ``with`` + When the blocking API is used, one might also use ``with`` statement to automatically close the query even if an exception is thrown in the iteration. :: @@ -841,11 +841,13 @@ def on_next_row(row_future): print(row) - To get the update count, use the :func:`update_count`. :: + To get the number of rows updated by the query, use the + :func:`update_count`. :: update_count = client.sql.execute("SELECT ...").update_count().result() - One does not have to call :func:`close` in this case. + One does not have to call :func:`close` in this case, because the result + will already be closed in the server-side. """ def __init__(self, sql_service, connection, query_id, cursor_buffer_size, execute_future): @@ -857,7 +859,7 @@ def __init__(self, sql_service, connection, query_id, cursor_buffer_size, execut that the execute request is made to.""" self._query_id = query_id - """_SqlQueryId: Uniuqe id of the SQL query.""" + """_SqlQueryId: Unique id of the SQL query.""" self._cursor_buffer_size = cursor_buffer_size """int: Size of the cursor buffer measured in the number of rows.""" @@ -1154,7 +1156,8 @@ def _handle_execute_response(self, future): # the server, invocation failed. self._on_execute_error(self._sql_service.re_raise(e, self._connection)) - def _handle_response_error(self, error): + @staticmethod + def _handle_response_error(error): """If the error is not ``None``, return it as :class:`HazelcastSqlError` so that we can raise it to user. @@ -1185,7 +1188,7 @@ def _on_execute_error(self, error): self._execute_response.set_exception(error) def _on_execute_response(self, row_metadata, row_page, update_count): - """Called when the first execute request is succeed. + """Called when the first execute request is succeeded. Args: row_metadata (SqlRowMetadata): The row metadata. Might be ``None`` @@ -1344,14 +1347,14 @@ def re_raise(self, error, connection): so that it can be raised to the user. Args: - error (Exception): The error to re raise. + error (Exception): The error to reraise. connection (hazelcast.connection.Connection): Connection that the query requests are routed to. If it is not live, we will inform the user about the possible cluster topology change. Returns: - HazelcastSqlError: The re raised error. + HazelcastSqlError: The reraised error. """ if not connection.live: return HazelcastSqlError( diff --git a/hazelcast/util.py b/hazelcast/util.py index 4f7fd6a556..4595ceab6c 100644 --- a/hazelcast/util.py +++ b/hazelcast/util.py @@ -325,11 +325,7 @@ def can_get_next_data_member(self): def _listener(self, _): members = self._cluster_service.get_members() - data_members = [] - - for member in members: - if not member.lite_member: - data_members.append(member) + data_members = [member for member in members if not member.lite_member] self._members = _Members(members, data_members) @@ -415,7 +411,7 @@ def int_from_bytes(buffer): if buffer[0] & 0x80: neg = bytearray() for c in buffer: - neg.append(c ^ 0xFF) + neg.append(~c) return -1 * int(binascii.hexlify(neg), 16) - 1 return int(binascii.hexlify(buffer), 16) From a56d876563567fd18d3430b6d522708814c2f100 Mon Sep 17 00:00:00 2001 From: mdumandag Date: Thu, 10 Jun 2021 15:40:37 +0300 Subject: [PATCH 6/9] address more review comments --- hazelcast/protocol/builtin.py | 28 +++++++-- hazelcast/sql.py | 24 ++++++-- hazelcast/util.py | 8 +-- .../backward_compatible/sql_test.py | 61 ++++++++++++------- tests/unit/sql_test.py | 24 ++++---- 5 files changed, 99 insertions(+), 46 deletions(-) diff --git a/hazelcast/protocol/builtin.py b/hazelcast/protocol/builtin.py index 34550db8fa..e221c5995b 100644 --- a/hazelcast/protocol/builtin.py +++ b/hazelcast/protocol/builtin.py @@ -652,33 +652,49 @@ class ListCNLocalDateCodec(object): @staticmethod def decode(msg): return ListCNFixedSizeCodec.decode( - msg, _LOCAL_DATE_SIZE_IN_BYTES, FixSizedTypesCodec.decode_local_date + msg, _LOCAL_DATE_SIZE_IN_BYTES, ListCNLocalDateCodec._decode_item ) + @staticmethod + def _decode_item(buf, offset): + return FixSizedTypesCodec.decode_local_date(buf, offset).isoformat() + class ListCNLocalTimeCodec(object): @staticmethod def decode(msg): return ListCNFixedSizeCodec.decode( - msg, _LOCAL_TIME_SIZE_IN_BYTES, FixSizedTypesCodec.decode_local_time + msg, _LOCAL_TIME_SIZE_IN_BYTES, ListCNLocalTimeCodec._decode_item ) + @staticmethod + def _decode_item(buf, offset): + return FixSizedTypesCodec.decode_local_time(buf, offset).isoformat() + class ListCNLocalDateTimeCodec(object): @staticmethod def decode(msg): return ListCNFixedSizeCodec.decode( - msg, _LOCAL_DATE_TIME_SIZE_IN_BYTES, FixSizedTypesCodec.decode_local_date_time + msg, _LOCAL_DATE_TIME_SIZE_IN_BYTES, ListCNLocalDateTimeCodec._decode_item ) + @staticmethod + def _decode_item(buf, offset): + return FixSizedTypesCodec.decode_local_date_time(buf, offset).isoformat() + class ListCNOffsetDateTimeCodec(object): @staticmethod def decode(msg): return ListCNFixedSizeCodec.decode( - msg, _OFFSET_DATE_TIME_SIZE_IN_BYTES, FixSizedTypesCodec.decode_offset_date_time + msg, _OFFSET_DATE_TIME_SIZE_IN_BYTES, ListCNOffsetDateTimeCodec._decode_item ) + @staticmethod + def _decode_item(buf, offset): + return FixSizedTypesCodec.decode_offset_date_time(buf, offset).isoformat() + class BigDecimalCodec(object): @staticmethod @@ -688,7 +704,9 @@ def decode(msg): unscaled_value = int_from_bytes(buf[INT_SIZE_IN_BYTES : INT_SIZE_IN_BYTES + size]) scale = FixSizedTypesCodec.decode_int(buf, INT_SIZE_IN_BYTES + size) sign = 0 if unscaled_value >= 0 else 1 - return Decimal((sign, tuple(int(digit) for digit in str(abs(unscaled_value))), -1 * scale)) + return str( + Decimal((sign, tuple(int(digit) for digit in str(abs(unscaled_value))), -1 * scale)) + ) class SqlPageCodec(object): diff --git a/hazelcast/sql.py b/hazelcast/sql.py index 058f7cdbe5..5d9de3ac8b 100644 --- a/hazelcast/sql.py +++ b/hazelcast/sql.py @@ -12,7 +12,11 @@ check_true, get_attr_name, try_to_get_error_message, + check_is_number, + check_is_int, + try_to_get_enum_value, ) +from hazelcast import six _logger = logging.getLogger(__name__) @@ -57,8 +61,8 @@ class SqlService(object): method are used to populate the column list. The whole key and value objects could be accessed through special fields - ``__key`` and ``this``, respectively. If key (value) object has fields, - then the whole key (value) field is exposed as a normal field. Otherwise the + ``__key`` and ``this``, respectively. If key (or value) object has fields, + then the whole key (or value) field is exposed as a normal field. Otherwise the field is hidden. Hidden fields can be accessed directly, but are not returned by ``SELECT * FROM ...`` queries. @@ -826,6 +830,8 @@ def on_next_row(row_future): When in doubt, use the blocking API shown in the first code sample. + Note that, iterators can be requested at most once per SqlResult. + One can call :func:`close` method of a result object to release the resources associated with the result on the server side. It might also be used to cancel query execution on the server side @@ -1440,7 +1446,7 @@ def sql(self): @sql.setter def sql(self, sql): - check_not_none(sql, "SQL cannot be None") + check_true(isinstance(sql, six.string_types) and sql is not None, "SQL must be a string") if not sql.strip(): raise ValueError("SQL cannot be empty") @@ -1465,6 +1471,10 @@ def schema(self): @schema.setter def schema(self, schema): + check_true( + isinstance(schema, six.string_types) or schema is None, + "Schema must be a string or None", + ) self._schema = schema @property @@ -1481,6 +1491,7 @@ def parameters(self): @parameters.setter def parameters(self, parameters): + check_true(isinstance(parameters, list), "Parameters must be a list") if not parameters: self._parameters = [] else: @@ -1488,7 +1499,7 @@ def parameters(self, parameters): @property def timeout(self): - """float: The execution timeout in seconds. + """float or int: The execution timeout in seconds. If the timeout is reached for a running statement, it will be cancelled forcefully. @@ -1503,6 +1514,7 @@ def timeout(self): @timeout.setter def timeout(self, timeout): + check_is_number(timeout, "Timeout must be an integer or float") if timeout < 0 and timeout != SqlStatement.TIMEOUT_NOT_SET: raise ValueError("Timeout must be non-negative or -1, not %s" % timeout) @@ -1531,6 +1543,7 @@ def cursor_buffer_size(self): @cursor_buffer_size.setter def cursor_buffer_size(self, cursor_buffer_size): + check_is_int(cursor_buffer_size, "Cursor buffer size must an integer") if cursor_buffer_size <= 0: raise ValueError("Cursor buffer size must be positive, not %s" % cursor_buffer_size) self._cursor_buffer_size = cursor_buffer_size @@ -1542,7 +1555,8 @@ def expected_result_type(self): @expected_result_type.setter def expected_result_type(self, expected_result_type): - check_not_none(expected_result_type, "Expected result type cannot be None") + # Ignore the result, we call this method just to type check. + try_to_get_enum_value(expected_result_type, SqlExpectedResultType) self._expected_result_type = expected_result_type def add_parameter(self, parameter): diff --git a/hazelcast/util.py b/hazelcast/util.py index 4595ceab6c..6ea7b67eca 100644 --- a/hazelcast/util.py +++ b/hazelcast/util.py @@ -41,14 +41,14 @@ def check_not_empty(collection, message): raise AssertionError(message) -def check_is_number(val): +def check_is_number(val, message="Number value expected"): if not isinstance(val, number_types): - raise AssertionError("Number value expected") + raise AssertionError(message) -def check_is_int(val): +def check_is_int(val, message="Int value expected"): if not isinstance(val, six.integer_types): - raise AssertionError("Int value expected") + raise AssertionError(message) def current_time(): diff --git a/tests/integration/backward_compatible/sql_test.py b/tests/integration/backward_compatible/sql_test.py index 5ed10bddb3..47b9778f34 100644 --- a/tests/integration/backward_compatible/sql_test.py +++ b/tests/integration/backward_compatible/sql_test.py @@ -1,5 +1,5 @@ -import datetime import decimal +import math import random import string @@ -7,9 +7,9 @@ from hazelcast.future import ImmediateFuture from hazelcast.serialization.api import Portable from hazelcast.sql import HazelcastSqlError, SqlStatement, SqlExpectedResultType, SqlColumnType -from hazelcast.util import timezone from tests.base import SingleMemberTestCase from tests.hzrc.ttypes import Lang +from mock import patch SERVER_CONFIG = """ Date: Fri, 11 Jun 2021 11:55:25 +0300 Subject: [PATCH 7/9] backward compatibility and non-blocking iterator fixes --- hazelcast/sql.py | 8 +++-- .../backward_compatible/cluster_test.py | 14 +++++++-- .../backward_compatible/sql_test.py | 31 +++++++++++++++++-- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/hazelcast/sql.py b/hazelcast/sql.py index 5d9de3ac8b..8f97782344 100644 --- a/hazelcast/sql.py +++ b/hazelcast/sql.py @@ -1124,13 +1124,17 @@ def _on_fetch_response(self, page): page (_SqlPage): The next page. """ with self._lock: - self._fetch_future.set_result(page) + future = self._fetch_future + self._fetch_future = None + if page.is_last: # This is the last page, there is nothing # more on the server. self._closed = True - self._fetch_future = None + # Resolving the future before resetting self._fetch_future + # might result in an infinite loop for non-blocking iterators + future.set_result(page) def _handle_execute_response(self, future): """Handles the result of the execute request, by either: diff --git a/tests/integration/backward_compatible/cluster_test.py b/tests/integration/backward_compatible/cluster_test.py index 64bca4b1f3..291cd801ba 100644 --- a/tests/integration/backward_compatible/cluster_test.py +++ b/tests/integration/backward_compatible/cluster_test.py @@ -149,7 +149,7 @@ def test_random_load_balancer(self): self.assertTrue(isinstance(lb, RandomLB)) six.assertCountEqual( - self, self.addresses, list(map(lambda m: m.address, lb._members.members)) + self, self.addresses, list(map(lambda m: m.address, self._get_members_from_lb(lb))) ) for _ in range(10): self.assertTrue(lb.next().address in self.addresses) @@ -164,13 +164,23 @@ def test_round_robin_load_balancer(self): self.assertTrue(isinstance(lb, RoundRobinLB)) six.assertCountEqual( - self, self.addresses, list(map(lambda m: m.address, lb._members.members)) + self, self.addresses, list(map(lambda m: m.address, self._get_members_from_lb(lb))) ) for i in range(10): self.assertEqual(self.addresses[i % len(self.addresses)], lb.next().address) client.shutdown() + @staticmethod + def _get_members_from_lb(lb): + # For backward-compatibility + members = lb._members + if isinstance(members, list): + return members + + # 4.2+ + return members.members + @set_attr(enterprise=True) class HotRestartEventTest(HazelcastTestCase): diff --git a/tests/integration/backward_compatible/sql_test.py b/tests/integration/backward_compatible/sql_test.py index 47b9778f34..897f745240 100644 --- a/tests/integration/backward_compatible/sql_test.py +++ b/tests/integration/backward_compatible/sql_test.py @@ -2,15 +2,24 @@ import math import random import string +import unittest from hazelcast import six from hazelcast.future import ImmediateFuture from hazelcast.serialization.api import Portable -from hazelcast.sql import HazelcastSqlError, SqlStatement, SqlExpectedResultType, SqlColumnType from tests.base import SingleMemberTestCase from tests.hzrc.ttypes import Lang from mock import patch +from tests.util import is_client_version_older_than, mark_server_version_at_least + +try: + from hazelcast.sql import HazelcastSqlError, SqlStatement, SqlExpectedResultType, SqlColumnType +except ImportError: + # For backward compatibility. If we cannot import those, we won't + # be even referencing them in tests. + pass + SERVER_CONFIG = """ Date: Fri, 11 Jun 2021 13:20:55 +0300 Subject: [PATCH 8/9] document raised exceptions and add a few unit tests for invalid inputs --- hazelcast/sql.py | 83 +++++++++++++++++++++++++----------------- tests/unit/sql_test.py | 38 +++++++++++++++++++ 2 files changed, 87 insertions(+), 34 deletions(-) diff --git a/hazelcast/sql.py b/hazelcast/sql.py index 8f97782344..36703f3caa 100644 --- a/hazelcast/sql.py +++ b/hazelcast/sql.py @@ -7,7 +7,6 @@ from hazelcast.invocation import Invocation from hazelcast.util import ( UUIDUtil, - check_not_none, to_millis, check_true, get_attr_name, @@ -159,6 +158,7 @@ def execute(self, sql, *params): Raises: HazelcastSqlError: In case of execution error. + AssertionError: If the SQL parameter is not a string. """ return self._service.execute(sql, *params) @@ -482,8 +482,12 @@ def get_column(self, index): Returns: SqlColumnMetadata: Metadata for the given column index. + + Raises: + IndexError: If the index is out of bounds. + AssertionError: If the index is not an integer. """ - check_true(0 <= index < len(self._columns), "Column index is out of bounds: %s" % index) + check_is_int(index, "Index must an integer") return self._columns[index] def find_column(self, column_name): @@ -494,8 +498,11 @@ def find_column(self, column_name): Returns: int: Column index or :const:`COLUMN_NOT_FOUND` if a column with the given name is not found. + + Raises: + AssertionError: If the column name is not a string. """ - check_not_none(column_name, "Column name cannot be None") + check_true(isinstance(column_name, six.string_types), "Column name must be a string") return self._name_to_index.get(column_name, SqlRowMetadata.COLUMN_NOT_FOUND) def __repr__(self): @@ -530,6 +537,10 @@ def get_object(self, column_name): Returns: Value of the column. + Raises: + ValueError: If a column with the given name does not exist. + AssertionError: If the column name is not a string. + See Also: :attr:`metadata` @@ -556,15 +567,16 @@ def get_object_with_index(self, column_index): Returns: Value of the column. + Raises: + IndexError: If the column index is out of bounds. + AssertionError: If the column index is not an integer. + See Also: :attr:`metadata` :attr:`SqlColumnMetadata.type` """ - check_true( - 0 <= column_index < self._row_metadata.column_count, - "Column index is out of bounds: %s" % column_index, - ) + check_is_int(column_index, "Column index must be an integer") return self._row[column_index] @property @@ -1290,36 +1302,39 @@ def execute_statement(self, statement): None, ) - # Create a new, unique query id. - query_id = _SqlQueryId.from_uuid(connection.remote_uuid) - - # Serialize the passed parameters. - serialized_params = [ - self._serialization_service.to_data(param) for param in statement.parameters - ] - - request = sql_execute_codec.encode_request( - statement.sql, - serialized_params, - # to_millis expects None to produce -1 - to_millis(None if statement.timeout == -1 else statement.timeout), - statement.cursor_buffer_size, - statement.schema, - statement.expected_result_type, - query_id, - ) + try: + # Create a new, unique query id. + query_id = _SqlQueryId.from_uuid(connection.remote_uuid) + + # Serialize the passed parameters. + serialized_params = [ + self._serialization_service.to_data(param) for param in statement.parameters + ] + + request = sql_execute_codec.encode_request( + statement.sql, + serialized_params, + # to_millis expects None to produce -1 + to_millis(None if statement.timeout == -1 else statement.timeout), + statement.cursor_buffer_size, + statement.schema, + statement.expected_result_type, + query_id, + ) - invocation = Invocation( - request, connection=connection, response_handler=sql_execute_codec.decode_response - ) + invocation = Invocation( + request, connection=connection, response_handler=sql_execute_codec.decode_response + ) - result = SqlResult( - self, connection, query_id, statement.cursor_buffer_size, invocation.future - ) + result = SqlResult( + self, connection, query_id, statement.cursor_buffer_size, invocation.future + ) - self._invocation_service.invoke(invocation) + self._invocation_service.invoke(invocation) - return result + return result + except Exception as e: + raise self.re_raise(e, connection) def deserialize_object(self, obj): return self._serialization_service.to_object(obj) @@ -1450,7 +1465,7 @@ def sql(self): @sql.setter def sql(self, sql): - check_true(isinstance(sql, six.string_types) and sql is not None, "SQL must be a string") + check_true(isinstance(sql, six.string_types), "SQL must be a string") if not sql.strip(): raise ValueError("SQL cannot be empty") diff --git a/tests/unit/sql_test.py b/tests/unit/sql_test.py index 52d8d18e95..b259e67198 100644 --- a/tests/unit/sql_test.py +++ b/tests/unit/sql_test.py @@ -302,3 +302,41 @@ def test_statement_expected_result_type(self): with self.assertRaises(TypeError): statement = SqlStatement("something") statement.expected_result_type = invalid + + def test_row_metadata_get_column(self): + row_metadata = self._create_row_metadata() + + valid_inputs = [0, 1, 2] + + for valid in valid_inputs: + column_metadata = row_metadata.get_column(valid) + self.assertEqual(str(valid), column_metadata.name) + + invalid_inputs = [4, 5, "6", None] + for invalid in invalid_inputs: + with self.assertRaises((IndexError, AssertionError)): + row_metadata.get_column(invalid) + + def test_row_metadata_find_column(self): + row_metadata = self._create_row_metadata() + + valid_inputs = ["0", "1", "2", "-1"] + + for valid in valid_inputs: + index = row_metadata.find_column(valid) + self.assertEqual(int(valid), index) + + invalid_inputs = [6, None] + for invalid in invalid_inputs: + with self.assertRaises((IndexError, AssertionError)): + row_metadata.get_column(invalid) + + @staticmethod + def _create_row_metadata(): + return SqlRowMetadata( + [ + SqlColumnMetadata("0", SqlColumnType.VARCHAR, True, True), + SqlColumnMetadata("1", SqlColumnType.TINYINT, True, True), + SqlColumnMetadata("2", SqlColumnType.OBJECT, True, True), + ] + ) From afbc382bf76413f62ef4eaf0b010351f2ee7c338 Mon Sep 17 00:00:00 2001 From: mdumandag Date: Mon, 14 Jun 2021 11:09:51 +0300 Subject: [PATCH 9/9] add more exception documentation --- hazelcast/sql.py | 81 +++++++++++++++++++++++++++++++++++------- tests/unit/sql_test.py | 12 +++++-- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/hazelcast/sql.py b/hazelcast/sql.py index 36703f3caa..f17c8f0164 100644 --- a/hazelcast/sql.py +++ b/hazelcast/sql.py @@ -60,8 +60,8 @@ class SqlService(object): method are used to populate the column list. The whole key and value objects could be accessed through special fields - ``__key`` and ``this``, respectively. If key (or value) object has fields, - then the whole key (or value) field is exposed as a normal field. Otherwise the + ``__key`` and ``this``, respectively. If key (value) object has fields, + then the whole key (value) field is exposed as a normal field. Otherwise the field is hidden. Hidden fields can be accessed directly, but are not returned by ``SELECT * FROM ...`` queries. @@ -159,6 +159,7 @@ def execute(self, sql, *params): Raises: HazelcastSqlError: In case of execution error. AssertionError: If the SQL parameter is not a string. + ValueError: If the SQL parameter is an empty string. """ return self._service.execute(sql, *params) @@ -908,6 +909,12 @@ def iterator(self): The iterator may be requested only once. + The returned Future results with: + + - :class:`HazelcastSqlError`: In case of an SQL execution error. + - **ValueError**: If the result only contains an update count, or the + iterator is already requested. + Returns: Future[Iterator[Future[SqlRow]]]: Iterator that produces Future of :class:`SqlRow` s. See the class documentation for the correct @@ -916,9 +923,14 @@ def iterator(self): return self._get_iterator(False) def is_row_set(self): - """ + """Returns whether this result has rows to iterate. + + The returned Future results with: + + - :class:`HazelcastSqlError`: In case of an SQL execution error. + Returns: - Future[bool]: Whether this result has rows to iterate. + Future[bool]: """ def continuation(future): @@ -934,6 +946,10 @@ def update_count(self): result is a row set. In case the result doesn't contain rows but the update count isn't applicable or known, ``0`` is returned. + The returned Future results with: + + - :class:`HazelcastSqlError`: In case of an SQL execution error. + Returns: Future[int]: """ @@ -949,6 +965,11 @@ def continuation(future): def get_row_metadata(self): """Gets the row metadata. + The returned Future results with: + + - :class:`HazelcastSqlError`: In case of an SQL execution error. + - **ValueError**: If the result only contains an update count. + Returns: Future[SqlRowMetadata]: """ @@ -972,6 +993,11 @@ def close(self): if the query is still active. Otherwise it is no-op. For a result with an update count it is always no-op. + The returned Future results with: + + - :class:`HazelcastSqlError`: In case there is an error closing the + result. + Returns: Future[None]: """ @@ -1460,7 +1486,13 @@ def __init__(self, sql): @property def sql(self): - """str: The SQL string to be executed.""" + """str: The SQL string to be executed. + + The setter raises: + + - **AssertionError**: If the SQL parameter is not a string. + - **ValueError**: If the SQL parameter is an empty string. + """ return self._sql @sql.setter @@ -1485,6 +1517,10 @@ def schema(self): The default value is ``None`` meaning only the default search path is used. + + The setter raises: + + - **AssertionError**: If the schema is not a string or ``None``. """ return self._schema @@ -1505,16 +1541,17 @@ def parameters(self): When the setter is called, the content of the parameters list is copied. Subsequent changes to the original list don't change the statement parameters. + + The setter raises: + + - **AssertionError**: If the parameter is not a list. """ return self._parameters @parameters.setter def parameters(self, parameters): check_true(isinstance(parameters, list), "Parameters must be a list") - if not parameters: - self._parameters = [] - else: - self._parameters = list(parameters) + self._parameters = list(parameters) @property def timeout(self): @@ -1528,6 +1565,12 @@ def timeout(self): values are prohibited. Defaults to :const:`TIMEOUT_NOT_SET`. + + The setter raises: + + - **AssertionError**: If the timeout is not an integer or float. + - **ValueError**: If the timeout is negative and not equal to + :const:`TIMEOUT_NOT_SET`. """ return self._timeout @@ -1557,6 +1600,11 @@ def cursor_buffer_size(self): large result sets at the cost of increased memory consumption. Defaults to :const:`DEFAULT_CURSOR_BUFFER_SIZE`. + + The setter raises: + + - **AssertionError**: If the cursor buffer size is not an integer. + - **ValueError**: If the cursor buffer size is not positive. """ return self._cursor_buffer_size @@ -1569,14 +1617,21 @@ def cursor_buffer_size(self, cursor_buffer_size): @property def expected_result_type(self): - """SqlExpectedResultType: The expected result type.""" + """SqlExpectedResultType: The expected result type. + + The setter raises: + + - **TypeError**: If the expected result type does not equal to one of + the values or names of the members of the + :class:`SqlExpectedResultType`. + """ return self._expected_result_type @expected_result_type.setter def expected_result_type(self, expected_result_type): - # Ignore the result, we call this method just to type check. - try_to_get_enum_value(expected_result_type, SqlExpectedResultType) - self._expected_result_type = expected_result_type + self._expected_result_type = try_to_get_enum_value( + expected_result_type, SqlExpectedResultType + ) def add_parameter(self, parameter): """Adds a single parameter to the end of the parameters list. diff --git a/tests/unit/sql_test.py b/tests/unit/sql_test.py index b259e67198..acd56018f1 100644 --- a/tests/unit/sql_test.py +++ b/tests/unit/sql_test.py @@ -20,6 +20,7 @@ SqlStatement, SqlExpectedResultType, ) +from hazelcast.util import try_to_get_enum_value EXPECTED_ROWS = ["result", "result2"] EXPECTED_UPDATE_COUNT = 42 @@ -289,12 +290,19 @@ def test_statement_cursor_buffer_size(self): statement.cursor_buffer_size = invalid def test_statement_expected_result_type(self): - valid_inputs = [SqlExpectedResultType.ROWS, SqlExpectedResultType.UPDATE_COUNT] + valid_inputs = [ + SqlExpectedResultType.ROWS, + SqlExpectedResultType.UPDATE_COUNT, + "ROWS", + "ANY", + ] for valid in valid_inputs: statement = SqlStatement("something") statement.expected_result_type = valid - self.assertEqual(valid, statement.expected_result_type) + self.assertEqual( + try_to_get_enum_value(valid, SqlExpectedResultType), statement.expected_result_type + ) invalid_inputs = [None, 123, "hey"]