Skip to content

Commit

Permalink
Add WAL message heartbeat
Browse files Browse the repository at this point in the history
This PR introduces WAL message heartbeats for PG >= 14.  The idea is,
every ~minute we push a wal message via the query connection.  This does
nothing but ensure that the wal subscriber receives messages and the DB
has activity.

We then force report the WAL position on heartbeats.
  • Loading branch information
tonyhb committed Oct 12, 2024
1 parent cd89c86 commit 176df24
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 23 deletions.
5 changes: 5 additions & 0 deletions pkg/changeset/changeset.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ const (
OperationUpdate Operation = "UPDATE"
OperationDelete Operation = "DELETE"
OperationTruncate Operation = "TRUNCATE"

// OperationHeartbeat represents the changeset generated for heartbeats when we
// send messages to increase the WAL LSN. This is used for updating watermarks only,
// and should not process events.
OperationHeartbeat Operation = "HEARTBEAT"
)

// WatermarkCommitter is an interface that commits a given watermark to backing datastores.
Expand Down
2 changes: 2 additions & 0 deletions pkg/consts/pgconsts/pgconsts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ const (
Username = "inngest"
SlotName = "inngest_cdc"
PublicationName = "inngest"

MessagesVersion = 14
)
20 changes: 17 additions & 3 deletions pkg/decoder/pg_logical_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)

func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger) Decoder {
func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger, messages bool) Decoder {
return v1LogicalDecoder{
log: log,
schema: s,
messages: messages,
relations: make(map[uint32]*pglogrepl.RelationMessage),
}
}

type v1LogicalDecoder struct {
log *slog.Logger

messages bool
schema *schema.PGXSchemaLoader
relations map[uint32]*pglogrepl.RelationMessage
}

func (v1LogicalDecoder) ReplicationPluginArgs() []string {
func (v v1LogicalDecoder) ReplicationPluginArgs() []string {
// https://www.postgresql.org/docs/current/protocol-logical-replication.html#PROTOCOL-LOGICAL-REPLICATION-PARAMS
//
// "Proto_version '2'" with "streaming 'true' streams transactions as they're progressing.
Expand All @@ -37,10 +39,17 @@ func (v1LogicalDecoder) ReplicationPluginArgs() []string {
//
// Version 1 only sends DML entries when the transaction commits, ensuring that any event
// generated by Inngest is for a committed transaction.
if v.messages {
return []string{
"proto_version '1'",
fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName),
"messages 'true'", // Doesn't work for <= v13
}
}

return []string{
"proto_version '1'",
fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName),
// "messages 'true'", // Doesn't work for v12 and v13.
}
}

