diff --git a/protocol/kafka_sarama/v2/sender.go b/protocol/kafka_sarama/v2/sender.go index 5345d94e5..0c92399a3 100644 --- a/protocol/kafka_sarama/v2/sender.go +++ b/protocol/kafka_sarama/v2/sender.go @@ -53,6 +53,10 @@ func (s *Sender) Send(ctx context.Context, m binding.Message, transformers ...bi kafkaMessage := sarama.ProducerMessage{Topic: s.topic} + if k := ctx.Value(withMessageKey{}); k != nil { + kafkaMessage.Key = k.(sarama.Encoder) + } + if err = WriteProducerMessage(ctx, m, &kafkaMessage, transformers...); err != nil { return err } @@ -70,3 +74,10 @@ func (s *Sender) Close(ctx context.Context) error { // otherwise it will close the whole client return s.syncProducer.Close() } + +type withMessageKey struct{} + +// WithMessageKey allows to set the key used when sending the producer message +func WithMessageKey(ctx context.Context, key sarama.Encoder) context.Context { + return context.WithValue(ctx, withMessageKey{}, key) +} diff --git a/protocol/kafka_sarama/v2/sender_test.go b/protocol/kafka_sarama/v2/sender_test.go new file mode 100644 index 000000000..9dc36cd5c --- /dev/null +++ b/protocol/kafka_sarama/v2/sender_test.go @@ -0,0 +1,51 @@ +package kafka_sarama + +import ( + "context" + "sync" + "testing" + + "github.com/Shopify/sarama" + "github.com/stretchr/testify/require" + + "github.com/cloudevents/sdk-go/v2/test" +) + +type syncProducerMock struct { + lock sync.Mutex + sent []*sarama.ProducerMessage +} + +func (s *syncProducerMock) SendMessage(msg *sarama.ProducerMessage) (partition int32, offset int64, err error) { + s.lock.Lock() + defer s.lock.Unlock() + s.sent = append(s.sent, msg) + return 0, int64(len(s.sent) - 1), err +} + +func (s *syncProducerMock) SendMessages(msgs []*sarama.ProducerMessage) error { + s.lock.Lock() + defer s.lock.Unlock() + s.sent = append(s.sent, msgs...) + return nil +} + +func (s *syncProducerMock) Close() error { + return nil +} + +func TestSenderWithKey(t *testing.T) { + syncProducerMock := &syncProducerMock{} + topic := "aaa" + + sender := &Sender{topic: topic, syncProducer: syncProducerMock} + require.NoError(t, sender.Send( + WithMessageKey(context.TODO(), sarama.StringEncoder("hello")), + test.FullMessage(), + )) + + require.Len(t, syncProducerMock.sent, 1) + kafkaMsg := syncProducerMock.sent[0] + require.Equal(t, kafkaMsg.Topic, topic) + require.Equal(t, kafkaMsg.Key, sarama.StringEncoder("hello")) +} diff --git a/samples/kafka/sender/main.go b/samples/kafka/sender/main.go index 04bfb0258..9a4e9cb5a 100644 --- a/samples/kafka/sender/main.go +++ b/samples/kafka/sender/main.go @@ -5,6 +5,7 @@ import ( "log" "github.com/Shopify/sarama" + "github.com/google/uuid" "github.com/cloudevents/sdk-go/protocol/kafka_sarama/v2" cloudevents "github.com/cloudevents/sdk-go/v2" @@ -32,6 +33,7 @@ func main() { for i := 0; i < count; i++ { e := cloudevents.NewEvent() + e.SetID(uuid.New().String()) e.SetType("com.cloudevents.sample.sent") e.SetSource("https://github.com/cloudevents/sdk-go/v2/samples/kafka/sender") _ = e.SetData(cloudevents.ApplicationJSON, map[string]interface{}{ @@ -39,7 +41,11 @@ func main() { "message": "Hello, World!", }) - if result := c.Send(context.Background(), e); cloudevents.IsUndelivered(result) { + if result := c.Send( + // Set the producer message key + kafka_sarama.WithMessageKey(context.Background(), sarama.StringEncoder(e.ID())), + e, + ); cloudevents.IsUndelivered(result) { log.Printf("failed to send: %v", err) } else { log.Printf("sent: %d, accepted: %t", i, cloudevents.IsACK(result)) diff --git a/v2/test/event_mocks.go b/v2/test/event_mocks.go index 7116662f4..05e5675b1 100644 --- a/v2/test/event_mocks.go +++ b/v2/test/event_mocks.go @@ -6,6 +6,7 @@ import ( "net/url" "time" + "github.com/cloudevents/sdk-go/v2/binding" "github.com/cloudevents/sdk-go/v2/binding/spec" "github.com/cloudevents/sdk-go/v2/event" "github.com/cloudevents/sdk-go/v2/types" @@ -56,6 +57,18 @@ func MinEvent() event.Event { } } +// FullMessage returns the same event of FullEvent but wrapped as Message. +func FullMessage() binding.Message { + ev := FullEvent() + return binding.ToMessage(&ev) +} + +// MinMessage returns the same event of MinEvent but wrapped as Message. +func MinMessage() binding.Message { + ev := MinEvent() + return binding.ToMessage(&ev) +} + // AllVersions returns all versions of each event in events. // ID gets a -number suffix so IDs are unique. func AllVersions(events []event.Event) []event.Event {