diff --git a/docs/api/modules.rst b/docs/api/modules.rst index f3ccd3dca8..88dd4f964c 100644 --- a/docs/api/modules.rst +++ b/docs/api/modules.rst @@ -14,6 +14,7 @@ API Documentation predicate proxy/modules serialization + sql transaction util diff --git a/docs/api/sql.rst b/docs/api/sql.rst new file mode 100644 index 0000000000..745bb7e25b --- /dev/null +++ b/docs/api/sql.rst @@ -0,0 +1,4 @@ +SQL +=========== + +.. automodule:: hazelcast.sql diff --git a/hazelcast/client.py b/hazelcast/client.py index 32c47eca8e..2680ca7aa9 100644 --- a/hazelcast/client.py +++ b/hazelcast/client.py @@ -36,6 +36,7 @@ ) from hazelcast.reactor import AsyncoreReactor from hazelcast.serialization import SerializationServiceV1 +from hazelcast.sql import _InternalSqlService, SqlService from hazelcast.statistics import Statistics from hazelcast.transaction import TWO_PHASE, TransactionManager from hazelcast.util import AtomicInteger, RoundRobinLB @@ -388,6 +389,10 @@ def __init__(self, **kwargs): self._invocation_service.init( self._internal_partition_service, self._connection_manager, self._listener_service ) + self._internal_sql_service = _InternalSqlService( + self._connection_manager, self._serialization_service, self._invocation_service + ) + self.sql = SqlService(self._internal_sql_service) self._init_context() self._start() diff --git a/hazelcast/connection.py b/hazelcast/connection.py index 02150e6eb1..8119c7379b 100644 --- a/hazelcast/connection.py +++ b/hazelcast/connection.py @@ -146,17 +146,22 @@ def add_listener(self, on_connection_opened=None, on_connection_closed=None): def get_connection(self, member_uuid): return self.active_connections.get(member_uuid, None) - def get_random_connection(self): + def get_random_connection(self, should_get_data_member=False): if self._smart_routing_enabled: - member = self._load_balancer.next() - if member: - connection = self.get_connection(member.uuid) - if connection: - return connection - - # We should not get to this point under normal circumstances. - # Therefore, copying the list should be OK. - for connection in list(six.itervalues(self.active_connections)): + connection = self._get_connection_from_load_balancer(should_get_data_member) + if connection: + return connection + + # We should not get to this point under normal circumstances + # for the smart client. For uni-socket client, there would be + # a single connection in the dict. Therefore, copying the list + # should be acceptable. + for member_uuid, connection in list(six.iteritems(self.active_connections)): + if should_get_data_member: + member = self._cluster_service.get_member(member_uuid) + if not member or member.lite_member: + continue + return connection return None @@ -256,6 +261,20 @@ def check_invocation_allowed(self): else: raise IOError("No connection found to cluster") + def _get_connection_from_load_balancer(self, should_get_data_member): + load_balancer = self._load_balancer + member = None + if should_get_data_member: + if load_balancer.can_get_next_data_member(): + member = load_balancer.next_data_member() + else: + member = load_balancer.next() + + if not member: + return None + + return self.get_connection(member.uuid) + def _get_or_connect_to_address(self, address): for connection in list(six.itervalues(self.active_connections)): if connection.remote_address == address: diff --git a/hazelcast/protocol/builtin.py b/hazelcast/protocol/builtin.py index e28c5fb7a3..e221c5995b 100644 --- a/hazelcast/protocol/builtin.py +++ b/hazelcast/protocol/builtin.py @@ -1,4 +1,6 @@ import uuid +from datetime import date, time, datetime, timedelta +from decimal import Decimal from hazelcast import six from hazelcast.six.moves import range @@ -11,7 +13,7 @@ NULL_FINAL_FRAME_BUF, END_FINAL_FRAME_BUF, ) -from hazelcast.serialization import ( +from hazelcast.serialization.bits import ( LONG_SIZE_IN_BYTES, UUID_SIZE_IN_BYTES, LE_INT, @@ -23,8 +25,21 @@ LE_INT8, UUID_MSB_SHIFT, UUID_LSB_MASK, + BYTE_SIZE_IN_BYTES, + SHORT_SIZE_IN_BYTES, + LE_INT16, + FLOAT_SIZE_IN_BYTES, + LE_FLOAT, + LE_DOUBLE, + DOUBLE_SIZE_IN_BYTES, ) from hazelcast.serialization.data import Data +from hazelcast.util import int_from_bytes, timezone + +_LOCAL_DATE_SIZE_IN_BYTES = SHORT_SIZE_IN_BYTES + BYTE_SIZE_IN_BYTES * 2 +_LOCAL_TIME_SIZE_IN_BYTES = BYTE_SIZE_IN_BYTES * 3 + INT_SIZE_IN_BYTES +_LOCAL_DATE_TIME_SIZE_IN_BYTES = _LOCAL_DATE_SIZE_IN_BYTES + _LOCAL_TIME_SIZE_IN_BYTES +_OFFSET_DATE_TIME_SIZE_IN_BYTES = _LOCAL_DATE_TIME_SIZE_IN_BYTES + INT_SIZE_IN_BYTES class CodecUtil(object): @@ -274,6 +289,49 @@ def decode_uuid(buf, offset): ) return uuid.UUID(bytes=bytes(b)) + @staticmethod + def decode_short(buf, offset): + return LE_INT16.unpack_from(buf, offset)[0] + + @staticmethod + def decode_float(buf, offset): + return LE_FLOAT.unpack_from(buf, offset)[0] + + @staticmethod + def decode_double(buf, offset): + return LE_DOUBLE.unpack_from(buf, offset)[0] + + @staticmethod + def decode_local_date(buf, offset): + year = FixSizedTypesCodec.decode_short(buf, offset) + month = FixSizedTypesCodec.decode_byte(buf, offset + SHORT_SIZE_IN_BYTES) + day = FixSizedTypesCodec.decode_byte(buf, offset + SHORT_SIZE_IN_BYTES + BYTE_SIZE_IN_BYTES) + + return date(year, month, day) + + @staticmethod + def decode_local_time(buf, offset): + hour = FixSizedTypesCodec.decode_byte(buf, offset) + minute = FixSizedTypesCodec.decode_byte(buf, offset + BYTE_SIZE_IN_BYTES) + second = FixSizedTypesCodec.decode_byte(buf, offset + BYTE_SIZE_IN_BYTES * 2) + nano = FixSizedTypesCodec.decode_int(buf, offset + BYTE_SIZE_IN_BYTES * 3) + + return time(hour, minute, second, int(nano / 1000.0)) + + @staticmethod + def decode_local_date_time(buf, offset): + date_value = FixSizedTypesCodec.decode_local_date(buf, offset) + time_value = FixSizedTypesCodec.decode_local_time(buf, offset + _LOCAL_DATE_SIZE_IN_BYTES) + + return datetime.combine(date_value, time_value) + + @staticmethod + def decode_offset_date_time(buf, offset): + datetime_value = FixSizedTypesCodec.decode_local_date_time(buf, offset) + offset_seconds = FixSizedTypesCodec.decode_int(buf, offset + _LOCAL_DATE_TIME_SIZE_IN_BYTES) + + return datetime_value.replace(tzinfo=timezone(timedelta(seconds=offset_seconds))) + class ListIntegerCodec(object): @staticmethod @@ -496,3 +554,220 @@ def encode(buf, value, is_final=False): @staticmethod def decode(msg): return msg.next_frame().buf.decode("utf-8") + + +class ListCNFixedSizeCodec(object): + _TYPE_NULL_ONLY = 1 + _TYPE_NOT_NULL_ONLY = 2 + _TYPE_MIXED = 3 + + _ITEMS_PER_BITMASK = 8 + + _HEADER_SIZE = BYTE_SIZE_IN_BYTES + INT_SIZE_IN_BYTES + + @staticmethod + def decode(msg, item_size, decoder): + frame = msg.next_frame() + type = FixSizedTypesCodec.decode_byte(frame.buf, 0) + count = FixSizedTypesCodec.decode_int(frame.buf, 1) + + if type == ListCNFixedSizeCodec._TYPE_NULL_ONLY: + return [None] * count + elif type == ListCNFixedSizeCodec._TYPE_NOT_NULL_ONLY: + header_size = ListCNFixedSizeCodec._HEADER_SIZE + return [decoder(frame.buf, header_size + i * item_size) for i in range(count)] + else: + response = [None] * count + position = ListCNFixedSizeCodec._HEADER_SIZE + read_count = 0 + items_per_bitmask = ListCNFixedSizeCodec._ITEMS_PER_BITMASK + + while read_count < count: + bitmask = FixSizedTypesCodec.decode_byte(frame.buf, position) + position += 1 + + batch_size = min(items_per_bitmask, count - read_count) + for i in range(batch_size): + mask = 1 << i + if (bitmask & mask) == mask: + response[read_count] = decoder(frame.buf, position) + position += item_size + + read_count += 1 + + return response + + +class ListCNBooleanCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, BOOLEAN_SIZE_IN_BYTES, FixSizedTypesCodec.decode_boolean + ) + + +class ListCNByteCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode(msg, BYTE_SIZE_IN_BYTES, FixSizedTypesCodec.decode_byte) + + +class ListCNShortCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, SHORT_SIZE_IN_BYTES, FixSizedTypesCodec.decode_short + ) + + +class ListCNIntegerCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode(msg, INT_SIZE_IN_BYTES, FixSizedTypesCodec.decode_int) + + +class ListCNLongCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode(msg, LONG_SIZE_IN_BYTES, FixSizedTypesCodec.decode_long) + + +class ListCNFloatCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, FLOAT_SIZE_IN_BYTES, FixSizedTypesCodec.decode_float + ) + + +class ListCNDoubleCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, DOUBLE_SIZE_IN_BYTES, FixSizedTypesCodec.decode_double + ) + + +class ListCNLocalDateCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, _LOCAL_DATE_SIZE_IN_BYTES, ListCNLocalDateCodec._decode_item + ) + + @staticmethod + def _decode_item(buf, offset): + return FixSizedTypesCodec.decode_local_date(buf, offset).isoformat() + + +class ListCNLocalTimeCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, _LOCAL_TIME_SIZE_IN_BYTES, ListCNLocalTimeCodec._decode_item + ) + + @staticmethod + def _decode_item(buf, offset): + return FixSizedTypesCodec.decode_local_time(buf, offset).isoformat() + + +class ListCNLocalDateTimeCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, _LOCAL_DATE_TIME_SIZE_IN_BYTES, ListCNLocalDateTimeCodec._decode_item + ) + + @staticmethod + def _decode_item(buf, offset): + return FixSizedTypesCodec.decode_local_date_time(buf, offset).isoformat() + + +class ListCNOffsetDateTimeCodec(object): + @staticmethod + def decode(msg): + return ListCNFixedSizeCodec.decode( + msg, _OFFSET_DATE_TIME_SIZE_IN_BYTES, ListCNOffsetDateTimeCodec._decode_item + ) + + @staticmethod + def _decode_item(buf, offset): + return FixSizedTypesCodec.decode_offset_date_time(buf, offset).isoformat() + + +class BigDecimalCodec(object): + @staticmethod + def decode(msg): + buf = msg.next_frame().buf + size = FixSizedTypesCodec.decode_int(buf, 0) + unscaled_value = int_from_bytes(buf[INT_SIZE_IN_BYTES : INT_SIZE_IN_BYTES + size]) + scale = FixSizedTypesCodec.decode_int(buf, INT_SIZE_IN_BYTES + size) + sign = 0 if unscaled_value >= 0 else 1 + return str( + Decimal((sign, tuple(int(digit) for digit in str(abs(unscaled_value))), -1 * scale)) + ) + + +class SqlPageCodec(object): + @staticmethod + def decode(msg): + from hazelcast.sql import SqlColumnType, _SqlPage + + # begin frame + msg.next_frame() + + # read the "last" flag + is_last = LE_INT8.unpack_from(msg.next_frame().buf, 0)[0] == 1 + + # read column types + column_type_ids = ListIntegerCodec.decode(msg) + column_count = len(column_type_ids) + + # read columns + columns = [None] * column_count + + for i in range(column_count): + column_type_id = column_type_ids[i] + + if column_type_id == SqlColumnType.VARCHAR: + columns[i] = ListMultiFrameCodec.decode_contains_nullable(msg, StringCodec.decode) + elif column_type_id == SqlColumnType.BOOLEAN: + columns[i] = ListCNBooleanCodec.decode(msg) + elif column_type_id == SqlColumnType.TINYINT: + columns[i] = ListCNByteCodec.decode(msg) + elif column_type_id == SqlColumnType.SMALLINT: + columns[i] = ListCNShortCodec.decode(msg) + elif column_type_id == SqlColumnType.INTEGER: + columns[i] = ListCNIntegerCodec.decode(msg) + elif column_type_id == SqlColumnType.BIGINT: + columns[i] = ListCNLongCodec.decode(msg) + elif column_type_id == SqlColumnType.REAL: + columns[i] = ListCNFloatCodec.decode(msg) + elif column_type_id == SqlColumnType.DOUBLE: + columns[i] = ListCNDoubleCodec.decode(msg) + elif column_type_id == SqlColumnType.DATE: + columns[i] = ListCNLocalDateCodec.decode(msg) + elif column_type_id == SqlColumnType.TIME: + columns[i] = ListCNLocalTimeCodec.decode(msg) + elif column_type_id == SqlColumnType.TIMESTAMP: + columns[i] = ListCNLocalDateTimeCodec.decode(msg) + elif column_type_id == SqlColumnType.TIMESTAMP_WITH_TIME_ZONE: + columns[i] = ListCNOffsetDateTimeCodec.decode(msg) + elif column_type_id == SqlColumnType.DECIMAL: + columns[i] = ListMultiFrameCodec.decode_contains_nullable( + msg, BigDecimalCodec.decode + ) + elif column_type_id == SqlColumnType.NULL: + frame = msg.next_frame() + size = FixSizedTypesCodec.decode_int(frame.buf, 0) + column = [None for _ in range(size)] + columns[i] = column + elif column_type_id == SqlColumnType.OBJECT: + columns[i] = ListMultiFrameCodec.decode_contains_nullable(msg, DataCodec.decode) + else: + raise ValueError("Unknown type %s" % column_type_id) + + CodecUtil.fast_forward_to_end_frame(msg) + + return _SqlPage(column_type_ids, columns, is_last) diff --git a/hazelcast/protocol/codec/custom/sql_column_metadata_codec.py b/hazelcast/protocol/codec/custom/sql_column_metadata_codec.py new file mode 100644 index 0000000000..4c72e71582 --- /dev/null +++ b/hazelcast/protocol/codec/custom/sql_column_metadata_codec.py @@ -0,0 +1,39 @@ +from hazelcast.protocol.builtin import FixSizedTypesCodec, CodecUtil +from hazelcast.serialization.bits import * +from hazelcast.protocol.client_message import END_FRAME_BUF, END_FINAL_FRAME_BUF, SIZE_OF_FRAME_LENGTH_AND_FLAGS, create_initial_buffer_custom +from hazelcast.sql import SqlColumnMetadata +from hazelcast.protocol.builtin import StringCodec + +_TYPE_ENCODE_OFFSET = 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS +_TYPE_DECODE_OFFSET = 0 +_NULLABLE_ENCODE_OFFSET = _TYPE_ENCODE_OFFSET + INT_SIZE_IN_BYTES +_NULLABLE_DECODE_OFFSET = _TYPE_DECODE_OFFSET + INT_SIZE_IN_BYTES +_INITIAL_FRAME_SIZE = _NULLABLE_ENCODE_OFFSET + BOOLEAN_SIZE_IN_BYTES - 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS + + +class SqlColumnMetadataCodec(object): + @staticmethod + def encode(buf, sql_column_metadata, is_final=False): + initial_frame_buf = create_initial_buffer_custom(_INITIAL_FRAME_SIZE) + FixSizedTypesCodec.encode_int(initial_frame_buf, _TYPE_ENCODE_OFFSET, sql_column_metadata.type) + FixSizedTypesCodec.encode_boolean(initial_frame_buf, _NULLABLE_ENCODE_OFFSET, sql_column_metadata.nullable) + buf.extend(initial_frame_buf) + StringCodec.encode(buf, sql_column_metadata.name) + if is_final: + buf.extend(END_FINAL_FRAME_BUF) + else: + buf.extend(END_FRAME_BUF) + + @staticmethod + def decode(msg): + msg.next_frame() + initial_frame = msg.next_frame() + type = FixSizedTypesCodec.decode_int(initial_frame.buf, _TYPE_DECODE_OFFSET) + is_nullable_exists = False + nullable = False + if len(initial_frame.buf) >= _NULLABLE_DECODE_OFFSET + BOOLEAN_SIZE_IN_BYTES: + nullable = FixSizedTypesCodec.decode_boolean(initial_frame.buf, _NULLABLE_DECODE_OFFSET) + is_nullable_exists = True + name = StringCodec.decode(msg) + CodecUtil.fast_forward_to_end_frame(msg) + return SqlColumnMetadata(name, type, is_nullable_exists, nullable) diff --git a/hazelcast/protocol/codec/custom/sql_error_codec.py b/hazelcast/protocol/codec/custom/sql_error_codec.py new file mode 100644 index 0000000000..37bea690ef --- /dev/null +++ b/hazelcast/protocol/codec/custom/sql_error_codec.py @@ -0,0 +1,35 @@ +from hazelcast.protocol.builtin import FixSizedTypesCodec, CodecUtil +from hazelcast.serialization.bits import * +from hazelcast.protocol.client_message import END_FRAME_BUF, END_FINAL_FRAME_BUF, SIZE_OF_FRAME_LENGTH_AND_FLAGS, create_initial_buffer_custom +from hazelcast.sql import _SqlError +from hazelcast.protocol.builtin import StringCodec + +_CODE_ENCODE_OFFSET = 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS +_CODE_DECODE_OFFSET = 0 +_ORIGINATING_MEMBER_ID_ENCODE_OFFSET = _CODE_ENCODE_OFFSET + INT_SIZE_IN_BYTES +_ORIGINATING_MEMBER_ID_DECODE_OFFSET = _CODE_DECODE_OFFSET + INT_SIZE_IN_BYTES +_INITIAL_FRAME_SIZE = _ORIGINATING_MEMBER_ID_ENCODE_OFFSET + UUID_SIZE_IN_BYTES - 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS + + +class SqlErrorCodec(object): + @staticmethod + def encode(buf, sql_error, is_final=False): + initial_frame_buf = create_initial_buffer_custom(_INITIAL_FRAME_SIZE) + FixSizedTypesCodec.encode_int(initial_frame_buf, _CODE_ENCODE_OFFSET, sql_error.code) + FixSizedTypesCodec.encode_uuid(initial_frame_buf, _ORIGINATING_MEMBER_ID_ENCODE_OFFSET, sql_error.originating_member_id) + buf.extend(initial_frame_buf) + CodecUtil.encode_nullable(buf, sql_error.message, StringCodec.encode) + if is_final: + buf.extend(END_FINAL_FRAME_BUF) + else: + buf.extend(END_FRAME_BUF) + + @staticmethod + def decode(msg): + msg.next_frame() + initial_frame = msg.next_frame() + code = FixSizedTypesCodec.decode_int(initial_frame.buf, _CODE_DECODE_OFFSET) + originating_member_id = FixSizedTypesCodec.decode_uuid(initial_frame.buf, _ORIGINATING_MEMBER_ID_DECODE_OFFSET) + message = CodecUtil.decode_nullable(msg, StringCodec.decode) + CodecUtil.fast_forward_to_end_frame(msg) + return _SqlError(code, message, originating_member_id) diff --git a/hazelcast/protocol/codec/custom/sql_query_id_codec.py b/hazelcast/protocol/codec/custom/sql_query_id_codec.py new file mode 100644 index 0000000000..4d967fc72b --- /dev/null +++ b/hazelcast/protocol/codec/custom/sql_query_id_codec.py @@ -0,0 +1,40 @@ +from hazelcast.protocol.builtin import FixSizedTypesCodec, CodecUtil +from hazelcast.serialization.bits import * +from hazelcast.protocol.client_message import END_FRAME_BUF, END_FINAL_FRAME_BUF, SIZE_OF_FRAME_LENGTH_AND_FLAGS, create_initial_buffer_custom +from hazelcast.sql import _SqlQueryId + +_MEMBER_ID_HIGH_ENCODE_OFFSET = 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS +_MEMBER_ID_HIGH_DECODE_OFFSET = 0 +_MEMBER_ID_LOW_ENCODE_OFFSET = _MEMBER_ID_HIGH_ENCODE_OFFSET + LONG_SIZE_IN_BYTES +_MEMBER_ID_LOW_DECODE_OFFSET = _MEMBER_ID_HIGH_DECODE_OFFSET + LONG_SIZE_IN_BYTES +_LOCAL_ID_HIGH_ENCODE_OFFSET = _MEMBER_ID_LOW_ENCODE_OFFSET + LONG_SIZE_IN_BYTES +_LOCAL_ID_HIGH_DECODE_OFFSET = _MEMBER_ID_LOW_DECODE_OFFSET + LONG_SIZE_IN_BYTES +_LOCAL_ID_LOW_ENCODE_OFFSET = _LOCAL_ID_HIGH_ENCODE_OFFSET + LONG_SIZE_IN_BYTES +_LOCAL_ID_LOW_DECODE_OFFSET = _LOCAL_ID_HIGH_DECODE_OFFSET + LONG_SIZE_IN_BYTES +_INITIAL_FRAME_SIZE = _LOCAL_ID_LOW_ENCODE_OFFSET + LONG_SIZE_IN_BYTES - 2 * SIZE_OF_FRAME_LENGTH_AND_FLAGS + + +class SqlQueryIdCodec(object): + @staticmethod + def encode(buf, sql_query_id, is_final=False): + initial_frame_buf = create_initial_buffer_custom(_INITIAL_FRAME_SIZE) + FixSizedTypesCodec.encode_long(initial_frame_buf, _MEMBER_ID_HIGH_ENCODE_OFFSET, sql_query_id.member_id_high) + FixSizedTypesCodec.encode_long(initial_frame_buf, _MEMBER_ID_LOW_ENCODE_OFFSET, sql_query_id.member_id_low) + FixSizedTypesCodec.encode_long(initial_frame_buf, _LOCAL_ID_HIGH_ENCODE_OFFSET, sql_query_id.local_id_high) + FixSizedTypesCodec.encode_long(initial_frame_buf, _LOCAL_ID_LOW_ENCODE_OFFSET, sql_query_id.local_id_low) + buf.extend(initial_frame_buf) + if is_final: + buf.extend(END_FINAL_FRAME_BUF) + else: + buf.extend(END_FRAME_BUF) + + @staticmethod + def decode(msg): + msg.next_frame() + initial_frame = msg.next_frame() + member_id_high = FixSizedTypesCodec.decode_long(initial_frame.buf, _MEMBER_ID_HIGH_DECODE_OFFSET) + member_id_low = FixSizedTypesCodec.decode_long(initial_frame.buf, _MEMBER_ID_LOW_DECODE_OFFSET) + local_id_high = FixSizedTypesCodec.decode_long(initial_frame.buf, _LOCAL_ID_HIGH_DECODE_OFFSET) + local_id_low = FixSizedTypesCodec.decode_long(initial_frame.buf, _LOCAL_ID_LOW_DECODE_OFFSET) + CodecUtil.fast_forward_to_end_frame(msg) + return _SqlQueryId(member_id_high, member_id_low, local_id_high, local_id_low) diff --git a/hazelcast/protocol/codec/sql_close_codec.py b/hazelcast/protocol/codec/sql_close_codec.py new file mode 100644 index 0000000000..b6647befc4 --- /dev/null +++ b/hazelcast/protocol/codec/sql_close_codec.py @@ -0,0 +1,15 @@ +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer +from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec + +# hex: 0x210300 +_REQUEST_MESSAGE_TYPE = 2163456 +# hex: 0x210301 +_RESPONSE_MESSAGE_TYPE = 2163457 + +_REQUEST_INITIAL_FRAME_SIZE = REQUEST_HEADER_SIZE + + +def encode_request(query_id): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + SqlQueryIdCodec.encode(buf, query_id, True) + return OutboundMessage(buf, False) diff --git a/hazelcast/protocol/codec/sql_execute_codec.py b/hazelcast/protocol/codec/sql_execute_codec.py new file mode 100644 index 0000000000..8b93c3db73 --- /dev/null +++ b/hazelcast/protocol/codec/sql_execute_codec.py @@ -0,0 +1,44 @@ +from hazelcast.serialization.bits import * +from hazelcast.protocol.builtin import FixSizedTypesCodec +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer, RESPONSE_HEADER_SIZE +from hazelcast.protocol.builtin import StringCodec +from hazelcast.protocol.builtin import ListMultiFrameCodec +from hazelcast.protocol.builtin import DataCodec +from hazelcast.protocol.builtin import CodecUtil +from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec +from hazelcast.protocol.codec.custom.sql_column_metadata_codec import SqlColumnMetadataCodec +from hazelcast.protocol.builtin import SqlPageCodec +from hazelcast.protocol.codec.custom.sql_error_codec import SqlErrorCodec + +# hex: 0x210400 +_REQUEST_MESSAGE_TYPE = 2163712 +# hex: 0x210401 +_RESPONSE_MESSAGE_TYPE = 2163713 + +_REQUEST_TIMEOUT_MILLIS_OFFSET = REQUEST_HEADER_SIZE +_REQUEST_CURSOR_BUFFER_SIZE_OFFSET = _REQUEST_TIMEOUT_MILLIS_OFFSET + LONG_SIZE_IN_BYTES +_REQUEST_EXPECTED_RESULT_TYPE_OFFSET = _REQUEST_CURSOR_BUFFER_SIZE_OFFSET + INT_SIZE_IN_BYTES +_REQUEST_INITIAL_FRAME_SIZE = _REQUEST_EXPECTED_RESULT_TYPE_OFFSET + BYTE_SIZE_IN_BYTES +_RESPONSE_UPDATE_COUNT_OFFSET = RESPONSE_HEADER_SIZE + + +def encode_request(sql, parameters, timeout_millis, cursor_buffer_size, schema, expected_result_type, query_id): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + FixSizedTypesCodec.encode_long(buf, _REQUEST_TIMEOUT_MILLIS_OFFSET, timeout_millis) + FixSizedTypesCodec.encode_int(buf, _REQUEST_CURSOR_BUFFER_SIZE_OFFSET, cursor_buffer_size) + FixSizedTypesCodec.encode_byte(buf, _REQUEST_EXPECTED_RESULT_TYPE_OFFSET, expected_result_type) + StringCodec.encode(buf, sql) + ListMultiFrameCodec.encode_contains_nullable(buf, parameters, DataCodec.encode) + CodecUtil.encode_nullable(buf, schema, StringCodec.encode) + SqlQueryIdCodec.encode(buf, query_id, True) + return OutboundMessage(buf, False) + + +def decode_response(msg): + initial_frame = msg.next_frame() + response = dict() + response["update_count"] = FixSizedTypesCodec.decode_long(initial_frame.buf, _RESPONSE_UPDATE_COUNT_OFFSET) + response["row_metadata"] = ListMultiFrameCodec.decode_nullable(msg, SqlColumnMetadataCodec.decode) + response["row_page"] = CodecUtil.decode_nullable(msg, SqlPageCodec.decode) + response["error"] = CodecUtil.decode_nullable(msg, SqlErrorCodec.decode) + return response diff --git a/hazelcast/protocol/codec/sql_fetch_codec.py b/hazelcast/protocol/codec/sql_fetch_codec.py new file mode 100644 index 0000000000..139832bff9 --- /dev/null +++ b/hazelcast/protocol/codec/sql_fetch_codec.py @@ -0,0 +1,30 @@ +from hazelcast.serialization.bits import * +from hazelcast.protocol.builtin import FixSizedTypesCodec +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer +from hazelcast.protocol.codec.custom.sql_query_id_codec import SqlQueryIdCodec +from hazelcast.protocol.builtin import SqlPageCodec +from hazelcast.protocol.builtin import CodecUtil +from hazelcast.protocol.codec.custom.sql_error_codec import SqlErrorCodec + +# hex: 0x210500 +_REQUEST_MESSAGE_TYPE = 2163968 +# hex: 0x210501 +_RESPONSE_MESSAGE_TYPE = 2163969 + +_REQUEST_CURSOR_BUFFER_SIZE_OFFSET = REQUEST_HEADER_SIZE +_REQUEST_INITIAL_FRAME_SIZE = _REQUEST_CURSOR_BUFFER_SIZE_OFFSET + INT_SIZE_IN_BYTES + + +def encode_request(query_id, cursor_buffer_size): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + FixSizedTypesCodec.encode_int(buf, _REQUEST_CURSOR_BUFFER_SIZE_OFFSET, cursor_buffer_size) + SqlQueryIdCodec.encode(buf, query_id, True) + return OutboundMessage(buf, False) + + +def decode_response(msg): + msg.next_frame() + response = dict() + response["row_page"] = CodecUtil.decode_nullable(msg, SqlPageCodec.decode) + response["error"] = CodecUtil.decode_nullable(msg, SqlErrorCodec.decode) + return response diff --git a/hazelcast/serialization/serializer.py b/hazelcast/serialization/serializer.py index 5e50c5f33e..730504045d 100644 --- a/hazelcast/serialization/serializer.py +++ b/hazelcast/serialization/serializer.py @@ -1,6 +1,5 @@ import binascii import time -import uuid from datetime import datetime from hazelcast import six @@ -10,8 +9,7 @@ from hazelcast.serialization.base import HazelcastSerializationError from hazelcast.serialization.serialization_const import * from hazelcast.six.moves import range, cPickle - -from hazelcast.util import to_signed +from hazelcast.util import UUIDUtil if not six.PY2: long = int @@ -135,12 +133,10 @@ class UuidSerializer(BaseSerializer): def read(self, inp): msb = inp.read_long() lsb = inp.read_long() - return uuid.UUID(int=(((msb << UUID_MSB_SHIFT) & UUID_MSB_MASK) | (lsb & UUID_LSB_MASK))) + return UUIDUtil.from_bits(msb, lsb) def write(self, out, obj): - i = obj.int - msb = to_signed(i >> UUID_MSB_SHIFT, 64) - lsb = to_signed(i & UUID_LSB_MASK, 64) + msb, lsb = UUIDUtil.to_bits(obj) out.write_long(msb) out.write_long(lsb) diff --git a/hazelcast/sql.py b/hazelcast/sql.py new file mode 100644 index 0000000000..f17c8f0164 --- /dev/null +++ b/hazelcast/sql.py @@ -0,0 +1,1684 @@ +import logging +import uuid +from threading import RLock + +from hazelcast.errors import HazelcastError +from hazelcast.future import Future, ImmediateFuture +from hazelcast.invocation import Invocation +from hazelcast.util import ( + UUIDUtil, + to_millis, + check_true, + get_attr_name, + try_to_get_error_message, + check_is_number, + check_is_int, + try_to_get_enum_value, +) +from hazelcast import six + +_logger = logging.getLogger(__name__) + + +class SqlService(object): + """A service to execute SQL statements. + + The service allows you to query data stored in a + :class:`Map `. + + Warnings: + + The service is in beta state. Behavior and API might change + in future releases. + + **Querying an IMap** + + Every Map instance is exposed as a table with the same name in the + ``partitioned`` schema. The ``partitioned`` schema is included into + a default search path, therefore a Map could be referenced in an + SQL statement with or without the schema name. + + **Column resolution** + + Every table backed by a Map has a set of columns that are resolved + automatically. Column resolution uses Map entries located on the + member that initiates the query. The engine extracts columns from a + key and a value and then merges them into a single column set. + In case the key and the value have columns with the same name, the + key takes precedence. + + Columns are extracted from objects as follows (which happens on the + server-side): + + - For non-Portable objects, public getters and fields are used to + populate the column list. For getters, the first letter is converted + to lower case. A getter takes precedence over a field in case of naming + conflict. + - For :class:`Portable ` objects, + field names used in the + :func:`write_portable() ` + method are used to populate the column list. + + The whole key and value objects could be accessed through special fields + ``__key`` and ``this``, respectively. If key (value) object has fields, + then the whole key (value) field is exposed as a normal field. Otherwise the + field is hidden. Hidden fields can be accessed directly, but are not returned + by ``SELECT * FROM ...`` queries. + + Consider the following key/value model: :: + + class PersonKey(Portable): + def __init__(self, person_id=None, department_id=None): + self.person_id = person_id + self.department_id = department_id + + def write_portable(self, writer): + writer.write_long("person_id", self.person_id) + writer.write_long("department_id", self.department_id) + + ... + + class Person(Portable): + def __init__(self, name=None): + self.name = name + + def write_portable(self, writer): + writer.write_string("name", self.name) + + ... + + This model will be resolved to the following table columns: + + - person_id ``BIGINT`` + - department_id ``BIGINT`` + - name ``VARCHAR`` + - __key ``OBJECT`` (hidden) + - this ``OBJECT`` (hidden) + + **Consistency** + + Results returned from Map query are weakly consistent: + + - If an entry was not updated during iteration, it is guaranteed to be + returned exactly once + - If an entry was modified during iteration, it might be returned zero, + one or several times + + **Usage** + + When a query is executed, an :class:`SqlResult` is returned. You may get + row iterator from the result. The result must be closed at the end. The + iterator will close the result automatically when it is exhausted given + that no error is raised during the iteration. The code snippet below + demonstrates a typical usage pattern: :: + + client = hazelcast.HazelcastClient() + + result = client.sql.execute("SELECT * FROM person") + + for row in result: + print(row.get_object("person_id")) + print(row.get_object("name")) + ... + + + See the documentation of the :class:`SqlResult` for more information about + different iteration methods. + + Notes: + + When an SQL statement is submitted to a member, it is parsed and + optimized by the ``hazelcast-sql`` module. The ``hazelcast-sql`` must + be in the classpath, otherwise an exception will be thrown. If you're + using the ``hazelcast-all`` or ``hazelcast-enterprise-all`` packages, the + ``hazelcast-sql`` module is included in them by default. If not, i.e., you + are using ``hazelcast`` or ``hazelcast-enterprise``, then you need to have + ``hazelcast-sql`` in the classpath. If you are using the Docker image, + the SQL module is included by default. + + """ + + def __init__(self, internal_sql_service): + self._service = internal_sql_service + + def execute(self, sql, *params): + """Convenient method to execute a distributed query with the given + parameters. + + Converts passed SQL string and parameters into an :class:`SqlStatement` + object and invokes :func:`execute_statement`. + + Args: + sql (str): SQL string. + *params: Query parameters that will be passed to + :func:`SqlStatement.add_parameter`. + + Returns: + SqlResult: The execution result. + + Raises: + HazelcastSqlError: In case of execution error. + AssertionError: If the SQL parameter is not a string. + ValueError: If the SQL parameter is an empty string. + """ + return self._service.execute(sql, *params) + + def execute_statement(self, statement): + """Executes an SQL statement. + + Args: + statement (SqlStatement): Statement to be executed + + Returns: + SqlResult: The execution result. + + Raises: + HazelcastSqlError: In case of execution error. + """ + return self._service.execute_statement(statement) + + +class _SqlQueryId(object): + """Cluster-wide unique query ID.""" + + __slots__ = ("member_id_high", "member_id_low", "local_id_high", "local_id_low") + + def __init__(self, member_id_high, member_id_low, local_id_high, local_id_low): + self.member_id_high = member_id_high + """int: Most significant bits of the UUID of the member + that the query will route to. + """ + + self.member_id_low = member_id_low + """int: Least significant bits of the UUID of the member + that the query will route to.""" + + self.local_id_high = local_id_high + """int: Most significant bits of the UUID of the local id.""" + + self.local_id_low = local_id_low + """int: Least significant bits of the UUID of the local id.""" + + @classmethod + def from_uuid(cls, member_uuid): + """Generates a local random UUID and creates a query id + out of it and the given member UUID. + + Args: + member_uuid (uuid.UUID): UUID of the member. + + Returns: + _SqlQueryId: Generated unique query id. + """ + local_id = uuid.uuid4() + + member_msb, member_lsb = UUIDUtil.to_bits(member_uuid) + local_msb, local_lsb = UUIDUtil.to_bits(local_id) + + return cls(member_msb, member_lsb, local_msb, local_lsb) + + +class SqlColumnMetadata(object): + """Metadata of a column in an SQL row.""" + + __slots__ = ("_name", "_type", "_nullable") + + def __init__(self, name, column_type, nullable, is_nullable_exists): + self._name = name + self._type = column_type + self._nullable = nullable if is_nullable_exists else True + + @property + def name(self): + """str: Name of the column.""" + return self._name + + @property + def type(self): + """SqlColumnType: Type of the column.""" + return self._type + + @property + def nullable(self): + """bool: ``True`` if this column values can be ``None``, ``False`` + otherwise. + """ + return self._nullable + + def __repr__(self): + return "%s %s" % (self.name, get_attr_name(SqlColumnType, self.type)) + + +class _SqlError(object): + """Server-side error that is propagated to the client.""" + + __slots__ = ("code", "message", "originating_member_uuid") + + def __init__(self, code, message, originating_member_uuid): + self.code = code + """_SqlErrorCode: The error code.""" + + self.message = message + """str: The error message.""" + + self.originating_member_uuid = originating_member_uuid + """uuid.UUID: UUID of the member that caused or initiated an error condition.""" + + +class _SqlPage(object): + """A finite set of rows returned to the user.""" + + __slots__ = ("_column_types", "_columns", "_is_last") + + def __init__(self, column_types, columns, last): + self._column_types = column_types + self._columns = columns + self._is_last = last + + @property + def row_count(self): + """int: Number of rows in the page.""" + # Each column should have equal number of rows. + # Just check the first one. + return len(self._columns[0]) + + @property + def column_count(self): + """int: Number of columns.""" + return len(self._column_types) + + @property + def is_last(self): + """bool: Whether this is the last page or not.""" + return self._is_last + + def get_value(self, column_index, row_index): + """ + Args: + column_index (int): + row_index (int): + + Returns: + The value with the given indexes. + """ + return self._columns[column_index][row_index] + + +class SqlColumnType(object): + + VARCHAR = 0 + """ + Represented by ``str``. + """ + + BOOLEAN = 1 + """ + Represented by ``bool``. + """ + + TINYINT = 2 + """ + Represented by ``int``. + """ + + SMALLINT = 3 + """ + Represented by ``int``. + """ + + INTEGER = 4 + """ + Represented by ``int``. + """ + + BIGINT = 5 + """ + Represented by ``int`` (for Python 3) or ``long`` (for Python 2). + """ + + DECIMAL = 6 + """ + Represented by ``decimal.Decimal``. + """ + + REAL = 7 + """ + Represented by ``float``. + """ + + DOUBLE = 8 + """ + Represented by ``float``. + """ + + DATE = 9 + """ + Represented by ``datetime.date``. + """ + + TIME = 10 + """ + Represented by ``datetime.time``. + """ + + TIMESTAMP = 11 + """ + Represented by ``datetime.datetime``. + """ + + TIMESTAMP_WITH_TIME_ZONE = 12 + """ + Represented by ``datetime.datetime`` with ``datetime.tzinfo``. + """ + + OBJECT = 13 + """ + Could be represented by any Python class. + """ + + NULL = 14 + """ + The type of the generic SQL ``NULL`` literal. + + The only valid value of ``NULL`` type is ``None``. + """ + + +class _SqlErrorCode(object): + + GENERIC = -1 + """ + Generic error. + """ + + CONNECTION_PROBLEM = 1001 + """ + A network connection problem between members, or between a client and a member. + """ + + CANCELLED_BY_USER = 1003 + """ + Query was cancelled due to user request. + """ + + TIMEOUT = 1004 + """ + Query was cancelled due to timeout. + """ + + PARTITION_DISTRIBUTION = 1005 + """ + A problem with partition distribution. + """ + + MAP_DESTROYED = 1006 + """ + An error caused by a concurrent destroy of a map. + """ + + MAP_LOADING_IN_PROGRESS = 1007 + """ + Map loading is not finished yet. + """ + + PARSING = 1008 + """ + Generic parsing error. + """ + + INDEX_INVALID = 1009 + """ + An error caused by an attempt to query an index that is not valid. + """ + + DATA_EXCEPTION = 2000 + """ + An error with data conversion or transformation. + """ + + +class HazelcastSqlError(HazelcastError): + """Represents an error occurred during the SQL query execution.""" + + def __init__(self, originating_member_uuid, code, message, cause): + super(HazelcastSqlError, self).__init__(message, cause) + self._originating_member_uuid = originating_member_uuid + + # TODO: This is private API, might be good to make it public or + # remove this information altogether. + self._code = code + + @property + def originating_member_uuid(self): + """uuid.UUID: UUID of the member that caused or initiated an error condition.""" + return self._originating_member_uuid + + +class SqlRowMetadata(object): + """Metadata for the returned rows.""" + + __slots__ = ("_columns", "_name_to_index") + + COLUMN_NOT_FOUND = -1 + """Constant indicating that the column is not found.""" + + def __init__(self, columns): + self._columns = columns + self._name_to_index = {column.name: index for index, column in enumerate(columns)} + + @property + def columns(self): + """list[SqlColumnMetadata]: List of column metadata.""" + return self._columns + + @property + def column_count(self): + """int: Number of columns in the row.""" + return len(self._columns) + + def get_column(self, index): + """ + Args: + index (int): Zero-based column index. + + Returns: + SqlColumnMetadata: Metadata for the given column index. + + Raises: + IndexError: If the index is out of bounds. + AssertionError: If the index is not an integer. + """ + check_is_int(index, "Index must an integer") + return self._columns[index] + + def find_column(self, column_name): + """ + Args: + column_name (str): Name of the column. + + Returns: + int: Column index or :const:`COLUMN_NOT_FOUND` if a column + with the given name is not found. + + Raises: + AssertionError: If the column name is not a string. + """ + check_true(isinstance(column_name, six.string_types), "Column name must be a string") + return self._name_to_index.get(column_name, SqlRowMetadata.COLUMN_NOT_FOUND) + + def __repr__(self): + return "[%s]" % ", ".join( + "%s %s" % (column.name, get_attr_name(SqlColumnType, column.type)) + for column in self._columns + ) + + +class SqlRow(object): + """One of the rows of an SQL query result.""" + + __slots__ = ("_row_metadata", "_row") + + def __init__(self, row_metadata, row): + self._row_metadata = row_metadata + self._row = row + + def get_object(self, column_name): + """Gets the value in the column indicated by the column name. + + Column name should be one of those defined in :class:`SqlRowMetadata`, + case-sensitive. You may also use :func:`SqlRowMetadata.find_column` to + test for column existence. + + The type of the returned value depends on the SQL type of the column. + No implicit conversions are performed on the value. + + Args: + column_name (str): + + Returns: + Value of the column. + + Raises: + ValueError: If a column with the given name does not exist. + AssertionError: If the column name is not a string. + + See Also: + :attr:`metadata` + + :func:`SqlRowMetadata.find_column` + + :attr:`SqlColumnMetadata.type` + + :attr:`SqlColumnMetadata.name` + """ + index = self._row_metadata.find_column(column_name) + if index == SqlRowMetadata.COLUMN_NOT_FOUND: + raise ValueError("Column '%s' doesn't exist" % column_name) + return self._row[index] + + def get_object_with_index(self, column_index): + """Gets the value of the column by index. + + The class of the returned value depends on the SQL type of the column. + No implicit conversions are performed on the value. + + Args: + column_index (int): Zero-based column index. + + Returns: + Value of the column. + + Raises: + IndexError: If the column index is out of bounds. + AssertionError: If the column index is not an integer. + + See Also: + :attr:`metadata` + + :attr:`SqlColumnMetadata.type` + """ + check_is_int(column_index, "Column index must be an integer") + return self._row[column_index] + + @property + def metadata(self): + """SqlRowMetadata: The row metadata.""" + return self._row_metadata + + def __repr__(self): + return "[%s]" % ", ".join( + "%s %s=%s" + % ( + self._row_metadata.get_column(i).name, + get_attr_name(SqlColumnType, self._row_metadata.get_column(i).type), + self._row[i], + ) + for i in range(self._row_metadata.column_count) + ) + + +class _ExecuteResponse(object): + """Represent the response of the first execute request.""" + + __slots__ = ("row_metadata", "row_page", "update_count") + + def __init__(self, row_metadata, row_page, update_count): + self.row_metadata = row_metadata + """SqlRowMetadata: Row metadata or None, if the response only + contains update count.""" + + self.row_page = row_page + """_SqlPage: First page of the query response or None, if the + response only contains update count. + """ + + self.update_count = update_count + """int: Update count or -1 if the result is a rowset.""" + + +class _IteratorBase(object): + """Base class for the blocking and Future-producing + iterators to use.""" + + __slots__ = ( + "row_metadata", + "fetch_fn", + "deserialize_fn", + "page", + "row_count", + "position", + "is_last", + ) + + def __init__(self, row_metadata, fetch_fn, deserialize_fn): + self.row_metadata = row_metadata + """SqlRowMetadata: Row metadata.""" + + self.fetch_fn = fetch_fn + """function: Fetches the next page. It produces a Future[_SqlPage].""" + + self.deserialize_fn = deserialize_fn + """function: Deserializes the value.""" + + self.page = None + """_SqlPage: Current page.""" + + self.row_count = 0 + """int: Number of rows in the current page.""" + + self.position = 0 + """int: Index of the next row in the page.""" + + self.is_last = False + """bool: Whether this is the last page or not.""" + + def on_next_page(self, page): + """ + Called when a new page is fetched or on the + initialization of the iterator to update its + internal state. + + Args: + page (_SqlPage): + """ + self.page = page + self.row_count = page.row_count + self.is_last = page.is_last + self.position = 0 + + def _get_current_row(self): + """ + Returns: + list: The row pointed by the current position. + """ + + # The column might contain user objects so we have to deserialize it. + # Deserialization is no-op if the value is not Data. + return [ + self.deserialize_fn(self.page.get_value(i, self.position)) + for i in range(self.page.column_count) + ] + + +class _FutureProducingIterator(_IteratorBase): + """An iterator that produces infinite stream of Futures. It is the + responsibility of the user to either call them in blocking fashion, + or call ``next`` only if the current call to next did not raise + ``StopIteration`` error (possibly with callback-based code). + """ + + def __iter__(self): + return self + + def next(self): + # Defined for backward-compatibility with Python 2. + return self.__next__() + + def __next__(self): + return self._has_next().continue_with(self._has_next_continuation) + + def _has_next_continuation(self, future): + """Based on the call to :func:`_has_next`, either + raises ``StopIteration`` error or gets the current row + and returns it. + + Args: + future (hazelcast.future.Future): + + Returns: + SqlRow: + """ + has_next = future.result() + if not has_next: + # Iterator is exhausted, raise this to inform the user. + # If the user continues to call next, we will continuously + # raise this. + raise StopIteration + + row = self._get_current_row() + self.position += 1 + return SqlRow(self.row_metadata, row) + + def _has_next(self): + """Returns a Future indicating whether there are more rows + left to iterate. + + Returns: + hazelcast.future.Future: + """ + if self.position == self.row_count: + # We exhausted the current page. + + if self.is_last: + # This was the last page, no row left + # on the server side. + return ImmediateFuture(False) + + # It seems that there are some rows left on the server. + # Fetch them, and then return. + return self.fetch_fn().continue_with(self._fetch_continuation) + + # There are some elements left in the current page. + return ImmediateFuture(True) + + def _fetch_continuation(self, future): + """After a new page is fetched, updates the internal state + of the iterator and returns whether or not there are some + rows in the fetched page. + + Args: + future (hazelcast.future.Future): + + Returns: + hazelcast.future.Future: + """ + page = future.result() + self.on_next_page(page) + return self._has_next() + + +class _BlockingIterator(_IteratorBase): + """An iterator that blocks when the current page is exhausted + and we need to fetch a new page from the server. Otherwise, + it returns immediately with an object. + + This version is more performant than the Future-producing + counterpart in a sense that, it does not box everything with + a Future object. + """ + + def __iter__(self): + return self + + def next(self): + # Defined for backward-compatibility with Python 2. + return self.__next__() + + def __next__(self): + if not self._has_next(): + # No more rows are left. + raise StopIteration + + row = self._get_current_row() + self.position += 1 + return SqlRow(self.row_metadata, row) + + def _has_next(self): + while self.position == self.row_count: + # We exhausted the current page. + + if self.is_last: + # No more rows left on the server. + return False + + # Block while waiting for the next page. + page = self.fetch_fn().result() + + # Update the internal state with the next page. + self.on_next_page(page) + + # There are some rows left in the current page. + return True + + +class SqlResult(object): + """SQL query result. + + Depending on the statement type it represents a stream of + rows or an update count. + + To iterate over the stream of rows, there are two possible options. + + The first, and the easiest one is to iterate over the rows + in a blocking fashion. :: + + result = client.sql.execute("SELECT ...") + for row in result: + # Process the row. + print(row) + + The second option is to use the non-blocking API with callbacks. :: + + result = client.sql.execute("SELECT ...") + it = result.iterator() # Future of iterator + + def on_iterator_response(iterator_future): + iterator = iterator_future.result() + + def on_next_row(row_future): + try: + row = row_future.result() + # Process the row. + print(row) + + # Iterate over the next row. + next(iterator).add_done_callback(on_next_row) + except StopIteration: + # Exhausted the iterator. No more rows are left. + pass + + next(iterator).add_done_callback(on_next_row) + + it.add_done_callback(on_iterator_response) + + When in doubt, use the blocking API shown in the first code sample. + + Note that, iterators can be requested at most once per SqlResult. + + One can call :func:`close` method of a result object to + release the resources associated with the result on the server side. + It might also be used to cancel query execution on the server side + if it is still active. + + When the blocking API is used, one might also use ``with`` + statement to automatically close the query even if an exception + is thrown in the iteration. :: + + with client.sql.execute("SELECT ...") as result: + for row in result: + # Process the row. + print(row) + + + To get the number of rows updated by the query, use the + :func:`update_count`. :: + + update_count = client.sql.execute("SELECT ...").update_count().result() + + One does not have to call :func:`close` in this case, because the result + will already be closed in the server-side. + """ + + def __init__(self, sql_service, connection, query_id, cursor_buffer_size, execute_future): + self._sql_service = sql_service + """_InternalSqlService: Reference to the SQL service.""" + + self._connection = connection + """hazelcast.connection.Connection: Reference to the connection + that the execute request is made to.""" + + self._query_id = query_id + """_SqlQueryId: Unique id of the SQL query.""" + + self._cursor_buffer_size = cursor_buffer_size + """int: Size of the cursor buffer measured in the number of rows.""" + + self._lock = RLock() + """RLock: Protects the shared access to instance variables below.""" + + self._execute_response = Future() + """Future: Will be resolved with :class:`_ExecuteResponse` once the + execute request is resolved.""" + + self._iterator_requested = False + """bool: Flag that shows whether an iterator is already requested.""" + + self._closed = False + """bool: Flag that shows whether the query execution is still active + on the server side. When ``True``, there is no need to send the "close" + request to the server.""" + + self._fetch_future = None + """Future: Will be set, if there are more pages to fetch on the server + side. It should be set to ``None`` once the fetch is completed.""" + + execute_future.add_done_callback(self._handle_execute_response) + + def iterator(self): + """Returns the iterator over the result rows. + + The iterator may be requested only once. + + The returned Future results with: + + - :class:`HazelcastSqlError`: In case of an SQL execution error. + - **ValueError**: If the result only contains an update count, or the + iterator is already requested. + + Returns: + Future[Iterator[Future[SqlRow]]]: Iterator that produces Future + of :class:`SqlRow` s. See the class documentation for the correct + way to use this. + """ + return self._get_iterator(False) + + def is_row_set(self): + """Returns whether this result has rows to iterate. + + The returned Future results with: + + - :class:`HazelcastSqlError`: In case of an SQL execution error. + + Returns: + Future[bool]: + """ + + def continuation(future): + response = future.result() + # By design, if the row_metadata (or row_page) is None, + # we only got the update count. + return response.row_metadata is not None + + return self._execute_response.continue_with(continuation) + + def update_count(self): + """Returns the number of rows updated by the statement or ``-1`` if this + result is a row set. In case the result doesn't contain rows but the + update count isn't applicable or known, ``0`` is returned. + + The returned Future results with: + + - :class:`HazelcastSqlError`: In case of an SQL execution error. + + Returns: + Future[int]: + """ + + def continuation(future): + response = future.result() + # This will be set to -1, when we got row set on the client side. + # See _on_execute_response. + return response.update_count + + return self._execute_response.continue_with(continuation) + + def get_row_metadata(self): + """Gets the row metadata. + + The returned Future results with: + + - :class:`HazelcastSqlError`: In case of an SQL execution error. + - **ValueError**: If the result only contains an update count. + + Returns: + Future[SqlRowMetadata]: + """ + + def continuation(future): + response = future.result() + + if not response.row_metadata: + raise ValueError("This result contains only update count") + + return response.row_metadata + + return self._execute_response.continue_with(continuation) + + def close(self): + """Release the resources associated with the query result. + + The query engine delivers the rows asynchronously. The query may + become inactive even before all rows are consumed. The invocation + of this command will cancel the execution of the query on all members + if the query is still active. Otherwise it is no-op. For a result + with an update count it is always no-op. + + The returned Future results with: + + - :class:`HazelcastSqlError`: In case there is an error closing the + result. + + Returns: + Future[None]: + """ + + with self._lock: + if self._closed: + # Do nothing if the result is already closed. + return ImmediateFuture(None) + + error = HazelcastSqlError( + self._sql_service.get_client_id(), + _SqlErrorCode.CANCELLED_BY_USER, + "Query was cancelled by the user", + None, + ) + + if not self._execute_response.done(): + # If the cancellation is initiated before the first response is + # received, then throw cancellation errors on the dependent + # methods (update count, row metadata, iterator). + self._on_execute_error(error) + + if not self._fetch_future: + # Make sure that all subsequent fetches will fail. + self._fetch_future = Future() + + self._on_fetch_error(error) + + def wrap_error_on_failure(f): + # If the close request is failed somehow, + # wrap it in a HazelcastSqlError. + try: + return f.result() + except Exception as e: + raise self._sql_service.re_raise(e, self._connection) + + self._closed = True + + # Send the close request + return self._sql_service.close(self._connection, self._query_id).continue_with( + wrap_error_on_failure + ) + + def __iter__(self): + # Get blocking iterator, and wait for the + # first page. + return self._get_iterator(True).result() + + def _get_iterator(self, should_get_blocking): + """Gets the iterator after the execute request finishes. + + Args: + should_get_blocking (bool): Whether to get a blocking iterator. + + Returns: + Future[Iterator]: + """ + + def continuation(future): + response = future.result() + + with self._lock: + if not response.row_metadata: + # Can't get an iterator when we only have update count + raise ValueError("This result contains only update count") + + if self._iterator_requested: + # Can't get an iterator when we already get one + raise ValueError("Iterator can be requested only once") + + self._iterator_requested = True + + if should_get_blocking: + iterator = _BlockingIterator( + response.row_metadata, + self._fetch_next_page, + self._sql_service.deserialize_object, + ) + else: + iterator = _FutureProducingIterator( + response.row_metadata, + self._fetch_next_page, + self._sql_service.deserialize_object, + ) + + # Pass the first page information to the iterator + iterator.on_next_page(response.row_page) + return iterator + + return self._execute_response.continue_with(continuation) + + def _fetch_next_page(self): + """Fetches the next page, if there is no fetch request + in-flight. + + Returns: + Future[_SqlPage]: + """ + with self._lock: + if self._fetch_future: + # A fetch request is already in-flight, return it. + return self._fetch_future + + future = Future() + self._fetch_future = future + + self._sql_service.fetch( + self._connection, self._query_id, self._cursor_buffer_size + ).add_done_callback(self._handle_fetch_response) + + # Need to return future, not self._fetch_future, because through + # some unlucky timing, we might call _handle_fetch_response + # before returning, which could set self._fetch_future to + # None. + return future + + def _handle_fetch_response(self, future): + """Handles the result of the fetch request, by either: + + - setting it to exception, so that the future calls to + fetch fails immediately. + - setting it to next page, and setting self._fetch_future + to None so that the next fetch request might actually + try to fetch something from the server. + + Args: + future (Future): The response from the server for + the fetch request. + """ + try: + response = future.result() + + response_error = self._handle_response_error(response["error"]) + if response_error: + # There is a server side error sent to client. + self._on_fetch_error(response_error) + return + + # The result contains the next page, as expected. + self._on_fetch_response(response["row_page"]) + except Exception as e: + # Something went bad, we couldn't get response from + # the server, invocation failed. + self._on_fetch_error(self._sql_service.re_raise(e, self._connection)) + + def _on_fetch_error(self, error): + """Sets the fetch future with exception, but not resetting it + so that the next fetch request fails immediately. + + Args: + error (Exception): The error. + """ + with self._lock: + self._fetch_future.set_exception(error) + + def _on_fetch_response(self, page): + """Sets the fetch future with the next page, + resets it, and if this is the last page, + marks the result as closed. + + Args: + page (_SqlPage): The next page. + """ + with self._lock: + future = self._fetch_future + self._fetch_future = None + + if page.is_last: + # This is the last page, there is nothing + # more on the server. + self._closed = True + + # Resolving the future before resetting self._fetch_future + # might result in an infinite loop for non-blocking iterators + future.set_result(page) + + def _handle_execute_response(self, future): + """Handles the result of the execute request, by either: + + - setting it to an exception so that the dependent methods + (iterator, update_count etc.) fails immediately + - setting it to an execute response + + Args: + future (Future): + """ + try: + response = future.result() + + response_error = self._handle_response_error(response["error"]) + if response_error: + # There is a server-side error sent to the client. + self._on_execute_error(response_error) + return + + row_metadata = response["row_metadata"] + if row_metadata is not None: + # The result contains some rows, not an update count. + row_metadata = SqlRowMetadata(row_metadata) + + self._on_execute_response(row_metadata, response["row_page"], response["update_count"]) + except Exception as e: + # Something went bad, we couldn't get response from + # the server, invocation failed. + self._on_execute_error(self._sql_service.re_raise(e, self._connection)) + + @staticmethod + def _handle_response_error(error): + """If the error is not ``None``, return it as + :class:`HazelcastSqlError` so that we can raise + it to user. + + Args: + error (_SqlError): The error or ``None``. + + Returns: + HazelcastSqlError: If the error is not ``None``, + ``None`` otherwise. + """ + if error: + return HazelcastSqlError(error.originating_member_uuid, error.code, error.message, None) + return None + + def _on_execute_error(self, error): + """Called when the first execute request is failed. + + Args: + error (HazelcastSqlError): The wrapped error that can + be raised to the user. + """ + with self._lock: + if self._closed: + # User might be already cancelled it. + return + + self._execute_response.set_exception(error) + + def _on_execute_response(self, row_metadata, row_page, update_count): + """Called when the first execute request is succeeded. + + Args: + row_metadata (SqlRowMetadata): The row metadata. Might be ``None`` + if the response only contains the update count. + row_page (_SqlPage): The first page of the result. Might be + ``None`` if the response only contains the update count. + update_count (int): The update count. + """ + with self._lock: + if self._closed: + # User might be already cancelled it. + return + + if row_metadata: + # Result contains the row set for the query. + # Set the update count to -1. + response = _ExecuteResponse(row_metadata, row_page, -1) + + if row_page.is_last: + # This is the last page, close the result. + self._closed = True + + self._execute_response.set_result(response) + else: + # Result only contains the update count. + response = _ExecuteResponse(None, None, update_count) + self._execute_response.set_result(response) + + # There is nothing more we can get from the server. + self._closed = True + + def __enter__(self): + # The execute request is already sent. + # There is nothing more to do. + return self + + def __exit__(self, exc_type, exc_value, traceback): + # Ignoring the possible exception details + # since we close the query regardless of that. + self.close().result() + + +class _InternalSqlService(object): + """Internal SQL service that offers more public API + than the one exposed to the user. + """ + + def __init__(self, connection_manager, serialization_service, invocation_service): + self._connection_manager = connection_manager + self._serialization_service = serialization_service + self._invocation_service = invocation_service + + def execute(self, sql, *params): + """Constructs a statement and executes it. + + Args: + sql (str): SQL string. + *params: Query parameters. + + Returns: + SqlResult: The execution result. + """ + statement = SqlStatement(sql) + + for param in params: + statement.add_parameter(param) + + return self.execute_statement(statement) + + def execute_statement(self, statement): + """Executes the given statement. + + Args: + statement (SqlStatement): The statement to execute. + + Returns: + SqlResult: The execution result. + """ + + # Get a random Data member (non-lite member) + connection = self._connection_manager.get_random_connection(True) + if not connection: + # Either the client is not connected to the cluster, or + # there are no data members in the cluster. + raise HazelcastSqlError( + self.get_client_id(), + _SqlErrorCode.CONNECTION_PROBLEM, + "Client is not currently connected to the cluster.", + None, + ) + + try: + # Create a new, unique query id. + query_id = _SqlQueryId.from_uuid(connection.remote_uuid) + + # Serialize the passed parameters. + serialized_params = [ + self._serialization_service.to_data(param) for param in statement.parameters + ] + + request = sql_execute_codec.encode_request( + statement.sql, + serialized_params, + # to_millis expects None to produce -1 + to_millis(None if statement.timeout == -1 else statement.timeout), + statement.cursor_buffer_size, + statement.schema, + statement.expected_result_type, + query_id, + ) + + invocation = Invocation( + request, connection=connection, response_handler=sql_execute_codec.decode_response + ) + + result = SqlResult( + self, connection, query_id, statement.cursor_buffer_size, invocation.future + ) + + self._invocation_service.invoke(invocation) + + return result + except Exception as e: + raise self.re_raise(e, connection) + + def deserialize_object(self, obj): + return self._serialization_service.to_object(obj) + + def fetch(self, connection, query_id, cursor_buffer_size): + """Fetches the next page of the query execution. + + Args: + connection (hazelcast.connection.Connection): Connection + that the first execute request, hence the fetch request + must route to. + query_id (_SqlQueryId): Unique id of the query. + cursor_buffer_size (int): Size of cursor buffer. Same as + the one used in the first execute request. + + Returns: + Future: Decoded fetch response. + """ + request = sql_fetch_codec.encode_request(query_id, cursor_buffer_size) + invocation = Invocation( + request, connection=connection, response_handler=sql_fetch_codec.decode_response + ) + self._invocation_service.invoke(invocation) + return invocation.future + + def get_client_id(self): + """ + Returns: + uuid.UUID: Unique client UUID. + """ + return self._connection_manager.client_uuid + + def re_raise(self, error, connection): + """Returns the error wrapped as the :class:`HazelcastSqlError` + so that it can be raised to the user. + + Args: + error (Exception): The error to reraise. + connection (hazelcast.connection.Connection): Connection + that the query requests are routed to. If it is not + live, we will inform the user about the possible + cluster topology change. + + Returns: + HazelcastSqlError: The reraised error. + """ + if not connection.live: + return HazelcastSqlError( + self.get_client_id(), + _SqlErrorCode.CONNECTION_PROBLEM, + "Cluster topology changed while a query was executed: Member cannot be reached: %s" + % connection.remote_address, + error, + ) + + if isinstance(error, HazelcastSqlError): + return error + + return HazelcastSqlError( + self.get_client_id(), _SqlErrorCode.GENERIC, try_to_get_error_message(error), error + ) + + def close(self, connection, query_id): + """Closes the remote query cursor. + + Args: + connection (hazelcast.connection.Connection): Connection + that the first execute request, hence the close request + must route to. + query_id (_SqlQueryId): The query id to close. + + Returns: + Future: + """ + request = sql_close_codec.encode_request(query_id) + invocation = Invocation(request, connection=connection) + self._invocation_service.invoke(invocation) + return invocation.future + + +class SqlExpectedResultType(object): + """The expected statement result type.""" + + ANY = 0 + """ + The statement may produce either rows or an update count. + """ + + ROWS = 1 + """ + The statement must produce rows. An exception is thrown is the statement produces an update count. + """ + + UPDATE_COUNT = 2 + """ + The statement must produce an update count. An exception is thrown is the statement produces rows. + """ + + +class SqlStatement(object): + """Definition of an SQL statement. + + This object is mutable. Properties are read once before the execution + is started. Changes to properties do not affect the behavior of already + running statements. + """ + + TIMEOUT_NOT_SET = -1 + + TIMEOUT_DISABLED = 0 + + DEFAULT_TIMEOUT = TIMEOUT_NOT_SET + + DEFAULT_CURSOR_BUFFER_SIZE = 4096 + + def __init__(self, sql): + self.sql = sql + self._parameters = [] + self._timeout = SqlStatement.DEFAULT_TIMEOUT + self._cursor_buffer_size = SqlStatement.DEFAULT_CURSOR_BUFFER_SIZE + self._schema = None + self._expected_result_type = SqlExpectedResultType.ANY + + @property + def sql(self): + """str: The SQL string to be executed. + + The setter raises: + + - **AssertionError**: If the SQL parameter is not a string. + - **ValueError**: If the SQL parameter is an empty string. + """ + return self._sql + + @sql.setter + def sql(self, sql): + check_true(isinstance(sql, six.string_types), "SQL must be a string") + + if not sql.strip(): + raise ValueError("SQL cannot be empty") + + self._sql = sql + + @property + def schema(self): + """str: The schema name. The engine will try to resolve the + non-qualified object identifiers from the statement in the given + schema. If not found, the default search path will be used, which + looks for objects in the predefined schemas ``partitioned`` and + ``public``. + + The schema name is case sensitive. For example, ``foo`` and ``Foo`` + are different schemas. + + The default value is ``None`` meaning only the default search path is + used. + + The setter raises: + + - **AssertionError**: If the schema is not a string or ``None``. + """ + return self._schema + + @schema.setter + def schema(self, schema): + check_true( + isinstance(schema, six.string_types) or schema is None, + "Schema must be a string or None", + ) + self._schema = schema + + @property + def parameters(self): + """list: Sets the statement parameters. + + You may define parameter placeholders in the statement with the ``?`` + character. For every placeholder, a parameter value must be provided. + + When the setter is called, the content of the parameters list is copied. + Subsequent changes to the original list don't change the statement parameters. + + The setter raises: + + - **AssertionError**: If the parameter is not a list. + """ + return self._parameters + + @parameters.setter + def parameters(self, parameters): + check_true(isinstance(parameters, list), "Parameters must be a list") + self._parameters = list(parameters) + + @property + def timeout(self): + """float or int: The execution timeout in seconds. + + If the timeout is reached for a running statement, it will be + cancelled forcefully. + + Zero value means no timeout. :const:`TIMEOUT_NOT_SET` means that + the value from the server-side config will be used. Other negative + values are prohibited. + + Defaults to :const:`TIMEOUT_NOT_SET`. + + The setter raises: + + - **AssertionError**: If the timeout is not an integer or float. + - **ValueError**: If the timeout is negative and not equal to + :const:`TIMEOUT_NOT_SET`. + """ + return self._timeout + + @timeout.setter + def timeout(self, timeout): + check_is_number(timeout, "Timeout must be an integer or float") + if timeout < 0 and timeout != SqlStatement.TIMEOUT_NOT_SET: + raise ValueError("Timeout must be non-negative or -1, not %s" % timeout) + + self._timeout = timeout + + @property + def cursor_buffer_size(self): + """int: The cursor buffer size (measured in the number of rows). + + When a statement is submitted for execution, a :class:`SqlResult` + is returned as a result. When rows are ready to be consumed, + they are put into an internal buffer of the cursor. This parameter + defines the maximum number of rows in that buffer. When the threshold + is reached, the backpressure mechanism will slow down the execution, + possibly to a complete halt, to prevent out-of-memory. + + Only positive values are allowed. + + The default value is expected to work well for most workloads. A bigger + buffer size may give you a slight performance boost for queries with + large result sets at the cost of increased memory consumption. + + Defaults to :const:`DEFAULT_CURSOR_BUFFER_SIZE`. + + The setter raises: + + - **AssertionError**: If the cursor buffer size is not an integer. + - **ValueError**: If the cursor buffer size is not positive. + """ + return self._cursor_buffer_size + + @cursor_buffer_size.setter + def cursor_buffer_size(self, cursor_buffer_size): + check_is_int(cursor_buffer_size, "Cursor buffer size must an integer") + if cursor_buffer_size <= 0: + raise ValueError("Cursor buffer size must be positive, not %s" % cursor_buffer_size) + self._cursor_buffer_size = cursor_buffer_size + + @property + def expected_result_type(self): + """SqlExpectedResultType: The expected result type. + + The setter raises: + + - **TypeError**: If the expected result type does not equal to one of + the values or names of the members of the + :class:`SqlExpectedResultType`. + """ + return self._expected_result_type + + @expected_result_type.setter + def expected_result_type(self, expected_result_type): + self._expected_result_type = try_to_get_enum_value( + expected_result_type, SqlExpectedResultType + ) + + def add_parameter(self, parameter): + """Adds a single parameter to the end of the parameters list. + + Args: + parameter: The parameter. + + See Also: + :attr:`parameters` + + :func:`clear_parameters` + """ + self._parameters.append(parameter) + + def clear_parameters(self): + """Clears statement parameters.""" + self._parameters = [] + + def copy(self): + """Creates a copy of this instance. + + Returns: + SqlStatement: The new copy. + """ + copied = SqlStatement(self.sql) + copied.parameters = list(self.parameters) + copied.timeout = self.timeout + copied.cursor_buffer_size = self.cursor_buffer_size + copied.schema = self.schema + copied.expected_result_type = self.expected_result_type + return copied + + def __repr__(self): + return ( + "SqlStatement(schema=%s, sql=%s, parameters=%s, timeout=%s," + " cursor_buffer_size=%s, expected_result_type=%s)" + % ( + self.schema, + self.sql, + self.parameters, + self.timeout, + self.cursor_buffer_size, + self._expected_result_type, + ) + ) + + +# These are imported at the bottom of the page to get rid of the +# cyclic import errors. +from hazelcast.protocol.codec import sql_execute_codec, sql_fetch_codec, sql_close_codec diff --git a/hazelcast/util.py b/hazelcast/util.py index 6accedfa06..6ea7b67eca 100644 --- a/hazelcast/util.py +++ b/hazelcast/util.py @@ -1,6 +1,10 @@ +import binascii import random import threading import time +import uuid + +from hazelcast.serialization import UUID_MSB_SHIFT, UUID_LSB_MASK, UUID_MSB_MASK try: from collections.abc import Sequence, Iterable @@ -37,14 +41,14 @@ def check_not_empty(collection, message): raise AssertionError(message) -def check_is_number(val): +def check_is_number(val, message="Number value expected"): if not isinstance(val, number_types): - raise AssertionError("Number value expected") + raise AssertionError(message) -def check_is_int(val): +def check_is_int(val, message="Int value expected"): if not isinstance(val, six.integer_types): - raise AssertionError("Int value expected") + raise AssertionError(message) def current_time(): @@ -272,18 +276,61 @@ def next(self): """ raise NotImplementedError("next") + def next_data_member(self): + """Returns the next data member to route to. + + Returns: + hazelcast.core.MemberInfo: The next data member or + ``None`` if no data member is available. + """ + return None + + def can_get_next_data_member(self): + """Returns whether this instance supports getting data members + through a call to :func:`next_data_member`. + + Returns: + bool: ``True`` if this instance supports getting data members. + """ + return False + + +class _Members(object): + __slots__ = ("members", "data_members") + + def __init__(self, members, data_members): + self.members = members + self.data_members = data_members + class _AbstractLoadBalancer(LoadBalancer): def __init__(self): self._cluster_service = None - self._members = [] + self._members = _Members([], []) def init(self, cluster_service): self._cluster_service = cluster_service cluster_service.add_listener(self._listener, self._listener, True) + def next(self): + members = self._members.members + return self._next(members) + + def next_data_member(self): + members = self._members.data_members + return self._next(members) + + def can_get_next_data_member(self): + return True + def _listener(self, _): - self._members = self._cluster_service.get_members() + members = self._cluster_service.get_members() + data_members = [member for member in members if not member.lite_member] + + self._members = _Members(members, data_members) + + def _next(self, members): + raise NotImplementedError("_next") class RoundRobinLB(_AbstractLoadBalancer): @@ -298,8 +345,7 @@ def __init__(self): super(RoundRobinLB, self).__init__() self._idx = 0 - def next(self): - members = self._members + def _next(self, members): if not members: return None @@ -312,15 +358,14 @@ def next(self): class RandomLB(_AbstractLoadBalancer): """A load balancer that selects a random member to route to.""" - def next(self): - members = self._members + def _next(self, members): if not members: return None idx = random.randrange(0, len(members)) return members[idx] -class IterationType: +class IterationType(object): """To differentiate users selection on result collection on map-wide operations like ``entry_set``, ``key_set``, ``values`` etc. """ @@ -333,3 +378,79 @@ class IterationType: ENTRY = 2 """Iterate over entries""" + + +class UUIDUtil(object): + @staticmethod + def to_bits(value): + i = value.int + most_significant_bits = to_signed(i >> UUID_MSB_SHIFT, 64) + least_significant_bits = to_signed(i & UUID_LSB_MASK, 64) + return most_significant_bits, least_significant_bits + + @staticmethod + def from_bits(most_significant_bits, least_significant_bits): + return uuid.UUID( + int=( + ((most_significant_bits << UUID_MSB_SHIFT) & UUID_MSB_MASK) + | (least_significant_bits & UUID_LSB_MASK) + ) + ) + + +if hasattr(int, "from_bytes"): + + def int_from_bytes(buffer): + return int.from_bytes(buffer, "big", signed=True) + + +else: + # Compatibility with Python 2 + def int_from_bytes(buffer): + buffer = bytearray(buffer) + if buffer[0] & 0x80: + neg = bytearray() + for c in buffer: + neg.append(~c) + return -1 * int(binascii.hexlify(neg), 16) - 1 + return int(binascii.hexlify(buffer), 16) + + +try: + from datetime import timezone +except ImportError: + from datetime import tzinfo, timedelta + + # There is no tzinfo implementation(timezone) in the + # Python 2. Here we provide the bare minimum + # to the user. + class FixedOffsetTimezone(tzinfo): + __slots__ = ("_offset",) + + def __init__(self, offset): + self._offset = offset + + def utcoffset(self, dt): + return self._offset + + def tzname(self, dt): + return None + + def dst(self, dt): + return timedelta(0) + + timezone = FixedOffsetTimezone + + +def try_to_get_error_message(error): + # If the error has a message attribute, + # return it. If not, almost all of the + # built-in errors (and Hazelcast Errors) + # set the exception message as the first + # parameter of args. If it is not there, + # then return None. + if hasattr(error, "message"): + return error.message + elif len(error.args) > 0: + return error.args[0] + return None diff --git a/start_rc.py b/start_rc.py index 60851429c0..4bc56f55e3 100644 --- a/start_rc.py +++ b/start_rc.py @@ -3,7 +3,7 @@ import sys from os.path import isfile -SERVER_VERSION = "4.1.4-SNAPSHOT" +SERVER_VERSION = "4.2.1-SNAPSHOT" RC_VERSION = "0.8-SNAPSHOT" RELEASE_REPO = "http://repo1.maven.apache.org/maven2" @@ -72,7 +72,7 @@ def start_rc(stdout=None, stderr=None): enterprise_key = os.environ.get("HAZELCAST_ENTERPRISE_KEY", None) if enterprise_key: - server = download_if_necessary(ENTERPRISE_REPO, "hazelcast-enterprise", SERVER_VERSION) + server = download_if_necessary(ENTERPRISE_REPO, "hazelcast-enterprise-all", SERVER_VERSION) ep_tests = download_if_necessary( ENTERPRISE_REPO, "hazelcast-enterprise", SERVER_VERSION, True ) @@ -80,7 +80,7 @@ def start_rc(stdout=None, stderr=None): artifacts.append(server) artifacts.append(ep_tests) else: - server = download_if_necessary(REPO, "hazelcast", SERVER_VERSION) + server = download_if_necessary(REPO, "hazelcast-all", SERVER_VERSION) artifacts.append(server) class_path = CLASS_PATH_SEPARATOR.join(artifacts) diff --git a/tests/integration/backward_compatible/cluster_test.py b/tests/integration/backward_compatible/cluster_test.py index 89343f8187..291cd801ba 100644 --- a/tests/integration/backward_compatible/cluster_test.py +++ b/tests/integration/backward_compatible/cluster_test.py @@ -148,7 +148,9 @@ def test_random_load_balancer(self): lb = client._load_balancer self.assertTrue(isinstance(lb, RandomLB)) - six.assertCountEqual(self, self.addresses, list(map(lambda m: m.address, lb._members))) + six.assertCountEqual( + self, self.addresses, list(map(lambda m: m.address, self._get_members_from_lb(lb))) + ) for _ in range(10): self.assertTrue(lb.next().address in self.addresses) @@ -161,12 +163,24 @@ def test_round_robin_load_balancer(self): lb = client._load_balancer self.assertTrue(isinstance(lb, RoundRobinLB)) - six.assertCountEqual(self, self.addresses, list(map(lambda m: m.address, lb._members))) + six.assertCountEqual( + self, self.addresses, list(map(lambda m: m.address, self._get_members_from_lb(lb))) + ) for i in range(10): self.assertEqual(self.addresses[i % len(self.addresses)], lb.next().address) client.shutdown() + @staticmethod + def _get_members_from_lb(lb): + # For backward-compatibility + members = lb._members + if isinstance(members, list): + return members + + # 4.2+ + return members.members + @set_attr(enterprise=True) class HotRestartEventTest(HazelcastTestCase): diff --git a/tests/integration/backward_compatible/sql_test.py b/tests/integration/backward_compatible/sql_test.py new file mode 100644 index 0000000000..897f745240 --- /dev/null +++ b/tests/integration/backward_compatible/sql_test.py @@ -0,0 +1,634 @@ +import decimal +import math +import random +import string +import unittest + +from hazelcast import six +from hazelcast.future import ImmediateFuture +from hazelcast.serialization.api import Portable +from tests.base import SingleMemberTestCase +from tests.hzrc.ttypes import Lang +from mock import patch + +from tests.util import is_client_version_older_than, mark_server_version_at_least + +try: + from hazelcast.sql import HazelcastSqlError, SqlStatement, SqlExpectedResultType, SqlColumnType +except ImportError: + # For backward compatibility. If we cannot import those, we won't + # be even referencing them in tests. + pass + +SERVER_CONFIG = """ + + + + com.hazelcast.client.test.PortableFactory + + + + +""" + + +class SqlTestBase(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + return SERVER_CONFIG + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + config["portable_factories"] = {666: {6: Student}} + return config + + def setUp(self): + self.map_name = random_string() + self.map = self.client.get_map(self.map_name).blocking() + mark_server_version_at_least(self, self.client, "4.2") + + def tearDown(self): + self.map.clear() + + def _populate_map(self, entry_count=10, value_factory=lambda v: v): + entries = {i: value_factory(i) for i in range(entry_count)} + self.map.put_all(entries) + + +@unittest.skipIf( + is_client_version_older_than("4.2"), "Tests the features added in 4.2 version of the client" +) +class SqlServiceTest(SqlTestBase): + def test_execute(self): + entry_count = 11 + self._populate_map(entry_count) + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + six.assertCountEqual( + self, + [(i, i) for i in range(entry_count)], + [(row.get_object("__key"), row.get_object("this")) for row in result], + ) + + def test_execute_with_params(self): + entry_count = 13 + self._populate_map(entry_count) + result = self.client.sql.execute( + "SELECT this FROM %s WHERE __key > ? AND this > ?" % self.map_name, 5, 6 + ) + six.assertCountEqual( + self, + [i for i in range(7, entry_count)], + [row.get_object("this") for row in result], + ) + + def test_execute_with_mismatched_params_when_sql_has_more(self): + self._populate_map() + result = self.client.sql.execute( + "SELECT * FROM %s WHERE __key > ? AND this > ?" % self.map_name, 5 + ) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + def test_execute_with_mismatched_params_when_params_has_more(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s WHERE this > ?" % self.map_name, 5, 6) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + def test_execute_statement(self): + entry_count = 12 + self._populate_map(entry_count, str) + statement = SqlStatement("SELECT this FROM %s" % self.map_name) + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual( + self, + [str(i) for i in range(entry_count)], + [row.get_object_with_index(0) for row in result], + ) + + def test_execute_statement_with_params(self): + entry_count = 20 + self._populate_map(entry_count, lambda v: Student(v, v)) + statement = SqlStatement( + "SELECT age FROM %s WHERE height = CAST(? AS REAL)" % self.map_name + ) + statement.add_parameter(13.0) + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual(self, [13], [row.get_object("age") for row in result]) + + def test_execute_statement_with_mismatched_params_when_sql_has_more(self): + self._populate_map() + statement = SqlStatement("SELECT * FROM %s WHERE __key > ? AND this > ?" % self.map_name) + statement.parameters = [5] + result = self.client.sql.execute_statement(statement) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + def test_execute_statement_with_mismatched_params_when_params_has_more(self): + self._populate_map() + statement = SqlStatement("SELECT * FROM %s WHERE this > ?" % self.map_name) + statement.parameters = [5, 6] + result = self.client.sql.execute_statement(statement) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + def test_execute_statement_with_timeout(self): + entry_count = 100 + self._populate_map(entry_count, lambda v: Student(v, v)) + statement = SqlStatement("SELECT age FROM %s WHERE height < 10" % self.map_name) + statement.timeout = 100 + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object("age") for row in result] + ) + + def test_execute_statement_with_cursor_buffer_size(self): + entry_count = 50 + self._populate_map(entry_count, lambda v: Student(v, v)) + statement = SqlStatement("SELECT age FROM %s" % self.map_name) + statement.cursor_buffer_size = 3 + result = self.client.sql.execute_statement(statement) + + with patch.object(result, "_fetch_next_page", wraps=result._fetch_next_page) as patched: + six.assertCountEqual( + self, [i for i in range(entry_count)], [row.get_object("age") for row in result] + ) + # -1 comes from the fact that, we don't fetch the first page + self.assertEqual( + math.ceil(float(entry_count) / statement.cursor_buffer_size) - 1, patched.call_count + ) + + def test_execute_statement_with_copy(self): + self._populate_map() + statement = SqlStatement("SELECT __key FROM %s WHERE this >= ?" % self.map_name) + statement.parameters = [9] + copy_statement = statement.copy() + statement.clear_parameters() + + result = self.client.sql.execute_statement(copy_statement) + self.assertEqual([9], [row.get_object_with_index(0) for row in result]) + + result = self.client.sql.execute_statement(statement) + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + # Can't test the case we would expect an update count, because the IMDG SQL + # engine does not support such query as of now. + def test_execute_statement_with_expected_result_type_of_rows_when_rows_are_expected(self): + entry_count = 100 + self._populate_map(entry_count, lambda v: Student(v, v)) + statement = SqlStatement("SELECT age FROM %s WHERE age < 3" % self.map_name) + statement.expected_result_type = SqlExpectedResultType.ROWS + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual(self, [i for i in range(3)], [row.get_object("age") for row in result]) + + # Can't test the case we would expect an update count, because the IMDG SQL + # engine does not support such query as of now. + def test_execute_statement_with_expected_result_type_of_update_count_when_rows_are_expected( + self, + ): + self._populate_map() + statement = SqlStatement("SELECT * FROM %s" % self.map_name) + statement.expected_result_type = SqlExpectedResultType.UPDATE_COUNT + result = self.client.sql.execute_statement(statement) + + with self.assertRaises(HazelcastSqlError): + for _ in result: + pass + + # Can't test the schema, because the IMDG SQL engine does not support + # specifying a schema yet. + + +@unittest.skipIf( + is_client_version_older_than("4.2"), "Tests the features added in 4.2 version of the client" +) +class SqlResultTest(SqlTestBase): + def test_blocking_iterator(self): + self._populate_map() + result = self.client.sql.execute("SELECT __key FROM %s" % self.map_name) + + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object_with_index(0) for row in result] + ) + + def test_blocking_iterator_when_iterator_requested_more_than_once(self): + self._populate_map() + result = self.client.sql.execute("SELECT this FROM %s" % self.map_name) + + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object_with_index(0) for row in result] + ) + + with self.assertRaises(ValueError): + for _ in result: + pass + + def test_blocking_iterator_with_multi_paged_result(self): + self._populate_map() + statement = SqlStatement("SELECT __key FROM %s" % self.map_name) + statement.cursor_buffer_size = 1 # Each page will contain just 1 result + result = self.client.sql.execute_statement(statement) + + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object_with_index(0) for row in result] + ) + + def test_iterator(self): + self._populate_map() + result = self.client.sql.execute("SELECT __key FROM %s" % self.map_name) + + iterator_future = result.iterator() + + rows = [] + + def cb(f): + iterator = f.result() + + def iterate(row_future): + try: + row = row_future.result() + rows.append(row.get_object_with_index(0)) + next(iterator).add_done_callback(iterate) + except StopIteration: + pass + + next(iterator).add_done_callback(iterate) + + iterator_future.add_done_callback(cb) + + def assertion(): + six.assertCountEqual( + self, + [i for i in range(10)], + rows, + ) + + self.assertTrueEventually(assertion) + + def test_iterator_when_iterator_requested_more_than_once(self): + self._populate_map() + result = self.client.sql.execute("SELECT this FROM %s" % self.map_name) + + iterator = result.iterator().result() + + rows = [] + for row_future in iterator: + try: + row = row_future.result() + rows.append(row.get_object("this")) + except StopIteration: + break + + six.assertCountEqual(self, [i for i in range(10)], rows) + + with self.assertRaises(ValueError): + result.iterator().result() + + def test_iterator_with_multi_paged_result(self): + self._populate_map() + statement = SqlStatement("SELECT __key FROM %s" % self.map_name) + statement.cursor_buffer_size = 1 # Each page will contain just 1 result + result = self.client.sql.execute_statement(statement) + + iterator = result.iterator().result() + + rows = [] + for row_future in iterator: + try: + row = row_future.result() + rows.append(row.get_object_with_index(0)) + except StopIteration: + break + + six.assertCountEqual(self, [i for i in range(10)], rows) + + def test_request_blocking_iterator_after_iterator(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + + result.iterator().result() + + with self.assertRaises(ValueError): + for _ in result: + pass + + def test_request_iterator_after_blocking_iterator(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + + for _ in result: + pass + + with self.assertRaises(ValueError): + result.iterator().result() + + # Can't test the case we would expect row to be not set, because the IMDG SQL + # engine does not support update/insert queries now. + def test_is_row_set_when_row_is_set(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + self.assertTrue(result.is_row_set().result()) + + # Can't test the case we would expect a non-negative updated count, because the IMDG SQL + # engine does not support update/insert queries now. + def test_update_count_when_there_is_no_update(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s WHERE __key > 5" % self.map_name) + self.assertEqual(-1, result.update_count().result()) + + def test_get_row_metadata(self): + self._populate_map(value_factory=str) + result = self.client.sql.execute("SELECT __key, this FROM %s" % self.map_name) + row_metadata = result.get_row_metadata().result() + self.assertEqual(2, row_metadata.column_count) + columns = row_metadata.columns + self.assertEqual(SqlColumnType.INTEGER, columns[0].type) + self.assertEqual(SqlColumnType.VARCHAR, columns[1].type) + self.assertTrue(columns[0].nullable) + self.assertTrue(columns[1].nullable) + + def test_close_after_query_execution(self): + self._populate_map() + result = self.client.sql.execute("SELECT * FROM %s" % self.map_name) + for _ in result: + pass + self.assertIsNone(result.close().result()) + + def test_close_when_query_is_active(self): + self._populate_map() + statement = SqlStatement("SELECT * FROM %s " % self.map_name) + statement.cursor_buffer_size = 1 # Each page will contain 1 row + result = self.client.sql.execute_statement(statement) + + # Fetch couple of pages + iterator = iter(result) + next(iterator) + + self.assertIsNone(result.close().result()) + + with self.assertRaises(HazelcastSqlError): + # Next fetch requests should fail + next(iterator) + + def test_with_statement(self): + self._populate_map() + with self.client.sql.execute("SELECT this FROM %s" % self.map_name) as result: + six.assertCountEqual( + self, [i for i in range(10)], [row.get_object_with_index(0) for row in result] + ) + + def test_with_statement_when_iteration_throws(self): + self._populate_map() + statement = SqlStatement("SELECT this FROM %s" % self.map_name) + statement.cursor_buffer_size = 1 # so that it doesn't close immediately + + with self.assertRaises(RuntimeError): + with self.client.sql.execute_statement(statement) as result: + for _ in result: + raise RuntimeError("expected") + + self.assertIsInstance(result.close(), ImmediateFuture) + + +@unittest.skipIf( + is_client_version_older_than("4.2"), "Tests the features added in 4.2 version of the client" +) +class SqlColumnTypesReadTest(SqlTestBase): + def test_varchar(self): + def value_factory(key): + return "val-%s" % key + + self._populate_map(value_factory=value_factory) + self._validate_rows(SqlColumnType.VARCHAR, value_factory) + + def test_boolean(self): + def value_factory(key): + return key % 2 == 0 + + self._populate_map(value_factory=value_factory) + self._validate_rows(SqlColumnType.BOOLEAN, value_factory) + + def test_tiny_int(self): + self._populate_map_via_rc("new java.lang.Byte(key)") + self._validate_rows(SqlColumnType.TINYINT) + + def test_small_int(self): + self._populate_map_via_rc("new java.lang.Short(key)") + self._validate_rows(SqlColumnType.SMALLINT) + + def test_integer(self): + self._populate_map_via_rc("new java.lang.Integer(key)") + self._validate_rows(SqlColumnType.INTEGER) + + def test_big_int(self): + self._populate_map_via_rc("new java.lang.Long(key)") + self._validate_rows(SqlColumnType.BIGINT) + + def test_real(self): + self._populate_map_via_rc("new java.lang.Float(key * 1.0 / 8)") + self._validate_rows(SqlColumnType.REAL, lambda x: x * 1.0 / 8) + + def test_double(self): + self._populate_map_via_rc("new java.lang.Double(key * 1.0 / 1.1)") + self._validate_rows(SqlColumnType.DOUBLE, lambda x: x * 1.0 / 1.1) + + def test_date(self): + def value_factory(key): + return "%d-%02d-%02d" % (key + 2000, key + 1, key + 1) + + self._populate_map_via_rc("java.time.LocalDate.of(key + 2000, key + 1, key + 1)") + self._validate_rows(SqlColumnType.DATE, value_factory) + + def test_time(self): + def value_factory(key): + time = "%02d:%02d:%02d" % (key, key, key) + if key != 0: + time += ".%06d" % key + return time + + self._populate_map_via_rc("java.time.LocalTime.of(key, key, key, key * 1000)") + self._validate_rows(SqlColumnType.TIME, value_factory) + + def test_timestamp(self): + def value_factory(key): + timestamp = "%d-%02d-%02dT%02d:%02d:%02d" % ( + key + 2000, + key + 1, + key + 1, + key, + key, + key, + ) + if key != 0: + timestamp += ".%06d" % key + + return timestamp + + self._populate_map_via_rc( + "java.time.LocalDateTime.of(key + 2000, key + 1, key + 1, key, key, key, key * 1000)" + ) + self._validate_rows(SqlColumnType.TIMESTAMP, value_factory) + + def test_timestamp_with_time_zone(self): + def value_factory(key): + timestamp = "%d-%02d-%02dT%02d:%02d:%02d" % ( + key + 2000, + key + 1, + key + 1, + key, + key, + key, + ) + if key != 0: + timestamp += ".%06d" % key + + return timestamp + "+%02d:00" % key + + self._populate_map_via_rc( + "java.time.OffsetDateTime.of(key + 2000, key + 1, key + 1, key, key, key, key * 1000, " + "java.time.ZoneOffset.ofHours(key))" + ) + self._validate_rows(SqlColumnType.TIMESTAMP_WITH_TIME_ZONE, value_factory) + + def test_decimal(self): + def value_factory(key): + return str(decimal.Decimal((0, (key,), -1 * key))) + + self._populate_map_via_rc("java.math.BigDecimal.valueOf(key, key)") + self._validate_rows(SqlColumnType.DECIMAL, value_factory) + + def test_null(self): + self._populate_map() + result = self.client.sql.execute("SELECT __key, NULL AS this FROM %s" % self.map_name) + self._validate_result(result, SqlColumnType.NULL, lambda _: None) + + def test_object(self): + def value_factory(key): + return Student(key, key) + + self._populate_map(value_factory=value_factory) + result = self.client.sql.execute("SELECT __key, this FROM %s" % self.map_name) + self._validate_result(result, SqlColumnType.OBJECT, value_factory) + + def test_null_only_column(self): + self._populate_map() + result = self.client.sql.execute( + "SELECT __key, CAST(NULL AS INTEGER) as this FROM %s" % self.map_name + ) + self._validate_result(result, SqlColumnType.INTEGER, lambda _: None) + + def _validate_rows(self, expected_type, value_factory=lambda key: key): + result = self.client.sql.execute("SELECT __key, this FROM %s " % self.map_name) + self._validate_result(result, expected_type, value_factory) + + def _validate_result(self, result, expected_type, factory): + for row in result: + key = row.get_object("__key") + expected_value = factory(key) + row_metadata = row.metadata + + self.assertEqual(2, row_metadata.column_count) + column_metadata = row_metadata.get_column(1) + self.assertEqual(expected_type, column_metadata.type) + self.assertEqual(expected_value, row.get_object("this")) + + def _populate_map_via_rc(self, new_object_literal): + script = """ + var map = instance_0.getMap("%s"); + for (var key = 0; key < 10; key++) { + map.set(new java.lang.Integer(key), %s); + } + """ % ( + self.map_name, + new_object_literal, + ) + + response = self.rc.executeOnController(self.cluster.id, script, Lang.JAVASCRIPT) + self.assertTrue(response.success) + + +LITE_MEMBER_CONFIG = """ + + + +""" + + +@unittest.skipIf( + is_client_version_older_than("4.2"), "Tests the features added in 4.2 version of the client" +) +class SqlServiceLiteMemberClusterTest(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + return LITE_MEMBER_CONFIG + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + def setUp(self): + mark_server_version_at_least(self, self.client, "4.2") + + def test_execute(self): + with self.assertRaises(HazelcastSqlError) as cm: + self.client.sql.execute("SOME QUERY") + + # Make sure that exception is originating from the client + self.assertNotEqual(self.member.uuid, str(cm.exception.originating_member_uuid)) + + def test_execute_statement(self): + statement = SqlStatement("SOME QUERY") + with self.assertRaises(HazelcastSqlError) as cm: + self.client.sql.execute_statement(statement) + + # Make sure that exception is originating from the client + self.assertNotEqual(self.member.uuid, str(cm.exception.originating_member_uuid)) + + +class Student(Portable): + def __init__(self, age=None, height=None): + self.age = age + self.height = height + + def write_portable(self, writer): + writer.write_long("age", self.age) + writer.write_float("height", self.height) + + def read_portable(self, reader): + self.age = reader.read_long("age") + self.height = reader.read_float("height") + + def get_factory_id(self): + return 666 + + def get_class_id(self): + return 6 + + def __eq__(self, other): + return isinstance(other, Student) and self.age == other.age and self.height == other.height + + +def random_string(): + return "".join(random.choice(string.ascii_letters) for _ in range(random.randint(3, 20))) diff --git a/tests/unit/load_balancer_test.py b/tests/unit/load_balancer_test.py index 125ab16915..d2dc5ca809 100644 --- a/tests/unit/load_balancer_test.py +++ b/tests/unit/load_balancer_test.py @@ -3,9 +3,27 @@ from hazelcast.util import RandomLB, RoundRobinLB +class _MockMember(object): + def __init__(self, id, lite_member): + self.id = id + self.lite_member = lite_member + + class _MockClusterService(object): - def __init__(self, members): - self._members = members + def __init__(self, data_member_count, lite_member_count): + id = 0 + data_members = [] + lite_members = [] + + for _ in range(data_member_count): + data_members.append(_MockMember(id, False)) + id += 1 + + for _ in range(lite_member_count): + lite_members.append(_MockMember(id, True)) + id += 1 + + self._members = data_members + lite_members def add_listener(self, listener, *_): for m in self._members: @@ -17,27 +35,74 @@ def get_members(self): class LoadBalancersTest(unittest.TestCase): def test_random_lb_with_no_members(self): - cluster = _MockClusterService([]) + cluster = _MockClusterService(0, 0) lb = RandomLB() lb.init(cluster) self.assertIsNone(lb.next()) + self.assertIsNone(lb.next_data_member()) def test_round_robin_lb_with_no_members(self): - cluster = _MockClusterService([]) + cluster = _MockClusterService(0, 0) lb = RoundRobinLB() lb.init(cluster) self.assertIsNone(lb.next()) + self.assertIsNone(lb.next_data_member()) - def test_random_lb_with_members(self): - cluster = _MockClusterService([0, 1, 2]) + def test_random_lb_with_data_members(self): + cluster = _MockClusterService(3, 0) lb = RandomLB() lb.init(cluster) + self.assertTrue(lb.can_get_next_data_member()) for _ in range(10): - self.assertTrue(0 <= lb.next() <= 2) + self.assertTrue(0 <= lb.next().id <= 2) + self.assertTrue(0 <= lb.next_data_member().id <= 2) - def test_round_robin_lb_with_members(self): - cluster = _MockClusterService([0, 1, 2]) + def test_round_robin_lb_with_data_members(self): + cluster = _MockClusterService(5, 0) lb = RoundRobinLB() lb.init(cluster) + self.assertTrue(lb.can_get_next_data_member()) + + for i in range(10): + self.assertEqual(i % 5, lb.next().id) + for i in range(10): - self.assertEqual(i % 3, lb.next()) + self.assertEqual(i % 5, lb.next_data_member().id) + + def test_random_lb_with_lite_members(self): + cluster = _MockClusterService(0, 3) + lb = RandomLB() + lb.init(cluster) + + for _ in range(10): + self.assertTrue(0 <= lb.next().id <= 2) + self.assertIsNone(lb.next_data_member()) + + def test_round_robin_lb_with_lite_members(self): + cluster = _MockClusterService(0, 3) + lb = RoundRobinLB() + lb.init(cluster) + + for i in range(10): + self.assertEqual(i % 3, lb.next().id) + self.assertIsNone(lb.next_data_member()) + + def test_random_lb_with_mixed_members(self): + cluster = _MockClusterService(3, 3) + lb = RandomLB() + lb.init(cluster) + + for _ in range(20): + self.assertTrue(0 <= lb.next().id < 6) + self.assertTrue(0 <= lb.next_data_member().id < 3) + + def test_round_robin_lb_with_mixed_members(self): + cluster = _MockClusterService(3, 3) + lb = RoundRobinLB() + lb.init(cluster) + + for i in range(24): + self.assertEqual(i % 6, lb.next().id) + + for i in range(24): + self.assertEqual(i % 3, lb.next_data_member().id) diff --git a/tests/unit/sql_test.py b/tests/unit/sql_test.py new file mode 100644 index 0000000000..acd56018f1 --- /dev/null +++ b/tests/unit/sql_test.py @@ -0,0 +1,350 @@ +import itertools +import unittest +import uuid + +from mock import MagicMock + +from hazelcast.protocol.codec import sql_execute_codec, sql_close_codec, sql_fetch_codec +from hazelcast.protocol.client_message import _OUTBOUND_MESSAGE_MESSAGE_TYPE_OFFSET +from hazelcast.serialization import LE_INT +from hazelcast.sql import ( + SqlService, + SqlColumnMetadata, + SqlColumnType, + _SqlPage, + SqlRowMetadata, + _InternalSqlService, + HazelcastSqlError, + _SqlErrorCode, + _SqlError, + SqlStatement, + SqlExpectedResultType, +) +from hazelcast.util import try_to_get_enum_value + +EXPECTED_ROWS = ["result", "result2"] +EXPECTED_UPDATE_COUNT = 42 + + +class SqlMockTest(unittest.TestCase): + def setUp(self): + + self.connection = MagicMock() + + connection_manager = MagicMock(client_uuid=uuid.uuid4()) + connection_manager.get_random_connection = MagicMock(return_value=self.connection) + + serialization_service = MagicMock() + serialization_service.to_object.side_effect = lambda arg: arg + serialization_service.to_data.side_effect = lambda arg: arg + + self.invocation_registry = {} + correlation_id_counter = itertools.count() + invocation_service = MagicMock() + + def invoke(invocation): + self.invocation_registry[next(correlation_id_counter)] = invocation + + invocation_service.invoke.side_effect = invoke + + self.internal_service = _InternalSqlService( + connection_manager, serialization_service, invocation_service + ) + self.service = SqlService(self.internal_service) + self.result = self.service.execute("SOME QUERY") + + def test_iterator_with_rows(self): + self.set_execute_response_with_rows() + self.assertEqual(-1, self.result.update_count().result()) + self.assertTrue(self.result.is_row_set().result()) + self.assertIsInstance(self.result.get_row_metadata().result(), SqlRowMetadata) + self.assertEqual(EXPECTED_ROWS, self.get_rows_from_iterator()) + + def test_blocking_iterator_with_rows(self): + self.set_execute_response_with_rows() + self.assertEqual(-1, self.result.update_count().result()) + self.assertTrue(self.result.is_row_set().result()) + self.assertIsInstance(self.result.get_row_metadata().result(), SqlRowMetadata) + self.assertEqual(EXPECTED_ROWS, self.get_rows_from_blocking_iterator()) + + def test_iterator_with_update_count(self): + self.set_execute_response_with_update_count() + self.assertEqual(EXPECTED_UPDATE_COUNT, self.result.update_count().result()) + self.assertFalse(self.result.is_row_set().result()) + + with self.assertRaises(ValueError): + self.result.get_row_metadata().result() + + with self.assertRaises(ValueError): + self.result.iterator().result() + + def test_blocking_iterator_with_update_count(self): + self.set_execute_response_with_update_count() + self.assertEqual(EXPECTED_UPDATE_COUNT, self.result.update_count().result()) + self.assertFalse(self.result.is_row_set().result()) + + with self.assertRaises(ValueError): + self.result.get_row_metadata().result() + + with self.assertRaises(ValueError): + for _ in self.result: + pass + + def test_execute_error(self): + self.set_execute_error(RuntimeError("expected")) + with self.assertRaises(HazelcastSqlError) as cm: + iter(self.result) + + self.assertEqual(_SqlErrorCode.GENERIC, cm.exception._code) + + def test_execute_error_when_connection_is_not_live(self): + self.connection.live = False + self.set_execute_error(RuntimeError("expected")) + with self.assertRaises(HazelcastSqlError) as cm: + iter(self.result) + + self.assertEqual(_SqlErrorCode.CONNECTION_PROBLEM, cm.exception._code) + + def test_close_when_execute_is_not_done(self): + future = self.result.close() + self.set_close_response() + self.assertIsNone(future.result()) + with self.assertRaises(HazelcastSqlError) as cm: + iter(self.result) + + self.assertEqual(_SqlErrorCode.CANCELLED_BY_USER, cm.exception._code) + + def test_close_when_close_request_fails(self): + future = self.result.close() + self.set_close_error(HazelcastSqlError(None, _SqlErrorCode.MAP_DESTROYED, "expected", None)) + + with self.assertRaises(HazelcastSqlError) as cm: + future.result() + + self.assertEqual(_SqlErrorCode.MAP_DESTROYED, cm.exception._code) + + def test_fetch_error(self): + self.set_execute_response_with_rows(is_last=False) + result = [] + i = self.result.iterator().result() + # First page contains two rows + result.append(next(i).result().get_object_with_index(0)) + result.append(next(i).result().get_object_with_index(0)) + + self.assertEqual(EXPECTED_ROWS, result) + + # initiate the fetch request + future = next(i) + + self.set_fetch_error(RuntimeError("expected")) + + with self.assertRaises(HazelcastSqlError) as cm: + future.result() + + self.assertEqual(_SqlErrorCode.GENERIC, cm.exception._code) + + def test_fetch_server_error(self): + self.set_execute_response_with_rows(is_last=False) + result = [] + i = self.result.iterator().result() + # First page contains two rows + result.append(next(i).result().get_object_with_index(0)) + result.append(next(i).result().get_object_with_index(0)) + + self.assertEqual(EXPECTED_ROWS, result) + + # initiate the fetch request + future = next(i) + + self.set_fetch_response_with_error() + + with self.assertRaises(HazelcastSqlError) as cm: + future.result() + + self.assertEqual(_SqlErrorCode.PARSING, cm.exception._code) + + def test_close_in_between_fetches(self): + self.set_execute_response_with_rows(is_last=False) + result = [] + i = self.result.iterator().result() + # First page contains two rows + result.append(next(i).result().get_object_with_index(0)) + result.append(next(i).result().get_object_with_index(0)) + + self.assertEqual(EXPECTED_ROWS, result) + + # initiate the fetch request + future = next(i) + + self.result.close() + + with self.assertRaises(HazelcastSqlError) as cm: + future.result() + + self.assertEqual(_SqlErrorCode.CANCELLED_BY_USER, cm.exception._code) + + def set_fetch_response_with_error(self): + response = {"row_page": None, "error": _SqlError(_SqlErrorCode.PARSING, "expected", None)} + self.set_future_result_or_exception(response, sql_fetch_codec._REQUEST_MESSAGE_TYPE) + + def set_fetch_error(self, error): + self.set_future_result_or_exception(error, sql_fetch_codec._REQUEST_MESSAGE_TYPE) + + def set_close_error(self, error): + self.set_future_result_or_exception(error, sql_close_codec._REQUEST_MESSAGE_TYPE) + + def set_close_response(self): + self.set_future_result_or_exception(None, sql_close_codec._REQUEST_MESSAGE_TYPE) + + def set_execute_response_with_update_count(self): + self.set_execute_response(EXPECTED_UPDATE_COUNT, None, None, None) + + def get_rows_from_blocking_iterator(self): + return [row.get_object_with_index(0) for row in self.result] + + def get_rows_from_iterator(self): + result = [] + for row_future in self.result.iterator().result(): + try: + row = row_future.result() + result.append(row.get_object_with_index(0)) + except StopIteration: + break + return result + + def set_execute_response_with_rows(self, is_last=True): + self.set_execute_response( + -1, + [SqlColumnMetadata("name", SqlColumnType.VARCHAR, True, True)], + _SqlPage([SqlColumnType.VARCHAR], [EXPECTED_ROWS], is_last), + None, + ) + + def set_execute_response(self, update_count, row_metadata, row_page, error): + response = { + "update_count": update_count, + "row_metadata": row_metadata, + "row_page": row_page, + "error": error, + } + + self.set_future_result_or_exception(response, sql_execute_codec._REQUEST_MESSAGE_TYPE) + + def set_execute_error(self, error): + self.set_future_result_or_exception(error, sql_execute_codec._REQUEST_MESSAGE_TYPE) + + def get_message_type(self, invocation): + return LE_INT.unpack_from(invocation.request.buf, _OUTBOUND_MESSAGE_MESSAGE_TYPE_OFFSET)[0] + + def set_future_result_or_exception(self, value, message_type): + for invocation in self.invocation_registry.values(): + if self.get_message_type(invocation) == message_type: + if isinstance(value, Exception): + invocation.future.set_exception(value) + else: + invocation.future.set_result(value) + + +class SqlInvalidInputTest(unittest.TestCase): + def test_statement_sql(self): + valid_inputs = ["a", " a", " a "] + + for valid in valid_inputs: + statement = SqlStatement(valid) + self.assertEqual(valid, statement.sql) + + invalid_inputs = ["", " ", None, 1] + + for invalid in invalid_inputs: + with self.assertRaises((ValueError, AssertionError)): + SqlStatement(invalid) + + def test_statement_timeout(self): + valid_inputs = [-1, 0, 15, 1.5] + + for valid in valid_inputs: + statement = SqlStatement("sql") + statement.timeout = valid + self.assertEqual(valid, statement.timeout) + + invalid_inputs = [-10, -100, "hey", None] + + for invalid in invalid_inputs: + statement = SqlStatement("sql") + with self.assertRaises((ValueError, AssertionError)): + statement.timeout = invalid + + def test_statement_cursor_buffer_size(self): + valid_inputs = [1, 10, 999999] + + for valid in valid_inputs: + statement = SqlStatement("something") + statement.cursor_buffer_size = valid + self.assertEqual(valid, statement.cursor_buffer_size) + + invalid_inputs = [0, -10, -99999, "hey", None, 1.0] + + for invalid in invalid_inputs: + statement = SqlStatement("something") + with self.assertRaises((ValueError, AssertionError)): + statement.cursor_buffer_size = invalid + + def test_statement_expected_result_type(self): + valid_inputs = [ + SqlExpectedResultType.ROWS, + SqlExpectedResultType.UPDATE_COUNT, + "ROWS", + "ANY", + ] + + for valid in valid_inputs: + statement = SqlStatement("something") + statement.expected_result_type = valid + self.assertEqual( + try_to_get_enum_value(valid, SqlExpectedResultType), statement.expected_result_type + ) + + invalid_inputs = [None, 123, "hey"] + + for invalid in invalid_inputs: + with self.assertRaises(TypeError): + statement = SqlStatement("something") + statement.expected_result_type = invalid + + def test_row_metadata_get_column(self): + row_metadata = self._create_row_metadata() + + valid_inputs = [0, 1, 2] + + for valid in valid_inputs: + column_metadata = row_metadata.get_column(valid) + self.assertEqual(str(valid), column_metadata.name) + + invalid_inputs = [4, 5, "6", None] + for invalid in invalid_inputs: + with self.assertRaises((IndexError, AssertionError)): + row_metadata.get_column(invalid) + + def test_row_metadata_find_column(self): + row_metadata = self._create_row_metadata() + + valid_inputs = ["0", "1", "2", "-1"] + + for valid in valid_inputs: + index = row_metadata.find_column(valid) + self.assertEqual(int(valid), index) + + invalid_inputs = [6, None] + for invalid in invalid_inputs: + with self.assertRaises((IndexError, AssertionError)): + row_metadata.get_column(invalid) + + @staticmethod + def _create_row_metadata(): + return SqlRowMetadata( + [ + SqlColumnMetadata("0", SqlColumnType.VARCHAR, True, True), + SqlColumnMetadata("1", SqlColumnType.TINYINT, True, True), + SqlColumnMetadata("2", SqlColumnType.OBJECT, True, True), + ] + )