Skip to content

Commit

Permalink
Merge pull request #990 from wladh/magic
Browse files Browse the repository at this point in the history
Determine the records type based on the magic number not API version
  • Loading branch information
eapache authored Nov 29, 2017
2 parents a709f2d + b7f694e commit 6a8d89d
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 22 deletions.
7 changes: 3 additions & 4 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,10 @@ func (child *partitionConsumer) parseMessages(msgSet *MessageSet) ([]*ConsumerMe
return messages, nil
}

func (child *partitionConsumer) parseRecords(block *FetchResponseBlock) ([]*ConsumerMessage, error) {
func (child *partitionConsumer) parseRecords(batch *RecordBatch) ([]*ConsumerMessage, error) {
var messages []*ConsumerMessage
var incomplete bool
prelude := true
batch := block.Records.recordBatch

for _, rec := range batch.Records {
offset := batch.FirstOffset + rec.OffsetDelta
Expand Down Expand Up @@ -599,10 +598,10 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
return nil, err
}

if response.Version < 4 {
if block.Records.recordsType == legacyRecords {
return child.parseMessages(block.Records.msgSet)
}
return child.parseRecords(block)
return child.parseRecords(block.Records.recordBatch)
}

// brokerConsumer
Expand Down
43 changes: 43 additions & 0 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,49 @@ func TestConsumerExtraOffsets(t *testing.T) {
}
}

func TestConsumeMessageWithNewerFetchAPIVersion(t *testing.T) {
// Given
fetchResponse1 := &FetchResponse{Version: 4}
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 1)
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 2)

cfg := NewConfig()
cfg.Version = V0_11_0_0

broker0 := NewMockBroker(t, 0)
fetchResponse2 := &FetchResponse{}
fetchResponse2.Version = 4
fetchResponse2.AddError("my_topic", 0, ErrNoError)
broker0.SetHandlerByMap(map[string]MockResponse{
"MetadataRequest": NewMockMetadataResponse(t).
SetBroker(broker0.Addr(), broker0.BrokerID()).
SetLeader("my_topic", 0, broker0.BrokerID()),
"OffsetRequest": NewMockOffsetResponse(t).
SetVersion(1).
SetOffset("my_topic", 0, OffsetNewest, 1234).
SetOffset("my_topic", 0, OffsetOldest, 0),
"FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2),
})

master, err := NewConsumer([]string{broker0.Addr()}, cfg)
if err != nil {
t.Fatal(err)
}

// When
consumer, err := master.ConsumePartition("my_topic", 0, 1)
if err != nil {
t.Fatal(err)
}

assertMessageOffset(t, <-consumer.Messages(), 1)
assertMessageOffset(t, <-consumer.Messages(), 2)

safeClose(t, consumer)
safeClose(t, master)
broker0.Close()
}

// It is fine if offsets of fetched messages are not sequential (although
// strictly increasing!).
func TestConsumerNonSequentialOffsets(t *testing.T) {
Expand Down
9 changes: 1 addition & 8 deletions fetch_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,11 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
if err != nil {
return err
}
var records Records
if version >= 4 {
records = newDefaultRecords(nil)
} else {
records = newLegacyRecords(nil)
}
if recordsSize > 0 {
if err = records.decode(recordsDecoder); err != nil {
if err = b.Records.decode(recordsDecoder); err != nil {
return err
}
}
b.Records = records

return nil
}
Expand Down
77 changes: 75 additions & 2 deletions fetch_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,28 @@ var (
0x06, 0x05, 0x06, 0x07,
0x02,
0x06, 0x08, 0x09, 0x0A,
0x04, 0x0B, 0x0C,
}
0x04, 0x0B, 0x0C}

oneMessageFetchResponseV4 = []byte{
0x00, 0x00, 0x00, 0x00, // ThrottleTime
0x00, 0x00, 0x00, 0x01, // Number of Topics
0x00, 0x05, 't', 'o', 'p', 'i', 'c', // Topic
0x00, 0x00, 0x00, 0x01, // Number of Partitions
0x00, 0x00, 0x00, 0x05, // Partition
0x00, 0x01, // Error
0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, // High Watermark Offset
0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, // Last Stable Offset
0x00, 0x00, 0x00, 0x00, // Number of Aborted Transactions
0x00, 0x00, 0x00, 0x1C,
// messageSet
0x00, 0x00, 0x00, 0x00, 0x00, 0x55, 0x00, 0x00,
0x00, 0x00, 0x00, 0x10,
// message
0x23, 0x96, 0x4a, 0xf7, // CRC
0x00,
0x00,
0xFF, 0xFF, 0xFF, 0xFF,
0x00, 0x00, 0x00, 0x02, 0x00, 0xEE}
)

func TestEmptyFetchResponse(t *testing.T) {
Expand Down Expand Up @@ -173,3 +193,56 @@ func TestOneRecordFetchResponse(t *testing.T) {
t.Error("Decoding produced incorrect record value.")
}
}

func TestOneMessageFetchResponseV4(t *testing.T) {
response := FetchResponse{}
testVersionDecodable(t, "one message v4", &response, oneMessageFetchResponseV4, 4)

if len(response.Blocks) != 1 {
t.Fatal("Decoding produced incorrect number of topic blocks.")
}

if len(response.Blocks["topic"]) != 1 {
t.Fatal("Decoding produced incorrect number of partition blocks for topic.")
}

block := response.GetBlock("topic", 5)
if block == nil {
t.Fatal("GetBlock didn't return block.")
}
if block.Err != ErrOffsetOutOfRange {
t.Error("Decoding didn't produce correct error code.")
}
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing record where there wasn't one.")
}

