@@ -52,8 +52,8 @@ def default_executor():
52
52
executor .shutdown ()
53
53
54
54
55
- def stub_partition_session (id : int = 0 ):
56
- return datatypes .PartitionSession (
55
+ def stub_partition_session (id : int = 0 , ended : bool = False ):
56
+ partition_session = datatypes .PartitionSession (
57
57
id = id ,
58
58
state = datatypes .PartitionSession .State .Active ,
59
59
topic_path = "asd" ,
@@ -63,6 +63,11 @@ def stub_partition_session(id: int = 0):
63
63
reader_stream_id = 513 ,
64
64
)
65
65
66
+ if ended :
67
+ partition_session .end ()
68
+
69
+ return partition_session
70
+
66
71
67
72
def stub_message (id : int ):
68
73
return PublicMessage (
@@ -73,7 +78,7 @@ def stub_message(id: int):
73
78
offset = 0 ,
74
79
written_at = datetime .datetime (2023 , 3 , 18 , 14 , 15 ),
75
80
producer_id = "" ,
76
- data = bytes () ,
81
+ data = id ,
77
82
metadata_items = {},
78
83
_partition_session = stub_partition_session (),
79
84
_commit_start_offset = 0 ,
@@ -746,6 +751,31 @@ def session_count():
746
751
with pytest .raises (asyncio .QueueEmpty ):
747
752
stream .from_client .get_nowait ()
748
753
754
+ async def test_end_partition_session (self , stream , stream_reader , partition_session ):
755
+ def session_count ():
756
+ return len (stream_reader ._partition_sessions )
757
+
758
+ initial_session_count = session_count ()
759
+
760
+ stream .from_server .put_nowait (
761
+ StreamReadMessage .FromServer (
762
+ server_status = ServerStatus (ydb_status_codes_pb2 .StatusIds .SUCCESS , []),
763
+ server_message = StreamReadMessage .EndPartitionSession (
764
+ partition_session_id = partition_session .id ,
765
+ adjacent_partition_ids = [],
766
+ child_partition_ids = [20 , 30 ],
767
+ ),
768
+ )
769
+ )
770
+
771
+ await asyncio .sleep (0 ) # wait next loop
772
+ with pytest .raises (asyncio .QueueEmpty ):
773
+ stream .from_client .get_nowait ()
774
+
775
+ assert session_count () == initial_session_count
776
+ assert partition_session .id in stream_reader ._partition_sessions
777
+ assert partition_session .ended
778
+
749
779
@pytest .mark .parametrize (
750
780
"graceful" ,
751
781
(
@@ -1168,6 +1198,82 @@ async def test_read_message(
1168
1198
assert mess == expected_message
1169
1199
assert dict (stream_reader ._message_batches ) == batches_after
1170
1200
1201
+ @pytest .mark .parametrize (
1202
+ "batches,expected_order" ,
1203
+ [
1204
+ (
1205
+ {
1206
+ 0 : PublicBatch (
1207
+ messages = [stub_message (1 )],
1208
+ _partition_session = stub_partition_session (0 , ended = True ),
1209
+ _bytes_size = 0 ,
1210
+ _codec = Codec .CODEC_RAW ,
1211
+ )
1212
+ },
1213
+ [1 ],
1214
+ ),
1215
+ (
1216
+ {
1217
+ 0 : PublicBatch (
1218
+ messages = [stub_message (1 ), stub_message (2 )],
1219
+ _partition_session = stub_partition_session (0 , ended = True ),
1220
+ _bytes_size = 0 ,
1221
+ _codec = Codec .CODEC_RAW ,
1222
+ ),
1223
+ 1 : PublicBatch (
1224
+ messages = [stub_message (3 ), stub_message (4 )],
1225
+ _partition_session = stub_partition_session (1 ),
1226
+ _bytes_size = 0 ,
1227
+ _codec = Codec .CODEC_RAW ,
1228
+ ),
1229
+ },
1230
+ [1 , 2 , 3 , 4 ],
1231
+ ),
1232
+ (
1233
+ {
1234
+ 0 : PublicBatch (
1235
+ messages = [stub_message (1 ), stub_message (2 )],
1236
+ _partition_session = stub_partition_session (0 ),
1237
+ _bytes_size = 0 ,
1238
+ _codec = Codec .CODEC_RAW ,
1239
+ ),
1240
+ 1 : PublicBatch (
1241
+ messages = [stub_message (3 ), stub_message (4 )],
1242
+ _partition_session = stub_partition_session (1 , ended = True ),
1243
+ _bytes_size = 0 ,
1244
+ _codec = Codec .CODEC_RAW ,
1245
+ ),
1246
+ 2 : PublicBatch (
1247
+ messages = [stub_message (5 )],
1248
+ _partition_session = stub_partition_session (2 ),
1249
+ _bytes_size = 0 ,
1250
+ _codec = Codec .CODEC_RAW ,
1251
+ ),
1252
+ },
1253
+ [1 , 3 , 4 , 5 , 2 ],
1254
+ ),
1255
+ ],
1256
+ )
1257
+ async def test_read_message_autosplit_order (
1258
+ self ,
1259
+ stream_reader ,
1260
+ batches : typing .Dict [int , datatypes .PublicBatch ],
1261
+ expected_order : typing .List [int ],
1262
+ ):
1263
+ stream_reader ._message_batches = OrderedDict (batches )
1264
+
1265
+ for id , batch in batches .items ():
1266
+ ps = batch ._partition_session
1267
+ stream_reader ._partition_sessions [id ] = ps
1268
+
1269
+ result = []
1270
+ for _ in range (len (expected_order )):
1271
+ mess = stream_reader .receive_message_nowait ()
1272
+ result .append (mess .data )
1273
+
1274
+ assert result == expected_order
1275
+ assert stream_reader .receive_message_nowait () is None
1276
+
1171
1277
@pytest .mark .parametrize (
1172
1278
"batches_before,max_messages,actual_messages,batches_after" ,
1173
1279
[
0 commit comments