From 48ed86ebb650a572241bf5092e2018da80f1a316 Mon Sep 17 00:00:00 2001 From: "Christoph Engelbert (noctarius)" Date: Sun, 30 Jul 2023 15:26:51 +0200 Subject: [PATCH] Added support for enum types. Enums use schema type STRING. --- .../replicationconnection_test.go | 7 +- .../replicationcontext/replicationcontext.go | 18 +- .../replication/sidechannel/sidechannel.go | 38 ++- internal/sysconfig/providers.go | 3 +- .../systemcatalog/snapshotting/snapshotter.go | 15 +- internal/typemanager/builtin_converters.go | 11 + internal/typemanager/coretypes.go | 9 + internal/typemanager/decoderplan.go | 86 +++++ internal/typemanager/lazy_converters.go | 96 ++++++ internal/typemanager/pgtype.go | 2 +- internal/typemanager/rowdecoder.go | 166 ++++++++++ internal/typemanager/typedecoder.go | 29 ++ internal/typemanager/typemanager.go | 252 +++++++------- spi/pgtypes/decoderplan.go | 77 +---- spi/pgtypes/rowdecoder.go | 43 +++ spi/pgtypes/typedecoder.go | 312 ------------------ spi/pgtypes/typemanager.go | 2 + spi/replicationcontext/replicationcontext.go | 7 +- spi/sidechannel/sidechannel.go | 8 +- tests/datatype_test.go | 15 + tests/streamer_test.go | 2 +- 21 files changed, 662 insertions(+), 536 deletions(-) create mode 100644 internal/typemanager/decoderplan.go create mode 100644 internal/typemanager/lazy_converters.go create mode 100644 internal/typemanager/rowdecoder.go create mode 100644 internal/typemanager/typedecoder.go create mode 100644 spi/pgtypes/rowdecoder.go delete mode 100644 spi/pgtypes/typedecoder.go diff --git a/internal/replication/replicationconnection/replicationconnection_test.go b/internal/replication/replicationconnection/replicationconnection_test.go index cb0c7ccb..3442a1dd 100644 --- a/internal/replication/replicationconnection/replicationconnection_test.go +++ b/internal/replication/replicationconnection/replicationconnection_test.go @@ -527,21 +527,22 @@ func (t testReplicationContext) ReadHypertableSchema( } func (t testReplicationContext) SnapshotChunkTable( - chunk *systemcatalog.Chunk, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, chunk *systemcatalog.Chunk, cb sidechannel.SnapshotRowCallback, ) (pgtypes.LSN, error) { return 0, nil } func (t testReplicationContext) FetchHypertableSnapshotBatch( - hypertable *systemcatalog.Hypertable, snapshotName string, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, + snapshotName string, cb sidechannel.SnapshotRowCallback, ) error { return nil } func (t testReplicationContext) ReadSnapshotHighWatermark( - hypertable *systemcatalog.Hypertable, snapshotName string, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, snapshotName string, ) (map[string]any, error) { return nil, nil diff --git a/internal/replication/replicationcontext/replicationcontext.go b/internal/replication/replicationcontext/replicationcontext.go index cb3f5a90..a82f3fe7 100644 --- a/internal/replication/replicationcontext/replicationcontext.go +++ b/internal/replication/replicationcontext/replicationcontext.go @@ -377,24 +377,30 @@ func (rc *replicationContext) ReadHypertableSchema( } func (rc *replicationContext) SnapshotChunkTable( - chunk *systemcatalog.Chunk, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, chunk *systemcatalog.Chunk, cb sidechannel.SnapshotRowCallback, ) (pgtypes.LSN, error) { - return rc.sideChannel.SnapshotChunkTable(chunk, rc.snapshotBatchSize, cb) + // FIXME: remove the intermediate function? + return rc.sideChannel.SnapshotChunkTable(rowDecoderFactory, chunk, rc.snapshotBatchSize, cb) } func (rc *replicationContext) FetchHypertableSnapshotBatch( - hypertable *systemcatalog.Hypertable, snapshotName string, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, + snapshotName string, cb sidechannel.SnapshotRowCallback, ) error { - return rc.sideChannel.FetchHypertableSnapshotBatch(hypertable, snapshotName, rc.snapshotBatchSize, cb) + // FIXME: remove the intermediate function? + return rc.sideChannel.FetchHypertableSnapshotBatch( + rowDecoderFactory, hypertable, snapshotName, rc.snapshotBatchSize, cb, + ) } func (rc *replicationContext) ReadSnapshotHighWatermark( - hypertable *systemcatalog.Hypertable, snapshotName string, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, snapshotName string, ) (map[string]any, error) { - return rc.sideChannel.ReadSnapshotHighWatermark(hypertable, snapshotName) + // FIXME: remove the intermediate function? + return rc.sideChannel.ReadSnapshotHighWatermark(rowDecoderFactory, hypertable, snapshotName) } func (rc *replicationContext) ReadReplicaIdentity( diff --git a/internal/replication/sidechannel/sidechannel.go b/internal/replication/sidechannel/sidechannel.go index 82fd2418..724112a2 100644 --- a/internal/replication/sidechannel/sidechannel.go +++ b/internal/replication/sidechannel/sidechannel.go @@ -328,7 +328,8 @@ func (sc *sideChannel) DetachTablesFromPublication( } func (sc *sideChannel) SnapshotChunkTable( - chunk *systemcatalog.Chunk, snapshotBatchSize int, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, chunk *systemcatalog.Chunk, + snapshotBatchSize int, cb sidechannel.SnapshotRowCallback, ) (pgtypes.LSN, error) { var currentLSN pgtypes.LSN = 0 @@ -338,14 +339,15 @@ func (sc *sideChannel) SnapshotChunkTable( "DECLARE %s SCROLL CURSOR FOR SELECT * FROM %s", cursorName, chunk.CanonicalName(), ) + callback := func(lsn pgtypes.LSN, values map[string]any) error { + if currentLSN == 0 { + currentLSN = lsn + } + return cb(lsn, values) + } + if err := sc.snapshotTableWithCursor( - cursorQuery, cursorName, nil, snapshotBatchSize, - func(lsn pgtypes.LSN, values map[string]any) error { - if currentLSN == 0 { - currentLSN = lsn - } - return cb(lsn, values) - }, + rowDecoderFactory, cursorQuery, cursorName, nil, snapshotBatchSize, callback, ); err != nil { return 0, errors.Wrap(err, 0) } @@ -354,8 +356,8 @@ func (sc *sideChannel) SnapshotChunkTable( } func (sc *sideChannel) FetchHypertableSnapshotBatch( - hypertable *systemcatalog.Hypertable, snapshotName string, - snapshotBatchSize int, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, + snapshotName string, snapshotBatchSize int, cb sidechannel.SnapshotRowCallback, ) error { index, present := hypertable.Columns().SnapshotIndex() @@ -419,13 +421,15 @@ func (sc *sideChannel) FetchHypertableSnapshotBatch( return cb(lsn, values) } - return sc.snapshotTableWithCursor(cursorQuery, cursorName, &snapshotName, snapshotBatchSize, hook) + return sc.snapshotTableWithCursor( + rowDecoderFactory, cursorQuery, cursorName, &snapshotName, snapshotBatchSize, hook, + ) }, ) } func (sc *sideChannel) ReadSnapshotHighWatermark( - hypertable *systemcatalog.Hypertable, snapshotName string, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, snapshotName string, ) (values map[string]any, err error) { index, present := hypertable.Columns().SnapshotIndex() @@ -451,7 +455,7 @@ func (sc *sideChannel) ReadSnapshotHighWatermark( return session.queryFunc(func(row pgx.Row) error { rows := row.(pgx.Rows) - rowDecoder, err := pgtypes.NewRowDecoder(rows.FieldDescriptions()) + rowDecoder, err := rowDecoderFactory(rows.FieldDescriptions()) if err != nil { return errors.Wrap(err, 0) } @@ -699,8 +703,8 @@ func (sc *sideChannel) readHypertableSchema0( } func (sc *sideChannel) snapshotTableWithCursor( - cursorQuery, cursorName string, snapshotName *string, - snapshotBatchSize int, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, cursorQuery, cursorName string, + snapshotName *string, snapshotBatchSize int, cb sidechannel.SnapshotRowCallback, ) error { return sc.newSession(time.Minute*60, func(session *session) error { @@ -726,7 +730,7 @@ func (sc *sideChannel) snapshotTableWithCursor( return errors.Wrap(err, 0) } - var rowDecoder *pgtypes.RowDecoder + var rowDecoder pgtypes.RowDecoder for { count := 0 if err := session.queryFunc(func(row pgx.Row) error { @@ -734,7 +738,7 @@ func (sc *sideChannel) snapshotTableWithCursor( if rowDecoder == nil { // Initialize the row decoder based on the returned field descriptions - rd, err := pgtypes.NewRowDecoder(rows.FieldDescriptions()) + rd, err := rowDecoderFactory(rows.FieldDescriptions()) if err != nil { return errors.Wrap(err, 0) } diff --git a/internal/sysconfig/providers.go b/internal/sysconfig/providers.go index 4d796f6f..b43d9d0e 100644 --- a/internal/sysconfig/providers.go +++ b/internal/sysconfig/providers.go @@ -79,7 +79,8 @@ type SinkManagerProvider = func( ) sink.Manager type SnapshotterProvider = func( - *config.Config, replicationcontext.ReplicationContext, task.TaskManager, publication.PublicationManager, + *config.Config, replicationcontext.ReplicationContext, + task.TaskManager, publication.PublicationManager, pgtypes.TypeManager, ) (*snapshotting.Snapshotter, error) type ReplicationChannelProvider = func( diff --git a/internal/systemcatalog/snapshotting/snapshotter.go b/internal/systemcatalog/snapshotting/snapshotter.go index 84b97b78..6ed8a9f8 100644 --- a/internal/systemcatalog/snapshotting/snapshotter.go +++ b/internal/systemcatalog/snapshotting/snapshotter.go @@ -45,6 +45,7 @@ type Snapshotter struct { partitionCount uint64 replicationContext replicationcontext.ReplicationContext taskManager task.TaskManager + typeManager pgtypes.TypeManager publicationManager publication.PublicationManager snapshotQueues []chan SnapshotTask shutdownAwaiter *waiting.MultiShutdownAwaiter @@ -54,15 +55,17 @@ type Snapshotter struct { func NewSnapshotterFromConfig( c *config.Config, replicationContext replicationcontext.ReplicationContext, taskManager task.TaskManager, publicationManager publication.PublicationManager, + typeManager pgtypes.TypeManager, ) (*Snapshotter, error) { parallelism := config.GetOrDefault(c, config.PropertySnapshotterParallelism, uint8(5)) - return NewSnapshotter(parallelism, replicationContext, taskManager, publicationManager) + return NewSnapshotter(parallelism, replicationContext, taskManager, publicationManager, typeManager) } func NewSnapshotter( partitionCount uint8, replicationContext replicationcontext.ReplicationContext, taskManager task.TaskManager, publicationManager publication.PublicationManager, + typeManager pgtypes.TypeManager, ) (*Snapshotter, error) { snapshotQueues := make([]chan SnapshotTask, partitionCount) @@ -79,6 +82,7 @@ func NewSnapshotter( partitionCount: uint64(partitionCount), replicationContext: replicationContext, taskManager: taskManager, + typeManager: typeManager, publicationManager: publicationManager, snapshotQueues: snapshotQueues, logger: logger, @@ -179,7 +183,8 @@ func (s *Snapshotter) snapshotChunk( } lsn, err := s.replicationContext.SnapshotChunkTable( - t.Chunk, func(lsn pgtypes.LSN, values map[string]any) error { + s.typeManager.GetOrPlanRowDecoder, t.Chunk, + func(lsn pgtypes.LSN, values map[string]any) error { return s.taskManager.EnqueueTask(func(notificator task.Notificator) { callback := func(handler eventhandlers.HypertableReplicationEventHandler) error { return handler.OnReadEvent(lsn, t.Hypertable, t.Chunk, values) @@ -218,7 +223,9 @@ func (s *Snapshotter) snapshotHypertable( // Initialize the watermark or update the high watermark after a restart if created || t.nextSnapshotFetch { - highWatermark, err := s.replicationContext.ReadSnapshotHighWatermark(t.Hypertable, *t.SnapshotName) + highWatermark, err := s.replicationContext.ReadSnapshotHighWatermark( + s.typeManager.GetOrPlanRowDecoder, t.Hypertable, *t.SnapshotName, + ) if err != nil { return errors.Wrap(err, 0) } @@ -269,7 +276,7 @@ func (s *Snapshotter) runSnapshotFetchBatch( ) error { return s.replicationContext.FetchHypertableSnapshotBatch( - t.Hypertable, *t.SnapshotName, + s.typeManager.GetOrPlanRowDecoder, t.Hypertable, *t.SnapshotName, func(lsn pgtypes.LSN, values map[string]any) error { return s.taskManager.EnqueueTask(func(notificator task.Notificator) { notificator.NotifyHypertableReplicationEventHandler( diff --git a/internal/typemanager/builtin_converters.go b/internal/typemanager/builtin_converters.go index 22f18ed1..1eb1e798 100644 --- a/internal/typemanager/builtin_converters.go +++ b/internal/typemanager/builtin_converters.go @@ -95,6 +95,17 @@ func reflectiveArrayConverter( } } +func enum2string( + _ uint32, value any, +) (any, error) { + + switch v := value.(type) { + case string: + return v, nil + } + return nil, errIllegalValue +} + func float42float( _ uint32, value any, ) (any, error) { diff --git a/internal/typemanager/coretypes.go b/internal/typemanager/coretypes.go index 62d7371d..4699c085 100644 --- a/internal/typemanager/coretypes.go +++ b/internal/typemanager/coretypes.go @@ -201,6 +201,9 @@ var coreTypes = map[uint32]typeRegistration{ pgtypes.MacAddr8OID: { schemaType: schema.STRING, converter: macaddr2text, + typeMapTypeFactory: func(_ *pgtype.Map, typ pgtypes.PgType) *pgtype.Type { + return &pgtype.Type{Name: "macaddr8", OID: pgtypes.MacAddr8OID, Codec: pgtype.MacaddrCodec{}} + }, }, pgtypes.MacAddrArray8OID: { schemaType: schema.ARRAY, @@ -318,6 +321,9 @@ var coreTypes = map[uint32]typeRegistration{ pgtypes.TimeTZOID: { schemaType: schema.STRING, converter: time2text, + typeMapTypeFactory: func(_ *pgtype.Map, typ pgtypes.PgType) *pgtype.Type { + return &pgtype.Type{Name: "timetz", OID: pgtypes.TimeTZOID, Codec: &pgtypes.TimetzCodec{}} + }, }, pgtypes.TimeTZArrayOID: { schemaType: schema.ARRAY, @@ -326,6 +332,9 @@ var coreTypes = map[uint32]typeRegistration{ }, pgtypes.XmlOID: { schemaType: schema.STRING, + typeMapTypeFactory: func(_ *pgtype.Map, typ pgtypes.PgType) *pgtype.Type { + return &pgtype.Type{Name: "xml", OID: pgtypes.XmlOID, Codec: pgtypes.XmlCodec{}} + }, }, pgtypes.XmlArrayOID: { schemaType: schema.ARRAY, diff --git a/internal/typemanager/decoderplan.go b/internal/typemanager/decoderplan.go new file mode 100644 index 00000000..99af7780 --- /dev/null +++ b/internal/typemanager/decoderplan.go @@ -0,0 +1,86 @@ +package typemanager + +import ( + "github.com/go-errors/errors" + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgtype" + "github.com/noctarius/timescaledb-event-streamer/internal/functional" + "github.com/noctarius/timescaledb-event-streamer/spi/pgtypes" +) + +type tupleDecoder func(column *pglogrepl.TupleDataColumn, values map[string]any) error + +type tupleCodec func(data []byte, binary bool) (any, error) + +func planTupleDecoder( + typeManager *typeManager, relation *pgtypes.RelationMessage, +) (pgtypes.TupleDecoderPlan, error) { + + decoders := make([]tupleDecoder, 0) + + for _, column := range relation.Columns { + codec := func(data []byte, binary bool) (any, error) { + return string(data), nil + } + if pgxType, ok := typeManager.typeMap.TypeForOID(column.DataType); ok { + codec = func(data []byte, binary bool) (any, error) { + dataformat := int16(pgtype.TextFormatCode) + if binary { + dataformat = pgtype.BinaryFormatCode + } + return pgxType.Codec.DecodeValue(typeManager.typeMap, column.DataType, dataformat, data) + } + } + + decoders = append(decoders, func(dataType uint32, name string, codec tupleCodec) tupleDecoder { + return func(column *pglogrepl.TupleDataColumn, values map[string]any) error { + switch column.DataType { + case 'n': // null + values[name] = nil + case 'u': // unchanged toast + // This TOAST value was not changed. TOAST values are not stored in the tuple, and + // logical replication doesn't want to spend a disk read to fetch its value for you. + case 't': // text (basically anything other than the two above) + val, err := codec(column.Data, false) + if err != nil { + return errors.Errorf("error decoding column data: %s", err) + } + values[name] = val + case 'b': // binary data + val, err := codec(column.Data, true) + if err != nil { + return errors.Errorf("error decoding column data: %s", err) + } + values[name] = val + } + return nil + } + }(column.DataType, column.Name, codec)) + } + + return &tupleDecoderPlan{ + decoders: decoders, + }, nil +} + +type tupleDecoderPlan struct { + decoders []tupleDecoder +} + +func (tdp *tupleDecoderPlan) Decode( + tupleData *pglogrepl.TupleData, +) (map[string]any, error) { + + if tupleData == nil { + return functional.Zero[map[string]any](), nil + } + + values := map[string]any{} + for i, decoder := range tdp.decoders { + column := tupleData.Columns[i] + if err := decoder(column, values); err != nil { + return nil, err + } + } + return values, nil +} diff --git a/internal/typemanager/lazy_converters.go b/internal/typemanager/lazy_converters.go new file mode 100644 index 00000000..223e3141 --- /dev/null +++ b/internal/typemanager/lazy_converters.go @@ -0,0 +1,96 @@ +package typemanager + +import ( + "github.com/go-errors/errors" + "github.com/noctarius/timescaledb-event-streamer/spi/pgtypes" + "github.com/noctarius/timescaledb-event-streamer/spi/schema" + "reflect" +) + +type lazyArrayConverter struct { + typeManager *typeManager + oidElement uint32 + converter pgtypes.TypeConverter +} + +func (lac *lazyArrayConverter) convert( + oid uint32, value any, +) (any, error) { + + if lac.converter == nil { + elementType, err := lac.typeManager.ResolveDataType(lac.oidElement) + if err != nil { + return nil, err + } + + elementConverter, err := lac.typeManager.ResolveTypeConverter(lac.oidElement) + if err != nil { + return nil, err + } + + reflectiveType, err := schemaType2ReflectiveType(elementType.SchemaType()) + if err != nil { + return nil, err + } + + targetType := reflect.SliceOf(reflectiveType) + lac.converter = reflectiveArrayConverter(lac.oidElement, targetType, elementConverter) + } + + return lac.converter(oid, value) +} + +/*type lazyCustomTypeConverter struct { + typeManager *typeManager + oidElement uint32 + converter pgtypes.TypeConverter +} + +func (lctc *lazyCustomTypeConverter) convert( + oid uint32, value any, +) (any, error) { + if lctc.converter == nil { + typ, err := lctc.typeManager.ResolveDataType(lctc.oidElement) + if err != nil { + return nil, err + } + + if typ.Kind() == pgtypes.EnumKind { + lctc.converter = enum2string + } else { + return nil, errIllegalValue + } + } + + return lctc.converter(oid, value) +}*/ + +func schemaType2ReflectiveType( + schemaType schema.Type, +) (reflect.Type, error) { + + switch schemaType { + case schema.INT8: + return int8Type, nil + case schema.INT16: + return int16Type, nil + case schema.INT32: + return int32Type, nil + case schema.INT64: + return int64Type, nil + case schema.FLOAT32: + return float32Type, nil + case schema.FLOAT64: + return float64Type, nil + case schema.BOOLEAN: + return booleanType, nil + case schema.STRING: + return stringType, nil + case schema.BYTES: + return byteaType, nil + case schema.MAP: + return mapType, nil + default: + return nil, errors.Errorf("Unsupported schema type %s", string(schemaType)) + } +} diff --git a/internal/typemanager/pgtype.go b/internal/typemanager/pgtype.go index e9fbc51f..c18efe6e 100644 --- a/internal/typemanager/pgtype.go +++ b/internal/typemanager/pgtype.go @@ -160,7 +160,7 @@ func (t *pgType) EnumValues() []string { if t.enumValues == nil { return []string{} } - enumValues := make([]string, 0, len(t.enumValues)) + enumValues := make([]string, len(t.enumValues)) copy(enumValues, t.enumValues) return enumValues } diff --git a/internal/typemanager/rowdecoder.go b/internal/typemanager/rowdecoder.go new file mode 100644 index 00000000..38da2be8 --- /dev/null +++ b/internal/typemanager/rowdecoder.go @@ -0,0 +1,166 @@ +package typemanager + +import ( + "github.com/go-errors/errors" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type rowDecoder struct { + decoders []func(src []byte) (any, error) + fields []pgconn.FieldDescription +} + +func newRowDecoder( + typeManager *typeManager, fields []pgconn.FieldDescription, +) (*rowDecoder, error) { + + decoders := make([]func(src []byte) (any, error), 0) + for _, field := range fields { + if decoder, err := findTypeDecoder(typeManager, field); err != nil { + return nil, errors.Wrap(err, 0) + } else { + // Store a decoder wrapper for easier usage + decoders = append(decoders, decoder) + } + } + return &rowDecoder{ + decoders: decoders, + fields: fields, + }, nil +} + +func (rd *rowDecoder) DecodeRowsMapAndSink( + rows pgx.Rows, sink func(values map[string]any) error, +) error { + + if !rd.compatible(rows.FieldDescriptions()) { + return errors.Errorf("incompatible rows instance provided") + } + + // Initial error check + if rows.Err() != nil { + return errors.Wrap(rows.Err(), 0) + } + defer rows.Close() + + for rows.Next() { + values, err := rd.Decode(rows.RawValues()) + if err != nil { + return errors.Wrap(err, 0) + } + + resultSet := make(map[string]any, 0) + for i, field := range rd.fields { + resultSet[field.Name] = values[i] + } + if err := sink(resultSet); err != nil { + return errors.Wrap(err, 0) + } + } + if rows.Err() != nil { + return errors.Wrap(rows.Err(), 0) + } + return nil +} + +func (rd *rowDecoder) DecodeRowsAndSink( + rows pgx.Rows, sink func(values []any) error, +) error { + + if !rd.compatible(rows.FieldDescriptions()) { + return errors.Errorf("incompatible rows instance provided") + } + + // Initial error check + if rows.Err() != nil { + return errors.Wrap(rows.Err(), 0) + } + defer rows.Close() + + for rows.Next() { + if err := rd.DecodeAndSink(rows.RawValues(), sink); err != nil { + return errors.Wrap(err, 0) + } + } + if rows.Err() != nil { + return errors.Wrap(rows.Err(), 0) + } + return nil +} + +func (rd *rowDecoder) Decode( + rawRow [][]byte, +) ([]any, error) { + + values := make([]any, 0) + for i, decoder := range rd.decoders { + if v, err := decoder(rawRow[i]); err != nil { + return nil, errors.Wrap(err, 0) + } else { + values = append(values, v) + } + } + return values, nil +} + +func (rd *rowDecoder) DecodeAndSink( + rawRow [][]byte, sink func(values []any) error, +) error { + + if values, err := rd.Decode(rawRow); err != nil { + return errors.Wrap(err, 0) + } else { + return sink(values) + } +} + +func (rd *rowDecoder) DecodeMapAndSink( + rawRow [][]byte, sink func(values map[string]any) error, +) error { + + if values, err := rd.Decode(rawRow); err != nil { + return errors.Wrap(err, 0) + } else { + resultSet := make(map[string]any) + for i, field := range rd.fields { + resultSet[field.Name] = values[i] + } + if err := sink(resultSet); err != nil { + return errors.Wrap(err, 0) + } + } + return nil +} + +func (rd *rowDecoder) compatible( + other []pgconn.FieldDescription, +) bool { + + if len(rd.fields) != len(other) { + return false + } + + for i, f := range rd.fields { + o := other[i] + if f.Format != o.Format { + return false + } + if f.DataTypeOID != o.DataTypeOID { + return false + } + if f.Name != o.Name { + return false + } + if f.DataTypeSize != o.DataTypeSize { + return false + } + if f.TypeModifier != o.TypeModifier { + return false + } + // Can we reuse the same decoder for all chunks? 🤔 + // if f.TableAttributeNumber != o.TableAttributeNumber { return false } + // if f.TableOID != o.TableOID { return false } + } + return true +} diff --git a/internal/typemanager/typedecoder.go b/internal/typemanager/typedecoder.go new file mode 100644 index 00000000..60f3ad98 --- /dev/null +++ b/internal/typemanager/typedecoder.go @@ -0,0 +1,29 @@ +package typemanager + +import ( + "github.com/go-errors/errors" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" +) + +type typeDecoder func(src []byte) (any, error) + +func findTypeDecoder( + typeManager *typeManager, field pgconn.FieldDescription, +) (typeDecoder, error) { + + if pgType, ok := typeManager.typeMap.TypeForOID(field.DataTypeOID); ok { + // Store a decoder wrapper for easier usage + return asTypeDecoder(typeManager, pgType, field), nil + } + return nil, errors.Errorf("Unsupported type oid: %d", field.DataTypeOID) +} + +func asTypeDecoder( + typeManager *typeManager, pgType *pgtype.Type, field pgconn.FieldDescription, +) func(src []byte) (any, error) { + + return func(src []byte) (any, error) { + return pgType.Codec.DecodeValue(typeManager.typeMap, field.DataTypeOID, field.Format, src) + } +} diff --git a/internal/typemanager/typemanager.go b/internal/typemanager/typemanager.go index c62d4d94..3b15e169 100644 --- a/internal/typemanager/typemanager.go +++ b/internal/typemanager/typemanager.go @@ -21,13 +21,13 @@ import ( "fmt" "github.com/go-errors/errors" "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/noctarius/timescaledb-event-streamer/internal/containers" "github.com/noctarius/timescaledb-event-streamer/internal/logging" "github.com/noctarius/timescaledb-event-streamer/spi/pgtypes" "github.com/noctarius/timescaledb-event-streamer/spi/schema" "github.com/noctarius/timescaledb-event-streamer/spi/sidechannel" - "github.com/samber/lo" "reflect" "sync" ) @@ -45,13 +45,16 @@ var ( mapType = reflect.TypeOf(map[string]any{}) ) +type typeMapTypeFactory func(typeMap *pgtype.Map, typ pgtypes.PgType) *pgtype.Type + type typeRegistration struct { - schemaType schema.Type - schemaBuilder schema.Builder - isArray bool - oidElement uint32 - converter pgtypes.TypeConverter - codec pgtype.Codec + schemaType schema.Type + schemaBuilder schema.Builder + isArray bool + oidElement uint32 + converter pgtypes.TypeConverter + codec pgtype.Codec + typeMapTypeFactory typeMapTypeFactory } // errIllegalValue represents an illegal type conversion request @@ -62,12 +65,15 @@ type typeManager struct { logger *logging.Logger sideChannel sidechannel.SideChannel + typeMap *pgtype.Map + typeCache map[uint32]pgtypes.PgType typeNameCache map[string]uint32 typeCacheMutex sync.RWMutex optimizedTypes map[uint32]pgtypes.PgType optimizedConverters map[uint32]typeRegistration + dynamicConverters map[uint32]typeRegistration cachedDecoderPlans *containers.ConcurrentMap[uint32, pgtypes.TupleDecoderPlan] } @@ -84,12 +90,15 @@ func NewTypeManager( logger: logger, sideChannel: sideChannel, + typeMap: pgtype.NewMap(), + typeCache: make(map[uint32]pgtypes.PgType), typeNameCache: make(map[string]uint32), typeCacheMutex: sync.RWMutex{}, optimizedTypes: make(map[uint32]pgtypes.PgType), optimizedConverters: make(map[uint32]typeRegistration), + dynamicConverters: make(map[uint32]typeRegistration), cachedDecoderPlans: containers.NewConcurrentMap[uint32, pgtypes.TupleDecoderPlan](), } @@ -103,61 +112,7 @@ func (tm *typeManager) initialize() error { tm.typeCacheMutex.Lock() defer tm.typeCacheMutex.Unlock() - // Extract keys from the built-in core types - coreTypesSlice := lo.Keys(coreTypes) - - if err := tm.sideChannel.ReadPgTypes(tm.typeFactory, func(typ pgtypes.PgType) error { - if lo.IndexOf(coreTypesSlice, typ.Oid()) != -1 { - return nil - } - - tm.typeCache[typ.Oid()] = typ - tm.typeNameCache[typ.Name()] = typ.Oid() - - if registration, present := optimizedTypes[typ.Name()]; present { - if t, ok := typ.(*pgType); ok { - t.schemaType = registration.schemaType - } - - tm.optimizedTypes[typ.Oid()] = typ - - var converter pgtypes.TypeConverter - if registration.isArray { - lazyConverter := &lazyArrayConverter{ - typeManager: tm, - oidElement: typ.OidElement(), - } - converter = lazyConverter.convert - } else { - converter = registration.converter - } - - if converter == nil { - return errors.Errorf("Type %s has no assigned value converter", typ.Name()) - } - - tm.optimizedConverters[typ.Oid()] = typeRegistration{ - schemaType: registration.schemaType, - schemaBuilder: registration.schemaBuilder, - isArray: registration.isArray, - oidElement: typ.OidElement(), - converter: converter, - codec: registration.codec, - } - - if typ.IsArray() { - if elementType, present := pgtypes.GetType(typ.OidElement()); present { - pgtypes.RegisterType(&pgtype.Type{ - Name: typ.Name(), OID: typ.Oid(), Codec: &pgtype.ArrayCodec{ElementType: elementType}, - }) - } - } else { - pgtypes.RegisterType(&pgtype.Type{Name: typ.Name(), OID: typ.Oid(), Codec: registration.codec}) - } - } - - return nil - }); err != nil { + if err := tm.sideChannel.ReadPgTypes(tm.typeFactory, tm.registerType); err != nil { return err } return nil @@ -193,9 +148,9 @@ func (tm *typeManager) ResolveDataType( defer tm.typeCacheMutex.Unlock() var pt pgtypes.PgType - err := tm.sideChannel.ReadPgTypes(tm.typeFactory, func(p pgtypes.PgType) error { - pt = p - return nil + err := tm.sideChannel.ReadPgTypes(tm.typeFactory, func(typ pgtypes.PgType) error { + pt = typ + return tm.registerType(typ) }, oid) if err != nil { @@ -205,9 +160,6 @@ func (tm *typeManager) ResolveDataType( if pt == nil { return false, nil } - - tm.typeCache[oid] = pt - tm.typeNameCache[pt.Name()] = oid return true, nil } @@ -243,6 +195,9 @@ func (tm *typeManager) ResolveTypeConverter( if registration, present := tm.optimizedConverters[oid]; present { return registration.converter, nil } + if registration, present := tm.dynamicConverters[oid]; present { + return registration.converter, nil + } return nil, fmt.Errorf("unsupported OID: %d", oid) } @@ -282,7 +237,7 @@ func (tm *typeManager) GetOrPlanTupleDecoder( plan, ok := tm.cachedDecoderPlans.Load(relation.RelationID) if !ok { - plan, err = pgtypes.PlanTupleDecoder(relation) + plan, err = planTupleDecoder(tm, relation) if err != nil { return nil, err } @@ -291,6 +246,13 @@ func (tm *typeManager) GetOrPlanTupleDecoder( return plan, nil } +func (tm *typeManager) GetOrPlanRowDecoder( + fields []pgconn.FieldDescription, +) (pgtypes.RowDecoder, error) { + + return newRowDecoder(tm, fields) +} + func (tm *typeManager) getSchemaType( oid uint32, arrayType bool, kind pgtypes.PgKind, ) schema.Type { @@ -375,65 +337,123 @@ func (tm *typeManager) resolveSchemaBuilder( } } -type lazyArrayConverter struct { - typeManager *typeManager - oidElement uint32 - converter pgtypes.TypeConverter -} +func (tm *typeManager) registerType( + typ pgtypes.PgType, +) error { -func (lac *lazyArrayConverter) convert( - oid uint32, value any, -) (any, error) { + tm.typeCache[typ.Oid()] = typ + tm.typeNameCache[typ.Name()] = typ.Oid() - if lac.converter == nil { - elementType, err := lac.typeManager.ResolveDataType(lac.oidElement) - if err != nil { - return nil, err + // Is core type not available in TypeMap by default (bug or not implemented in pgx)? + if registration, present := coreTypes[typ.Oid()]; present { + if !tm.knownInTypeMap(typ.Oid()) { + if err := tm.registerTypeInTypeMap(typ, registration); err != nil { + return err + } } + } - elementConverter, err := lac.typeManager.ResolveTypeConverter(lac.oidElement) - if err != nil { - return nil, err + // Optimized types have dynamic OIDs and need to registered dynamically + if registration, present := optimizedTypes[typ.Name()]; present { + if t, ok := typ.(*pgType); ok { + t.schemaType = registration.schemaType } - reflectiveType, err := schemaType2ReflectiveType(elementType.SchemaType()) - if err != nil { - return nil, err + tm.optimizedTypes[typ.Oid()] = typ + + converter := tm.resolveOptimizedTypeConverter(typ, registration) + if converter == nil { + return errors.Errorf("Type %s has no assigned value converter", typ.Name()) + } + + tm.optimizedConverters[typ.Oid()] = typeRegistration{ + schemaType: registration.schemaType, + schemaBuilder: registration.schemaBuilder, + isArray: registration.isArray, + oidElement: typ.OidElement(), + converter: converter, + codec: registration.codec, + } + + if err := tm.registerTypeInTypeMap(typ, registration); err != nil { + return err + } + } + + // Enums are user defined objects and need to be registered manually + if typ.Kind() == pgtypes.EnumKind { + registration := typeRegistration{ + schemaType: typ.SchemaType(), + converter: enum2string, + codec: &pgtype.EnumCodec{}, } - targetType := reflect.SliceOf(reflectiveType) - lac.converter = reflectiveArrayConverter(lac.oidElement, targetType, elementConverter) + tm.dynamicConverters[typ.Oid()] = registration + if err := tm.registerTypeInTypeMap(typ, registration); err != nil { + return err + } + } + + // Object types (all remaining) need to be handled specifically + if typ.SchemaType() == schema.STRUCT { + // TODO: ignore for now - missing implementation + registration := typeRegistration{ + schemaType: typ.SchemaType(), + } + tm.dynamicConverters[typ.Oid()] = registration } - return lac.converter(oid, value) + return nil } -func schemaType2ReflectiveType( - schemaType schema.Type, -) (reflect.Type, error) { +func (tm *typeManager) knownInTypeMap(oid uint32) (known bool) { + _, known = tm.typeMap.TypeForOID(oid) + return +} - switch schemaType { - case schema.INT8: - return int8Type, nil - case schema.INT16: - return int16Type, nil - case schema.INT32: - return int32Type, nil - case schema.INT64: - return int64Type, nil - case schema.FLOAT32: - return float32Type, nil - case schema.FLOAT64: - return float64Type, nil - case schema.BOOLEAN: - return booleanType, nil - case schema.STRING: - return stringType, nil - case schema.BYTES: - return byteaType, nil - case schema.MAP: - return mapType, nil - default: - return nil, errors.Errorf("Unsupported schema type %s", string(schemaType)) +func (tm *typeManager) registerTypeInTypeMap( + typ pgtypes.PgType, registration typeRegistration, +) error { + + // If specific codec is registered, we can use it directly + if registration.codec != nil { + tm.typeMap.RegisterType(&pgtype.Type{Name: typ.Name(), OID: typ.Oid(), Codec: registration.codec}) + return nil + } + + // Slightly more complicated types have a factory for the pgx type + if registration.typeMapTypeFactory != nil { + tm.typeMap.RegisterType(registration.typeMapTypeFactory(tm.typeMap, typ)) + return nil + } + + // When array type, try to resolve element type and use generic array codec + if typ.IsArray() { + if elementDecoderType, present := tm.typeMap.TypeForOID(typ.OidElement()); present { + tm.typeMap.RegisterType( + &pgtype.Type{ + Name: typ.Name(), + OID: typ.Oid(), + Codec: &pgtype.ArrayCodec{ElementType: elementDecoderType}, + }, + ) + return nil + } + } + + return errors.Errorf("Unknown codec for type registration with oid %d", typ.Oid()) +} + +func (tm *typeManager) resolveOptimizedTypeConverter( + typ pgtypes.PgType, registration typeRegistration, +) pgtypes.TypeConverter { + + if registration.isArray { + lazyConverter := &lazyArrayConverter{ + typeManager: tm, + oidElement: typ.OidElement(), + } + return lazyConverter.convert } + return registration.converter } diff --git a/spi/pgtypes/decoderplan.go b/spi/pgtypes/decoderplan.go index a584a4e3..ba9e5d4b 100644 --- a/spi/pgtypes/decoderplan.go +++ b/spi/pgtypes/decoderplan.go @@ -18,83 +18,22 @@ package pgtypes import ( - "github.com/go-errors/errors" + "fmt" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgtype" - "github.com/noctarius/timescaledb-event-streamer/internal/functional" ) type TupleDecoderPlan interface { Decode(tupleData *pglogrepl.TupleData) (map[string]any, error) } -func PlanTupleDecoder(relation *RelationMessage) (TupleDecoderPlan, error) { - decoders := make([]tupleDecoder, 0) +func codecScan( + codec pgtype.Codec, m *pgtype.Map, oid uint32, format int16, src []byte, dst any, +) error { - for _, column := range relation.Columns { - codec := func(data []byte, binary bool) (any, error) { - return string(data), nil - } - if pgxType, ok := typeMap.TypeForOID(column.DataType); ok { - codec = func(data []byte, binary bool) (any, error) { - dataformat := int16(pgtype.TextFormatCode) - if binary { - dataformat = pgtype.BinaryFormatCode - } - return pgxType.Codec.DecodeValue(typeMap, column.DataType, dataformat, data) - } - } - - decoders = append(decoders, func(dataType uint32, name string, codec tupleCodec) tupleDecoder { - return func(column *pglogrepl.TupleDataColumn, values map[string]any) error { - switch column.DataType { - case 'n': // null - values[name] = nil - case 'u': // unchanged toast - // This TOAST value was not changed. TOAST values are not stored in the tuple, and - // logical replication doesn't want to spend a disk read to fetch its value for you. - case 't': // text (basically anything other than the two above) - val, err := codec(column.Data, false) - if err != nil { - return errors.Errorf("error decoding column data: %s", err) - } - values[name] = val - case 'b': // binary data - val, err := codec(column.Data, true) - if err != nil { - return errors.Errorf("error decoding column data: %s", err) - } - values[name] = val - } - return nil - } - }(column.DataType, column.Name, codec)) - } - - return &tupleDecoderPlan{ - decoders: decoders, - }, nil -} - -type tupleDecoder func(column *pglogrepl.TupleDataColumn, values map[string]any) error - -type tupleCodec func(data []byte, binary bool) (any, error) - -type tupleDecoderPlan struct { - decoders []tupleDecoder -} - -func (tdp *tupleDecoderPlan) Decode(tupleData *pglogrepl.TupleData) (map[string]any, error) { - if tupleData == nil { - return functional.Zero[map[string]any](), nil - } - - values := map[string]any{} - for i, decoder := range tdp.decoders { - column := tupleData.Columns[i] - if err := decoder(column, values); err != nil { - return nil, err - } + scanPlan := codec.PlanScan(m, oid, format, dst) + if scanPlan == nil { + return fmt.Errorf("PlanScan did not find a plan") } - return values, nil + return scanPlan.Scan(src, dst) } diff --git a/spi/pgtypes/rowdecoder.go b/spi/pgtypes/rowdecoder.go new file mode 100644 index 00000000..ed8d17a0 --- /dev/null +++ b/spi/pgtypes/rowdecoder.go @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package pgtypes + +import ( + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type RowDecoderFactory = func(fields []pgconn.FieldDescription) (RowDecoder, error) + +type RowDecoder interface { + DecodeRowsMapAndSink( + rows pgx.Rows, sink func(values map[string]any) error, + ) error + DecodeRowsAndSink( + rows pgx.Rows, sink func(values []any) error, + ) error + Decode( + rawRow [][]byte, + ) ([]any, error) + DecodeAndSink( + rawRow [][]byte, sink func(values []any) error, + ) error + DecodeMapAndSink( + rawRow [][]byte, sink func(values map[string]any) error, + ) error +} diff --git a/spi/pgtypes/typedecoder.go b/spi/pgtypes/typedecoder.go deleted file mode 100644 index c51ce38f..00000000 --- a/spi/pgtypes/typedecoder.go +++ /dev/null @@ -1,312 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package pgtypes - -import ( - "fmt" - "github.com/go-errors/errors" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" -) - -var typeMap *pgtype.Map - -func init() { - typeMap = pgtype.NewMap() - - macaddr8Type := &pgtype.Type{Name: "macaddr8", OID: MacAddr8OID, Codec: pgtype.MacaddrCodec{}} - typeMap.RegisterType(macaddr8Type) - typeMap.RegisterType( - &pgtype.Type{ - Name: "_macaddr8", - OID: MacAddrArray8OID, - Codec: &pgtype.ArrayCodec{ElementType: macaddr8Type}}, - ) - - xmlType := &pgtype.Type{Name: "xml", OID: XmlOID, Codec: XmlCodec{}} - typeMap.RegisterType(xmlType) - typeMap.RegisterType( - &pgtype.Type{ - Name: "_xml", - OID: XmlArrayOID, - Codec: &pgtype.ArrayCodec{ElementType: xmlType}}, - ) - - timetzType := &pgtype.Type{Name: "timetz", OID: TimeTZOID, Codec: &TimetzCodec{}} - typeMap.RegisterType(timetzType) - typeMap.RegisterType( - &pgtype.Type{ - Name: "_timetz", - OID: TimeTZArrayOID, - Codec: &pgtype.ArrayCodec{ElementType: timetzType}}, - ) - - qcharType, _ := typeMap.TypeForOID(pgtype.QCharOID) - typeMap.RegisterType( - &pgtype.Type{ - Name: "_char", - OID: QCharArrayOID, - Codec: &pgtype.ArrayCodec{ElementType: qcharType}}, - ) -} - -type RowDecoder struct { - decoders []func(src []byte) (any, error) - fields []pgconn.FieldDescription -} - -func NewRowDecoder( - fields []pgconn.FieldDescription, -) (*RowDecoder, error) { - - decoders := make([]func(src []byte) (any, error), 0) - for _, field := range fields { - if decoder, err := FindTypeDecoder(field); err != nil { - return nil, errors.Wrap(err, 0) - } else { - // Store a decoder wrapper for easier usage - decoders = append(decoders, decoder) - } - } - return &RowDecoder{ - decoders: decoders, - fields: fields, - }, nil -} - -func (rd *RowDecoder) DecodeRowsMapAndSink( - rows pgx.Rows, sink func(values map[string]any) error, -) error { - - if !rd.compatible(rows.FieldDescriptions()) { - return errors.Errorf("incompatible rows instance provided") - } - - // Initial error check - if rows.Err() != nil { - return errors.Wrap(rows.Err(), 0) - } - defer rows.Close() - - for rows.Next() { - values, err := rd.Decode(rows.RawValues()) - if err != nil { - return errors.Wrap(err, 0) - } - - resultSet := make(map[string]any, 0) - for i, field := range rd.fields { - resultSet[field.Name] = values[i] - } - if err := sink(resultSet); err != nil { - return errors.Wrap(err, 0) - } - } - if rows.Err() != nil { - return errors.Wrap(rows.Err(), 0) - } - return nil -} - -func (rd *RowDecoder) DecodeRowsAndSink( - rows pgx.Rows, sink func(values []any) error, -) error { - - if !rd.compatible(rows.FieldDescriptions()) { - return errors.Errorf("incompatible rows instance provided") - } - - // Initial error check - if rows.Err() != nil { - return errors.Wrap(rows.Err(), 0) - } - defer rows.Close() - - for rows.Next() { - if err := rd.DecodeAndSink(rows.RawValues(), sink); err != nil { - return errors.Wrap(err, 0) - } - } - if rows.Err() != nil { - return errors.Wrap(rows.Err(), 0) - } - return nil -} - -func (rd *RowDecoder) Decode( - rawRow [][]byte, -) ([]any, error) { - - values := make([]any, 0) - for i, decoder := range rd.decoders { - if v, err := decoder(rawRow[i]); err != nil { - return nil, errors.Wrap(err, 0) - } else { - values = append(values, v) - } - } - return values, nil -} - -func (rd *RowDecoder) DecodeAndSink( - rawRow [][]byte, sink func(values []any) error, -) error { - - if values, err := rd.Decode(rawRow); err != nil { - return errors.Wrap(err, 0) - } else { - return sink(values) - } -} - -func (rd *RowDecoder) DecodeMapAndSink( - rawRow [][]byte, sink func(values map[string]any) error, -) error { - - if values, err := rd.Decode(rawRow); err != nil { - return errors.Wrap(err, 0) - } else { - resultSet := make(map[string]any) - for i, field := range rd.fields { - resultSet[field.Name] = values[i] - } - if err := sink(resultSet); err != nil { - return errors.Wrap(err, 0) - } - } - return nil -} - -func (rd *RowDecoder) compatible( - other []pgconn.FieldDescription, -) bool { - - if len(rd.fields) != len(other) { - return false - } - - for i, f := range rd.fields { - o := other[i] - if f.Format != o.Format { - return false - } - if f.DataTypeOID != o.DataTypeOID { - return false - } - if f.Name != o.Name { - return false - } - if f.DataTypeSize != o.DataTypeSize { - return false - } - if f.TypeModifier != o.TypeModifier { - return false - } - // Can we reuse the same decoder for all chunks? 🤔 - // if f.TableAttributeNumber != o.TableAttributeNumber { return false } - // if f.TableOID != o.TableOID { return false } - } - return true -} - -func DecodeTextColumn( - src []byte, dataTypeOid uint32, -) (any, error) { - - if dt, ok := typeMap.TypeForOID(dataTypeOid); ok { - return dt.Codec.DecodeValue(typeMap, dataTypeOid, pgtype.TextFormatCode, src) - } - return string(src), nil -} - -func DecodeBinaryColumn( - src []byte, dataTypeOid uint32, -) (any, error) { - - if dt, ok := typeMap.TypeForOID(dataTypeOid); ok { - return dt.Codec.DecodeValue(typeMap, dataTypeOid, pgtype.BinaryFormatCode, src) - } - return string(src), nil -} - -func DecodeValue( - field pgconn.FieldDescription, src []byte, -) (any, error) { - - if t, ok := typeMap.TypeForOID(field.DataTypeOID); ok { - return t.Codec.DecodeValue(typeMap, field.DataTypeOID, field.Format, src) - } - return nil, errors.Errorf("Unsupported type oid: %d", field.DataTypeOID) -} - -func DecodeRowValues( - rows pgx.Rows, sink func(values []any) error, -) error { - - decoder, err := NewRowDecoder(rows.FieldDescriptions()) - if err != nil { - return err - } - return decoder.DecodeRowsAndSink(rows, sink) -} - -func FindTypeDecoder( - field pgconn.FieldDescription, -) (func(src []byte) (any, error), error) { - - if t, ok := typeMap.TypeForOID(field.DataTypeOID); ok { - // Store a decoder wrapper for easier usage - return asTypeDecoder(t, field), nil - } - return nil, errors.Errorf("Unsupported type oid: %d", field.DataTypeOID) -} - -func RegisterType( - t *pgtype.Type, -) { - - typeMap.RegisterType(t) -} - -func GetType(oid uint32) ( - *pgtype.Type, bool, -) { - - return typeMap.TypeForOID(oid) -} - -func asTypeDecoder( - t *pgtype.Type, field pgconn.FieldDescription, -) func(src []byte) (any, error) { - - return func(src []byte) (any, error) { - return t.Codec.DecodeValue(typeMap, field.DataTypeOID, field.Format, src) - } -} - -func codecScan( - codec pgtype.Codec, m *pgtype.Map, oid uint32, format int16, src []byte, dst any, -) error { - - scanPlan := codec.PlanScan(m, oid, format, dst) - if scanPlan == nil { - return fmt.Errorf("PlanScan did not find a plan") - } - return scanPlan.Scan(src, dst) -} diff --git a/spi/pgtypes/typemanager.go b/spi/pgtypes/typemanager.go index 59e0ba26..4a0ad2dc 100644 --- a/spi/pgtypes/typemanager.go +++ b/spi/pgtypes/typemanager.go @@ -2,6 +2,7 @@ package pgtypes import ( "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgconn" ) type TypeManager interface { @@ -16,4 +17,5 @@ type TypeManager interface { relation *RelationMessage, tupleData *pglogrepl.TupleData, ) (map[string]any, error) GetOrPlanTupleDecoder(relation *RelationMessage) (TupleDecoderPlan, error) + GetOrPlanRowDecoder(fields []pgconn.FieldDescription) (RowDecoder, error) } diff --git a/spi/replicationcontext/replicationcontext.go b/spi/replicationcontext/replicationcontext.go index 45bb49e6..eeb872af 100644 --- a/spi/replicationcontext/replicationcontext.go +++ b/spi/replicationcontext/replicationcontext.go @@ -78,13 +78,14 @@ type ReplicationContext interface { hypertables ...*systemcatalog.Hypertable, ) error SnapshotChunkTable( - chunk *systemcatalog.Chunk, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, chunk *systemcatalog.Chunk, cb sidechannel.SnapshotRowCallback, ) (pgtypes.LSN, error) FetchHypertableSnapshotBatch( - hypertable *systemcatalog.Hypertable, snapshotName string, cb sidechannel.SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, + snapshotName string, cb sidechannel.SnapshotRowCallback, ) error ReadSnapshotHighWatermark( - hypertable *systemcatalog.Hypertable, snapshotName string, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, snapshotName string, ) (map[string]any, error) ReadReplicaIdentity( entity systemcatalog.SystemEntity, diff --git a/spi/sidechannel/sidechannel.go b/spi/sidechannel/sidechannel.go index abdf17e7..13fd3517 100644 --- a/spi/sidechannel/sidechannel.go +++ b/spi/sidechannel/sidechannel.go @@ -66,13 +66,15 @@ type SideChannel interface { publicationName string, entities ...systemcatalog.SystemEntity, ) error SnapshotChunkTable( - chunk *systemcatalog.Chunk, snapshotBatchSize int, cb SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, chunk *systemcatalog.Chunk, + snapshotBatchSize int, cb SnapshotRowCallback, ) (lsn pgtypes.LSN, err error) FetchHypertableSnapshotBatch( - hypertable *systemcatalog.Hypertable, snapshotName string, snapshotBatchSize int, cb SnapshotRowCallback, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, + snapshotName string, snapshotBatchSize int, cb SnapshotRowCallback, ) error ReadSnapshotHighWatermark( - hypertable *systemcatalog.Hypertable, snapshotName string, + rowDecoderFactory pgtypes.RowDecoderFactory, hypertable *systemcatalog.Hypertable, snapshotName string, ) (values map[string]any, err error) ReadReplicaIdentity( schemaName, tableName string, diff --git a/tests/datatype_test.go b/tests/datatype_test.go index 0e30bd8a..21b47d38 100644 --- a/tests/datatype_test.go +++ b/tests/datatype_test.go @@ -930,6 +930,14 @@ var dataTypeTable = []DataTypeTest{ expectedValueOverride: []string{"foo", ""}, expected: quickCheckValue[[]string], }, + { + name: "Enum Type", + pgTypeName: "myenum", + customTypeDefinition: "CREATE TYPE tsdb.myenum AS ENUM ('Foo', 'Bar')", + schemaType: schema.STRING, + value: "Foo", + expected: quickCheckValue[string], + }, } const lookupTypeOidQuery = "SELECT oid FROM pg_catalog.pg_type where typname = $1" @@ -1067,6 +1075,12 @@ func (dtt *DataTypeTestSuite) runDataTypeTest( } } + if testCase.customTypeDefinition != "" { + if _, err := setupContext.Exec(context.Background(), testCase.customTypeDefinition); err != nil { + return err + } + } + _, tn, err := setupContext.CreateHypertable("ts", time.Hour*24, testsupport.NewColumn("ts", "timestamptz", false, false, nil), testsupport.NewColumn(columnName, testCase.pgTypeName, false, false, nil), @@ -1087,6 +1101,7 @@ type DataTypeTest struct { oid uint32 pgTypeName string columnNameOverride string + customTypeDefinition string schemaType schema.Type elementSchemaType schema.Type value any diff --git a/tests/streamer_test.go b/tests/streamer_test.go index b446ad97..c68cebdd 100644 --- a/tests/streamer_test.go +++ b/tests/streamer_test.go @@ -1120,7 +1120,7 @@ func (its *IntegrationTestSuite) TestContinuousAggregate_Scheduled_Refresh_Creat logger.Warnln("Scheduling continuous aggregate refresh") if err := ctx.PrivilegedContext(func(pctx testrunner.PrivilegedContext) error { _, err := pctx.Exec(context.Background(), ` - SELECT alter_job(j.id, next_start => now() + interval '5 seconds') + SELECT tsdb.alter_job(j.id, next_start => now() + interval '5 seconds') FROM _timescaledb_config.bgw_job j LEFT JOIN _timescaledb_catalog.hypertable h ON h.id = j.hypertable_id LEFT JOIN _timescaledb_catalog.continuous_agg c ON c.mat_hypertable_id = j.hypertable_id