From 999d0b3960e209dc2fc6422400cf1ec64751c6d9 Mon Sep 17 00:00:00 2001 From: Nicholas Sun Date: Wed, 9 Feb 2022 10:37:42 -0800 Subject: [PATCH] Fix offset when a batch ends with compacted records Saves the lastOffset and jumps past it when compacted records are detected at the end of a batch. - Adds a test for batches that end with compacted records - Adds a test for batches truncated due to MaxBytes Co-authored-by: iddqdeika --- batch.go | 32 +++++- message_reader.go | 24 ++++- message_test.go | 6 +- reader_test.go | 256 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 311 insertions(+), 7 deletions(-) diff --git a/batch.go b/batch.go index 3af692ec0..012cf1609 100644 --- a/batch.go +++ b/batch.go @@ -28,6 +28,15 @@ type Batch struct { offset int64 highWaterMark int64 err error + // The last offset in the batch. + // + // We use lastOffset to skip offsets that have been compacted away. + // + // We store lastOffset because we get lastOffset when we read a new message + // but only try to handle compaction when we receive an EOF. However, when + // we get an EOF we do not get the lastOffset. So there is a mismatch + // between when we receive it and need to use it. + lastOffset int64 } // Throttle gives the throttling duration applied by the kafka server on the @@ -190,6 +199,8 @@ func (batch *Batch) ReadMessage() (Message, error) { return }, ) + // A batch may start before the requested offset so skip messages + // until the requested offset is reached. for batch.conn != nil && offset < batch.conn.offset { if err != nil { break @@ -225,10 +236,12 @@ func (batch *Batch) readMessage( return } - offset, timestamp, headers, err = batch.msgs.readMessage(batch.offset, key, val) + var lastOffset int64 + offset, lastOffset, timestamp, headers, err = batch.msgs.readMessage(batch.offset, key, val) switch err { case nil: batch.offset = offset + 1 + batch.lastOffset = lastOffset case errShortRead: // As an "optimization" kafka truncates the returned response after // producing MaxBytes, which could then cause the code to return @@ -252,6 +265,23 @@ func (batch *Batch) readMessage( // read deadline management. err = checkTimeoutErr(batch.deadline) batch.err = err + + // Checks the following: + // - `batch.err` for a "success" from the previous timeout check + // - `batch.msgs.lengthRemain` to ensure that this EOF is not due + // to MaxBytes truncation + // - `batch.lastOffset` to ensure that the message format contains + // `lastOffset` + if batch.err == io.EOF && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 { + // Log compaction can create batches that end with compacted + // records so the normal strategy that increments the "next" + // offset as records are read doesn't work as the compacted + // records are "missing" and never get "read". + // + // In order to reliably reach the next non-compacted offset we + // jump past the saved lastOffset. + batch.offset = batch.lastOffset + 1 + } } default: // Since io.EOF is used by the batch to indicate that there is are diff --git a/message_reader.go b/message_reader.go index 2401b43fc..0bb576647 100644 --- a/message_reader.go +++ b/message_reader.go @@ -18,6 +18,10 @@ type messageSetReader struct { *readerStack // used for decompressing compressed messages and record batches empty bool // if true, short circuits messageSetReader methods debug bool // enable debug log messages + // How many bytes are expected to remain in the response. + // + // This is used to detect truncation of the response. + lengthRemain int } type readerStack struct { @@ -114,7 +118,7 @@ func (r *messageSetReader) discard() (err error) { } func (r *messageSetReader) readMessage(min int64, key readBytesFunc, val readBytesFunc) ( - offset int64, timestamp int64, headers []Header, err error) { + offset int64, lastOffset int64, timestamp int64, headers []Header, err error) { if r.empty { err = RequestTimedOut @@ -126,8 +130,10 @@ func (r *messageSetReader) readMessage(min int64, key readBytesFunc, val readByt switch r.header.magic { case 0, 1: offset, timestamp, headers, err = r.readMessageV1(min, key, val) + // Set an invalid value so that it can be ignored + lastOffset = -1 case 2: - offset, timestamp, headers, err = r.readMessageV2(min, key, val) + offset, lastOffset, timestamp, headers, err = r.readMessageV2(min, key, val) default: err = r.header.badMagic() } @@ -239,7 +245,7 @@ func (r *messageSetReader) readMessageV1(min int64, key readBytesFunc, val readB } func (r *messageSetReader) readMessageV2(_ int64, key readBytesFunc, val readBytesFunc) ( - offset int64, timestamp int64, headers []Header, err error) { + offset int64, lastOffset int64, timestamp int64, headers []Header, err error) { if err = r.readHeader(); err != nil { return } @@ -282,10 +288,12 @@ func (r *messageSetReader) readMessageV2(_ int64, key readBytesFunc, val readByt r.readerStack.parent.count = 0 } } + remainBefore := r.remain var length int64 if err = r.readVarInt(&length); err != nil { return } + lengthOfLength := remainBefore - r.remain var attrs int8 if err = r.readInt8(&attrs); err != nil { return @@ -316,6 +324,8 @@ func (r *messageSetReader) readMessageV2(_ int64, key readBytesFunc, val readByt return } } + lastOffset = r.header.firstOffset + int64(r.header.v2.lastOffsetDelta) + r.lengthRemain -= int(length) + lengthOfLength r.markRead() return } @@ -407,6 +417,9 @@ func (r *messageSetReader) readHeader() (err error) { return } r.count = 1 + // Set arbitrary non-zero length so that we always assume the + // message is truncated since bytes remain. + r.lengthRemain = 1 r.log("Read v0 header with offset=%d len=%d magic=%d attributes=%d", r.header.firstOffset, r.header.length, r.header.magic, r.header.v1.attributes) case 1: r.header.crc = crcOrLeaderEpoch @@ -417,6 +430,9 @@ func (r *messageSetReader) readHeader() (err error) { return } r.count = 1 + // Set arbitrary non-zero length so that we always assume the + // message is truncated since bytes remain. + r.lengthRemain = 1 r.log("Read v1 header with remain=%d offset=%d magic=%d and attributes=%d", r.remain, r.header.firstOffset, r.header.magic, r.header.v1.attributes) case 2: r.header.v2.leaderEpoch = crcOrLeaderEpoch @@ -448,6 +464,8 @@ func (r *messageSetReader) readHeader() (err error) { return } r.count = int(r.header.v2.count) + // Subtracts the header bytes from the length + r.lengthRemain = int(r.header.length) - 49 r.log("Read v2 header with count=%d offset=%d len=%d magic=%d attributes=%d", r.count, r.header.firstOffset, r.header.length, r.header.magic, r.header.v2.attributes) default: err = r.header.badMagic() diff --git a/message_test.go b/message_test.go index 14250c9cc..d214ca544 100644 --- a/message_test.go +++ b/message_test.go @@ -541,7 +541,7 @@ func TestMessageSetReaderEmpty(t *testing.T) { return 0, nil } - offset, timestamp, headers, err := m.readMessage(0, noop, noop) + offset, _, timestamp, headers, err := m.readMessage(0, noop, noop) if offset != 0 { t.Errorf("expected offset of 0, get %d", offset) } @@ -737,12 +737,12 @@ func (r *readerHelper) readMessageErr() (msg Message, err error) { } var timestamp int64 var headers []Header - r.offset, timestamp, headers, err = r.messageSetReader.readMessage(r.offset, keyFunc, valueFunc) + r.offset, _, timestamp, headers, err = r.messageSetReader.readMessage(r.offset, keyFunc, valueFunc) if err != nil { return } msg.Offset = r.offset - msg.Time = time.Unix(timestamp / 1000, (timestamp % 1000) * 1000000) + msg.Time = time.Unix(timestamp/1000, (timestamp%1000)*1000000) msg.Headers = headers return } diff --git a/reader_test.go b/reader_test.go index 562ead541..3862b2bfa 100644 --- a/reader_test.go +++ b/reader_test.go @@ -1,6 +1,7 @@ package kafka import ( + "bytes" "context" "fmt" "io" @@ -12,6 +13,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestReader(t *testing.T) { @@ -1716,3 +1719,256 @@ func TestErrorCannotConnectGroupSubscription(t *testing.T) { "must fail when it cannot connect") } } + +// Tests that the reader can handle messages where the response is truncated +// due to reaching MaxBytes. +// +// If MaxBytes is too small to fit 1 record then it will never truncate, so +// we start from a small message size and increase it until we are sure +// truncation has happened at some point. +func TestReaderTruncatedResponse(t *testing.T) { + topic := makeTopic() + createTopic(t, topic, 1) + defer deleteTopic(t, topic) + + readerMaxBytes := 100 + batchSize := 4 + maxMsgPadding := 5 + readContextTimeout := 10 * time.Second + + var msgs []Message + // The key of each message + n := 0 + // `i` is the amount of padding per message + for i := 0; i < maxMsgPadding; i++ { + bb := bytes.Buffer{} + for x := 0; x < i; x++ { + _, err := bb.WriteRune('0') + require.NoError(t, err) + } + padding := bb.Bytes() + // `j` is the number of times the message repeats + for j := 0; j < batchSize*4; j++ { + msgs = append(msgs, Message{ + Key: []byte(fmt.Sprintf("%05d", n)), + Value: padding, + }) + n++ + } + } + + wr := NewWriter(WriterConfig{ + Brokers: []string{"localhost:9092"}, + BatchSize: batchSize, + Async: false, + Topic: topic, + Balancer: &LeastBytes{}, + }) + err := wr.WriteMessages(context.Background(), msgs...) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), readContextTimeout) + defer cancel() + r := NewReader(ReaderConfig{ + Brokers: []string{"localhost:9092"}, + Topic: topic, + MinBytes: 1, + MaxBytes: readerMaxBytes, + // Speed up testing + MaxWait: 100 * time.Millisecond, + }) + defer r.Close() + + expectedKeys := map[string]struct{}{} + for _, k := range msgs { + expectedKeys[string(k.Key)] = struct{}{} + } + keys := map[string]struct{}{} + for { + m, err := r.FetchMessage(ctx) + require.NoError(t, err) + keys[string(m.Key)] = struct{}{} + + t.Logf("got key %s have %d keys expect %d\n", string(m.Key), len(keys), len(expectedKeys)) + if len(keys) == len(expectedKeys) { + require.Equal(t, expectedKeys, keys) + return + } + } +} + +// Tests that the reader can read record batches from log compacted topics +// where the batch ends with compacted records. +// +// This test forces varying sized chunks of duplicated messages along with +// configuring the topic with a minimal `segment.bytes` in order to +// guarantee that at least 1 batch can be compacted down to 0 "unread" messages +// with at least 1 "old" message otherwise the batch is skipped entirely. +func TestReaderReadCompactedMessage(t *testing.T) { + topic := makeTopic() + createTopicWithCompaction(t, topic, 1) + defer deleteTopic(t, topic) + + msgs := makeTestDuplicateSequence() + + writeMessagesForCompactionCheck(t, topic, msgs) + + expectedKeys := map[string]int{} + for _, msg := range msgs { + expectedKeys[string(msg.Key)] = 1 + } + + // kafka 2.0.1 is extra slow + ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) + defer cancel() + for { + success := func() bool { + r := NewReader(ReaderConfig{ + Brokers: []string{"localhost:9092"}, + Topic: topic, + MinBytes: 200, + MaxBytes: 200, + // Speed up testing + MaxWait: 100 * time.Millisecond, + }) + defer r.Close() + + keys := map[string]int{} + for { + m, err := r.FetchMessage(ctx) + if err != nil { + t.Logf("can't get message from compacted log: %v", err) + return false + } + keys[string(m.Key)]++ + + if len(keys) == countKeys(msgs) { + t.Logf("got keys: %+v", keys) + return reflect.DeepEqual(keys, expectedKeys) + } + } + }() + if success { + return + } + select { + case <-ctx.Done(): + t.Fatal(ctx.Err()) + default: + } + } +} + +// writeMessagesForCompactionCheck writes messages with specific writer configuration +func writeMessagesForCompactionCheck(t *testing.T, topic string, msgs []Message) { + t.Helper() + + wr := NewWriter(WriterConfig{ + Brokers: []string{"localhost:9092"}, + // Batch size must be large enough to have multiple compacted records + // for testing more edge cases. + BatchSize: 3, + Async: false, + Topic: topic, + Balancer: &LeastBytes{}, + }) + err := wr.WriteMessages(context.Background(), msgs...) + require.NoError(t, err) +} + +// makeTestDuplicateSequence creates messages for compacted log testing +// +// All keys and values are 4 characters long to tightly control how many +// messages are per log segment. +func makeTestDuplicateSequence() []Message { + var msgs []Message + // `n` is an increasing counter so it is never compacted. + n := 0 + // `i` determines how many compacted records to create + for i := 0; i < 5; i++ { + // `j` is how many times the current pattern repeats. We repeat because + // as long as we have a pattern that is slightly larger/smaller than + // the log segment size then if we repeat enough it will eventually + // try all configurations. + for j := 0; j < 30; j++ { + msgs = append(msgs, Message{ + Key: []byte(fmt.Sprintf("%04d", n)), + Value: []byte(fmt.Sprintf("%04d", n)), + }) + n++ + + // This produces the duplicated messages to compact. + for k := 0; k < i; k++ { + msgs = append(msgs, Message{ + Key: []byte("dup_"), + Value: []byte("dup_"), + }) + } + } + } + + // "end markers" to force duplicate message outside of the last segment of + // the log so that they can all be compacted. + for i := 0; i < 10; i++ { + msgs = append(msgs, Message{ + Key: []byte(fmt.Sprintf("e-%02d", i)), + Value: []byte(fmt.Sprintf("e-%02d", i)), + }) + } + return msgs +} + +// countKeys counts unique keys from given Message slice +func countKeys(msgs []Message) int { + m := make(map[string]struct{}) + for _, msg := range msgs { + m[string(msg.Key)] = struct{}{} + } + return len(m) +} + +func createTopicWithCompaction(t *testing.T, topic string, partitions int) { + t.Helper() + + t.Logf("createTopic(%s, %d)", topic, partitions) + + conn, err := Dial("tcp", "localhost:9092") + require.NoError(t, err) + defer conn.Close() + + controller, err := conn.Controller() + require.NoError(t, err) + + conn, err = Dial("tcp", net.JoinHostPort(controller.Host, strconv.Itoa(controller.Port))) + require.NoError(t, err) + + conn.SetDeadline(time.Now().Add(10 * time.Second)) + + err = conn.CreateTopics(TopicConfig{ + Topic: topic, + NumPartitions: partitions, + ReplicationFactor: 1, + ConfigEntries: []ConfigEntry{ + { + ConfigName: "cleanup.policy", + ConfigValue: "compact", + }, + { + ConfigName: "segment.bytes", + ConfigValue: "200", + }, + }, + }) + switch err { + case nil: + // ok + case TopicAlreadyExists: + // ok + default: + require.NoError(t, err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + waitForTopic(ctx, t, topic) +}