n, err := block.Records.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of records.")
}
msgBlock := block.Records.msgSet.Messages[0]
if msgBlock.Offset != 0x550000 {
t.Error("Decoding produced incorrect message offset.")
}
msg := msgBlock.Msg
if msg.Codec != CompressionNone {
t.Error("Decoding produced incorrect message compression.")
}
if msg.Key != nil {
t.Error("Decoding produced message key where there was none.")
}
if !bytes.Equal(msg.Value, []byte{0x00, 0xEE}) {
t.Error("Decoding produced incorrect message value.")
}
}
1 change: 1 addition & 0 deletions packet_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type packetDecoder interface {
// Subsets
remaining() int
getSubset(length int) (packetDecoder, error)
peek(offset, length int) (packetDecoder, error) // similar to getSubset, but it doesn't advance the offset

// Stacks, see PushDecoder
push(in pushDecoder) error
Expand Down
5 changes: 0 additions & 5 deletions produce_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,6 @@ func (r *ProduceRequest) decode(pd packetDecoder, version int16) error {
return err
}
var records Records
if version >= 3 {
records = newDefaultRecords(nil)
} else {
records = newLegacyRecords(nil)
}
if err := records.decode(recordsDecoder); err != nil {
return err
}
Expand Down
8 changes: 8 additions & 0 deletions real_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ func (rd *realDecoder) getRawBytes(length int) ([]byte, error) {
return rd.raw[start:rd.off], nil
}

func (rd *realDecoder) peek(offset, length int) (packetDecoder, error) {
if rd.remaining() < offset+length {
return nil, ErrInsufficientData
}
off := rd.off + offset
return &realDecoder{raw: rd.raw[off : off+length]}, nil
}

// stacks

func (rd *realDecoder) push(in pushDecoder) error {
Expand Down
73 changes: 72 additions & 1 deletion records.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package sarama
import "fmt"

const (
legacyRecords = iota
unknownRecords = iota
legacyRecords
defaultRecords

magicOffset = 16
magicLength = 1
)

// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
Expand All @@ -22,7 +26,30 @@ func newDefaultRecords(batch *RecordBatch) Records {
return Records{recordsType: defaultRecords, recordBatch: batch}
}

// setTypeFromFields sets type of Records depending on which of msgSet or recordBatch is not nil.
// The first return value indicates whether both fields are nil (and the type is not set).
// If both fields are not nil, it returns an error.
func (r *Records) setTypeFromFields() (bool, error) {
if r.msgSet == nil && r.recordBatch == nil {
return true, nil
}
if r.msgSet != nil && r.recordBatch != nil {
return false, fmt.Errorf("both msgSet and recordBatch are set, but record type is unknown")
}
r.recordsType = defaultRecords
if r.msgSet != nil {
r.recordsType = legacyRecords
}
return false, nil
}

func (r *Records) encode(pe packetEncoder) error {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return err
}
}

switch r.recordsType {
case legacyRecords:
if r.msgSet == nil {
Expand All @@ -38,7 +65,31 @@ func (r *Records) encode(pe packetEncoder) error {
return fmt.Errorf("unknown records type: %v", r.recordsType)
}

func (r *Records) setTypeFromMagic(pd packetDecoder) error {
dec, err := pd.peek(magicOffset, magicLength)
if err != nil {
return err
}

magic, err := dec.getInt8()
if err != nil {
return err
}

r.recordsType = defaultRecords
if magic < 2 {
r.recordsType = legacyRecords
}
return nil
}

func (r *Records) decode(pd packetDecoder) error {
if r.recordsType == unknownRecords {
if err := r.setTypeFromMagic(pd); err != nil {
return nil
}
}

switch r.recordsType {
case legacyRecords:
r.msgSet = &MessageSet{}
Expand All @@ -51,6 +102,12 @@ func (r *Records) decode(pd packetDecoder) error {
}

func (r *Records) numRecords() (int, error) {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return 0, err
}
}

switch r.recordsType {
case legacyRecords:
if r.msgSet == nil {
Expand All @@ -67,7 +124,15 @@ func (r *Records) numRecords() (int, error) {
}

func (r *Records) isPartial() (bool, error) {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return false, err
}
}

switch r.recordsType {
case unknownRecords:
return false, nil
case legacyRecords:
if r.msgSet == nil {
return false, nil
Expand All @@ -83,6 +148,12 @@ func (r *Records) isPartial() (bool, error) {
}

func (r *Records) isControl() (bool, error) {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return false, err
}
}

switch r.recordsType {
case legacyRecords:
return false, nil
Expand Down
10 changes: 8 additions & 2 deletions records_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestLegacyRecords(t *testing.T) {
}

set = &MessageSet{}
r = newLegacyRecords(nil)
r = Records{}

err = decode(exp, set)
if err != nil {
Expand All @@ -42,6 +42,9 @@ func TestLegacyRecords(t *testing.T) {
t.Fatal(err)
}

if r.recordsType != legacyRecords {
t.Fatalf("Wrong records type %v, expected %v", r.recordsType, legacyRecords)
}
if !reflect.DeepEqual(set, r.msgSet) {
t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
}
Expand Down Expand Up @@ -96,7 +99,7 @@ func TestDefaultRecords(t *testing.T) {
}

batch = &RecordBatch{}
r = newDefaultRecords(nil)
r = Records{}

err = decode(exp, batch)
if err != nil {
Expand All @@ -107,6 +110,9 @@ func TestDefaultRecords(t *testing.T) {
t.Fatal(err)
}

if r.recordsType != defaultRecords {
t.Fatalf("Wrong records type %v, expected %v", r.recordsType, defaultRecords)
}
if !reflect.DeepEqual(batch, r.recordBatch) {
t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
}
Expand Down

0 comments on commit 6a8d89d

Please sign in to comment.