diff --git a/README.md b/README.md index 67061c5..c5307d3 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ configuration. ## Change Data Capture This connector implements CDC features for PostgreSQL by creating a logical replication slot and a publication that -listens to changes in the configured table. Every detected change is converted into a record and returned in the call to +listens to changes in the configured tables. Every detected change is converted into a record and returned in the call to `Read`. If there is no record available at the moment `Read` is called, it blocks until a record is available or the connector receives a stop signal. @@ -58,16 +58,15 @@ returned. ## Configuration Options -| name | description | required | default | -| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | -------- | ---------------------- | -| `url` | Connection string for the Postgres database. | true | | -| `table` | The name of the table in Postgres that the connector should read. | true | | -| `columns` | Comma separated list of column names that should be included in each Record's payload. | false | (all columns) | -| `key` | Column name that records should use for their `Key` fields. | false | (primary key of table) | -| `snapshotMode` | Whether or not the plugin will take a snapshot of the entire table before starting cdc mode (allowed values: `initial` or `never`). | false | `initial` | -| `cdcMode` | Determines the CDC mode (allowed values: `auto`, `logrepl` or `long_polling`). | false | `auto` | -| `logrepl.publicationName` | Name of the publication to listen for WAL events. | false | `conduitpub` | -| `logrepl.slotName` | Name of the slot opened for replication events. | false | `conduitslot` | +| name | description | required | default | +|---------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------|---------------| +| `url` | Connection string for the Postgres database. | true | | +| `table` | List of table names to read from, separated by comma. example: `"employees,offices,payments"` | true | | +| `key` | List of Key column names per table, separated by comma. example:`"table1:key1,table2:key2"`, if not supplied, the table primary key will be used as the `'Key'` field for the records. | false | | +| `snapshotMode` | Whether or not the plugin will take a snapshot of the entire table before starting cdc mode (allowed values: `initial` or `never`). | false | `initial` | +| `cdcMode` | Determines the CDC mode (allowed values: `auto`, `logrepl` or `long_polling`). | false | `auto` | +| `logrepl.publicationName` | Name of the publication to listen for WAL events. | false | `conduitpub` | +| `logrepl.slotName` | Name of the slot opened for replication events. | false | `conduitslot` | # Destination @@ -91,7 +90,7 @@ If there is no key, the record will be simply appended. ## Configuration Options | name | description | required | default | -| ------- | --------------------------------------------------------------------------- | -------- | ------- | +|---------|-----------------------------------------------------------------------------|----------|---------| | `url` | Connection string for the Postgres database. | true | | | `table` | The name of the table in Postgres that the connector should write to. | false | | | `key` | Column name used to detect if the target table already contains the record. | false | | diff --git a/source.go b/source.go index aa37e5b..8caf771 100644 --- a/source.go +++ b/source.go @@ -29,9 +29,10 @@ import ( type Source struct { sdk.UnimplementedSource - iterator source.Iterator - config source.Config - conn *pgx.Conn + iterator source.Iterator + config source.Config + conn *pgx.Conn + tableKeys map[string]string } func NewSource() sdk.Source { @@ -47,10 +48,9 @@ func (s *Source) Configure(_ context.Context, cfg map[string]string) error { if err != nil { return err } - // try parsing the url - _, err = pgx.ParseConfig(s.config.URL) + s.tableKeys, err = s.config.Validate() if err != nil { - return fmt.Errorf("invalid url: %w", err) + return err } return nil } @@ -59,6 +59,10 @@ func (s *Source) Open(ctx context.Context, pos sdk.Position) error { if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } + columns, err := s.getTableColumns(ctx, conn) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } s.conn = conn switch s.config.CDCMode { @@ -77,9 +81,8 @@ func (s *Source) Open(ctx context.Context, pos sdk.Position) error { Position: pos, SlotName: s.config.LogreplSlotName, PublicationName: s.config.LogreplPublicationName, - TableName: s.config.Table, - KeyColumnName: s.config.Key, - Columns: s.config.Columns, + Tables: s.config.Table, + TableKeys: s.tableKeys, }) if err != nil { return fmt.Errorf("failed to create logical replication iterator: %w", err) @@ -96,9 +99,9 @@ func (s *Source) Open(ctx context.Context, pos sdk.Position) error { snap, err := longpoll.NewSnapshotIterator( ctx, s.conn, - s.config.Table, - s.config.Columns, - s.config.Key) + s.config.Table[0], //todo: only the first table for now + columns, + s.tableKeys[s.config.Table[0]]) if err != nil { return fmt.Errorf("failed to create long polling iterator: %w", err) } @@ -131,3 +134,27 @@ func (s *Source) Teardown(ctx context.Context) error { } return nil } + +func (s *Source) getTableColumns(ctx context.Context, conn *pgx.Conn) ([]string, error) { + query := "SELECT column_name FROM information_schema.columns WHERE table_name = $1" + + rows, err := conn.Query(ctx, query, s.config.Table[0]) + if err != nil { + return nil, err + } + defer rows.Close() + + var columns []string + for rows.Next() { + var columnName string + err := rows.Scan(&columnName) + if err != nil { + return nil, err + } + columns = append(columns, columnName) + } + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("rows error: %w", err) + } + return columns, nil +} diff --git a/source/config.go b/source/config.go index abd8e61..cf86dab 100644 --- a/source/config.go +++ b/source/config.go @@ -16,6 +16,13 @@ package source +import ( + "fmt" + "strings" + + "github.com/jackc/pgx/v4" +) + type SnapshotMode string const ( @@ -40,14 +47,12 @@ const ( type Config struct { // URL is the connection string for the Postgres database. URL string `json:"url" validate:"required"` - // The name of the table in Postgres that the connector should read. - Table string `json:"table" validate:"required"` - // Comma separated list of column names that should be included in each Record's payload. - Columns []string `json:"columns"` - // Column name that records should use for their `Key` fields. - Key string `json:"key"` - - // Whether or not the plugin will take a snapshot of the entire table before starting cdc mode. + // Table is a List of table names to read from, separated by a comma. + Table []string `json:"table" validate:"required"` + // Key is a list of Key column names per table, ex:"table1:key1,table2:key2", records should use the key values for their `Key` fields. + Key []string `json:"key"` + + // SnapshotMode is whether the plugin will take a snapshot of the entire table before starting cdc mode. SnapshotMode SnapshotMode `json:"snapshotMode" validate:"inclusion=initial|never" default:"initial"` // CDCMode determines how the connector should listen to changes. CDCMode CDCMode `json:"cdcMode" validate:"inclusion=auto|logrepl|long_polling" default:"auto"` @@ -59,3 +64,26 @@ type Config struct { // connector uses logical replication to listen to changes (see CDCMode). LogreplSlotName string `json:"logrepl.slotName" default:"conduitslot"` } + +// Validate validates the provided config values. +func (c Config) Validate() (map[string]string, error) { + // try parsing the url + _, err := pgx.ParseConfig(c.URL) + if err != nil { + return nil, fmt.Errorf("invalid url: %w", err) + } + // todo: when cdcMode "auto" is implemented, change this check + if len(c.Table) != 1 && c.CDCMode == CDCModeLongPolling { + return nil, fmt.Errorf("multi tables are only supported for logrepl CDCMode, please provide only one table") + } + tableKeys := make(map[string]string, len(c.Table)) + for _, pair := range c.Key { + // Split each pair into key and value + parts := strings.Split(pair, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("wrong format for the configuration %q, use comma separated pairs of tables and keys, example: table1:key1,table2:key2", "key") + } + tableKeys[parts[0]] = parts[1] + } + return tableKeys, nil +} diff --git a/source/config_test.go b/source/config_test.go new file mode 100644 index 0000000..5d59360 --- /dev/null +++ b/source/config_test.go @@ -0,0 +1,77 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package source + +import ( + "testing" + + "github.com/matryer/is" +) + +func TestConfig_Validate(t *testing.T) { + testCases := []struct { + name string + cfg Config + wantErr bool + }{{ + name: "valid config", + cfg: Config{ + URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", + Table: []string{"table1", "table2"}, + Key: []string{"table1:key1"}, + CDCMode: CDCModeLogrepl, + }, + wantErr: false, + }, { + name: "invalid postgres url", + cfg: Config{ + URL: "postgresql", + Table: []string{"table1", "table2"}, + Key: []string{"table1:key1"}, + CDCMode: CDCModeLogrepl, + }, + wantErr: true, + }, { + name: "invalid multiple tables for long polling", + cfg: Config{ + URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", + Table: []string{"table1", "table2"}, + Key: []string{"table1:key1"}, + CDCMode: CDCModeLongPolling, + }, + wantErr: true, + }, { + name: "invalid key list format", + cfg: Config{ + URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", + Table: []string{"table1", "table2"}, + Key: []string{"key1,key2"}, + CDCMode: CDCModeLogrepl, + }, + wantErr: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + _, err := tc.cfg.Validate() + if tc.wantErr { + is.True(err != nil) + return + } + is.True(err == nil) + }) + } +} diff --git a/source/logrepl/cdc.go b/source/logrepl/cdc.go index ada92f9..265191c 100644 --- a/source/logrepl/cdc.go +++ b/source/logrepl/cdc.go @@ -36,9 +36,8 @@ type Config struct { Position sdk.Position SlotName string PublicationName string - TableName string - KeyColumnName string - Columns []string + Tables []string + TableKeys map[string]string } // CDCIterator asynchronously listens for events from the logical replication @@ -154,21 +153,30 @@ func (i *CDCIterator) attachSubscription(ctx context.Context, conn *pgx.Conn) er } } - keyColumn, err := i.getKeyColumn(ctx, conn) - if err != nil { - return fmt.Errorf("failed to find key for table %s (try specifying it manually): %w", i.config.TableName, err) + var err error + if i.config.TableKeys == nil { + i.config.TableKeys = make(map[string]string, len(i.config.Tables)) + } + for _, tableName := range i.config.Tables { + // get unprovided table keys + if _, ok := i.config.TableKeys[tableName]; ok { + continue // key was provided manually + } + i.config.TableKeys[tableName], err = i.getTableKeys(ctx, conn, tableName) + if err != nil { + return fmt.Errorf("failed to find key for table %s (try specifying it manually): %w", tableName, err) + } } sub := internal.NewSubscription( conn.Config().Config, i.config.SlotName, i.config.PublicationName, - []string{i.config.TableName}, + i.config.Tables, lsn, NewCDCHandler( internal.NewRelationSet(conn.ConnInfo()), - keyColumn, - i.config.Columns, + i.config.TableKeys, i.records, ).Handle, ) @@ -177,23 +185,19 @@ func (i *CDCIterator) attachSubscription(ctx context.Context, conn *pgx.Conn) er return nil } -// getKeyColumn queries the db for the name of the primary key column for a +// getTableKeys queries the db for the name of the primary key column for a // table if one exists and returns it. -func (i *CDCIterator) getKeyColumn(ctx context.Context, conn *pgx.Conn) (string, error) { - if i.config.KeyColumnName != "" { - return i.config.KeyColumnName, nil - } - +func (i *CDCIterator) getTableKeys(ctx context.Context, conn *pgx.Conn, tableName string) (string, error) { query := `SELECT column_name FROM information_schema.key_column_usage WHERE table_name = $1 AND constraint_name LIKE '%_pkey' LIMIT 1;` - row := conn.QueryRow(ctx, query, i.config.TableName) + row := conn.QueryRow(ctx, query, tableName) var colName string err := row.Scan(&colName) if err != nil { - return "", fmt.Errorf("getKeyColumn query failed: %w", err) + return "", fmt.Errorf("getTableKeys query failed: %w", err) } return colName, nil diff --git a/source/logrepl/cdc_test.go b/source/logrepl/cdc_test.go index c33afd2..6e18b92 100644 --- a/source/logrepl/cdc_test.go +++ b/source/logrepl/cdc_test.go @@ -137,7 +137,7 @@ func TestIterator_Next(t *testing.T) { func testIterator(ctx context.Context, t *testing.T, pool *pgxpool.Pool, table string) *CDCIterator { is := is.New(t) config := Config{ - TableName: table, + Tables: []string{table}, PublicationName: table, // table is random, reuse for publication name SlotName: table, // table is random, reuse for slot name } diff --git a/source/logrepl/handler.go b/source/logrepl/handler.go index 73aaa98..cdbdb37 100644 --- a/source/logrepl/handler.go +++ b/source/logrepl/handler.go @@ -27,28 +27,18 @@ import ( // CDCHandler is responsible for handling logical replication messages, // converting them to a record and sending them to a channel. type CDCHandler struct { - keyColumn string - columns map[string]bool // columns can be used to filter only specific columns + tableKeys map[string]string relationSet *internal.RelationSet out chan<- sdk.Record } func NewCDCHandler( rs *internal.RelationSet, - keyColumn string, - columns []string, + tableKeys map[string]string, out chan<- sdk.Record, ) *CDCHandler { - var columnSet map[string]bool - if len(columns) > 0 { - columnSet = make(map[string]bool) - for _, col := range columns { - columnSet[col] = true - } - } return &CDCHandler{ - keyColumn: keyColumn, - columns: columnSet, + tableKeys: tableKeys, relationSet: rs, out: out, } @@ -106,7 +96,7 @@ func (h *CDCHandler) handleInsert( rec := sdk.Util.Source.NewRecordCreate( LSNToPosition(lsn), h.buildRecordMetadata(rel), - h.buildRecordKey(newValues), + h.buildRecordKey(newValues, rel.RelationName), h.buildRecordPayload(newValues), ) return h.send(ctx, rec) @@ -139,7 +129,7 @@ func (h *CDCHandler) handleUpdate( rec := sdk.Util.Source.NewRecordUpdate( LSNToPosition(lsn), h.buildRecordMetadata(rel), - h.buildRecordKey(newValues), + h.buildRecordKey(newValues, rel.RelationName), h.buildRecordPayload(oldValues), h.buildRecordPayload(newValues), ) @@ -166,7 +156,7 @@ func (h *CDCHandler) handleDelete( rec := sdk.Util.Source.NewRecordDelete( LSNToPosition(lsn), h.buildRecordMetadata(rel), - h.buildRecordKey(oldValues), + h.buildRecordKey(oldValues, rel.RelationName), ) return h.send(ctx, rec) } @@ -190,10 +180,11 @@ func (h *CDCHandler) buildRecordMetadata(relation *pglogrepl.RelationMessage) ma // buildRecordKey takes the values from the message and extracts the key that // matches the configured keyColumnName. -func (h *CDCHandler) buildRecordKey(values map[string]pgtype.Value) sdk.Data { +func (h *CDCHandler) buildRecordKey(values map[string]pgtype.Value, table string) sdk.Data { + keyColumn := h.tableKeys[table] key := make(sdk.StructuredData) for k, v := range values { - if h.keyColumn == k { + if keyColumn == k { key[k] = v.Get() break // TODO add support for composite keys } @@ -209,11 +200,7 @@ func (h *CDCHandler) buildRecordPayload(values map[string]pgtype.Value) sdk.Data } payload := make(sdk.StructuredData) for k, v := range values { - // filter columns if columns are specified - if h.columns == nil || h.columns[k] { - value := v.Get() - payload[k] = value - } + payload[k] = v.Get() } return payload } diff --git a/source/longpoll/snapshot.go b/source/longpoll/snapshot.go index ef4e9de..08ea175 100644 --- a/source/longpoll/snapshot.go +++ b/source/longpoll/snapshot.go @@ -72,7 +72,7 @@ type SnapshotIterator struct { // * NewSnapshotIterator attempts to load the sql rows into the SnapshotIterator and will // immediately begin to return them to subsequent Read calls. // * It acquires a read only transaction lock before reading the table. -// * If Teardown is called while a snpashot is in progress, it will return an +// * If Teardown is called while a snapshot is in progress, it will return an // ErrSnapshotInterrupt error. func NewSnapshotIterator(ctx context.Context, conn *pgx.Conn, table string, columns []string, key string) (*SnapshotIterator, error) { s := &SnapshotIterator{ diff --git a/source/paramgen.go b/source/paramgen.go index d7fcbda..f859bca 100644 --- a/source/paramgen.go +++ b/source/paramgen.go @@ -17,15 +17,9 @@ func (Config) Parameters() map[string]sdk.Parameter { sdk.ValidationInclusion{List: []string{"auto", "logrepl", "long_polling"}}, }, }, - "columns": { - Default: "", - Description: "Comma separated list of column names that should be included in each Record's payload.", - Type: sdk.ParameterTypeString, - Validations: []sdk.Validation{}, - }, "key": { Default: "", - Description: "Column name that records should use for their `key` fields.", + Description: "todo: remove param, Column name that records should use for their `key` fields.", Type: sdk.ParameterTypeString, Validations: []sdk.Validation{}, }, @@ -51,7 +45,7 @@ func (Config) Parameters() map[string]sdk.Parameter { }, "table": { Default: "", - Description: "The name of the table in Postgres that the connector should read.", + Description: "List of table names to read from, separated by a comma.", Type: sdk.ParameterTypeString, Validations: []sdk.Validation{ sdk.ValidationRequired{},