From 23e0171b93ac3ea75491bf731c70806dac2f7710 Mon Sep 17 00:00:00 2001 From: Lyubo Kamenov Date: Thu, 30 May 2024 15:23:04 -0400 Subject: [PATCH] Handle numeric type in cdc and snapshot (#161) * Handle numeric type in cdc and snapshot --- source/logrepl/cdc_test.go | 12 ++- source/logrepl/combined_test.go | 18 +++- source/logrepl/internal/relationset.go | 8 +- source/logrepl/internal/relationset_test.go | 14 +--- source/snapshot/fetch_worker.go | 33 +++++--- source/snapshot/fetch_worker_test.go | 35 ++++---- source/types/numeric.go | 56 +++++++++++++ source/types/time.go | 26 ++++++ source/types/types.go | 41 +++++++++ source/types/types_test.go | 92 +++++++++++++++++++++ test/helper.go | 15 ++-- 11 files changed, 303 insertions(+), 47 deletions(-) create mode 100644 source/types/numeric.go create mode 100644 source/types/time.go create mode 100644 source/types/types.go create mode 100644 source/types/types_test.go diff --git a/source/logrepl/cdc_test.go b/source/logrepl/cdc_test.go index d4e723c..392bc18 100644 --- a/source/logrepl/cdc_test.go +++ b/source/logrepl/cdc_test.go @@ -25,6 +25,8 @@ import ( "github.com/conduitio/conduit-connector-postgres/source/position" "github.com/conduitio/conduit-connector-postgres/test" sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/matryer/is" @@ -147,8 +149,8 @@ func TestCDCIterator_Next(t *testing.T) { }{ { name: "should detect insert", - setupQuery: `INSERT INTO %s (id, column1, column2, column3) - VALUES (6, 'bizz', 456, false)`, + setupQuery: `INSERT INTO %s (id, column1, column2, column3, column4, column5) + VALUES (6, 'bizz', 456, false, 12.3, 14)`, wantErr: false, want: sdk.Record{ Operation: sdk.OperationCreate, @@ -163,6 +165,8 @@ func TestCDCIterator_Next(t *testing.T) { "column1": "bizz", "column2": int32(456), "column3": false, + "column4": 12.3, + "column5": int64(14), "key": nil, }, }, @@ -187,6 +191,8 @@ func TestCDCIterator_Next(t *testing.T) { "column1": "test cdc updates", "column2": int32(123), "column3": false, + "column4": 12.2, + "column5": int64(4), "key": []uint8("1"), }, }, @@ -228,7 +234,7 @@ func TestCDCIterator_Next(t *testing.T) { tt.want.Metadata[sdk.MetadataReadAt] = got.Metadata[sdk.MetadataReadAt] tt.want.Position = got.Position - is.Equal(got, tt.want) + is.Equal("", cmp.Diff(tt.want, got, cmpopts.IgnoreUnexported(sdk.Record{}))) is.NoErr(i.Ack(ctx, got.Position)) }) } diff --git a/source/logrepl/combined_test.go b/source/logrepl/combined_test.go index 52a7a8c..47ef7ad 100644 --- a/source/logrepl/combined_test.go +++ b/source/logrepl/combined_test.go @@ -155,7 +155,8 @@ func TestCombinedIterator_Next(t *testing.T) { is.NoErr(err) _, err = pool.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, column1, column2, column3) VALUES (6, 'bizz', 1010, false)`, + `INSERT INTO %s (id, column1, column2, column3, column4, column5) + VALUES (6, 'bizz', 1010, false, 872.2, 101)`, table, )) is.NoErr(err) @@ -221,7 +222,8 @@ func TestCombinedIterator_Next(t *testing.T) { }) is.NoErr(err) _, err = pool.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, column1, column2, column3) VALUES (7, 'buzz', 10101, true)`, + `INSERT INTO %s (id, column1, column2, column3, column4, column5) + VALUES (7, 'buzz', 10101, true, 121.9, 51)`, table, )) is.NoErr(err) @@ -262,6 +264,8 @@ func testRecords() []sdk.StructuredData { "column1": "foo", "column2": int32(123), "column3": false, + "column4": 12.2, + "column5": int64(4), }, { "id": int64(2), @@ -269,6 +273,8 @@ func testRecords() []sdk.StructuredData { "column1": "bar", "column2": int32(456), "column3": true, + "column4": 13.42, + "column5": int64(8), }, { "id": int64(3), @@ -276,6 +282,8 @@ func testRecords() []sdk.StructuredData { "column1": "baz", "column2": int32(789), "column3": false, + "column4": nil, + "column5": int64(9), }, { "id": int64(4), @@ -283,6 +291,8 @@ func testRecords() []sdk.StructuredData { "column1": nil, "column2": nil, "column3": nil, + "column4": 91.1, + "column5": nil, }, { "id": int64(6), @@ -290,6 +300,8 @@ func testRecords() []sdk.StructuredData { "column1": "bizz", "column2": int32(1010), "column3": false, + "column4": 872.2, + "column5": int64(101), }, { "id": int64(7), @@ -297,6 +309,8 @@ func testRecords() []sdk.StructuredData { "column1": "buzz", "column2": int32(10101), "column3": true, + "column4": 121.9, + "column5": int64(51), }, } } diff --git a/source/logrepl/internal/relationset.go b/source/logrepl/internal/relationset.go index 70cc971..07fcc95 100644 --- a/source/logrepl/internal/relationset.go +++ b/source/logrepl/internal/relationset.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" + "github.com/conduitio/conduit-connector-postgres/source/types" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgtype" ) @@ -70,7 +71,12 @@ func (rs *RelationSet) Values(id uint32, row *pglogrepl.TupleData) (map[string]a return nil, fmt.Errorf("failed to decode tuple %d: %w", i, err) } - values[col.Name] = val + v, err := types.Format(val) + if err != nil { + return nil, fmt.Errorf("failed to format column %q type %T: %w", col.Name, val, err) + } + + values[col.Name] = v } return values, nil diff --git a/source/logrepl/internal/relationset_test.go b/source/logrepl/internal/relationset_test.go index b1f3344..1c0c5ee 100644 --- a/source/logrepl/internal/relationset_test.go +++ b/source/logrepl/internal/relationset_test.go @@ -273,7 +273,7 @@ func isValuesAllTypes(is *is.I, got map[string]any) { R: 13, Valid: true, }, - "col_date": time.Date(2022, 3, 14, 0, 0, 0, 0, time.UTC), + "col_date": time.Date(2022, 3, 14, 0, 0, 0, 0, time.UTC).UTC().String(), "col_float4": float32(15), "col_float8": float64(16.16), "col_inet": netip.MustParsePrefix("192.168.0.17/32"), @@ -301,13 +301,7 @@ func isValuesAllTypes(is *is.I, got map[string]any) { "col_macaddr": net.HardwareAddr{0x08, 0x00, 0x2b, 0x01, 0x02, 0x26}, "col_macaddr8": net.HardwareAddr{0x08, 0x00, 0x2b, 0x01, 0x02, 0x03, 0x04, 0x27}, "col_money": "$28.00", - "col_numeric": pgtype.Numeric{ - Int: big.NewInt(29292929), - Exp: -2, - NaN: false, - InfinityModifier: pgtype.Finite, - Valid: true, - }, + "col_numeric": float64(292929.29), "col_path": pgtype.Path{ P: []pgtype.Vec2{{X: 30, Y: 31}, {X: 32, Y: 33}, {X: 34, Y: 35}}, Closed: false, @@ -332,8 +326,8 @@ func isValuesAllTypes(is *is.I, got map[string]any) { Valid: true, }, "col_timetz": "04:05:06.789-08", - "col_timestamp": time.Date(2022, 3, 14, 15, 16, 17, 0, time.UTC), - "col_timestamptz": time.Date(2022, 3, 14, 15+8, 16, 17, 0, time.UTC), + "col_timestamp": time.Date(2022, 3, 14, 15, 16, 17, 0, time.UTC).UTC().String(), + "col_timestamptz": time.Date(2022, 3, 14, 15+8, 16, 17, 0, time.UTC).UTC().String(), "col_tsquery": "'fat' & ( 'rat' | 'cat' )", "col_tsvector": "'a' 'and' 'ate' 'cat' 'fat' 'mat' 'on' 'rat' 'sat'", "col_uuid": [16]uint8{0xbd, 0x94, 0xee, 0x0b, 0x56, 0x4f, 0x40, 0x88, 0xbf, 0x4e, 0x8d, 0x5e, 0x62, 0x6c, 0xaf, 0x66}, diff --git a/source/snapshot/fetch_worker.go b/source/snapshot/fetch_worker.go index d4c16f4..2d95b30 100644 --- a/source/snapshot/fetch_worker.go +++ b/source/snapshot/fetch_worker.go @@ -23,6 +23,7 @@ import ( "time" "github.com/conduitio/conduit-connector-postgres/source/position" + "github.com/conduitio/conduit-connector-postgres/source/types" sdk "github.com/conduitio/conduit-connector-sdk" "github.com/google/uuid" "github.com/jackc/pgx/v5" @@ -304,7 +305,12 @@ func (f *FetchWorker) buildFetchData(fields []string, values []any) (FetchData, if err != nil { return FetchData{}, fmt.Errorf("failed to build snapshot position: %w", err) } - key, payload := f.buildRecordData(fields, values) + + key, payload, err := f.buildRecordData(fields, values) + if err != nil { + return FetchData{}, fmt.Errorf("failed to encode record data: %w", err) + } + return FetchData{ Key: key, Payload: payload, @@ -330,23 +336,28 @@ func (f *FetchWorker) buildSnapshotPosition(fields []string, values []any) (posi return position.SnapshotPosition{}, fmt.Errorf("key %q not found in fields", f.conf.Key) } -func (f *FetchWorker) buildRecordData(fields []string, values []any) (key sdk.StructuredData, payload sdk.StructuredData) { - payload = make(sdk.StructuredData) +func (f *FetchWorker) buildRecordData(fields []string, values []any) (sdk.StructuredData, sdk.StructuredData, error) { + var ( + key = make(sdk.StructuredData) + payload = make(sdk.StructuredData) + ) for i, name := range fields { - switch t := values[i].(type) { - case time.Time: // type not supported in sdk.Record - payload[name] = t.UTC().String() - default: - payload[name] = t + v, err := types.Format(values[i]) + if err != nil { + return key, payload, fmt.Errorf("failed to format payload field %q: %w", name, err) } + payload[name] = v } - key = sdk.StructuredData{ - f.conf.Key: payload[f.conf.Key], + k, err := types.Format(payload[f.conf.Key]) + if err != nil { + return key, payload, fmt.Errorf("failed to format key %q: %w", f.conf.Key, err) } - return key, payload + key[f.conf.Key] = k + + return key, payload, nil } func (f *FetchWorker) withSnapshot(ctx context.Context, tx pgx.Tx) error { diff --git a/source/snapshot/fetch_worker_test.go b/source/snapshot/fetch_worker_test.go index 722608d..3532647 100644 --- a/source/snapshot/fetch_worker_test.go +++ b/source/snapshot/fetch_worker_test.go @@ -25,6 +25,7 @@ import ( "github.com/conduitio/conduit-connector-postgres/source/position" "github.com/conduitio/conduit-connector-postgres/test" sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/google/go-cmp/cmp" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/matryer/is" @@ -233,21 +234,24 @@ func Test_FetcherRun_Initial(t *testing.T) { is.True(len(dd) == 4) expectedMatch := []sdk.StructuredData{ - {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false}, - {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true}, - {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false}, - {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil}, + {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": 12.2, "column5": int64(4)}, + {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": 13.42, "column5": int64(8)}, + {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": nil, "column5": int64(9)}, + {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": 91.1, "column5": nil}, } for i, d := range dd { - is.Equal(d.Key, sdk.StructuredData{"id": int64(i + 1)}) - is.Equal(d.Payload, expectedMatch[i]) - - is.Equal(d.Position, position.SnapshotPosition{ - LastRead: int64(i + 1), - SnapshotEnd: 4, + t.Run(fmt.Sprintf("payload_%d", i+1), func(t *testing.T) { + is := is.New(t) + is.Equal(d.Key, sdk.StructuredData{"id": int64(i + 1)}) + is.Equal("", cmp.Diff(expectedMatch[i], d.Payload)) + + is.Equal(d.Position, position.SnapshotPosition{ + LastRead: int64(i + 1), + SnapshotEnd: 4, + }) + is.Equal(d.Table, table) }) - is.Equal(d.Table, table) } } @@ -295,13 +299,15 @@ func Test_FetcherRun_Resume(t *testing.T) { // validate generated record is.Equal(dd[0].Key, sdk.StructuredData{"id": int64(3)}) - is.Equal(dd[0].Payload, sdk.StructuredData{ + is.Equal("", cmp.Diff(dd[0].Payload, sdk.StructuredData{ "id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, - }) + "column4": nil, + "column5": int64(9), + })) is.Equal(dd[0].Position, position.SnapshotPosition{ LastRead: 3, @@ -402,10 +408,11 @@ func Test_FetchWorker_buildRecordData(t *testing.T) { expectValues = []any{1, now.String()} ) - key, payload := (&FetchWorker{ + key, payload, err := (&FetchWorker{ conf: FetchConfig{Table: "mytable", Key: "id"}, }).buildRecordData(fields, values) + is.NoErr(err) is.Equal(len(payload), 2) for i, k := range fields { is.Equal(payload[k], expectValues[i]) diff --git a/source/types/numeric.go b/source/types/numeric.go new file mode 100644 index 0000000..c7dfbf0 --- /dev/null +++ b/source/types/numeric.go @@ -0,0 +1,56 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type NumericFormatter struct{} + +// Format coerces `pgtype.Numeric` to int or double depending on the exponent. +// Returns error when value is invalid. +func (NumericFormatter) Format(num pgtype.Numeric) (any, error) { + // N.B. The numeric type in pgx is represented by two ints. + // When the type in Postgres is defined as `NUMERIC(10)' the scale is assumed to be 0. + // However, pgx may represent the number as two ints e.g. 1200 -> (int=12,exp=2) = 12*10^2. as well + // as a type with zero exponent, e.g. 121 -> (int=121,exp=0). + // Thus, a Numeric type with positive or zero exponent is assumed to be an integer. + if num.Exp >= 0 { + i8v, err := num.Int64Value() + if err != nil { + return nil, err + } + + v, err := i8v.Value() + if err != nil { + return nil, err + } + + return v, nil + } + + f8v, err := num.Float64Value() + if err != nil { + return nil, err + } + + v, err := f8v.Value() + if err != nil { + return nil, err + } + + return v, nil +} diff --git a/source/types/time.go b/source/types/time.go new file mode 100644 index 0000000..24cbcbf --- /dev/null +++ b/source/types/time.go @@ -0,0 +1,26 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "time" +) + +type TimeFormatter struct{} + +// Format coerces `time.Time` to a string representation in UTC tz. +func (n TimeFormatter) Format(t time.Time) (any, error) { + return t.UTC().String(), nil +} diff --git a/source/types/types.go b/source/types/types.go new file mode 100644 index 0000000..a45d46a --- /dev/null +++ b/source/types/types.go @@ -0,0 +1,41 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +var ( + Numeric = NumericFormatter{} + Time = TimeFormatter{} +) + +func Format(v any) (any, error) { + switch t := v.(type) { + case pgtype.Numeric: + return Numeric.Format(t) + case *pgtype.Numeric: + return Numeric.Format(*t) + case time.Time: + return Time.Format(t) + case *time.Time: + return Time.Format(*t) + default: // supported type + return t, nil + } +} diff --git a/source/types/types_test.go b/source/types/types_test.go new file mode 100644 index 0000000..71632bf --- /dev/null +++ b/source/types/types_test.go @@ -0,0 +1,92 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/matryer/is" +) + +func Test_Format(t *testing.T) { + tests := []struct { + name string + input []any + expect []any + }{ + { + name: "int float string bool", + input: []any{ + 1021, 199.2, "foo", true, + }, + expect: []any{ + 1021, 199.2, "foo", true, + }, + }, + { + name: "pgtype.Numeric", + input: []any{ + pgxNumeric(t, "12.2121"), pgxNumeric(t, "101"), &pgtype.Numeric{}, nil, + }, + expect: []any{ + float64(12.2121), int64(101), nil, nil, + }, + }, + { + name: "time.Time", + input: []any{ + func() time.Time { + is := is.New(t) + is.Helper() + t, err := time.Parse(time.DateTime, "2009-11-10 23:00:00") + is.NoErr(err) + return t + }(), + nil, + }, + expect: []any{ + "2009-11-10 23:00:00 +0000 UTC", nil, + }, + }, + } + _ = time.Now() + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + + for i, in := range tc.input { + v, err := Format(in) + is.NoErr(err) + is.Equal(v, tc.expect[i]) + } + }) + } +} + +// as per https://github.com/jackc/pgx/blob/master/pgtype/numeric_test.go#L66 +func pgxNumeric(t *testing.T, num string) pgtype.Numeric { + is := is.New(t) + is.Helper() + + var n pgtype.Numeric + plan := pgtype.NumericCodec{}.PlanScan(nil, pgtype.NumericOID, pgtype.TextFormatCode, &n) + is.True(plan != nil) + is.NoErr(plan.Scan([]byte(num), &n)) + + return n +} diff --git a/test/helper.go b/test/helper.go index 5076cca..2c6e68b 100644 --- a/test/helper.go +++ b/test/helper.go @@ -85,7 +85,10 @@ func SetupTestTable(ctx context.Context, t *testing.T, conn Querier) string { key bytea, column1 varchar(256), column2 integer, - column3 boolean)` + column3 boolean, + column4 numeric(16,3), + column5 numeric(5) + )` query = fmt.Sprintf(query, table) _, err := conn.Exec(ctx, query) is.NoErr(err) @@ -98,11 +101,11 @@ func SetupTestTable(ctx context.Context, t *testing.T, conn Querier) string { }) query = ` - INSERT INTO %s (key, column1, column2, column3) - VALUES ('1', 'foo', 123, false), - ('2', 'bar', 456, true), - ('3', 'baz', 789, false), - ('4', null, null, null)` + INSERT INTO %s (key, column1, column2, column3, column4, column5) + VALUES ('1', 'foo', 123, false, 12.2, 4), + ('2', 'bar', 456, true, 13.42, 8), + ('3', 'baz', 789, false, null, 9), + ('4', null, null, null, 91.1, null)` query = fmt.Sprintf(query, table) _, err = conn.Exec(ctx, query) is.NoErr(err)