Skip to content

Commit

Permalink
[improve][broker] perf: Reduce stickyHash calculations of non-persist…
Browse files Browse the repository at this point in the history
…ent topics in SHARED subscriptions (#22536)
  • Loading branch information
dlg99 authored Apr 26, 2024
1 parent 8323a3c commit bf5d6aa
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,29 @@ public Future<Void> sendMessages(final List<? extends Entry> entries, EntryBatch
totalChunkedMessages, redeliveryTracker, DEFAULT_CONSUMER_EPOCH);
}

public Future<Void> sendMessages(final List<? extends Entry> entries, EntryBatchSizes batchSizes,
EntryBatchIndexesAcks batchIndexesAcks,
int totalMessages, long totalBytes, long totalChunkedMessages,
RedeliveryTracker redeliveryTracker, long epoch) {
return sendMessages(entries, null, batchSizes, batchIndexesAcks, totalMessages, totalBytes,
totalChunkedMessages, redeliveryTracker, epoch);
}

/**
* Dispatch a list of entries to the consumer. <br/>
* <b>It is also responsible to release entries data and recycle entries object.</b>
*
* @return a SendMessageInfo object that contains the detail of what was sent to consumer
*/
public Future<Void> sendMessages(final List<? extends Entry> entries, EntryBatchSizes batchSizes,
public Future<Void> sendMessages(final List<? extends Entry> entries,
final List<Integer> stickyKeyHashes,
EntryBatchSizes batchSizes,
EntryBatchIndexesAcks batchIndexesAcks,
int totalMessages, long totalBytes, long totalChunkedMessages,
RedeliveryTracker redeliveryTracker, long epoch) {
int totalMessages,
long totalBytes,
long totalChunkedMessages,
RedeliveryTracker redeliveryTracker,
long epoch) {
this.lastConsumedTimestamp = System.currentTimeMillis();

if (entries.isEmpty() || totalMessages == 0) {
Expand Down Expand Up @@ -323,7 +336,7 @@ public Future<Void> sendMessages(final List<? extends Entry> entries, EntryBatch
// because this consumer is possible to disconnect at this time.
if (pendingAcks != null) {
int batchSize = batchSizes.getBatchSize(i);
int stickyKeyHash = getStickyKeyHash(entry);
int stickyKeyHash = stickyKeyHashes == null ? getStickyKeyHash(entry) : stickyKeyHashes.get(i);
long[] ackSet = batchIndexesAcks == null ? null : batchIndexesAcks.getAckSet(i);
if (ackSet != null) {
unackedMessages -= (batchSize - BitSet.valueOf(ackSet).cardinality());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ protected Map<Consumer, List<Entry>> initialValue() throws Exception {
}
};

private static final FastThreadLocal<Map<Consumer, List<Integer>>> localGroupedStickyKeyHashes =
new FastThreadLocal<Map<Consumer, List<Integer>>>() {
@Override
protected Map<Consumer, List<Integer>> initialValue() throws Exception {
return new HashMap<>();
}
};

@Override
public void sendMessages(List<Entry> entries) {
if (entries.isEmpty()) {
Expand All @@ -139,28 +147,38 @@ public void sendMessages(List<Entry> entries) {

final Map<Consumer, List<Entry>> groupedEntries = localGroupedEntries.get();
groupedEntries.clear();
final Map<Consumer, List<Integer>> consumerStickyKeyHashesMap = localGroupedStickyKeyHashes.get();
consumerStickyKeyHashesMap.clear();

for (Entry entry : entries) {
Consumer consumer = selector.select(peekStickyKey(entry.getDataBuffer()));
byte[] stickyKey = peekStickyKey(entry.getDataBuffer());
int stickyKeyHash = StickyKeyConsumerSelector.makeStickyKeyHash(stickyKey);

Consumer consumer = selector.select(stickyKeyHash);
if (consumer != null) {
groupedEntries.computeIfAbsent(consumer, k -> new ArrayList<>()).add(entry);
int startingSize = Math.max(10, entries.size() / (2 * consumerSet.size()));
groupedEntries.computeIfAbsent(consumer, k -> new ArrayList<>(startingSize)).add(entry);
consumerStickyKeyHashesMap
.computeIfAbsent(consumer, k -> new ArrayList<>(startingSize)).add(stickyKeyHash);
} else {
entry.release();
}
}

for (Map.Entry<Consumer, List<Entry>> entriesByConsumer : groupedEntries.entrySet()) {
Consumer consumer = entriesByConsumer.getKey();
List<Entry> entriesForConsumer = entriesByConsumer.getValue();
final Consumer consumer = entriesByConsumer.getKey();
final List<Entry> entriesForConsumer = entriesByConsumer.getValue();
final List<Integer> stickyKeysForConsumer = consumerStickyKeyHashesMap.get(consumer);

SendMessageInfo sendMessageInfo = SendMessageInfo.getThreadLocal();
EntryBatchSizes batchSizes = EntryBatchSizes.get(entriesForConsumer.size());
filterEntriesForConsumer(entriesForConsumer, batchSizes, sendMessageInfo, null, null, false, consumer);

if (consumer.getAvailablePermits() > 0 && consumer.isWritable()) {
consumer.sendMessages(entriesForConsumer, batchSizes, null, sendMessageInfo.getTotalMessages(),
consumer.sendMessages(entriesForConsumer, stickyKeysForConsumer, batchSizes,
null, sendMessageInfo.getTotalMessages(),
sendMessageInfo.getTotalBytes(), sendMessageInfo.getTotalChunkedMessages(),
getRedeliveryTracker());
getRedeliveryTracker(), Commands.DEFAULT_CONSUMER_EPOCH);
TOTAL_AVAILABLE_PERMITS_UPDATER.addAndGet(this, -sendMessageInfo.getTotalMessages());
} else {
entriesForConsumer.forEach(e -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ public void testSendMessage() throws BrokerServiceException {
assertEquals(byteBuf.toString(UTF_8), "message" + index);
};
return mockPromise;
}).when(consumerMock).sendMessages(any(List.class), any(EntryBatchSizes.class), any(),
anyInt(), anyLong(), anyLong(), any(RedeliveryTracker.class));
}).when(consumerMock).sendMessages(any(List.class), any(List.class), any(EntryBatchSizes.class), any(),
anyInt(), anyLong(), anyLong(), any(RedeliveryTracker.class), anyLong());
try {
nonpersistentDispatcher.sendMessages(entries);
} catch (Exception e) {
fail("Failed to sendMessages.", e);
}
verify(consumerMock, times(1)).sendMessages(any(List.class), any(EntryBatchSizes.class),
eq(null), anyInt(), anyLong(), anyLong(), any(RedeliveryTracker.class));
verify(consumerMock, times(1)).sendMessages(any(List.class), any(List.class), any(EntryBatchSizes.class),
eq(null), anyInt(), anyLong(), anyLong(), any(RedeliveryTracker.class), anyLong());
}

@Test(timeOut = 10000)
Expand Down

0 comments on commit bf5d6aa

Please sign in to comment.