diff --git a/examples/topic/topic_example.py b/examples/topic/topic_example.py index ec5d0f4443..4782e5e564 100644 --- a/examples/topic/topic_example.py +++ b/examples/topic/topic_example.py @@ -16,4 +16,7 @@ def on_message(event): topic.publish("Message " + str(i)) time.sleep(0.1) +topic.publish_all(["m1", "m2", "m3", "m4", "m5"]) +time.sleep(1) + client.shutdown() diff --git a/hazelcast/protocol/codec/topic_publish_all_codec.py b/hazelcast/protocol/codec/topic_publish_all_codec.py new file mode 100644 index 0000000000..6ed1f3b7a5 --- /dev/null +++ b/hazelcast/protocol/codec/topic_publish_all_codec.py @@ -0,0 +1,18 @@ +from hazelcast.protocol.client_message import OutboundMessage, REQUEST_HEADER_SIZE, create_initial_buffer +from hazelcast.protocol.builtin import StringCodec +from hazelcast.protocol.builtin import ListMultiFrameCodec +from hazelcast.protocol.builtin import DataCodec + +# hex: 0x040400 +_REQUEST_MESSAGE_TYPE = 263168 +# hex: 0x040401 +_RESPONSE_MESSAGE_TYPE = 263169 + +_REQUEST_INITIAL_FRAME_SIZE = REQUEST_HEADER_SIZE + + +def encode_request(name, messages): + buf = create_initial_buffer(_REQUEST_INITIAL_FRAME_SIZE, _REQUEST_MESSAGE_TYPE) + StringCodec.encode(buf, name) + ListMultiFrameCodec.encode(buf, messages, DataCodec.encode, True) + return OutboundMessage(buf, False) diff --git a/hazelcast/proxy/topic.py b/hazelcast/proxy/topic.py index 785fc9c6ed..a17df87435 100644 --- a/hazelcast/proxy/topic.py +++ b/hazelcast/proxy/topic.py @@ -4,11 +4,13 @@ from hazelcast.protocol.codec import ( topic_add_message_listener_codec, topic_publish_codec, + topic_publish_all_codec, topic_remove_message_listener_codec, ) from hazelcast.proxy.base import PartitionSpecificProxy, TopicMessage -from hazelcast.types import MessageType from hazelcast.serialization.compact import SchemaNotReplicatedError +from hazelcast.types import MessageType +from hazelcast.util import check_not_none class Topic(PartitionSpecificProxy["BlockingTopic"], typing.Generic[MessageType]): @@ -57,7 +59,7 @@ def handle(item_data, publish_time, uuid): ) def publish(self, message: MessageType) -> Future[None]: - """Publishes the message to all subscribers of this topic + """Publishes the message to all subscribers of this topic. Args: message: The message to be published. @@ -70,6 +72,25 @@ def publish(self, message: MessageType) -> Future[None]: request = topic_publish_codec.encode_request(self.name, message_data) return self._invoke(request) + def publish_all(self, messages: typing.Sequence[MessageType]) -> Future[None]: + """Publishes the messages to all subscribers of this topic. + + Args: + messages: The messages to be published. + """ + check_not_none(messages, "Messages cannot be None") + try: + topic_messages = [] + for m in messages: + check_not_none(m, "Message cannot be None") + data = self._to_data(m) + topic_messages.append(data) + except SchemaNotReplicatedError as e: + return self._send_schema_and_retry(e, self.publish_all, messages) + + request = topic_publish_all_codec.encode_request(self.name, topic_messages) + return self._invoke(request) + def remove_listener(self, registration_id: str) -> Future[bool]: """Stops receiving messages for the given message listener. @@ -107,6 +128,12 @@ def publish( # type: ignore[override] ) -> None: return self._wrapped.publish(message).result() + def publish_all( # type: ignore[override] + self, + messages: typing.Sequence[MessageType], + ) -> None: + return self._wrapped.publish_all(messages).result() + def remove_listener( # type: ignore[override] self, registration_id: str, diff --git a/tests/integration/backward_compatible/proxy/topic_test.py b/tests/integration/backward_compatible/proxy/topic_test.py index eae452432e..eda0a82525 100644 --- a/tests/integration/backward_compatible/proxy/topic_test.py +++ b/tests/integration/backward_compatible/proxy/topic_test.py @@ -1,5 +1,5 @@ from tests.base import SingleMemberTestCase -from tests.util import random_string, event_collector +from tests.util import random_string, event_collector, skip_if_client_version_older_than class TopicTest(SingleMemberTestCase): @@ -44,3 +44,27 @@ def assert_event(): def test_str(self): self.assertTrue(str(self.topic).startswith("Topic")) + + def test_publish_all(self): + skip_if_client_version_older_than(self, "5.2") + collector = event_collector() + self.topic.add_listener(on_message=collector) + + messages = ["message1", "message2", "message3"] + self.topic.publish_all(messages) + + def assert_event(): + self.assertEqual(len(collector.events), 3) + + self.assertTrueEventually(assert_event, 5) + + def test_publish_all_none_messages(self): + skip_if_client_version_older_than(self, "5.2") + with self.assertRaises(AssertionError): + self.topic.publish_all(None) + + def test_publish_all_none_message(self): + skip_if_client_version_older_than(self, "5.2") + messages = ["message1", None, "message3"] + with self.assertRaises(AssertionError): + self.topic.publish_all(messages) diff --git a/tests/integration/backward_compatible/serialization/compact_compatibility/compact_compatibility_test.py b/tests/integration/backward_compatible/serialization/compact_compatibility/compact_compatibility_test.py index 178c336037..eb63e53020 100644 --- a/tests/integration/backward_compatible/serialization/compact_compatibility/compact_compatibility_test.py +++ b/tests/integration/backward_compatible/serialization/compact_compatibility/compact_compatibility_test.py @@ -6,7 +6,12 @@ from hazelcast.errors import NullPointerError, IllegalMonitorStateError from hazelcast.predicate import Predicate, paging from tests.base import HazelcastTestCase -from tests.util import random_string, compare_client_version, compare_server_version_with_rc +from tests.util import ( + random_string, + compare_client_version, + compare_server_version_with_rc, + skip_if_client_version_older_than, +) try: from hazelcast.serialization.api import ( @@ -1402,6 +1407,24 @@ def assertion(): def test_publish(self): self.topic.publish(OUTER_COMPACT_INSTANCE) + def test_publish_all(self): + skip_if_client_version_older_than(self, "5.2") + messages = [] + + def listener(message): + messages.append(message) + + self.topic.add_listener(listener) + + self.topic.publish_all([INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE]) + + def assertion(): + self.assertEqual(2, len(messages)) + self.assertEqual(INNER_COMPACT_INSTANCE, messages[0].message) + self.assertEqual(OUTER_COMPACT_INSTANCE, messages[1].message) + + self.assertTrueEventually(assertion) + def _publish_from_another_client(self, item): other_client = self.create_client(self.client_config) other_client_topic = other_client.get_topic(self.topic.name).blocking()