From acbd95d8c62db0f2fd54130e98561332cef0f747 Mon Sep 17 00:00:00 2001
From: Viacheslav Moskalenko <jason.rammoray@gmail.com>
Date: Sun, 9 Jun 2024 23:58:21 +0300
Subject: [PATCH] Support max_{messages,bytes} parameters, when reading a batch
 through a sync topic reader

---
 CHANGELOG.md                           |  3 ++
 ydb/_topic_reader/topic_reader_sync.py | 55 +++++++++++++++++++++++++-
 2 files changed, 57 insertions(+), 1 deletion(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 48949d4b..5f445f78 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,6 @@
+## 3.11.5 ##
+* Added support for max_messages and max_bytes parameters, when reading a batch through a sync topic reader
+
 ## 3.11.4 ##
 * Added missing returns to time converters for topic options
 
diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py
index c266de82..6a0e9f9f 100644
--- a/ydb/_topic_reader/topic_reader_sync.py
+++ b/ydb/_topic_reader/topic_reader_sync.py
@@ -86,6 +86,57 @@ def async_wait_message(self) -> concurrent.futures.Future:
 
         return self._caller.unsafe_call_with_future(self._async_reader.wait_message())
 
+    def _make_batch_slice(
+        self,
+        batch: Union[PublicBatch, None],
+        max_messages: typing.Union[int, None] = None,
+        max_bytes: typing.Union[int, None] = None,
+    ) -> Union[PublicBatch, None]:
+        all_amount = float("inf")
+
+        # A non-empty batch must stay non-empty regardless of the max messages value
+        if max_messages is not None:
+            max_messages = max(max_messages, 1)
+        else:
+            max_messages = all_amount
+
+        if max_bytes is not None:
+            max_bytes = max(max_bytes, 1)
+        else:
+            max_bytes = all_amount
+
+        is_batch_set = batch is not None
+        is_msg_limit_set = max_messages < all_amount
+        is_bytes_limit_set = max_bytes < all_amount
+        is_limit_set = is_msg_limit_set or is_bytes_limit_set
+        is_slice_required = is_batch_set and is_limit_set
+
+        if not is_slice_required:
+            return batch
+
+        sliced_messages = []
+        bytes_taken = 0
+
+        for batch_message in batch.messages:
+            sliced_messages.append(batch_message)
+            bytes_taken += len(batch_message.data)
+
+            is_enough_messages = len(sliced_messages) >= max_messages
+            is_enough_bytes = bytes_taken >= max_bytes
+            is_stop_required = is_enough_messages or is_enough_bytes
+
+            if is_stop_required:
+                break
+
+        sliced_batch = PublicBatch(
+            messages=sliced_messages,
+            _partition_session=batch._partition_session,
+            _bytes_size=bytes_taken,
+            _codec=batch._codec,
+        )
+
+        return sliced_batch
+
     def receive_batch(
         self,
         *,
@@ -102,11 +153,13 @@ def receive_batch(
         """
         self._check_closed()
 
-        return self._caller.safe_call_with_result(
+        maybe_batch: Union[PublicBatch, None] = self._caller.safe_call_with_result(
             self._async_reader.receive_batch(),
             timeout,
         )
 
+        return self._make_batch_slice(maybe_batch, max_messages, max_bytes)
+
     def commit(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]):
         """
         Put commit message to internal buffer.