Skip to content

Commit

Permalink
Get attempt for consumed msg
Browse files Browse the repository at this point in the history
  • Loading branch information
petans24 committed Nov 30, 2023
1 parent 25e98a5 commit 247de35
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 80 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ func main() {

// we must specify the message handler, which implements simple interface
type handler struct {}
func (h *handler) HandleMessage(_ context.Context, msg pgq.Message) (processed bool, err error) {
fmt.Println("Message payload:", string(msg.Payload()))
func (h *handler) HandleMessage(_ context.Context, msg *pgq.MessageIncoming) (processed bool, err error) {
fmt.Println("Message payload:", string(msg.Payload))
return true, nil
}
```
Expand Down
40 changes: 22 additions & 18 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ const (
// // | true | <nil> | processed, no error. |
// // | true | some error | processed, ended with error. Don't retry! |
type MessageHandler interface {
HandleMessage(context.Context, Message) (processed bool, err error)
HandleMessage(context.Context, *MessageIncoming) (processed bool, err error)
}

// MessageHandlerFunc is MessageHandler implementation by simple function.
type MessageHandlerFunc func(context.Context, Message) (processed bool, err error)
type MessageHandlerFunc func(context.Context, *MessageIncoming) (processed bool, err error)

// HandleMessage calls self. It also implements MessageHandler interface.
func (fn MessageHandlerFunc) HandleMessage(ctx context.Context, msg Message) (processed bool, err error) {
func (fn MessageHandlerFunc) HandleMessage(ctx context.Context, msg *MessageIncoming) (processed bool, err error) {
return fn(ctx, msg)
}

Expand Down Expand Up @@ -280,7 +280,7 @@ func (c *Consumer) Run(ctx context.Context) error {
}
wg.Add(len(msgs))
for _, msg := range msgs {
go func(msg *message) {
go func(msg *MessageIncoming) {
defer wg.Done()
defer c.sem.Release(1)
c.handleMessage(ctx, msg)
Expand Down Expand Up @@ -370,19 +370,19 @@ func (c *Consumer) generateQuery() string {
sb.WriteString(` LIMIT $2`)
sb.WriteString(` FOR UPDATE SKIP LOCKED`)
}
sb.WriteString(`) RETURNING id, payload, metadata`)
sb.WriteString(`) RETURNING id, payload, metadata, consumed_count`)
return sb.String()
}

func (c *Consumer) handleMessage(ctx context.Context, msg *message) {
func (c *Consumer) handleMessage(ctx context.Context, msg *MessageIncoming) {
ctx, cancel := context.WithTimeout(ctx, c.cfg.LockDuration)
defer cancel()

ctxTimeout, cancel := prepareCtxTimeout()
defer cancel()
// TODO configurable Propagator
propagator := otel.GetTextMapPropagator()
carrier := propagation.MapCarrier(msg.metadata)
carrier := propagation.MapCarrier(msg.Metadata)
ctx = propagator.Extract(ctx, carrier)

ctx, span := otel.Tracer("pgq").Start(ctx, "HandleMessage")
Expand All @@ -405,7 +405,7 @@ func (c *Consumer) handleMessage(ctx context.Context, msg *message) {
"error", err.Error(),
"ackTimeout", c.cfg.AckTimeout,
"reason", reason,
"msg.metadata", msg.metadata,
"msg.metadata", msg.Metadata,
)
}
return
Expand All @@ -419,7 +419,7 @@ func (c *Consumer) handleMessage(ctx context.Context, msg *message) {
"error", err,
"ackTimeout", c.cfg.AckTimeout,
"reason", discardReason,
"msg.metadata", msg.metadata,
"msg.metadata", msg.Metadata,
)
}
return
Expand All @@ -429,7 +429,7 @@ func (c *Consumer) handleMessage(ctx context.Context, msg *message) {
c.cfg.Logger.ErrorContext(ctx, "pgq: ack failed",
"error", err,
"ackTimeout", c.cfg.AckTimeout,
"msg.metadata", msg.metadata,
"msg.metadata", msg.Metadata,
)
}
}
Expand All @@ -444,7 +444,7 @@ func prepareCtxTimeout() (func(td time.Duration) context.Context, context.Cancel
return fn, cancel
}

func (c *Consumer) consumeMessages(ctx context.Context, query string) ([]*message, error) {
func (c *Consumer) consumeMessages(ctx context.Context, query string) ([]*MessageIncoming, error) {
for {
maxMsg, err := acquireMaxFromSemaphore(ctx, c.sem, int64(c.cfg.MaxParallelMessages))
if err != nil {
Expand Down Expand Up @@ -473,9 +473,10 @@ type pgMessage struct {
ID pgtype.UUID
Payload pgtype.JSONB
Metadata pgtype.JSONB
Attempt pgtype.Int4
}

func (c *Consumer) tryConsumeMessages(ctx context.Context, query string, limit int64) (_ []*message, err error) {
func (c *Consumer) tryConsumeMessages(ctx context.Context, query string, limit int64) (_ []*MessageIncoming, err error) {
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
// TODO not necessary fatal, network could wiggle.
Expand Down Expand Up @@ -509,7 +510,7 @@ func (c *Consumer) tryConsumeMessages(ctx context.Context, query string, limit i
}
defer rows.Close()

var msgs []*message
var msgs []*MessageIncoming
for rows.Next() {
msg, err := c.parseRow(ctx, rows)
if err != nil {
Expand All @@ -529,12 +530,13 @@ func (c *Consumer) tryConsumeMessages(ctx context.Context, query string, limit i
return msgs, nil
}

func (c *Consumer) parseRow(ctx context.Context, rows *sql.Rows) (*message, error) {
func (c *Consumer) parseRow(ctx context.Context, rows *sql.Rows) (*MessageIncoming, error) {
var pgMsg pgMessage
if err := rows.Scan(
&pgMsg.ID,
&pgMsg.Payload,
&pgMsg.Metadata,
&pgMsg.Attempt,
); err != nil {
if isErrorCode(err, undefinedTableErrCode, undefinedColumnErrCode) {
return nil, fatalError{Err: err}
Expand Down Expand Up @@ -584,23 +586,25 @@ func (c *Consumer) discardInvalidMsg(ctx context.Context, id pgtype.UUID, err er
}
}

func (c *Consumer) finishParsing(pgMsg pgMessage) (*message, error) {
msg := &message{
func (c *Consumer) finishParsing(pgMsg pgMessage) (*MessageIncoming, error) {
msg := &MessageIncoming{
id: uuid.UUID(pgMsg.ID.Bytes),
once: sync.Once{},
ackFn: c.ackMessage(c.db, pgMsg.ID),
nackFn: c.nackMessage(c.db, pgMsg.ID),
discardFn: c.discardMessage(c.db, pgMsg.ID),
}
var err error
msg.payload, err = parsePayload(pgMsg)
msg.Payload, err = parsePayload(pgMsg)
if err != nil {
return msg, errors.Wrap(err, "parsing payload")
}
msg.metadata, err = parseMetadata(pgMsg)
msg.Metadata, err = parseMetadata(pgMsg)
if err != nil {
return msg, errors.Wrap(err, "parsing metadata")
}
msg.attempt = int(pgMsg.Attempt.Int)
msg.maxConsumedCount = c.cfg.MaxConsumeCount
return msg, nil
}

Expand Down
12 changes: 6 additions & 6 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestConsumer_generateQuery(t *testing.T) {
{
name: "simple",
args: args{queueName: "testing_queue"},
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata",
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata, consumed_count",
},
{
name: "scanInterval 12 hours",
Expand All @@ -34,7 +34,7 @@ func TestConsumer_generateQuery(t *testing.T) {
WithHistoryLimit(12 * time.Hour),
},
},
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE created_at >= CURRENT_TIMESTAMP - $3::interval AND created_at < CURRENT_TIMESTAMP AND (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata",
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE created_at >= CURRENT_TIMESTAMP - $3::interval AND created_at < CURRENT_TIMESTAMP AND (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata, consumed_count",
},
{
name: "scn interval 12 hours abd max consumed count limit disabled",
Expand All @@ -46,12 +46,12 @@ func TestConsumer_generateQuery(t *testing.T) {
WithMaxConsumeCount(0),
},
},
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE created_at >= CURRENT_TIMESTAMP - $3::interval AND created_at < CURRENT_TIMESTAMP AND (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata",
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE created_at >= CURRENT_TIMESTAMP - $3::interval AND created_at < CURRENT_TIMESTAMP AND (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata, consumed_count",
},
{
name: "with metadata condition",
args: args{queueName: "testing_queue"},
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata",
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata, consumed_count",
},
{
name: "scanInterval 12 hours with metadata condition",
Expand All @@ -61,12 +61,12 @@ func TestConsumer_generateQuery(t *testing.T) {
WithHistoryLimit(12 * time.Hour),
},
},
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE created_at >= CURRENT_TIMESTAMP - $3::interval AND created_at < CURRENT_TIMESTAMP AND (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata",
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE created_at >= CURRENT_TIMESTAMP - $3::interval AND created_at < CURRENT_TIMESTAMP AND (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata, consumed_count",
},
{
name: "with negative metadata condition",
args: args{queueName: "testing_queue"},
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata",
want: "UPDATE \"testing_queue\" SET locked_until = $1, started_at = CURRENT_TIMESTAMP, consumed_count = consumed_count+1 WHERE id IN (SELECT id FROM \"testing_queue\" WHERE (locked_until IS NULL OR locked_until < CURRENT_TIMESTAMP) AND consumed_count < 3 AND processed_at IS NULL ORDER BY consumed_count ASC, created_at ASC LIMIT $2 FOR UPDATE SKIP LOCKED) RETURNING id, payload, metadata, consumed_count",
},
}
for _, tt := range tests {
Expand Down
6 changes: 3 additions & 3 deletions example_consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

type Handler struct{}

func (h *Handler) HandleMessage(ctx context.Context, msg pgq.Message) (res bool, err error) {
func (h *Handler) HandleMessage(ctx context.Context, msg *pgq.MessageIncoming) (res bool, err error) {
defer func() {
r := recover()
if r == nil {
Expand All @@ -30,15 +30,15 @@ func (h *Handler) HandleMessage(ctx context.Context, msg pgq.Message) (res bool,
err = fmt.Errorf("%v", r)
}
}()
if msg.Metadata()["heaviness"] == "heavy" {
if msg.Metadata["heaviness"] == "heavy" {
// nack the message, it will be retried
// Message won't contain error detail in the database.
return pgq.MessageNotProcessed, nil
}
var myPayload struct {
Foo string `json:"foo"`
}
if err := json.Unmarshal(msg.Payload(), &myPayload); err != nil {
if err := json.Unmarshal(msg.Payload, &myPayload); err != nil {
// discard the message, it will not be retried
// Message will contain error detail in the database.
return pgq.MessageProcessed, fmt.Errorf("invalid payload: %v", err)
Expand Down
18 changes: 9 additions & 9 deletions example_publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ func ExamplePublisher() {
p := pgq.NewPublisher(db)
ctx, _ := context.WithTimeout(context.Background(), 5*time.Second)
payload, _ := json.Marshal(PayloadStruct{Foo: "bar"})
messages := []pgq.Message{
pgq.NewMessage(
pgq.Metadata{
messages := []*pgq.MessageOutgoing{
{
Metadata: pgq.Metadata{
"version": "1.0",
},
json.RawMessage(payload),
),
pgq.NewMessage(
pgq.Metadata{
Payload: json.RawMessage(payload),
},
{
Metadata: pgq.Metadata{
"version": "1.0",
},
json.RawMessage(payload),
),
Payload: json.RawMessage(payload),
},
}
ids, err := p.Publish(ctx, queueName, messages...)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions integtest/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestConsumer_Run_graceful_shutdown(t *testing.T) {
require.NoError(t, err)
publisher := NewPublisher(db)
msgIDs, err := publisher.Publish(ctx, queueName,
NewMessage(Metadata{"foo": "bar"}, json.RawMessage(`{"foo":"bar"}`)),
&MessageOutgoing{Metadata: Metadata{"foo": "bar"}, Payload: json.RawMessage(`{"foo":"bar"}`)},
)
require.NoError(t, err)
require.Equal(t, 1, len(msgIDs))
Expand Down Expand Up @@ -116,7 +116,7 @@ func ensureUUIDExtension(t *testing.T, db *sql.DB) {

type slowHandler struct{}

func (s *slowHandler) HandleMessage(ctx context.Context, _ Message) (processed bool, err error) {
func (s *slowHandler) HandleMessage(ctx context.Context, _ *MessageIncoming) (processed bool, err error) {
<-ctx.Done()
return MessageNotProcessed, ctx.Err()
}
Expand Down
10 changes: 5 additions & 5 deletions integtest/publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ func TestPublisher(t *testing.T) {
}
tests := []struct {
name string
msg pgq.Message
msg *pgq.MessageOutgoing
publisherOpts []pgq.PublisherOption
want want
wantErr bool
}{
{
name: "Select extra columns",
msg: pgq.NewMessage(
pgq.Metadata{
msg: &pgq.MessageOutgoing{
Metadata: pgq.Metadata{
"test": "test_value",
},
json.RawMessage(`{"foo":"bar"}`),
),
Payload: json.RawMessage(`{"foo":"bar"}`),
},
publisherOpts: []pgq.PublisherOption{
pgq.WithMetaInjectors(
pgq.StaticMetaInjector(pgq.Metadata{"host": "localhost"}),
Expand Down
Loading

0 comments on commit 247de35

Please sign in to comment.