diff --git a/Makefile b/Makefile index 6291d3d..54d88d2 100644 --- a/Makefile +++ b/Makefile @@ -9,17 +9,22 @@ test: # run required docker containers, execute integration tests, stop containers after tests docker compose -f test/docker-compose.yml up --quiet-pull -d --wait go test $(GOTEST_FLAGS) -race ./...; ret=$$?; \ - docker compose -f test/docker-compose.yml down; \ + docker compose -f test/docker-compose.yml down --volumes; \ exit $$ret .PHONY: lint lint: - golangci-lint run + golangci-lint run -v .PHONY: generate generate: go generate ./... +.PHONY: fmt +fmt: + gofumpt -l -w . + gci write --skip-generated . + .PHONY: install-tools install-tools: @echo Installing tools from tools.go diff --git a/destination_integration_test.go b/destination_integration_test.go index a90a41f..51cda7f 100644 --- a/destination_integration_test.go +++ b/destination_integration_test.go @@ -50,75 +50,76 @@ func TestDestination_Write(t *testing.T) { tests := []struct { name string record sdk.Record - }{{ - name: "snapshot", - record: sdk.Record{ - Position: sdk.Position("foo"), - Operation: sdk.OperationSnapshot, - Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, - Key: sdk.StructuredData{"id": 5000}, - Payload: sdk.Change{ - After: sdk.StructuredData{ - "column1": "foo", - "column2": 123, - "column3": true, + }{ + { + name: "snapshot", + record: sdk.Record{ + Position: sdk.Position("foo"), + Operation: sdk.OperationSnapshot, + Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, + Key: sdk.StructuredData{"id": 5000}, + Payload: sdk.Change{ + After: sdk.StructuredData{ + "column1": "foo", + "column2": 123, + "column3": true, + }, }, }, - }, - }, { - name: "create", - record: sdk.Record{ - Position: sdk.Position("foo"), - Operation: sdk.OperationCreate, - Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, - Key: sdk.StructuredData{"id": 5}, - Payload: sdk.Change{ - After: sdk.StructuredData{ - "column1": "foo", - "column2": 456, - "column3": false, + }, { + name: "create", + record: sdk.Record{ + Position: sdk.Position("foo"), + Operation: sdk.OperationCreate, + Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, + Key: sdk.StructuredData{"id": 5}, + Payload: sdk.Change{ + After: sdk.StructuredData{ + "column1": "foo", + "column2": 456, + "column3": false, + }, }, }, - }, - }, { - name: "insert on update (upsert)", - record: sdk.Record{ - Position: sdk.Position("foo"), - Operation: sdk.OperationUpdate, - Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, - Key: sdk.StructuredData{"id": 6}, - Payload: sdk.Change{ - After: sdk.StructuredData{ - "column1": "bar", - "column2": 567, - "column3": true, + }, { + name: "insert on update (upsert)", + record: sdk.Record{ + Position: sdk.Position("foo"), + Operation: sdk.OperationUpdate, + Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, + Key: sdk.StructuredData{"id": 6}, + Payload: sdk.Change{ + After: sdk.StructuredData{ + "column1": "bar", + "column2": 567, + "column3": true, + }, }, }, - }, - }, { - name: "update on conflict", - record: sdk.Record{ - Position: sdk.Position("foo"), - Operation: sdk.OperationUpdate, - Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, - Key: sdk.StructuredData{"id": 1}, - Payload: sdk.Change{ - After: sdk.StructuredData{ - "column1": "foobar", - "column2": 567, - "column3": true, + }, { + name: "update on conflict", + record: sdk.Record{ + Position: sdk.Position("foo"), + Operation: sdk.OperationUpdate, + Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, + Key: sdk.StructuredData{"id": 1}, + Payload: sdk.Change{ + After: sdk.StructuredData{ + "column1": "foobar", + "column2": 567, + "column3": true, + }, }, }, + }, { + name: "delete", + record: sdk.Record{ + Position: sdk.Position("foo"), + Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, + Operation: sdk.OperationDelete, + Key: sdk.StructuredData{"id": 4}, + }, }, - }, { - name: "delete", - record: sdk.Record{ - Position: sdk.Position("foo"), - Metadata: map[string]string{MetadataOpenCDCCollection: tableName}, - Operation: sdk.OperationDelete, - Key: sdk.StructuredData{"id": 4}, - }, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/go.mod b/go.mod index 02003b3..1195499 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,21 @@ module github.com/conduitio/conduit-connector-postgres -go 1.21 +go 1.22 require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/Masterminds/squirrel v1.5.4 github.com/conduitio/conduit-connector-sdk v0.8.0 + github.com/daixiang0/gci v0.12.3 github.com/golangci/golangci-lint v1.57.2 github.com/google/go-cmp v0.6.0 + github.com/google/uuid v1.6.0 github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 github.com/jackc/pgx/v5 v5.5.5 github.com/matryer/is v1.4.1 + golang.org/x/tools v0.19.0 + gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 + mvdan.cc/gofumpt v0.6.0 ) require ( @@ -50,7 +55,6 @@ require ( github.com/ckaznocha/intrange v0.1.1 // indirect github.com/conduitio/conduit-connector-protocol v0.5.0 // indirect github.com/curioswitch/go-reassign v0.2.0 // indirect - github.com/daixiang0/gci v0.12.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/denis-tingaikin/go-header v0.5.0 // indirect github.com/ettle/strcase v0.2.0 // indirect @@ -80,7 +84,6 @@ require ( github.com/golangci/plugin-module-register v0.1.1 // indirect github.com/golangci/revgrep v0.5.2 // indirect github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gordonklaus/ineffassign v0.1.0 // indirect github.com/gostaticanalysis/analysisutil v0.7.1 // indirect github.com/gostaticanalysis/comment v1.4.2 // indirect @@ -209,15 +212,12 @@ require ( golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.19.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f // indirect google.golang.org/grpc v1.59.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect honnef.co/go/tools v0.4.7 // indirect - mvdan.cc/gofumpt v0.6.0 // indirect mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14 // indirect ) diff --git a/source.go b/source.go index 1b16bad..51893de 100644 --- a/source.go +++ b/source.go @@ -54,6 +54,7 @@ func (s *Source) Configure(_ context.Context, cfg map[string]string) error { } return nil } + func (s *Source) Open(ctx context.Context, pos sdk.Position) error { conn, err := pgx.Connect(ctx, s.config.URL) if err != nil { diff --git a/source/config_test.go b/source/config_test.go index 5d59360..1d371ec 100644 --- a/source/config_test.go +++ b/source/config_test.go @@ -25,43 +25,44 @@ func TestConfig_Validate(t *testing.T) { name string cfg Config wantErr bool - }{{ - name: "valid config", - cfg: Config{ - URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", - Table: []string{"table1", "table2"}, - Key: []string{"table1:key1"}, - CDCMode: CDCModeLogrepl, + }{ + { + name: "valid config", + cfg: Config{ + URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", + Table: []string{"table1", "table2"}, + Key: []string{"table1:key1"}, + CDCMode: CDCModeLogrepl, + }, + wantErr: false, + }, { + name: "invalid postgres url", + cfg: Config{ + URL: "postgresql", + Table: []string{"table1", "table2"}, + Key: []string{"table1:key1"}, + CDCMode: CDCModeLogrepl, + }, + wantErr: true, + }, { + name: "invalid multiple tables for long polling", + cfg: Config{ + URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", + Table: []string{"table1", "table2"}, + Key: []string{"table1:key1"}, + CDCMode: CDCModeLongPolling, + }, + wantErr: true, + }, { + name: "invalid key list format", + cfg: Config{ + URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", + Table: []string{"table1", "table2"}, + Key: []string{"key1,key2"}, + CDCMode: CDCModeLogrepl, + }, + wantErr: true, }, - wantErr: false, - }, { - name: "invalid postgres url", - cfg: Config{ - URL: "postgresql", - Table: []string{"table1", "table2"}, - Key: []string{"table1:key1"}, - CDCMode: CDCModeLogrepl, - }, - wantErr: true, - }, { - name: "invalid multiple tables for long polling", - cfg: Config{ - URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", - Table: []string{"table1", "table2"}, - Key: []string{"table1:key1"}, - CDCMode: CDCModeLongPolling, - }, - wantErr: true, - }, { - name: "invalid key list format", - cfg: Config{ - URL: "postgresql://meroxauser:meroxapass@127.0.0.1:5432/meroxadb", - Table: []string{"table1", "table2"}, - Key: []string{"key1,key2"}, - CDCMode: CDCModeLogrepl, - }, - wantErr: true, - }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { diff --git a/source/iterator.go b/source/iterator.go index 329ee67..f055a8e 100644 --- a/source/iterator.go +++ b/source/iterator.go @@ -19,7 +19,6 @@ import ( "github.com/conduitio/conduit-connector-postgres/source/logrepl" "github.com/conduitio/conduit-connector-postgres/source/longpoll" - sdk "github.com/conduitio/conduit-connector-sdk" ) diff --git a/source/logrepl/snapshot_test.go b/source/logrepl/snapshot_test.go index 7964d8b..b7d018a 100644 --- a/source/logrepl/snapshot_test.go +++ b/source/logrepl/snapshot_test.go @@ -156,7 +156,8 @@ func createTestSnapshot(ctx context.Context, t *testing.T, pool *pgxpool.Pool) s // creates a snapshot iterator for testing that hands its connection's cleanup. func createTestSnapshotIterator(ctx context.Context, t *testing.T, - pool *pgxpool.Pool, cfg SnapshotConfig) *SnapshotIterator { + pool *pgxpool.Pool, cfg SnapshotConfig, +) *SnapshotIterator { is := is.New(t) conn, err := pool.Acquire(ctx) diff --git a/source/position/position.go b/source/position/position.go new file mode 100644 index 0000000..ac557b4 --- /dev/null +++ b/source/position/position.go @@ -0,0 +1,83 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package position + +import ( + "encoding/json" + "fmt" + + sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/jackc/pglogrepl" +) + +//go:generate stringer -type=Type -trimprefix Type + +type Type int + +const ( + TypeInitial Type = iota + TypeSnapshot + TypeCDC +) + +type Position struct { + Type Type `json:"type"` + Snapshots SnapshotPositions `json:"snapshots,omitempty"` + LastLSN string `json:"last_lsn,omitempty"` +} + +type SnapshotPositions map[string]SnapshotPosition + +type SnapshotPosition struct { + LastRead int64 `json:"last_read"` + SnapshotEnd int64 `json:"snapshot_end"` + Done bool `json:"done,omitempty"` +} + +func ParseSDKPosition(sdkPos sdk.Position) (Position, error) { + var p Position + + if len(sdkPos) == 0 { + return p, nil + } + + if err := json.Unmarshal(sdkPos, &p); err != nil { + return p, fmt.Errorf("invalid position: %w", err) + } + return p, nil +} + +func (p Position) ToSDKPosition() sdk.Position { + v, err := json.Marshal(p) + if err != nil { + // This should never happen, all Position structs should be valid. + panic(err) + } + return v +} + +// LSN returns the last LSN (Log Sequence Number) in the position. +func (p Position) LSN() (pglogrepl.LSN, error) { + if p.LastLSN == "" { + return 0, nil + } + + lsn, err := pglogrepl.ParseLSN(p.LastLSN) + if err != nil { + return 0, fmt.Errorf("failed to parse LSN in position: %w", err) + } + + return lsn, nil +} diff --git a/source/position/position_test.go b/source/position/position_test.go new file mode 100644 index 0000000..aec2843 --- /dev/null +++ b/source/position/position_test.go @@ -0,0 +1,79 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package position + +import ( + "testing" + + sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/matryer/is" +) + +func Test_ToSDKPosition(t *testing.T) { + is := is.New(t) + + p := Position{ + Type: TypeSnapshot, + Snapshots: SnapshotPositions{ + "orders": {LastRead: 1, SnapshotEnd: 2}, + }, + LastLSN: "4/137515E8", + } + + sdkPos := p.ToSDKPosition() + is.Equal( + string(sdkPos), + `{"type":1,"snapshots":{"orders":{"last_read":1,"snapshot_end":2}},"last_lsn":"4/137515E8"}`, + ) +} + +func Test_PositionLSN(t *testing.T) { + is := is.New(t) + + invalid := Position{LastLSN: "invalid"} + _, err := invalid.LSN() + is.True(err != nil) + is.Equal(err.Error(), "failed to parse LSN in position: failed to parse LSN: expected integer") + + valid := Position{LastLSN: "4/137515E8"} + lsn, noErr := valid.LSN() + is.NoErr(noErr) + is.Equal(uint64(lsn), uint64(17506309608)) +} + +func Test_ParseSDKPosition(t *testing.T) { + is := is.New(t) + + valid := sdk.Position( + []byte( + `{"type":1,"snapshots":{"orders":{"last_read":1,"snapshot_end":2}},"last_lsn":"4/137515E8"}`, + ), + ) + + p, validErr := ParseSDKPosition(valid) + is.NoErr(validErr) + + is.Equal(p, Position{ + Type: TypeSnapshot, + Snapshots: SnapshotPositions{ + "orders": {LastRead: 1, SnapshotEnd: 2}, + }, + LastLSN: "4/137515E8", + }) + + _, invalidErr := ParseSDKPosition(sdk.Position("{")) + is.True(invalidErr != nil) + is.Equal(invalidErr.Error(), "invalid position: unexpected end of JSON input") +} diff --git a/source/position/type_string.go b/source/position/type_string.go new file mode 100644 index 0000000..6746f18 --- /dev/null +++ b/source/position/type_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -type=Type -trimprefix Type"; DO NOT EDIT. + +package position + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[TypeInitial-0] + _ = x[TypeSnapshot-1] + _ = x[TypeCDC-2] +} + +const _Type_name = "InitialSnapshotCDC" + +var _Type_index = [...]uint8{0, 7, 15, 18} + +func (i Type) String() string { + if i < 0 || i >= Type(len(_Type_index)-1) { + return "Type(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Type_name[_Type_index[i]:_Type_index[i+1]] +} diff --git a/source/snapshot/convert.go b/source/snapshot/convert.go new file mode 100644 index 0000000..12bc6c1 --- /dev/null +++ b/source/snapshot/convert.go @@ -0,0 +1,36 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapshot + +import ( + "fmt" +) + +func keyInt64(id any) (int64, error) { + switch t := id.(type) { + case int: + return int64(t), nil + case int8: + return int64(t), nil + case int16: + return int64(t), nil + case int32: + return int64(t), nil + case int64: + return t, nil + default: + return 0, fmt.Errorf("invalid type for key %T", id) + } +} diff --git a/source/snapshot/fetch_worker.go b/source/snapshot/fetch_worker.go new file mode 100644 index 0000000..7b0b93b --- /dev/null +++ b/source/snapshot/fetch_worker.go @@ -0,0 +1,417 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapshot + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + "time" + + "github.com/conduitio/conduit-connector-postgres/source/position" + sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +const defaultFetchSize = 50000 + +var supportedKeyTypes = []string{ + "smallint", + "integer", + "bigint", +} + +type FetchConfig struct { + Table string + Key string + TXSnapshotID string + FetchSize int + Position position.Position +} + +var ( + errTableRequired = errors.New("table name is required") + errKeyRequired = errors.New("table key required") + errInvalidCDCType = errors.New("invalid position type CDC") +) + +func (c FetchConfig) Validate() error { + var errs []error + + if c.Table == "" { + errs = append(errs, errTableRequired) + } + + if c.Key == "" { + errs = append(errs, errKeyRequired) + } + + switch c.Position.Type { + case position.TypeSnapshot, position.TypeInitial: + default: + errs = append(errs, errInvalidCDCType) + } + + return errors.Join(errs...) +} + +type FetchData struct { + Key sdk.StructuredData + Payload sdk.StructuredData + Position position.SnapshotPosition + Table string +} + +type FetchWorker struct { + conf FetchConfig + db *pgxpool.Pool + out chan<- FetchData + + snapshotEnd int64 + lastRead int64 + cursorName string +} + +func NewFetchWorker(db *pgxpool.Pool, out chan<- FetchData, c FetchConfig) *FetchWorker { + f := &FetchWorker{ + conf: c, + db: db, + out: out, + cursorName: "fetcher_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + } + + if f.conf.FetchSize == 0 { + f.conf.FetchSize = defaultFetchSize + } + + if c.Position.Type == position.TypeInitial || c.Position.Snapshots == nil { + return f + } + + if t, ok := c.Position.Snapshots[c.Table]; ok { + f.snapshotEnd = t.SnapshotEnd + f.lastRead = t.LastRead + } + + return f +} + +// Validate will ensure the config is correct. +// * Table and keys exist +// * Key is a primary key +func (f *FetchWorker) Validate(ctx context.Context) error { + if err := f.conf.Validate(); err != nil { + return fmt.Errorf("failed to validate config: %w", err) + } + + tx, err := f.db.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to start tx for validation: %w", err) + } + defer func() { + if err := tx.Rollback(ctx); err != nil { + sdk.Logger(ctx).Warn(). + Err(err). + Msgf("error on validation tx rollback for %q", f.cursorName) + } + }() + + if err := f.validateTable(ctx, f.conf.Table, tx); err != nil { + return fmt.Errorf("failed to validate table: %w", err) + } + + if err := f.validateKey(ctx, f.conf.Table, f.conf.Key, tx); err != nil { + return fmt.Errorf("failed to validate key: %w", err) + } + + return nil +} + +func (f *FetchWorker) Run(ctx context.Context) error { + start := time.Now().UTC() + + tx, err := f.db.BeginTx(ctx, pgx.TxOptions{ + IsoLevel: pgx.RepeatableRead, + AccessMode: pgx.ReadOnly, + }) + if err != nil { + return fmt.Errorf("failed to start tx: %w", err) + } + defer func() { + if err := tx.Rollback(ctx); err != nil { + sdk.Logger(ctx).Warn(). + Err(err). + Msgf("error run tx rollback for %q", f.cursorName) + } + }() + + if err := f.withSnapshot(ctx, tx); err != nil { + return err + } + + if err := f.updateSnapshotEnd(ctx, tx); err != nil { + return fmt.Errorf("failed to update fetch limit: %w", err) + } + + closeCursor, err := f.createCursor(ctx, tx) + if err != nil { + return fmt.Errorf("failed to create cursor: %w", err) + } + defer closeCursor() + + var nfetched int + + for { + n, err := f.fetch(ctx, tx) + if err != nil { + return fmt.Errorf("failed to fetch results: %w", err) + } + + if n == 0 { // end of cursor + break + } + + nfetched += n + + sdk.Logger(ctx).Info(). + Int("rows", nfetched). + Str("table", f.conf.Table). + Dur("elapsed", time.Since(start)). + Msg("fetching rows") + } + + sdk.Logger(ctx).Info(). + Dur("elapsed", time.Since(start)). + Str("table", f.conf.Table). + Msgf("%q snapshot completed", f.conf.Table) + + return nil +} + +func (f *FetchWorker) createCursor(ctx context.Context, tx pgx.Tx) (func(), error) { + // N.B. Prepare as much as possible when the cursor is created. + // Table and columns cannot be prepared. + // This query will scan the table for rows based on the conditions. + selectQuery := "SELECT * FROM " + f.conf.Table + " WHERE " + f.conf.Key + " > $1 AND " + f.conf.Key + " <= $2 ORDER BY $3" + + if _, err := tx.Exec( + ctx, + "DECLARE "+f.cursorName+" CURSOR FOR("+selectQuery+")", + f.lastRead, // range start + f.snapshotEnd, // range end + f.conf.Key, // order by this + ); err != nil { + return nil, err + } + + return func() { + // N.B. The cursor will automatically close when the TX is done. + if _, err := tx.Exec(ctx, "CLOSE "+f.cursorName); err != nil { + sdk.Logger(ctx).Warn(). + Err(err). + Msgf("unexpected error when closing cursor %q", f.cursorName) + } + }, nil +} + +func (f *FetchWorker) updateSnapshotEnd(ctx context.Context, tx pgx.Tx) error { + if f.snapshotEnd > 0 { + return nil + } + + if err := tx.QueryRow( + ctx, + fmt.Sprintf("SELECT max(%s) FROM %s", f.conf.Key, f.conf.Table), + ).Scan(&f.snapshotEnd); err != nil { + return fmt.Errorf("failed to query max on %q.%q: %w", f.conf.Table, f.conf.Key, err) + } + + return nil +} + +func (f *FetchWorker) fetch(ctx context.Context, tx pgx.Tx) (int, error) { + rows, err := tx.Query(ctx, fmt.Sprintf("FETCH %d FROM %s", f.conf.FetchSize, f.cursorName)) + if err != nil { + return 0, fmt.Errorf("failed to fetch rows: %w", err) + } + defer rows.Close() + + var fields []string + for _, f := range rows.FieldDescriptions() { + fields = append(fields, f.Name) + } + + var nread int + + for rows.Next() { + values, err := rows.Values() + if err != nil { + return 0, fmt.Errorf("failed to get values: %w", err) + } + + data, err := f.buildFetchData(fields, values) + if err != nil { + return nread, fmt.Errorf("failed to build fetch data: %w", err) + } + + if err := f.send(ctx, data); err != nil { + return nread, fmt.Errorf("failed to send record: %w", err) + } + + nread++ + } + if rows.Err() != nil { + return 0, fmt.Errorf("failed to read rows: %w", rows.Err()) + } + + return nread, nil +} + +func (f *FetchWorker) send(ctx context.Context, d FetchData) error { + select { + case <-ctx.Done(): + return ctx.Err() + case f.out <- d: + return nil + } +} + +func (f *FetchWorker) buildFetchData(fields []string, values []any) (FetchData, error) { + pos, err := f.buildSnapshotPosition(fields, values) + if err != nil { + return FetchData{}, fmt.Errorf("failed to build snapshot position: %w", err) + } + key, payload := f.buildRecordData(fields, values) + return FetchData{ + Key: key, + Payload: payload, + Position: pos, + Table: f.conf.Table, + }, nil +} + +func (f *FetchWorker) buildSnapshotPosition(fields []string, values []any) (position.SnapshotPosition, error) { + for i, name := range fields { + if name == f.conf.Key { + // Always coerce snapshot position to bigint, pk may be any type of integer. + lastRead, err := keyInt64(values[i]) + if err != nil { + return position.SnapshotPosition{}, fmt.Errorf("failed to parse key: %w", err) + } + return position.SnapshotPosition{ + LastRead: lastRead, + SnapshotEnd: f.snapshotEnd, + Done: f.snapshotEnd == lastRead, + }, nil + } + } + return position.SnapshotPosition{}, fmt.Errorf("key %q not found in fields", f.conf.Key) +} + +func (f *FetchWorker) buildRecordData(fields []string, values []any) (key sdk.StructuredData, payload sdk.StructuredData) { + payload = make(sdk.StructuredData) + + for i, name := range fields { + switch t := values[i].(type) { + case time.Time: // type not supported in sdk.Record + payload[name] = t.UTC().String() + default: + payload[name] = t + } + } + + key = sdk.StructuredData{ + f.conf.Key: payload[f.conf.Key], + } + + return key, payload +} + +func (f *FetchWorker) withSnapshot(ctx context.Context, tx pgx.Tx) error { + if f.conf.TXSnapshotID == "" { + sdk.Logger(ctx).Warn(). + Msgf("fetcher %q starting without transaction snapshot", f.cursorName) + return nil + } + + if _, err := tx.Exec( + ctx, + fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", f.conf.TXSnapshotID), + ); err != nil { + return fmt.Errorf("failed to set tx snapshot %q: %w", f.conf.TXSnapshotID, err) + } + + return nil +} + +func (*FetchWorker) validateKey(ctx context.Context, table, key string, tx pgx.Tx) error { + var dataType string + + if err := tx.QueryRow( + ctx, + "SELECT data_type FROM information_schema.columns WHERE table_name=$1 AND column_name=$2", + table, key, + ).Scan(&dataType); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("key %q not present on table %q", key, table) + } + return fmt.Errorf("unable to check key %q on table %q: %w", key, table, err) + } + + if !slices.Contains(supportedKeyTypes, dataType) { + return fmt.Errorf("key %q of type %q is unsupported", key, dataType) + } + + var isPK bool + + if err := tx.QueryRow( + ctx, + `SELECT EXISTS(SELECT tc.constraint_type + FROM information_schema.constraint_column_usage cu JOIN information_schema.table_constraints tc + ON tc.constraint_name = cu.constraint_name + WHERE cu.table_name=$1 AND cu.column_name=$2)`, + table, key, + ).Scan(&isPK); err != nil { + return fmt.Errorf("unable to determine key %q constraints: %w", key, err) + } + + if !isPK { + return fmt.Errorf("invalid key %q, not a primary key", key) + } + + return nil +} + +func (*FetchWorker) validateTable(ctx context.Context, table string, tx pgx.Tx) error { + var tableExists bool + + if err := tx.QueryRow( + ctx, + "SELECT EXISTS(SELECT tablename FROM pg_tables WHERE tablename=$1)", + table, + ).Scan(&tableExists); err != nil { + return fmt.Errorf("unable to check table %q: %w", table, err) + } + + if !tableExists { + return fmt.Errorf("table %q does not exist", table) + } + + return nil +} diff --git a/source/snapshot/fetch_worker_test.go b/source/snapshot/fetch_worker_test.go new file mode 100644 index 0000000..4c2d789 --- /dev/null +++ b/source/snapshot/fetch_worker_test.go @@ -0,0 +1,474 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapshot + +import ( + "context" + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/conduitio/conduit-connector-postgres/source/position" + "github.com/conduitio/conduit-connector-postgres/test" + sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/matryer/is" + "gopkg.in/tomb.v2" +) + +func Test_NewFetcher(t *testing.T) { + t.Run("with initial position", func(t *testing.T) { + is := is.New(t) + f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- FetchData), FetchConfig{}) + + is.Equal(f.snapshotEnd, int64(0)) + is.Equal(f.lastRead, int64(0)) + }) + + t.Run("with missing position data", func(t *testing.T) { + is := is.New(t) + f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- FetchData), FetchConfig{ + Position: position.Position{ + Type: position.TypeSnapshot, + }, + }) + + is.Equal(f.snapshotEnd, int64(0)) + is.Equal(f.lastRead, int64(0)) + }) + + t.Run("resume from position", func(t *testing.T) { + is := is.New(t) + + f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- FetchData), FetchConfig{ + Position: position.Position{ + Type: position.TypeSnapshot, + Snapshots: position.SnapshotPositions{ + "mytable": {SnapshotEnd: 10, LastRead: 5}, + }, + }, + Table: "mytable", + }) + + is.Equal(f.snapshotEnd, int64(10)) + is.Equal(f.lastRead, int64(5)) + }) +} + +func Test_FetchConfigValidate(t *testing.T) { + t.Run("multiple errors", func(t *testing.T) { + is := is.New(t) + err := (&FetchConfig{}).Validate() + + is.True(errors.Is(err, errTableRequired)) + is.True(errors.Is(err, errKeyRequired)) + }) + + t.Run("missing table", func(t *testing.T) { + is := is.New(t) + tableErr := (&FetchConfig{Key: "id"}).Validate() + + is.True(errors.Is(tableErr, errTableRequired)) + }) + + t.Run("missing table key", func(t *testing.T) { + is := is.New(t) + keyErr := (&FetchConfig{Table: "table"}).Validate() + + is.True(errors.Is(keyErr, errKeyRequired)) + }) + + t.Run("invalid position", func(t *testing.T) { + is := is.New(t) + positionErr := (&FetchConfig{ + Table: "mytable", + Key: "id", + Position: position.Position{Type: position.TypeCDC}, + }).Validate() + + is.True(errors.Is(positionErr, errInvalidCDCType)) + }) + + t.Run("success", func(t *testing.T) { + is := is.New(t) + err := (&FetchConfig{ + Table: "mytable", + Key: "id", + }).Validate() + + is.True(err == nil) + }) +} + +func Test_FetcherValidate(t *testing.T) { + var ( + ctx = context.Background() + pool = test.ConnectPool(ctx, t, test.RegularConnString) + table = test.SetupTestTable(ctx, t, pool) + ) + + t.Run("success", func(t *testing.T) { + is := is.New(t) + f := FetchWorker{ + db: pool, + conf: FetchConfig{ + Table: table, + Key: "id", + }, + } + + is.NoErr(f.Validate(ctx)) + }) + + t.Run("table missing", func(t *testing.T) { + is := is.New(t) + f := FetchWorker{ + db: pool, + conf: FetchConfig{ + Table: "missing_table", + Key: "id", + }, + } + + err := f.Validate(ctx) + is.True(err != nil) + is.True(strings.Contains(err.Error(), `table "missing_table" does not exist`)) + }) + + t.Run("key is wrong type", func(t *testing.T) { + is := is.New(t) + f := FetchWorker{ + db: pool, + conf: FetchConfig{ + Table: table, + Key: "column3", + }, + } + + err1 := f.Validate(ctx) + is.True(err1 != nil) + t.Logf("err: %s\n", err1.Error()) + is.True(strings.Contains(err1.Error(), `failed to validate key: key "column3" of type "boolean" is unsupported`)) + + f.conf.Key = "missing_key" + err2 := f.Validate(ctx) + is.True(err2 != nil) + is.True(strings.Contains( + err2.Error(), + fmt.Sprintf(`key "missing_key" not present on table %q`, table), + )) + }) + + t.Run("key is not pk", func(t *testing.T) { + is := is.New(t) + f := FetchWorker{ + db: pool, + conf: FetchConfig{ + Table: table, + Key: "column2", + }, + } + + err1 := f.Validate(ctx) + is.True(err1 != nil) + is.True(strings.Contains(err1.Error(), `failed to validate key: invalid key "column2", not a primary key`)) + + f.conf.Key = "missing_key" + err2 := f.Validate(ctx) + is.True(err2 != nil) + is.True(strings.Contains( + err2.Error(), + fmt.Sprintf(`key "missing_key" not present on table %q`, table), + )) + }) +} + +func Test_FetcherRun_Initial(t *testing.T) { + var ( + pool = test.ConnectPool(context.Background(), t, test.RegularConnString) + table = test.SetupTestTable(context.Background(), t, pool) + is = is.New(t) + out = make(chan FetchData) + ctx = context.Background() + tt = &tomb.Tomb{} + ) + + f := NewFetchWorker(pool, out, FetchConfig{ + Table: table, + Key: "id", + }) + + tt.Go(func() error { + ctx = tt.Context(ctx) + defer close(out) + + if err := f.Validate(ctx); err != nil { + return err + } + return f.Run(ctx) + }) + + var dd []FetchData + for data := range out { + dd = append(dd, data) + } + + is.NoErr(tt.Err()) + is.True(len(dd) == 4) + + expectedMatch := []sdk.StructuredData{ + {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false}, + {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true}, + {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false}, + {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil}, + } + + for i, d := range dd { + is.Equal(d.Key, sdk.StructuredData{"id": int64(i + 1)}) + is.Equal(d.Payload, expectedMatch[i]) + + is.Equal(d.Position, position.SnapshotPosition{ + LastRead: int64(i + 1), + SnapshotEnd: 4, + Done: i == 3, + }) + is.Equal(d.Table, table) + } +} + +func Test_FetcherRun_Resume(t *testing.T) { + var ( + pool = test.ConnectPool(context.Background(), t, test.RegularConnString) + table = test.SetupTestTable(context.Background(), t, pool) + is = is.New(t) + out = make(chan FetchData) + ctx = context.Background() + tt = &tomb.Tomb{} + ) + + f := NewFetchWorker(pool, out, FetchConfig{ + Table: table, + Key: "id", + Position: position.Position{ + Type: position.TypeSnapshot, + Snapshots: position.SnapshotPositions{ + table: { + SnapshotEnd: 3, + LastRead: 2, + }, + }, + }, + }) + + tt.Go(func() error { + ctx = tt.Context(ctx) + defer close(out) + + if err := f.Validate(ctx); err != nil { + return err + } + return f.Run(ctx) + }) + + var dd []FetchData + for d := range out { + dd = append(dd, d) + } + + is.NoErr(tt.Err()) + is.True(len(dd) == 1) + + // validate generated record + is.Equal(dd[0].Key, sdk.StructuredData{"id": int64(3)}) + is.Equal(dd[0].Payload, sdk.StructuredData{ + "id": int64(3), + "key": []uint8{51}, + "column1": "baz", + "column2": int32(789), + "column3": false, + }) + + is.Equal(dd[0].Position, position.SnapshotPosition{ + LastRead: 3, + SnapshotEnd: 3, + Done: true, + }) + is.Equal(dd[0].Table, table) +} + +func Test_withSnapshot(t *testing.T) { + var ( + is = is.New(t) + ctx = context.Background() + pool = test.ConnectPool(ctx, t, test.RegularConnString) + ) + + conn1, conn1Err := pool.Acquire(ctx) + is.NoErr(conn1Err) + t.Cleanup(func() { conn1.Release() }) + + tx1, tx1Err := conn1.Begin(ctx) + is.NoErr(tx1Err) + t.Cleanup(func() { is.NoErr(tx1.Rollback(ctx)) }) + + var snapshot string + queryErr := tx1.QueryRow(ctx, "SELECT pg_export_snapshot()").Scan(&snapshot) + is.NoErr(queryErr) + + t.Run("with valid snapshot", func(t *testing.T) { + is := is.New(t) + + c, err := pool.Acquire(ctx) + is.NoErr(err) + t.Cleanup(func() { c.Release() }) + + tx, txErr := c.BeginTx(ctx, pgx.TxOptions{ + IsoLevel: pgx.RepeatableRead, + AccessMode: pgx.ReadOnly, + }) + is.NoErr(txErr) + t.Cleanup(func() { _ = tx.Rollback(ctx) }) + + f := FetchWorker{conf: FetchConfig{TXSnapshotID: snapshot}} + + is.NoErr(f.withSnapshot(ctx, tx)) + }) + + t.Run("with invalid snapshot", func(t *testing.T) { + is := is.New(t) + + c, err := pool.Acquire(ctx) + is.NoErr(err) + t.Cleanup(func() { c.Release() }) + + tx, txErr := c.BeginTx(ctx, pgx.TxOptions{ + IsoLevel: pgx.RepeatableRead, + AccessMode: pgx.ReadOnly, + }) + is.NoErr(txErr) + t.Cleanup(func() { is.NoErr(tx.Rollback(ctx)) }) + + f := FetchWorker{conf: FetchConfig{TXSnapshotID: "invalid"}} + + snapErr := f.withSnapshot(ctx, tx) + is.True(strings.Contains(snapErr.Error(), `invalid snapshot identifier: "invalid"`)) + }) + + t.Run("without snapshot", func(t *testing.T) { + is := is.New(t) + + f := FetchWorker{conf: FetchConfig{}} + + snapErr := f.withSnapshot(ctx, nil) + is.NoErr(snapErr) + }) +} + +func Test_send(t *testing.T) { + is := is.New(t) + + ctx, cancel := context.WithCancel(context.Background()) + f := FetchWorker{conf: FetchConfig{}} + + cancel() + + err := f.send(ctx, FetchData{}) + + is.Equal(err, context.Canceled) +} + +func Test_FetchWorker_buildRecordData(t *testing.T) { + var ( + is = is.New(t) + now = time.Now().UTC() + + // special case fields + fields = []string{"id", "time"} + values = []any{1, now} + expectValues = []any{1, now.String()} + ) + + key, payload := (&FetchWorker{ + conf: FetchConfig{Table: "mytable", Key: "id"}, + }).buildRecordData(fields, values) + + is.Equal(len(payload), 2) + for i, k := range fields { + is.Equal(payload[k], expectValues[i]) + } + + is.Equal(len(key), 1) + is.Equal(key["id"], 1) +} + +func Test_FetchWorker_updateSnapshotEnd(t *testing.T) { + var ( + is = is.New(t) + ctx = context.Background() + pool = test.ConnectPool(ctx, t, test.RegularConnString) + table = test.SetupTestTable(ctx, t, pool) + ) + + tx, err := pool.Begin(ctx) + is.NoErr(err) + t.Cleanup(func() { is.NoErr(tx.Rollback(ctx)) }) + + tests := []struct { + desc string + w *FetchWorker + expected int64 + wantErr error + }{ + { + desc: "success", + w: &FetchWorker{conf: FetchConfig{ + Table: table, + Key: "id", + }}, + expected: 4, + }, + { + desc: "skip update when set", + w: &FetchWorker{snapshotEnd: 10}, + expected: 10, + }, + { + desc: "fails to get range", + w: &FetchWorker{conf: FetchConfig{ + Table: table, + Key: "notid", + }}, + wantErr: errors.New(`ERROR: column "notid" does not exist`), + }, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + is := is.New(t) + + err := tc.w.updateSnapshotEnd(ctx, tx) + if tc.wantErr != nil { + is.True(err != nil) + is.True(strings.Contains(err.Error(), tc.wantErr.Error())) + } else { + is.NoErr(err) + is.Equal(tc.w.snapshotEnd, tc.expected) + } + }) + } +} diff --git a/source/snapshot/iterator.go b/source/snapshot/iterator.go new file mode 100644 index 0000000..fd203c9 --- /dev/null +++ b/source/snapshot/iterator.go @@ -0,0 +1,155 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapshot + +import ( + "context" + "errors" + "fmt" + + "github.com/conduitio/conduit-connector-postgres/source/position" + sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/jackc/pgx/v5/pgxpool" + "gopkg.in/tomb.v2" +) + +var ErrIteratorDone = errors.New("snapshot complete") + +type Config struct { + Position sdk.Position + Tables []string + TablesKeys map[string]string + TXSnapshotID string +} + +type Iterator struct { + db *pgxpool.Pool + t *tomb.Tomb + workers []*FetchWorker + + conf Config + + lastPosition position.Position + + data chan FetchData +} + +func NewIterator(ctx context.Context, db *pgxpool.Pool, c Config) (*Iterator, error) { + p, err := position.ParseSDKPosition(c.Position) + if err != nil { + return nil, fmt.Errorf("failed to parse position: %w", err) + } + + if p.Snapshots == nil { + p.Snapshots = make(position.SnapshotPositions) + } + + t, _ := tomb.WithContext(ctx) + i := &Iterator{ + db: db, + t: t, + conf: c, + data: make(chan FetchData), + lastPosition: p, + } + + if err := i.initFetchers(ctx); err != nil { + return nil, fmt.Errorf("failed to initialize table fetchers: %w", err) + } + + i.startWorkers() + + return i, nil +} + +func (i *Iterator) Next(ctx context.Context) (sdk.Record, error) { + select { + case <-ctx.Done(): + return sdk.Record{}, fmt.Errorf("iterator stopped: %w", ctx.Err()) + case d, ok := <-i.data: + if !ok { // closed + if err := i.t.Err(); err != nil { + return sdk.Record{}, fmt.Errorf("fetchers exited unexpectedly: %w", err) + } + return sdk.Record{}, ErrIteratorDone + } + + return i.buildRecord(d), nil + } +} + +func (i *Iterator) Ack(_ context.Context) error { + return nil +} + +func (i *Iterator) Teardown(_ context.Context) error { + if i.t != nil { + i.t.Kill(errors.New("tearing down snapshot iterator")) + } + + return nil +} + +func (i *Iterator) buildRecord(d FetchData) sdk.Record { + // merge this position with latest position + i.lastPosition.Type = position.TypeSnapshot + i.lastPosition.Snapshots[d.Table] = d.Position + + pos := i.lastPosition.ToSDKPosition() + metadata := make(sdk.Metadata) + metadata["postgres.table"] = d.Table + + return sdk.Util.Source.NewRecordCreate(pos, metadata, d.Key, d.Payload) +} + +func (i *Iterator) initFetchers(ctx context.Context) error { + var errs []error + + i.workers = make([]*FetchWorker, len(i.conf.Tables)) + + for j, t := range i.conf.Tables { + w := NewFetchWorker(i.db, i.data, FetchConfig{ + Table: t, + Key: i.conf.TablesKeys[t], + TXSnapshotID: i.conf.TXSnapshotID, + Position: i.lastPosition, + }) + + if err := w.Validate(ctx); err != nil { + errs = append(errs, fmt.Errorf("failed to validate table fetcher %q config: %w", t, err)) + } + + i.workers[j] = w + } + + return errors.Join(errs...) +} + +func (i *Iterator) startWorkers() { + for j := range i.workers { + f := i.workers[j] + i.t.Go(func() error { + ctx := i.t.Context(nil) //nolint:staticcheck // This is the correct usage of tomb.Context + if err := f.Run(ctx); err != nil { + return fmt.Errorf("fetcher for table %q exited: %w", f.conf.Table, err) + } + return nil + }) + } + go func() { + <-i.t.Dead() + close(i.data) + }() +} diff --git a/source/snapshot/iterator_test.go b/source/snapshot/iterator_test.go new file mode 100644 index 0000000..65a0a18 --- /dev/null +++ b/source/snapshot/iterator_test.go @@ -0,0 +1,101 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapshot + +import ( + "context" + "errors" + "testing" + + "github.com/conduitio/conduit-connector-postgres/source/position" + "github.com/conduitio/conduit-connector-postgres/test" + "github.com/matryer/is" +) + +func Test_Iterator_Next(t *testing.T) { + var ( + ctx = context.Background() + pool = test.ConnectPool(ctx, t, test.RegularConnString) + table = test.SetupTestTable(ctx, t, pool) + ) + + t.Run("success", func(t *testing.T) { + is := is.New(t) + + i, err := NewIterator(ctx, pool, Config{ + Position: position.Position{}.ToSDKPosition(), + Tables: []string{table}, + TablesKeys: map[string]string{ + table: "id", + }, + }) + is.NoErr(err) + defer func() { + is.NoErr(i.Teardown(ctx)) + }() + + for j := 1; j <= 4; j++ { + _, err := i.Next(ctx) + is.NoErr(err) + } + + _, err = i.Next(ctx) + is.Equal(err, ErrIteratorDone) + }) + + t.Run("context cancelled", func(t *testing.T) { + is := is.New(t) + + i, err := NewIterator(ctx, pool, Config{ + Position: position.Position{}.ToSDKPosition(), + Tables: []string{table}, + TablesKeys: map[string]string{ + table: "id", + }, + }) + is.NoErr(err) + defer func() { + is.NoErr(i.Teardown(ctx)) + }() + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + _, err = i.Next(cancelCtx) + is.Equal(err.Error(), "iterator stopped: context canceled") + }) + + t.Run("tomb exited", func(t *testing.T) { + is := is.New(t) + cancelCtx, cancel := context.WithCancel(ctx) + + i, err := NewIterator(cancelCtx, pool, Config{ + Position: position.Position{}.ToSDKPosition(), + Tables: []string{table}, + TablesKeys: map[string]string{ + table: "id", + }, + }) + is.NoErr(err) + defer func() { + is.NoErr(i.Teardown(ctx)) + }() + + cancel() + + _, err = i.Next(ctx) + is.True(errors.Is(err, context.Canceled)) + }) +} diff --git a/test/conf.d/postgresql.conf b/test/conf.d/postgresql.conf index 1c5fd27..9e83b22 100644 --- a/test/conf.d/postgresql.conf +++ b/test/conf.d/postgresql.conf @@ -1,3 +1,4 @@ wal_level=logical max_wal_senders=5 -max_replication_slots=5 \ No newline at end of file +max_replication_slots=5 +log_statement='all' diff --git a/test/docker-compose.yml b/test/docker-compose.yml index dcd03a6..ebcf63d 100644 --- a/test/docker-compose.yml +++ b/test/docker-compose.yml @@ -1,9 +1,9 @@ version: '3.4' services: pg-0: - image: docker.io/bitnami/postgresql-repmgr:14 + image: docker.io/bitnami/postgresql-repmgr:15 ports: - - "5432:5432" + - "5433:5432" volumes: - "pg_0_data:/bitnami/postgresql" - "./conf.d/:/bitnami/postgresql/conf/conf.d/" diff --git a/test/helper.go b/test/helper.go index ff70798..775e1f7 100644 --- a/test/helper.go +++ b/test/helper.go @@ -29,10 +29,10 @@ import ( ) // RepmgrConnString is a replication user connection string for the test postgres. -const RepmgrConnString = "postgres://repmgr:repmgrmeroxa@localhost:5432/meroxadb?sslmode=disable" +const RepmgrConnString = "postgres://repmgr:repmgrmeroxa@127.0.0.1:5433/meroxadb?sslmode=disable" // RegularConnString is a non-replication user connection string for the test postgres. -const RegularConnString = "postgres://meroxauser:meroxapass@localhost:5432/meroxadb?sslmode=disable" +const RegularConnString = "postgres://meroxauser:meroxapass@127.0.0.1:5433/meroxadb?sslmode=disable" type Querier interface { Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) diff --git a/tools.go b/tools.go index 2f8457c..6ea1a10 100644 --- a/tools.go +++ b/tools.go @@ -18,5 +18,8 @@ package postgres import ( _ "github.com/conduitio/conduit-connector-sdk/cmd/paramgen" + _ "github.com/daixiang0/gci" _ "github.com/golangci/golangci-lint/cmd/golangci-lint" + _ "golang.org/x/tools/cmd/stringer" + _ "mvdan.cc/gofumpt" )