Expand All @@ -49,6 +58,11 @@ func (v v1LogicalDecoder) Decode(in []byte, cs *changeset.Changeset) (bool, erro
msgType := pglogrepl.MessageType(in[0])

switch msgType {
case pglogrepl.MessageTypeMessage:
// This is a heartbeat (or another WAL message). Do nothing but record
// the heartbeat and updated watermark.
cs.Operation = changeset.OperationHeartbeat
return true, nil
case pglogrepl.MessageTypeRelation:
// MessageTypeRelation describes the OIDs for any relation before DML messages are sent. From the docs:
//
Expand Down
106 changes: 86 additions & 20 deletions pkg/replicator/pgreplicator/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"strings"
"sync"
"sync/atomic"
"time"

Expand All @@ -24,8 +25,9 @@ import (
)

var (
ReadTimeout = time.Second * 5
CommitInterval = time.Second * 5
ReadTimeout = time.Second * 5
CommitInterval = time.Second * 5
DefaultHeartbeatTime = time.Minute
)

// PostgresReplicator is a Replicator with added postgres functionality.
Expand Down Expand Up @@ -61,6 +63,12 @@ type Opts struct {

// New returns a new postgres replicator for a single postgres database.
func New(ctx context.Context, opts Opts) (PostgresReplicator, error) {
if opts.Log == nil {
opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
}

cfg := opts.Config

// Ensure that we add "replication": "database" as a to the replication
Expand All @@ -84,24 +92,28 @@ func New(ctx context.Context, opts Opts) (PostgresReplicator, error) {
return nil, fmt.Errorf("error connecting to postgres host for schemas: %w", err)
}

// Query for current postgres version.
var version int
row := pgxc.QueryRow(ctx, "SELECT current_setting('server_version_num')::int / 10000;")
if err := row.Scan(&version); err != nil {
opts.Log.Warn("error querying for postgres version", "error", err)
}

sl := schema.NewPGXSchemaLoader(pgxc)
// Refresh all schemas to begin with
if err := sl.Refresh(); err != nil {
return nil, err
}

if opts.Log == nil {
opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
}

return &pg{
opts: opts,
conn: replConn,
queryConn: pgxc,
decoder: decoder.NewV1LogicalDecoder(sl, opts.Log),
log: opts.Log,
opts: opts,
conn: replConn,
queryConn: pgxc,
queryLock: &sync.Mutex{},
decoder: decoder.NewV1LogicalDecoder(sl, opts.Log, version >= pgconsts.MessagesVersion),
log: opts.Log,
version: version,
heartbeatTime: DefaultHeartbeatTime,
}, nil
}

Expand All @@ -111,8 +123,14 @@ type pg struct {
// conn is the WAL streaming connection. Once replication starts, this
// conn cannot be used for any queries.
conn *pgx.Conn

// queryCon is a conn for querying data.
queryConn *pgx.Conn

// queryLock is used to lock pgx.Conn, as it's a single connection which cannot be used
// in parallel.
queryLock *sync.Mutex

// decoder decodes the binary WAL log
decoder decoder.Decoder
// nextReportTime records the time in which we must next report the current
Expand All @@ -125,6 +143,9 @@ type pg struct {
// log is a stdlib logger for reporting debug and warn logs.
log *slog.Logger

version int
heartbeatTime time.Duration

stopped int32
}

Expand All @@ -140,14 +161,20 @@ func (p *pg) Close(ctx context.Context) error {
}

func (p *pg) ReplicationSlot(ctx context.Context) (ReplicationSlot, error) {

mode, err := p.walMode(ctx)
if err != nil {
return ReplicationSlot{}, err
}

if mode != "logical" {
return ReplicationSlot{}, ErrLogicalReplicationNotSetUp
}

// Lock when querying repl slot data.
p.queryLock.Lock()
defer p.queryLock.Unlock()

return ReplicationSlotData(ctx, p.queryConn)
}

Expand Down Expand Up @@ -218,6 +245,25 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error {
// the DML.
unwrapper := &txnUnwrapper{cc: cc}

go func() {
if p.version < pgconsts.MessagesVersion {
// doesn't support wal messages; ignore.
return
}

t := time.NewTicker(p.heartbeatTime)
for range t.C {
// Send a hearbeat every minute
p.queryLock.Lock()
_, err := p.queryConn.Exec(ctx, "SELECT pg_logical_emit_message(false, 'heartbeat', now()::varchar);")
p.queryLock.Unlock()

if err != nil {
p.log.Warn("unable to emit heartbeat", "error", err, "host", p.opts.Config.Host)
}
}
}()

for {
if ctx.Err() != nil || atomic.LoadInt32(&p.stopped) == 1 {
// Always call Close automatically.
Expand All @@ -233,6 +279,12 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error {
continue
}

if changes.Operation == changeset.OperationHeartbeat {
p.Commit(changes.Watermark)
p.forceNextReport(ctx)
continue
}

unwrapper.Process(changes)
}
}
Expand All @@ -259,7 +311,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {

if err != nil {
if pgconn.Timeout(err) {
p.forceNextReport()
p.forceNextReport(ctx)
// We return nil as we want to keep iterating.
return nil, nil
}
Expand Down Expand Up @@ -291,7 +343,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {
return nil, fmt.Errorf("error parsing replication keepalive: %w", err)
}
if pkm.ReplyRequested {
p.forceNextReport()
p.forceNextReport(ctx)
}
return nil, nil
case pglogrepl.XLogDataByteID:
Expand All @@ -316,6 +368,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {
if err != nil {
return nil, fmt.Errorf("error decoding xlog data: %w", err)
}

if !ok {
return nil, nil
}
Expand Down Expand Up @@ -348,10 +401,11 @@ func (p *pg) committedWatermark() (wm changeset.Watermark) {
}
}

func (p *pg) forceNextReport() {
func (p *pg) forceNextReport(ctx context.Context) {
// Updating the next report time to a zero time always reports the LSN,
// as time.Now() is always after the empty time.
p.nextReportTime = time.Time{}
p.report(ctx, true)

Check failure on line 408 in pkg/replicator/pgreplicator/pg.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `p.report` is not checked (errcheck)
}

// report reports the current replication slot's LSN progress to the server. We can optionally
Expand Down Expand Up @@ -384,6 +438,9 @@ func (p *pg) LSN() (lsn pglogrepl.LSN) {
}

func (p *pg) walMode(ctx context.Context) (string, error) {
p.queryLock.Lock()
defer p.queryLock.Unlock()

var mode string
row := p.queryConn.QueryRow(ctx, "SHOW wal_level")
err := row.Scan(&mode)
Expand All @@ -405,15 +462,24 @@ type ReplicationSlot struct {

func ReplicationSlotData(ctx context.Context, conn *pgx.Conn) (ReplicationSlot, error) {
ret := ReplicationSlot{}
row := conn.QueryRow(
rows, err := conn.Query(
ctx,
fmt.Sprintf(`SELECT
active, restart_lsn, confirmed_flush_lsn
FROM pg_replication_slots WHERE slot_name = '%s';`,
active, restart_lsn, confirmed_flush_lsn
FROM pg_replication_slots WHERE slot_name = '%s';`,
pgconsts.SlotName,
),
)
err := row.Scan(&ret.Active, &ret.RestartLSN, &ret.ConfirmedFlushLSN)
defer rows.Close()

Check failure on line 473 in pkg/replicator/pgreplicator/pg.go

View workflow job for this annotation

GitHub Actions / lint

SA5001: should check error returned from conn.Query() before deferring rows.Close() (staticcheck)
if err != nil {
return ReplicationSlot{}, err
}

if !rows.Next() {
return ReplicationSlot{}, ErrReplicationSlotNotFound
}

err = rows.Scan(&ret.Active, &ret.RestartLSN, &ret.ConfirmedFlushLSN)
// pgx has its own ErrNoRows :(
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, pgx.ErrNoRows) {
return ret, ErrReplicationSlotNotFound
Expand Down
47 changes: 47 additions & 0 deletions pkg/replicator/pgreplicator/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,53 @@ func TestInsert(t *testing.T) {
}
}

func TestLogicalEmitHeartbeat(t *testing.T) {
t.Parallel()
versions := []int{14, 15, 16}

for _, v1 := range versions {
v := v1 // loop capture
t.Run(fmt.Sprintf("EmitHeartbeat - Postgres %d", v), func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())

c, conn := test.StartPG(t, ctx, test.StartPGOpts{Version: v})
opts := Opts{Config: conn}
repl, err := New(ctx, opts)

// heartbeat fast in tests.
r := repl.(*pg)
r.heartbeatTime = 250 * time.Millisecond
require.NoError(t, err)

cb := eventwriter.NewCallbackWriter(ctx, 1, time.Millisecond, func(batch []*changeset.Changeset) error {
return nil
})
csChan := cb.Listen(ctx, r)

go func() {
err := r.Pull(ctx, csChan)
require.NoError(t, err)
}()

slotA, err := r.ReplicationSlot(ctx)
require.NoError(t, err)

<-time.After(1100 * time.Millisecond)

slotB, err := r.ReplicationSlot(ctx)
require.NoError(t, err)

require.NotEqual(t, slotA.ConfirmedFlushLSN, slotB.ConfirmedFlushLSN)
require.True(t, int(slotB.ConfirmedFlushLSN) > int(slotA.ConfirmedFlushLSN))

cancel()
_ = c.Stop(ctx, nil)
})
}
}

func TestUpdateMany_ReplicaIdentityFull(t *testing.T) {
t.Parallel()
versions := []int{12, 13, 14, 15, 16}
Expand Down
6 changes: 6 additions & 0 deletions pkg/replicator/pgreplicator/txn_unwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ func (t *txnUnwrapper) Process(cs *changeset.Changeset) {
}

switch cs.Operation {
case changeset.OperationHeartbeat:
// The unwrapper should never receive heartbeats as the replicator should
// handle them and short circuit. However, always transmit them immediately
// for safety in code in case someone changes something in the future.
t.cc <- cs
return
case changeset.OperationBegin:
t.begin = cs
case changeset.OperationCommit:
Expand Down

0 comments on commit 176df24

Please sign in to comment.