diff --git a/example/author/query.sql.go b/example/author/query.sql.go index 114440a..c993d30 100644 --- a/example/author/query.sql.go +++ b/example/author/query.sql.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) @@ -82,7 +83,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -109,9 +111,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -160,6 +176,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const findAuthorByIDSQL = `SELECT * FROM author WHERE author_id = $1;` type FindAuthorByIDRow struct { @@ -459,3 +510,29 @@ func (q *DBQuerier) InsertAuthorSuffixScan(results pgx.BatchResults) (InsertAuth } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/complex_params/query.sql.go b/example/complex_params/query.sql.go index f72264e..2a10dbc 100644 --- a/example/complex_params/query.sql.go +++ b/example/complex_params/query.sql.go @@ -53,7 +53,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -80,9 +81,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -141,89 +156,111 @@ type ProductImageType struct { Dimensions Dimensions `json:"dimensions"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} // setValue sets the value of a ValueTranscoder to a value that should always // work and panics if it fails. -func setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { if err := vt.Set(val); err != nil { panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) } return vt } -// textEncoder wraps a pgtype.ValueTranscoder and sets the preferred encoding -// format to text instead binary (the default). pggen must use the text format -// because the Postgres binary format requires the type OID but pggen doesn't -// necessarily know the OIDs of the types, hence ignoredOID. -type textEncoder struct { - pgtype.ValueTranscoder +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use } -// PreferredParamFormat implements pgtype.ParamFormatPreferrer. -func (t textEncoder) PreferredParamFormat() int16 { return pgtype.TextFormatCode } - -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType -} - -// newProductImageTypeArray creates a new pgtype.ValueTranscoder for the Postgres -// '_product_image_type' array type. -func newProductImageTypeArray() pgtype.ValueTranscoder { - return pgtype.NewArrayType("_product_image_type", ignoredOID, newProductImageType) -} - -// newProductImageTypeArrayInit creates an initialized pgtype.ValueTranscoder for the -// Postgres array type '_product_image_type' to encode query parameters. -func newProductImageTypeArrayInit(ps []ProductImageType) textEncoder { - dec := newProductImageTypeArray() - if err := dec.Set(newProductImageTypeArrayRaw(ps)); err != nil { - panic("encode []ProductImageType: " + err.Error()) // should always succeed + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} } - return textEncoder{ValueTranscoder: dec} + return typ } -// newProductImageTypeArrayRaw returns all elements for the Postgres array type '_product_image_type' -// as a slice of interface{} for use with the pgtype.Value Set method. -func newProductImageTypeArrayRaw(vs []ProductImageType) []interface{} { - elems := make([]interface{}, len(vs)) - for i, v := range vs { - elems[i] = newProductImageTypeRaw(v) +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val } - return elems + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newDimensions creates a new pgtype.ValueTranscoder for the Postgres // composite type 'dimensions'. -func newDimensions() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newDimensions() pgtype.ValueTranscoder { + return tr.newCompositeValue( "dimensions", - []string{"width", "height"}, - &pgtype.Int4{}, - &pgtype.Int4{}, + compositeField{"width", "int4", &pgtype.Int4{}}, + compositeField{"height", "int4", &pgtype.Int4{}}, ) } // newDimensionsInit creates an initialized pgtype.ValueTranscoder for the // Postgres composite type 'dimensions' to encode query parameters. -func newDimensionsInit(v Dimensions) pgtype.ValueTranscoder { - return textEncoder{setValue(newDimensions(), newDimensionsRaw(v))} +func (tr *typeResolver) newDimensionsInit(v Dimensions) pgtype.ValueTranscoder { + return textPreferrer{tr.setValue(tr.newDimensions(), tr.newDimensionsRaw(v)), "dimensions"} } // newDimensionsRaw returns all composite fields for the Postgres composite // type 'dimensions' as a slice of interface{} to encode query parameters. -func newDimensionsRaw(v Dimensions) []interface{} { +func (tr *typeResolver) newDimensionsRaw(v Dimensions) []interface{} { return []interface{}{ v.Width, v.Height, @@ -232,56 +269,80 @@ func newDimensionsRaw(v Dimensions) []interface{} { // newProductImageSetType creates a new pgtype.ValueTranscoder for the Postgres // composite type 'product_image_set_type'. -func newProductImageSetType() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newProductImageSetType() pgtype.ValueTranscoder { + return tr.newCompositeValue( "product_image_set_type", - []string{"name", "orig_image", "images"}, - &pgtype.Text{}, - newProductImageType(), - newProductImageTypeArray(), + compositeField{"name", "text", &pgtype.Text{}}, + compositeField{"orig_image", "product_image_type", tr.newProductImageType()}, + compositeField{"images", "_product_image_type", tr.newProductImageTypeArray()}, ) } // newProductImageSetTypeInit creates an initialized pgtype.ValueTranscoder for the // Postgres composite type 'product_image_set_type' to encode query parameters. -func newProductImageSetTypeInit(v ProductImageSetType) pgtype.ValueTranscoder { - return textEncoder{setValue(newProductImageSetType(), newProductImageSetTypeRaw(v))} +func (tr *typeResolver) newProductImageSetTypeInit(v ProductImageSetType) pgtype.ValueTranscoder { + return textPreferrer{tr.setValue(tr.newProductImageSetType(), tr.newProductImageSetTypeRaw(v)), "product_image_set_type"} } // newProductImageSetTypeRaw returns all composite fields for the Postgres composite // type 'product_image_set_type' as a slice of interface{} to encode query parameters. -func newProductImageSetTypeRaw(v ProductImageSetType) []interface{} { +func (tr *typeResolver) newProductImageSetTypeRaw(v ProductImageSetType) []interface{} { return []interface{}{ v.Name, - newProductImageTypeRaw(v.OrigImage), - newProductImageTypeArrayRaw(v.Images), + tr.newProductImageTypeRaw(v.OrigImage), + tr.newProductImageTypeArrayRaw(v.Images), } } // newProductImageType creates a new pgtype.ValueTranscoder for the Postgres // composite type 'product_image_type'. -func newProductImageType() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newProductImageType() pgtype.ValueTranscoder { + return tr.newCompositeValue( "product_image_type", - []string{"source", "dimensions"}, - &pgtype.Text{}, - newDimensions(), + compositeField{"source", "text", &pgtype.Text{}}, + compositeField{"dimensions", "dimensions", tr.newDimensions()}, ) } // newProductImageTypeInit creates an initialized pgtype.ValueTranscoder for the // Postgres composite type 'product_image_type' to encode query parameters. -func newProductImageTypeInit(v ProductImageType) pgtype.ValueTranscoder { - return textEncoder{setValue(newProductImageType(), newProductImageTypeRaw(v))} +func (tr *typeResolver) newProductImageTypeInit(v ProductImageType) pgtype.ValueTranscoder { + return textPreferrer{tr.setValue(tr.newProductImageType(), tr.newProductImageTypeRaw(v)), "product_image_type"} } // newProductImageTypeRaw returns all composite fields for the Postgres composite // type 'product_image_type' as a slice of interface{} to encode query parameters. -func newProductImageTypeRaw(v ProductImageType) []interface{} { +func (tr *typeResolver) newProductImageTypeRaw(v ProductImageType) []interface{} { return []interface{}{ v.Source, - newDimensionsRaw(v.Dimensions), + tr.newDimensionsRaw(v.Dimensions), + } +} + +// newProductImageTypeArray creates a new pgtype.ValueTranscoder for the Postgres +// '_product_image_type' array type. +func (tr *typeResolver) newProductImageTypeArray() pgtype.ValueTranscoder { + return tr.newArrayValue("_product_image_type", "product_image_type", tr.newProductImageType) +} + +// newProductImageTypeArrayInit creates an initialized pgtype.ValueTranscoder for the +// Postgres array type '_product_image_type' to encode query parameters. +func (tr *typeResolver) newProductImageTypeArrayInit(ps []ProductImageType) pgtype.ValueTranscoder { + dec := tr.newProductImageTypeArray() + if err := dec.Set(tr.newProductImageTypeArrayRaw(ps)); err != nil { + panic("encode []ProductImageType: " + err.Error()) // should always succeed + } + return textPreferrer{ValueTranscoder: dec, typeName: "_product_image_type"} +} + +// newProductImageTypeArrayRaw returns all elements for the Postgres array type '_product_image_type' +// as a slice of interface{} for use with the pgtype.Value Set method. +func (tr *typeResolver) newProductImageTypeArrayRaw(vs []ProductImageType) []interface{} { + elems := make([]interface{}, len(vs)) + for i, v := range vs { + elems[i] = tr.newProductImageTypeRaw(v) } + return elems } const paramArrayIntSQL = `SELECT $1::bigint[];` @@ -315,9 +376,9 @@ const paramNested1SQL = `SELECT $1::dimensions;` // ParamNested1 implements Querier.ParamNested1. func (q *DBQuerier) ParamNested1(ctx context.Context, dimensions Dimensions) (Dimensions, error) { - row := q.conn.QueryRow(ctx, paramNested1SQL, newDimensionsInit(dimensions)) + row := q.conn.QueryRow(ctx, paramNested1SQL, q.types.newDimensionsInit(dimensions)) var item Dimensions - dimensionsRow := newDimensions() + dimensionsRow := q.types.newDimensions() if err := row.Scan(dimensionsRow); err != nil { return item, fmt.Errorf("query ParamNested1: %w", err) } @@ -329,14 +390,14 @@ func (q *DBQuerier) ParamNested1(ctx context.Context, dimensions Dimensions) (Di // ParamNested1Batch implements Querier.ParamNested1Batch. func (q *DBQuerier) ParamNested1Batch(batch *pgx.Batch, dimensions Dimensions) { - batch.Queue(paramNested1SQL, newDimensionsInit(dimensions)) + batch.Queue(paramNested1SQL, q.types.newDimensionsInit(dimensions)) } // ParamNested1Scan implements Querier.ParamNested1Scan. func (q *DBQuerier) ParamNested1Scan(results pgx.BatchResults) (Dimensions, error) { row := results.QueryRow() var item Dimensions - dimensionsRow := newDimensions() + dimensionsRow := q.types.newDimensions() if err := row.Scan(dimensionsRow); err != nil { return item, fmt.Errorf("scan ParamNested1Batch row: %w", err) } @@ -350,9 +411,9 @@ const paramNested2SQL = `SELECT $1::product_image_type;` // ParamNested2 implements Querier.ParamNested2. func (q *DBQuerier) ParamNested2(ctx context.Context, image ProductImageType) (ProductImageType, error) { - row := q.conn.QueryRow(ctx, paramNested2SQL, newProductImageTypeInit(image)) + row := q.conn.QueryRow(ctx, paramNested2SQL, q.types.newProductImageTypeInit(image)) var item ProductImageType - productImageTypeRow := newProductImageType() + productImageTypeRow := q.types.newProductImageType() if err := row.Scan(productImageTypeRow); err != nil { return item, fmt.Errorf("query ParamNested2: %w", err) } @@ -364,14 +425,14 @@ func (q *DBQuerier) ParamNested2(ctx context.Context, image ProductImageType) (P // ParamNested2Batch implements Querier.ParamNested2Batch. func (q *DBQuerier) ParamNested2Batch(batch *pgx.Batch, image ProductImageType) { - batch.Queue(paramNested2SQL, newProductImageTypeInit(image)) + batch.Queue(paramNested2SQL, q.types.newProductImageTypeInit(image)) } // ParamNested2Scan implements Querier.ParamNested2Scan. func (q *DBQuerier) ParamNested2Scan(results pgx.BatchResults) (ProductImageType, error) { row := results.QueryRow() var item ProductImageType - productImageTypeRow := newProductImageType() + productImageTypeRow := q.types.newProductImageType() if err := row.Scan(productImageTypeRow); err != nil { return item, fmt.Errorf("scan ParamNested2Batch row: %w", err) } @@ -385,9 +446,9 @@ const paramNested2ArraySQL = `SELECT $1::product_image_type[];` // ParamNested2Array implements Querier.ParamNested2Array. func (q *DBQuerier) ParamNested2Array(ctx context.Context, images []ProductImageType) ([]ProductImageType, error) { - row := q.conn.QueryRow(ctx, paramNested2ArraySQL, newProductImageTypeArrayInit(images)) + row := q.conn.QueryRow(ctx, paramNested2ArraySQL, q.types.newProductImageTypeArrayInit(images)) item := []ProductImageType{} - productImageTypeArray := newProductImageTypeArray() + productImageTypeArray := q.types.newProductImageTypeArray() if err := row.Scan(productImageTypeArray); err != nil { return item, fmt.Errorf("query ParamNested2Array: %w", err) } @@ -399,14 +460,14 @@ func (q *DBQuerier) ParamNested2Array(ctx context.Context, images []ProductImage // ParamNested2ArrayBatch implements Querier.ParamNested2ArrayBatch. func (q *DBQuerier) ParamNested2ArrayBatch(batch *pgx.Batch, images []ProductImageType) { - batch.Queue(paramNested2ArraySQL, newProductImageTypeArrayInit(images)) + batch.Queue(paramNested2ArraySQL, q.types.newProductImageTypeArrayInit(images)) } // ParamNested2ArrayScan implements Querier.ParamNested2ArrayScan. func (q *DBQuerier) ParamNested2ArrayScan(results pgx.BatchResults) ([]ProductImageType, error) { row := results.QueryRow() item := []ProductImageType{} - productImageTypeArray := newProductImageTypeArray() + productImageTypeArray := q.types.newProductImageTypeArray() if err := row.Scan(productImageTypeArray); err != nil { return item, fmt.Errorf("scan ParamNested2ArrayBatch row: %w", err) } @@ -420,9 +481,9 @@ const paramNested3SQL = `SELECT $1::product_image_set_type;` // ParamNested3 implements Querier.ParamNested3. func (q *DBQuerier) ParamNested3(ctx context.Context, imageSet ProductImageSetType) (ProductImageSetType, error) { - row := q.conn.QueryRow(ctx, paramNested3SQL, newProductImageSetTypeInit(imageSet)) + row := q.conn.QueryRow(ctx, paramNested3SQL, q.types.newProductImageSetTypeInit(imageSet)) var item ProductImageSetType - productImageSetTypeRow := newProductImageSetType() + productImageSetTypeRow := q.types.newProductImageSetType() if err := row.Scan(productImageSetTypeRow); err != nil { return item, fmt.Errorf("query ParamNested3: %w", err) } @@ -434,14 +495,14 @@ func (q *DBQuerier) ParamNested3(ctx context.Context, imageSet ProductImageSetTy // ParamNested3Batch implements Querier.ParamNested3Batch. func (q *DBQuerier) ParamNested3Batch(batch *pgx.Batch, imageSet ProductImageSetType) { - batch.Queue(paramNested3SQL, newProductImageSetTypeInit(imageSet)) + batch.Queue(paramNested3SQL, q.types.newProductImageSetTypeInit(imageSet)) } // ParamNested3Scan implements Querier.ParamNested3Scan. func (q *DBQuerier) ParamNested3Scan(results pgx.BatchResults) (ProductImageSetType, error) { row := results.QueryRow() var item ProductImageSetType - productImageSetTypeRow := newProductImageSetType() + productImageSetTypeRow := q.types.newProductImageSetType() if err := row.Scan(productImageSetTypeRow); err != nil { return item, fmt.Errorf("scan ParamNested3Batch row: %w", err) } @@ -450,3 +511,29 @@ func (q *DBQuerier) ParamNested3Scan(results pgx.BatchResults) (ProductImageSetT } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/complex_params/query.sql_test.go b/example/complex_params/query.sql_test.go index 94b82b6..1b17191 100644 --- a/example/complex_params/query.sql_test.go +++ b/example/complex_params/query.sql_test.go @@ -124,7 +124,6 @@ func TestNewQuerier_ParamNested2Array(t *testing.T) { } func TestNewQuerier_ParamNested3(t *testing.T) { - t.Skip("https://github.com/jackc/pgx/issues/874") conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) defer cleanup() @@ -157,3 +156,38 @@ func TestNewQuerier_ParamNested3(t *testing.T) { assert.Equal(t, want, row) }) } + +func TestNewQuerier_ParamNested3_QueryAllDataTypes(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) + defer cleanup() + ctx := context.Background() + // dataTypes, err := QueryAllDataTypes(ctx, conn) + // require.NoError(t, err) + q := NewQuerierConfig(conn, QuerierConfig{DataTypes: nil}) + + want := ProductImageSetType{ + Name: "set1", + OrigImage: ProductImageType{Source: "src1", Dimensions: Dimensions{Width: 11, Height: 11}}, + Images: []ProductImageType{ + {Source: "src1", Dimensions: Dimensions{Width: 11, Height: 11}}, + {Source: "src2", Dimensions: Dimensions{Width: 22, Height: 22}}, + }, + } + + t.Run("ParamNested3", func(t *testing.T) { + row, err := q.ParamNested3(ctx, want) + require.NoError(t, err) + assert.Equal(t, want, row) + }) + + t.Run("ParamNested3Batch", func(t *testing.T) { + batch := &pgx.Batch{} + q.ParamNested3Batch(batch, want) + results := conn.SendBatch(ctx, batch) + row, err := q.ParamNested3Scan(results) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, want, row) + }) +} diff --git a/example/composite/query.sql.go b/example/composite/query.sql.go index abc4c2f..88e43c5 100644 --- a/example/composite/query.sql.go +++ b/example/composite/query.sql.go @@ -39,7 +39,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -66,9 +67,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -109,41 +124,109 @@ type Blocks struct { Body string `json:"body"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ } -// newBlocksArray creates a new pgtype.ValueTranscoder for the Postgres -// '_blocks' array type. -func newBlocksArray() pgtype.ValueTranscoder { - return pgtype.NewArrayType("_blocks", ignoredOID, newBlocks) +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newBlocks creates a new pgtype.ValueTranscoder for the Postgres // composite type 'blocks'. -func newBlocks() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newBlocks() pgtype.ValueTranscoder { + return tr.newCompositeValue( "blocks", - []string{"id", "screenshot_id", "body"}, - &pgtype.Int4{}, - &pgtype.Int8{}, - &pgtype.Text{}, + compositeField{"id", "int4", &pgtype.Int4{}}, + compositeField{"screenshot_id", "int8", &pgtype.Int8{}}, + compositeField{"body", "text", &pgtype.Text{}}, ) } +// newBlocksArray creates a new pgtype.ValueTranscoder for the Postgres +// '_blocks' array type. +func (tr *typeResolver) newBlocksArray() pgtype.ValueTranscoder { + return tr.newArrayValue("_blocks", "blocks", tr.newBlocks) +} + const searchScreenshotsSQL = `SELECT ss.id, array_agg(bl) AS blocks @@ -173,7 +256,7 @@ func (q *DBQuerier) SearchScreenshots(ctx context.Context, params SearchScreensh } defer rows.Close() items := []SearchScreenshotsRow{} - blocksArray := newBlocksArray() + blocksArray := q.types.newBlocksArray() for rows.Next() { var item SearchScreenshotsRow if err := rows.Scan(&item.ID, blocksArray); err != nil { @@ -203,7 +286,7 @@ func (q *DBQuerier) SearchScreenshotsScan(results pgx.BatchResults) ([]SearchScr } defer rows.Close() items := []SearchScreenshotsRow{} - blocksArray := newBlocksArray() + blocksArray := q.types.newBlocksArray() for rows.Next() { var item SearchScreenshotsRow if err := rows.Scan(&item.ID, blocksArray); err != nil { @@ -243,7 +326,7 @@ func (q *DBQuerier) SearchScreenshotsOneCol(ctx context.Context, params SearchSc } defer rows.Close() items := [][]Blocks{} - blocksArray := newBlocksArray() + blocksArray := q.types.newBlocksArray() for rows.Next() { var item []Blocks if err := rows.Scan(blocksArray); err != nil { @@ -273,7 +356,7 @@ func (q *DBQuerier) SearchScreenshotsOneColScan(results pgx.BatchResults) ([][]B } defer rows.Close() items := [][]Blocks{} - blocksArray := newBlocksArray() + blocksArray := q.types.newBlocksArray() for rows.Next() { var item []Blocks if err := rows.Scan(blocksArray); err != nil { @@ -329,3 +412,29 @@ func (q *DBQuerier) InsertScreenshotBlocksScan(results pgx.BatchResults) (Insert } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/custom_types/query.sql.go b/example/custom_types/query.sql.go index 4211046..6b7d907 100644 --- a/example/custom_types/query.sql.go +++ b/example/custom_types/query.sql.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/jschaf/pggen/example/custom_types/mytype" ) @@ -39,7 +40,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -66,9 +68,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -102,6 +118,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const customTypesSQL = `SELECT 'some_text', 1::bigint;` type CustomTypesRow struct { @@ -209,3 +260,29 @@ func (q *DBQuerier) IntArrayScan(results pgx.BatchResults) ([][]int32, error) { } return items, err } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/device/query.sql.go b/example/device/query.sql.go index 0ebfd65..8a1b464 100644 --- a/example/device/query.sql.go +++ b/example/device/query.sql.go @@ -67,7 +67,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -94,9 +95,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -148,12 +163,6 @@ type User struct { Name *string `json:"name"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 - // DeviceType represents the Postgres enum "device_type". type DeviceType string @@ -168,25 +177,99 @@ const ( func (d DeviceType) String() string { return string(d) } -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newUser creates a new pgtype.ValueTranscoder for the Postgres // composite type 'user'. -func newUser() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newUser() pgtype.ValueTranscoder { + return tr.newCompositeValue( "user", - []string{"id", "name"}, - &pgtype.Int8{}, - &pgtype.Text{}, + compositeField{"id", "int8", &pgtype.Int8{}}, + compositeField{"name", "text", &pgtype.Text{}}, ) } @@ -271,7 +354,7 @@ func (q *DBQuerier) CompositeUser(ctx context.Context) ([]CompositeUserRow, erro } defer rows.Close() items := []CompositeUserRow{} - userRow := newUser() + userRow := q.types.newUser() for rows.Next() { var item CompositeUserRow if err := rows.Scan(&item.Mac, &item.Type, userRow); err != nil { @@ -301,7 +384,7 @@ func (q *DBQuerier) CompositeUserScan(results pgx.BatchResults) ([]CompositeUser } defer rows.Close() items := []CompositeUserRow{} - userRow := newUser() + userRow := q.types.newUser() for rows.Next() { var item CompositeUserRow if err := rows.Scan(&item.Mac, &item.Type, userRow); err != nil { @@ -324,7 +407,7 @@ const compositeUserOneSQL = `SELECT ROW (15, 'qux')::"user" AS "user";` func (q *DBQuerier) CompositeUserOne(ctx context.Context) (User, error) { row := q.conn.QueryRow(ctx, compositeUserOneSQL) var item User - userRow := newUser() + userRow := q.types.newUser() if err := row.Scan(userRow); err != nil { return item, fmt.Errorf("query CompositeUserOne: %w", err) } @@ -343,7 +426,7 @@ func (q *DBQuerier) CompositeUserOneBatch(batch *pgx.Batch) { func (q *DBQuerier) CompositeUserOneScan(results pgx.BatchResults) (User, error) { row := results.QueryRow() var item User - userRow := newUser() + userRow := q.types.newUser() if err := row.Scan(userRow); err != nil { return item, fmt.Errorf("scan CompositeUserOneBatch row: %w", err) } @@ -364,7 +447,7 @@ type CompositeUserOneTwoColsRow struct { func (q *DBQuerier) CompositeUserOneTwoCols(ctx context.Context) (CompositeUserOneTwoColsRow, error) { row := q.conn.QueryRow(ctx, compositeUserOneTwoColsSQL) var item CompositeUserOneTwoColsRow - userRow := newUser() + userRow := q.types.newUser() if err := row.Scan(&item.Num, userRow); err != nil { return item, fmt.Errorf("query CompositeUserOneTwoCols: %w", err) } @@ -383,7 +466,7 @@ func (q *DBQuerier) CompositeUserOneTwoColsBatch(batch *pgx.Batch) { func (q *DBQuerier) CompositeUserOneTwoColsScan(results pgx.BatchResults) (CompositeUserOneTwoColsRow, error) { row := results.QueryRow() var item CompositeUserOneTwoColsRow - userRow := newUser() + userRow := q.types.newUser() if err := row.Scan(&item.Num, userRow); err != nil { return item, fmt.Errorf("scan CompositeUserOneTwoColsBatch row: %w", err) } @@ -403,7 +486,7 @@ func (q *DBQuerier) CompositeUserMany(ctx context.Context) ([]User, error) { } defer rows.Close() items := []User{} - userRow := newUser() + userRow := q.types.newUser() for rows.Next() { var item User if err := rows.Scan(userRow); err != nil { @@ -433,7 +516,7 @@ func (q *DBQuerier) CompositeUserManyScan(results pgx.BatchResults) ([]User, err } defer rows.Close() items := []User{} - userRow := newUser() + userRow := q.types.newUser() for rows.Next() { var item User if err := rows.Scan(userRow); err != nil { @@ -501,3 +584,29 @@ func (q *DBQuerier) InsertDeviceScan(results pgx.BatchResults) (pgconn.CommandTa } return cmdTag, err } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/domain/query.sql.go b/example/domain/query.sql.go index 862ec70..cde2dc0 100644 --- a/example/domain/query.sql.go +++ b/example/domain/query.sql.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) @@ -24,7 +25,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -51,9 +53,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -81,6 +97,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const domainOneSQL = `SELECT '90210'::us_postal_code;` // DomainOne implements Querier.DomainOne. @@ -107,3 +158,29 @@ func (q *DBQuerier) DomainOneScan(results pgx.BatchResults) (string, error) { } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/enums/query.sql.go b/example/enums/query.sql.go index 13761be..ce9c5ee 100644 --- a/example/enums/query.sql.go +++ b/example/enums/query.sql.go @@ -64,7 +64,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -91,9 +92,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -142,12 +157,6 @@ type Device struct { Type DeviceType `json:"type"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 - // newDeviceTypeEnum creates a new pgtype.ValueTranscoder for the // Postgres enum type 'device_type'. func newDeviceTypeEnum() pgtype.ValueTranscoder { @@ -178,34 +187,108 @@ const ( func (d DeviceType) String() string { return string(d) } -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ } -// newDeviceTypeArray creates a new pgtype.ValueTranscoder for the Postgres -// '_device_type' array type. -func newDeviceTypeArray() pgtype.ValueTranscoder { - return pgtype.NewArrayType("_device_type", ignoredOID, newDeviceTypeEnum) +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newDevice creates a new pgtype.ValueTranscoder for the Postgres // composite type 'device'. -func newDevice() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newDevice() pgtype.ValueTranscoder { + return tr.newCompositeValue( "device", - []string{"mac", "type"}, - &pgtype.Macaddr{}, - newDeviceTypeEnum(), + compositeField{"mac", "macaddr", &pgtype.Macaddr{}}, + compositeField{"type", "device_type", newDeviceTypeEnum()}, ) } +// newDeviceTypeArray creates a new pgtype.ValueTranscoder for the Postgres +// '_device_type' array type. +func (tr *typeResolver) newDeviceTypeArray() pgtype.ValueTranscoder { + return tr.newArrayValue("_device_type", "device_type", newDeviceTypeEnum) +} + const findAllDevicesSQL = `SELECT mac, type FROM device;` @@ -293,7 +376,7 @@ const findOneDeviceArraySQL = `SELECT enum_range(NULL::device_type) AS device_ty func (q *DBQuerier) FindOneDeviceArray(ctx context.Context) ([]DeviceType, error) { row := q.conn.QueryRow(ctx, findOneDeviceArraySQL) item := []DeviceType{} - deviceTypesArray := newDeviceTypeArray() + deviceTypesArray := q.types.newDeviceTypeArray() if err := row.Scan(deviceTypesArray); err != nil { return item, fmt.Errorf("query FindOneDeviceArray: %w", err) } @@ -312,7 +395,7 @@ func (q *DBQuerier) FindOneDeviceArrayBatch(batch *pgx.Batch) { func (q *DBQuerier) FindOneDeviceArrayScan(results pgx.BatchResults) ([]DeviceType, error) { row := results.QueryRow() item := []DeviceType{} - deviceTypesArray := newDeviceTypeArray() + deviceTypesArray := q.types.newDeviceTypeArray() if err := row.Scan(deviceTypesArray); err != nil { return item, fmt.Errorf("scan FindOneDeviceArrayBatch row: %w", err) } @@ -334,7 +417,7 @@ func (q *DBQuerier) FindManyDeviceArray(ctx context.Context) ([][]DeviceType, er } defer rows.Close() items := [][]DeviceType{} - deviceTypesArray := newDeviceTypeArray() + deviceTypesArray := q.types.newDeviceTypeArray() for rows.Next() { var item []DeviceType if err := rows.Scan(deviceTypesArray); err != nil { @@ -364,7 +447,7 @@ func (q *DBQuerier) FindManyDeviceArrayScan(results pgx.BatchResults) ([][]Devic } defer rows.Close() items := [][]DeviceType{} - deviceTypesArray := newDeviceTypeArray() + deviceTypesArray := q.types.newDeviceTypeArray() for rows.Next() { var item []DeviceType if err := rows.Scan(deviceTypesArray); err != nil { @@ -398,7 +481,7 @@ func (q *DBQuerier) FindManyDeviceArrayWithNum(ctx context.Context) ([]FindManyD } defer rows.Close() items := []FindManyDeviceArrayWithNumRow{} - deviceTypesArray := newDeviceTypeArray() + deviceTypesArray := q.types.newDeviceTypeArray() for rows.Next() { var item FindManyDeviceArrayWithNumRow if err := rows.Scan(&item.Num, deviceTypesArray); err != nil { @@ -428,7 +511,7 @@ func (q *DBQuerier) FindManyDeviceArrayWithNumScan(results pgx.BatchResults) ([] } defer rows.Close() items := []FindManyDeviceArrayWithNumRow{} - deviceTypesArray := newDeviceTypeArray() + deviceTypesArray := q.types.newDeviceTypeArray() for rows.Next() { var item FindManyDeviceArrayWithNumRow if err := rows.Scan(&item.Num, deviceTypesArray); err != nil { @@ -451,7 +534,7 @@ const enumInsideCompositeSQL = `SELECT ROW('08:00:2b:01:02:03'::macaddr, 'phone' func (q *DBQuerier) EnumInsideComposite(ctx context.Context) (Device, error) { row := q.conn.QueryRow(ctx, enumInsideCompositeSQL) var item Device - rowRow := newDevice() + rowRow := q.types.newDevice() if err := row.Scan(rowRow); err != nil { return item, fmt.Errorf("query EnumInsideComposite: %w", err) } @@ -470,7 +553,7 @@ func (q *DBQuerier) EnumInsideCompositeBatch(batch *pgx.Batch) { func (q *DBQuerier) EnumInsideCompositeScan(results pgx.BatchResults) (Device, error) { row := results.QueryRow() var item Device - rowRow := newDevice() + rowRow := q.types.newDevice() if err := row.Scan(rowRow); err != nil { return item, fmt.Errorf("scan EnumInsideCompositeBatch row: %w", err) } @@ -479,3 +562,29 @@ func (q *DBQuerier) EnumInsideCompositeScan(results pgx.BatchResults) (Device, e } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/erp/order/customer.sql.go b/example/erp/order/customer.sql.go index 601bcd9..3dda52c 100644 --- a/example/erp/order/customer.sql.go +++ b/example/erp/order/customer.sql.go @@ -67,7 +67,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -94,9 +95,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -142,6 +157,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const createTenantSQL = `INSERT INTO tenant (tenant_id, name) VALUES (base36_decode($1::text)::tenant_id, $2::text) RETURNING *;` @@ -377,3 +427,29 @@ func (q *DBQuerier) InsertOrderScan(results pgx.BatchResults) (InsertOrderRow, e } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/go_pointer_types/query.sql.go b/example/go_pointer_types/query.sql.go index 1374a7e..553429d 100644 --- a/example/go_pointer_types/query.sql.go +++ b/example/go_pointer_types/query.sql.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) @@ -59,7 +60,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -86,9 +88,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -131,6 +147,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const genSeries1SQL = `SELECT n FROM generate_series(0, 2) n LIMIT 1;` @@ -366,3 +417,29 @@ func (q *DBQuerier) GenSeriesStrScan(results pgx.BatchResults) ([]*string, error } return items, err } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/ltree/query.sql.go b/example/ltree/query.sql.go index 4c4120a..77fe9bc 100644 --- a/example/ltree/query.sql.go +++ b/example/ltree/query.sql.go @@ -46,7 +46,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -73,9 +74,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -112,6 +127,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const findTopScienceChildrenSQL = `SELECT path FROM test WHERE path <@ 'Top.Science';` @@ -270,3 +320,29 @@ func (q *DBQuerier) FindLtreeInputScan(results pgx.BatchResults) (FindLtreeInput } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/nested/query.sql.go b/example/nested/query.sql.go index e55562d..6740e12 100644 --- a/example/nested/query.sql.go +++ b/example/nested/query.sql.go @@ -32,7 +32,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -59,9 +60,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -111,63 +126,129 @@ type ProductImageType struct { Dimensions Dimensions `json:"dimensions"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ } -// newProductImageTypeArray creates a new pgtype.ValueTranscoder for the Postgres -// '_product_image_type' array type. -func newProductImageTypeArray() pgtype.ValueTranscoder { - return pgtype.NewArrayType("_product_image_type", ignoredOID, newProductImageType) +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newDimensions creates a new pgtype.ValueTranscoder for the Postgres // composite type 'dimensions'. -func newDimensions() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newDimensions() pgtype.ValueTranscoder { + return tr.newCompositeValue( "dimensions", - []string{"width", "height"}, - &pgtype.Int4{}, - &pgtype.Int4{}, + compositeField{"width", "int4", &pgtype.Int4{}}, + compositeField{"height", "int4", &pgtype.Int4{}}, ) } // newProductImageSetType creates a new pgtype.ValueTranscoder for the Postgres // composite type 'product_image_set_type'. -func newProductImageSetType() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newProductImageSetType() pgtype.ValueTranscoder { + return tr.newCompositeValue( "product_image_set_type", - []string{"name", "orig_image", "images"}, - &pgtype.Text{}, - newProductImageType(), - newProductImageTypeArray(), + compositeField{"name", "text", &pgtype.Text{}}, + compositeField{"orig_image", "product_image_type", tr.newProductImageType()}, + compositeField{"images", "_product_image_type", tr.newProductImageTypeArray()}, ) } // newProductImageType creates a new pgtype.ValueTranscoder for the Postgres // composite type 'product_image_type'. -func newProductImageType() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newProductImageType() pgtype.ValueTranscoder { + return tr.newCompositeValue( "product_image_type", - []string{"source", "dimensions"}, - &pgtype.Text{}, - newDimensions(), + compositeField{"source", "text", &pgtype.Text{}}, + compositeField{"dimensions", "dimensions", tr.newDimensions()}, ) } +// newProductImageTypeArray creates a new pgtype.ValueTranscoder for the Postgres +// '_product_image_type' array type. +func (tr *typeResolver) newProductImageTypeArray() pgtype.ValueTranscoder { + return tr.newArrayValue("_product_image_type", "product_image_type", tr.newProductImageType) +} + const arrayNested2SQL = `SELECT ARRAY [ ROW ('img2', ROW (22, 22)::dimensions)::product_image_type, @@ -178,7 +259,7 @@ const arrayNested2SQL = `SELECT func (q *DBQuerier) ArrayNested2(ctx context.Context) ([]ProductImageType, error) { row := q.conn.QueryRow(ctx, arrayNested2SQL) item := []ProductImageType{} - imagesArray := newProductImageTypeArray() + imagesArray := q.types.newProductImageTypeArray() if err := row.Scan(imagesArray); err != nil { return item, fmt.Errorf("query ArrayNested2: %w", err) } @@ -197,7 +278,7 @@ func (q *DBQuerier) ArrayNested2Batch(batch *pgx.Batch) { func (q *DBQuerier) ArrayNested2Scan(results pgx.BatchResults) ([]ProductImageType, error) { row := results.QueryRow() item := []ProductImageType{} - imagesArray := newProductImageTypeArray() + imagesArray := q.types.newProductImageTypeArray() if err := row.Scan(imagesArray); err != nil { return item, fmt.Errorf("scan ArrayNested2Batch row: %w", err) } @@ -225,7 +306,7 @@ func (q *DBQuerier) Nested3(ctx context.Context) ([]ProductImageSetType, error) } defer rows.Close() items := []ProductImageSetType{} - rowRow := newProductImageSetType() + rowRow := q.types.newProductImageSetType() for rows.Next() { var item ProductImageSetType if err := rows.Scan(rowRow); err != nil { @@ -255,7 +336,7 @@ func (q *DBQuerier) Nested3Scan(results pgx.BatchResults) ([]ProductImageSetType } defer rows.Close() items := []ProductImageSetType{} - rowRow := newProductImageSetType() + rowRow := q.types.newProductImageSetType() for rows.Next() { var item ProductImageSetType if err := rows.Scan(rowRow); err != nil { @@ -271,3 +352,29 @@ func (q *DBQuerier) Nested3Scan(results pgx.BatchResults) ([]ProductImageSetType } return items, err } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/pgcrypto/query.sql.go b/example/pgcrypto/query.sql.go index 1b98cd8..44fef0b 100644 --- a/example/pgcrypto/query.sql.go +++ b/example/pgcrypto/query.sql.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) @@ -31,7 +32,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -58,9 +60,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -91,6 +107,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const createUserSQL = `INSERT INTO "user" (email, pass) VALUES ($1, crypt($2, gen_salt('bf')));` @@ -149,3 +200,29 @@ func (q *DBQuerier) FindUserScan(results pgx.BatchResults) (FindUserRow, error) } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/separate_out_dir/out/alpha_query.sql.0.go b/example/separate_out_dir/out/alpha_query.sql.0.go index c06ab1e..e946390 100644 --- a/example/separate_out_dir/out/alpha_query.sql.0.go +++ b/example/separate_out_dir/out/alpha_query.sql.0.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) @@ -38,7 +39,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -65,9 +67,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -101,6 +117,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const alphaNestedSQL = `SELECT 'alpha_nested' as output;` // AlphaNested implements Querier.AlphaNested. @@ -127,3 +178,29 @@ func (q *DBQuerier) AlphaNestedScan(results pgx.BatchResults) (string, error) { } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/syntax/query.sql.go b/example/syntax/query.sql.go index 960952a..17f503f 100644 --- a/example/syntax/query.sql.go +++ b/example/syntax/query.sql.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) @@ -80,7 +81,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -107,9 +109,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -170,6 +186,41 @@ const ( func (u UnnamedEnum123) String() string { return string(u) } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const backtickSQL = "SELECT '`';" // Backtick implements Querier.Backtick. @@ -390,3 +441,29 @@ func (q *DBQuerier) GoKeywordScan(results pgx.BatchResults) (string, error) { } return item, nil } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/void/query.sql.go b/example/void/query.sql.go index 32a9822..2fbf48d 100644 --- a/example/void/query.sql.go +++ b/example/void/query.sql.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) @@ -52,7 +53,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -79,9 +81,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -121,6 +137,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const voidOnlySQL = `SELECT void_fn();` // VoidOnly implements Querier.VoidOnly. @@ -278,3 +329,29 @@ func (q *DBQuerier) VoidThree2Scan(results pgx.BatchResults) ([]string, error) { } return items, err } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/go.mod b/go.mod index c5e4b6a..9cd3b20 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/gorilla/mux v1.8.0 // indirect github.com/jackc/pgconn v1.8.1 github.com/jackc/pgproto3/v2 v2.0.6 - github.com/jackc/pgtype v1.7.0 + github.com/jackc/pgtype v1.7.1-0.20210424130834-4380e23ae1c8 github.com/jackc/pgx/v4 v4.11.0 github.com/moby/term v0.0.0-20201216013528-df9cb8a40635 // indirect github.com/morikuni/aec v1.0.0 // indirect diff --git a/go.sum b/go.sum index 0636eeb..4102ee6 100644 --- a/go.sum +++ b/go.sum @@ -197,6 +197,8 @@ github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkAL github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= github.com/jackc/pgtype v1.7.0 h1:6f4kVsW01QftE38ufBYxKciO6gyioXSC0ABIRLcZrGs= github.com/jackc/pgtype v1.7.0/go.mod h1:ZnHF+rMePVqDKaOfJVI4Q8IVvAQMryDlDkZnKOI75BE= +github.com/jackc/pgtype v1.7.1-0.20210424130834-4380e23ae1c8 h1:E/fEiSFd7fPJkyxXNZBWi4SnvTo7xrERCwl+QCt9QaY= +github.com/jackc/pgtype v1.7.1-0.20210424130834-4380e23ae1c8/go.mod h1:ZnHF+rMePVqDKaOfJVI4Q8IVvAQMryDlDkZnKOI75BE= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= diff --git a/internal/codegen/golang/declarer.go b/internal/codegen/golang/declarer.go index 1cad711..4bfa1d1 100644 --- a/internal/codegen/golang/declarer.go +++ b/internal/codegen/golang/declarer.go @@ -51,18 +51,19 @@ func FindInputDeclarers(typ gotype.Type) DeclarerSet { switch typ := typ.(type) { case gotype.CompositeType: decls.AddAll( - NewSetValueDeclarer(), + NewTypeResolverDeclarer(), NewCompositeInitDeclarer(typ), ) case gotype.ArrayType: switch typ.Elem.(type) { case gotype.CompositeType, gotype.EnumType: decls.AddAll( - NewSetValueDeclarer(), + NewTypeResolverDeclarer(), NewArrayInitDeclarer(typ), ) } } + decls.AddAll(NewTypeResolverInitDeclarer()) // always add findInputDeclsHelper(typ, decls) // Inputs depend on output transcoders. findOutputDeclsHelper(typ, decls /*hadCompositeParent*/, false) @@ -73,7 +74,6 @@ func findInputDeclsHelper(typ gotype.Type, decls DeclarerSet) { switch typ := typ.(type) { case gotype.CompositeType: decls.AddAll( - NewTextEncoderDeclarer(), NewCompositeRawDeclarer(typ), ) for _, childType := range typ.FieldTypes { @@ -82,7 +82,6 @@ func findInputDeclsHelper(typ gotype.Type, decls DeclarerSet) { case gotype.ArrayType: decls.AddAll( - NewTextEncoderDeclarer(), NewArrayRawDeclarer(typ), ) findInputDeclsHelper(typ.Elem, decls) @@ -96,6 +95,7 @@ func findInputDeclsHelper(typ gotype.Type, decls DeclarerSet) { // the output rows. Returns nil if no declarers are needed. func FindOutputDeclarers(typ gotype.Type) DeclarerSet { decls := NewDeclarerSet() + decls.AddAll(NewTypeResolverInitDeclarer()) // always add findOutputDeclsHelper(typ, decls, false) return decls } @@ -116,15 +116,14 @@ func findOutputDeclsHelper(typ gotype.Type, decls DeclarerSet, hadCompositeParen decls.AddAll( NewCompositeTypeDeclarer(typ), NewCompositeTranscoderDeclarer(typ), - ignoredOIDDeclarer, - newCompositeTypeDeclarer, + NewTypeResolverDeclarer(), ) for _, childType := range typ.FieldTypes { findOutputDeclsHelper(childType, decls, true) } case gotype.ArrayType: - decls.AddAll(ignoredOIDDeclarer) + decls.AddAll(NewTypeResolverDeclarer()) switch typ.Elem.(type) { case gotype.CompositeType, gotype.EnumType: decls.AddAll(NewArrayDecoderDeclarer(typ)) @@ -149,38 +148,98 @@ func NewConstantDeclarer(key, str string) ConstantDeclarer { func (c ConstantDeclarer) DedupeKey() string { return c.key } func (c ConstantDeclarer) Declare(string) (string, error) { return c.str, nil } -const ignoredOIDDecl = `// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0` - -var ignoredOIDDeclarer = NewConstantDeclarer("const::ignoredOID", ignoredOIDDecl) - -const textEncoderDecl = `// textEncoder wraps a pgtype.ValueTranscoder and sets the preferred encoding -// format to text instead binary (the default). pggen must use the text format -// because the Postgres binary format requires the type OID but pggen doesn't -// necessarily know the OIDs of the types, hence ignoredOID. -type textEncoder struct { - pgtype.ValueTranscoder +const typeResolverInitDecl = `// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name } -// PreferredParamFormat implements pgtype.ParamFormatPreferrer. -func (t textEncoder) PreferredParamFormat() int16 { return pgtype.TextFormatCode }` +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} -func NewTextEncoderDeclarer() ConstantDeclarer { - return NewConstantDeclarer("const::textEncoder", textEncoderDecl) +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true } -const setValueDecl = `// setValue sets the value of a ValueTranscoder to a value that should always +// setValue sets the value of a ValueTranscoder to a value that should always // work and panics if it fails. -func setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { if err := vt.Set(val); err != nil { panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) } return vt }` -func NewSetValueDeclarer() ConstantDeclarer { - return NewConstantDeclarer("const::setValue", setValueDecl) +// NewTypeResolverInitDeclarer declare type resolver init code always needed. +func NewTypeResolverInitDeclarer() ConstantDeclarer { + return NewConstantDeclarer("type_resolver::00_common", typeResolverInitDecl) +} + +const typeResolverBodyDecl = `type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val + } + // Okay to ignore error because it's only thrown when the number of field + // names does not equal the number of ValueTranscoders. + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ +}` + +// NewTypeResolverDeclarer declares type resolver body code sometimes needed. +func NewTypeResolverDeclarer() ConstantDeclarer { + return NewConstantDeclarer("type_resolver::01_common", typeResolverBodyDecl) } diff --git a/internal/codegen/golang/declarer_array.go b/internal/codegen/golang/declarer_array.go index 1c42a4f..dfeaa4c 100644 --- a/internal/codegen/golang/declarer_array.go +++ b/internal/codegen/golang/declarer_array.go @@ -40,7 +40,7 @@ func NewArrayDecoderDeclarer(typ gotype.ArrayType) ArrayTranscoderDeclarer { } func (a ArrayTranscoderDeclarer) DedupeKey() string { - return "types_array::" + a.typ.Name + "_01_transcoder" + return "type_resolver::" + a.typ.Name + "_01_transcoder" } func (a ArrayTranscoderDeclarer) Declare(string) (string, error) { @@ -56,20 +56,21 @@ func (a ArrayTranscoderDeclarer) Declare(string) (string, error) { sb.WriteString("' array type.\n") // Function signature - sb.WriteString("func ") + sb.WriteString("func (tr *typeResolver) ") sb.WriteString(funcName) sb.WriteString("() pgtype.ValueTranscoder {\n\t") - // NewArrayType call - sb.WriteString("return pgtype.NewArrayType(") + // newArrayValue call + sb.WriteString("return tr.newArrayValue(") sb.WriteString(strconv.Quote(a.typ.PgArray.Name)) sb.WriteString(", ") - sb.WriteString("ignoredOID") + sb.WriteString(strconv.Quote(a.typ.PgArray.ElemType.String())) sb.WriteString(", ") - // Element decoder + // Default element transcoder switch elem := a.typ.Elem.(type) { case gotype.CompositeType: + sb.WriteString("tr.") sb.WriteString(NameCompositeTranscoderFunc(elem)) case gotype.EnumType: sb.WriteString(NameEnumTranscoderFunc(elem)) @@ -104,7 +105,7 @@ func NewArrayInitDeclarer(typ gotype.ArrayType) ArrayInitDeclarer { } func (a ArrayInitDeclarer) DedupeKey() string { - return "types_array::" + a.typ.Name + "_02_init" + return "type_resolver::" + a.typ.Name + "_02_init" } func (a ArrayInitDeclarer) Declare(string) (string, error) { @@ -121,23 +122,25 @@ func (a ArrayInitDeclarer) Declare(string) (string, error) { sb.WriteString("' to encode query parameters.\n") // Function signature - sb.WriteString("func ") + sb.WriteString("func (tr *typeResolver) ") sb.WriteString(funcName) sb.WriteString("(ps ") sb.WriteString(a.typ.Name) - sb.WriteString(") textEncoder {\n\t") + sb.WriteString(") pgtype.ValueTranscoder {\n\t") // Function body - sb.WriteString("dec := ") + sb.WriteString("dec := tr.") sb.WriteString(NameArrayTranscoderFunc(a.typ)) sb.WriteString("()\n\t") - sb.WriteString("if err := dec.Set(") + sb.WriteString("if err := dec.Set(tr.") sb.WriteString(NameArrayRawFunc(a.typ)) sb.WriteString("(ps)); err != nil {\n\t\t") sb.WriteString(fmt.Sprintf(`panic("encode %s: " + err.Error())`, a.typ.Name)) sb.WriteString(" // should always succeed\n\t") sb.WriteString("}\n\t") - sb.WriteString("return textEncoder{ValueTranscoder: dec}\n") + sb.WriteString("return textPreferrer{ValueTranscoder: dec, typeName: ") + sb.WriteString(strconv.Quote(a.typ.PgArray.Name)) + sb.WriteString("}\n") sb.WriteString("}") return sb.String(), nil } @@ -154,7 +157,7 @@ func NewArrayRawDeclarer(typ gotype.ArrayType) ArrayRawDeclarer { } func (a ArrayRawDeclarer) DedupeKey() string { - return "types_array::" + a.typ.Name + "_03_raw" + return "type_resolver::" + a.typ.Name + "_03_raw" } func (a ArrayRawDeclarer) Declare(string) (string, error) { @@ -170,7 +173,7 @@ func (a ArrayRawDeclarer) Declare(string) (string, error) { sb.WriteString("'\n// as a slice of interface{} for use with the pgtype.Value Set method.\n") // Function signature - sb.WriteString("func ") + sb.WriteString("func (tr *typeResolver) ") sb.WriteString(funcName) sb.WriteString("(vs ") sb.WriteString(a.typ.Name) @@ -182,6 +185,7 @@ func (a ArrayRawDeclarer) Declare(string) (string, error) { sb.WriteString("elems[i] = ") switch elem := a.typ.Elem.(type) { case gotype.CompositeType: + sb.WriteString("tr.") sb.WriteString(NameCompositeRawFunc(elem)) sb.WriteString("(v)") default: diff --git a/internal/codegen/golang/declarer_composite.go b/internal/codegen/golang/declarer_composite.go index 1b0d215..549b3f2 100644 --- a/internal/codegen/golang/declarer_composite.go +++ b/internal/codegen/golang/declarer_composite.go @@ -30,19 +30,6 @@ func NameCompositeRawFunc(typ gotype.CompositeType) string { return "new" + typ.Name + "Raw" } -const newCompositeTypeDecl = `func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} - } - // Okay to ignore error because it's only thrown when the number of field - // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType -}` - -var newCompositeTypeDeclarer = NewConstantDeclarer("func::newCompositeType", newCompositeTypeDecl) - // CompositeTypeDeclarer declares a new Go struct to represent a Postgres // composite type. type CompositeTypeDeclarer struct { @@ -130,7 +117,7 @@ func NewCompositeTranscoderDeclarer(typ gotype.CompositeType) CompositeTranscode } func (c CompositeTranscoderDeclarer) DedupeKey() string { - return "types_composite::" + c.typ.Name + "_01_transcoder" + return "type_resolver::" + c.typ.Name + "_01_transcoder" } func (c CompositeTranscoderDeclarer) Declare(pkgPath string) (string, error) { @@ -147,41 +134,38 @@ func (c CompositeTranscoderDeclarer) Declare(pkgPath string) (string, error) { sb.WriteString("'.\n") // Function signature - sb.WriteString("func ") + sb.WriteString("func (tr *typeResolver) ") sb.WriteString(funcName) sb.WriteString("() pgtype.ValueTranscoder {\n\t") - // newCompositeType call - sb.WriteString("return newCompositeType(\n\t\t") + // newCompositeValue call + sb.WriteString("return tr.newCompositeValue(\n\t\t") sb.WriteString(strconv.Quote(c.typ.PgComposite.Name)) - sb.WriteString(",\n\t\t") + sb.WriteString(",") - // newCompositeType - field names of the composite type - sb.WriteString(`[]string{`) + // newCompositeValue - field names of the composite type for i := range c.typ.FieldNames { - sb.WriteByte('"') - sb.WriteString(c.typ.PgComposite.ColumnNames[i]) - sb.WriteByte('"') - if i < len(c.typ.FieldNames)-1 { - sb.WriteString(", ") - } - } - sb.WriteString("},") - - // newCompositeType - child decoders - for i, fieldType := range c.typ.FieldTypes { sb.WriteString("\n\t\t") - switch fieldType := fieldType.(type) { + sb.WriteString(`compositeField{`) + sb.WriteString(strconv.Quote(c.typ.PgComposite.ColumnNames[i])) // field name + sb.WriteString(", ") + sb.WriteString(strconv.Quote(c.typ.PgComposite.ColumnTypes[i].String())) // field type name + sb.WriteString(", ") + + // field default pgtype.ValueTranscoder + switch fieldType := c.typ.FieldTypes[i].(type) { case gotype.CompositeType: childFuncName := NameCompositeTranscoderFunc(fieldType) + sb.WriteString("tr.") sb.WriteString(childFuncName) - sb.WriteString("(),") + sb.WriteString("()") case gotype.EnumType: sb.WriteString(NameEnumTranscoderFunc(fieldType)) - sb.WriteString("(),") + sb.WriteString("()") case gotype.ArrayType: + sb.WriteString("tr.") sb.WriteString(NameArrayTranscoderFunc(fieldType)) - sb.WriteString("(),") + sb.WriteString("()") case gotype.VoidType: // skip default: @@ -193,16 +177,18 @@ func (c CompositeTranscoderDeclarer) Declare(pkgPath string) (string, error) { sb.WriteString("nil,") } else { // We need the pgx variant because it matches the interface expected by - // newCompositeType, pgtype.ValueTranscoder. + // newCompositeValue, pgtype.ValueTranscoder. if decoderType, ok := gotype.FindKnownTypePgx(pgType.OID()); ok { fieldType = decoderType } sb.WriteString(fieldType.QualifyRel(pkgPath)) - sb.WriteString("{},") + sb.WriteString("{}") } } + sb.WriteString(`},`) } + sb.WriteString("\n\t") sb.WriteString(")") sb.WriteString("\n") @@ -231,7 +217,7 @@ func NewCompositeInitDeclarer(typ gotype.CompositeType) CompositeInitDeclarer { } func (c CompositeInitDeclarer) DedupeKey() string { - return "types_composite::" + c.typ.Name + "_02_init" + return "type_resolver::" + c.typ.Name + "_02_init" } func (c CompositeInitDeclarer) Declare(string) (string, error) { @@ -248,18 +234,20 @@ func (c CompositeInitDeclarer) Declare(string) (string, error) { sb.WriteString("' to encode query parameters.\n") // Function signature - sb.WriteString("func ") + sb.WriteString("func (tr *typeResolver) ") sb.WriteString(funcName) sb.WriteString("(v ") sb.WriteString(c.typ.Name) sb.WriteString(") pgtype.ValueTranscoder {\n\t") // Function body - sb.WriteString("return textEncoder{setValue(") + sb.WriteString("return textPreferrer{tr.setValue(tr.") sb.WriteString(NameCompositeTranscoderFunc(c.typ)) - sb.WriteString("(), ") + sb.WriteString("(), tr.") sb.WriteString(NameCompositeRawFunc(c.typ)) - sb.WriteString("(v))}\n") + sb.WriteString("(v)), ") + sb.WriteString(strconv.Quote(c.typ.PgComposite.Name)) + sb.WriteString("}\n") sb.WriteString("}") return sb.String(), nil } @@ -279,7 +267,7 @@ func NewCompositeRawDeclarer(typ gotype.CompositeType) CompositeRawDeclarer { } func (c CompositeRawDeclarer) DedupeKey() string { - return "types_composite::" + c.typ.Name + "_03_raw" + return "type_resolver::" + c.typ.Name + "_03_raw" } func (c CompositeRawDeclarer) Declare(string) (string, error) { @@ -296,7 +284,7 @@ func (c CompositeRawDeclarer) Declare(string) (string, error) { sb.WriteString("' as a slice of interface{} to encode query parameters.\n") // Function signature - sb.WriteString("func ") + sb.WriteString("func (tr *typeResolver) ") sb.WriteString(funcName) sb.WriteString("(v ") sb.WriteString(c.typ.Name) @@ -312,11 +300,13 @@ func (c CompositeRawDeclarer) Declare(string) (string, error) { switch fieldType := fieldType.(type) { case gotype.CompositeType: childFuncName := NameCompositeRawFunc(fieldType) + sb.WriteString("tr.") sb.WriteString(childFuncName) sb.WriteString("(v.") sb.WriteString(fieldName) sb.WriteString(")") case gotype.ArrayType: + sb.WriteString("tr.") sb.WriteString(NameArrayRawFunc(fieldType)) sb.WriteString("(v.") sb.WriteString(fieldName) diff --git a/internal/codegen/golang/declarer_test.go b/internal/codegen/golang/declarer_test.go index 00b3ec5..b70914d 100644 --- a/internal/codegen/golang/declarer_test.go +++ b/internal/codegen/golang/declarer_test.go @@ -57,7 +57,8 @@ func TestDeclarers(t *testing.T) { typ: gotype.CompositeType{ PgComposite: pg.CompositeType{ Name: "some_table_enum", - ColumnNames: []string{"foo", "bar_baz"}, + ColumnNames: []string{"foo"}, + ColumnTypes: []pg.Type{pg.EnumType{Name: "some_table_enum"}}, }, PkgPath: "example.com/foo", Pkg: "foo", diff --git a/internal/codegen/golang/query.gotemplate b/internal/codegen/golang/query.gotemplate index 90d931d..ebef7c4 100644 --- a/internal/codegen/golang/query.gotemplate +++ b/internal/codegen/golang/query.gotemplate @@ -34,7 +34,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -61,9 +62,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -188,5 +203,34 @@ func (q *DBQuerier) {{.Name}}Scan(results pgx.BatchResults) ({{ $q.EmitResultTyp {{- end }} } {{- end -}} + +{{- if .IsLeader -}} +{{- "\n\n" -}} +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 +{{- end -}} {{- "\n" -}} {{- end -}} diff --git a/internal/codegen/golang/templated_file.go b/internal/codegen/golang/templated_file.go index e6bf513..3c9230e 100644 --- a/internal/codegen/golang/templated_file.go +++ b/internal/codegen/golang/templated_file.go @@ -136,11 +136,13 @@ func (tq TemplatedQuery) EmitParamNames() string { appendParam := func(sb *strings.Builder, typ gotype.Type, name string) { switch typ := typ.(type) { case gotype.CompositeType: + sb.WriteString("q.types.") sb.WriteString(NameCompositeInitFunc(typ)) sb.WriteString("(") sb.WriteString(name) sb.WriteString(")") case gotype.ArrayType: + sb.WriteString("q.types.") sb.WriteString(NameArrayInitFunc(typ)) sb.WriteString("(") sb.WriteString(name) @@ -286,7 +288,7 @@ func (tq TemplatedQuery) EmitResultDecoders() (string, error) { case gotype.CompositeType: sb.WriteString(indent) sb.WriteString(out.LowerName) - sb.WriteString("Row := ") + sb.WriteString("Row := q.types.") sb.WriteString(NameCompositeTranscoderFunc(typ)) sb.WriteString("()") case gotype.ArrayType: @@ -295,7 +297,7 @@ func (tq TemplatedQuery) EmitResultDecoders() (string, error) { // For all other array elems, a normal array works. sb.WriteString(indent) sb.WriteString(out.LowerName) - sb.WriteString("Array := ") + sb.WriteString("Array := q.types.") sb.WriteString(NameArrayTranscoderFunc(typ)) sb.WriteString("()") } diff --git a/internal/codegen/golang/templater.go b/internal/codegen/golang/templater.go index 1f38f44..ee18d4c 100644 --- a/internal/codegen/golang/templater.go +++ b/internal/codegen/golang/templater.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/jschaf/pggen/internal/casing" "github.com/jschaf/pggen/internal/codegen" - "github.com/jschaf/pggen/internal/codegen/golang/gotype" "github.com/jschaf/pggen/internal/gomod" "strconv" "strings" @@ -38,8 +37,19 @@ func (tm Templater) TemplateAll(files []codegen.QueryFile) ([]TemplatedFile, err goQueryFiles := make([]TemplatedFile, 0, len(files)) var declarers DeclarerSet - for _, queryFile := range files { - goFile, decls, err := tm.templateFile(queryFile) + // Pick leader file to define common structs and interfaces via Declarer. + firstIndex := -1 + firstName := string(unicode.MaxRune) + for i, f := range files { + if f.SourcePath < firstName { + firstIndex = i + firstName = f.SourcePath + } + } + + for i, queryFile := range files { + isLeader := i == firstIndex + goFile, decls, err := tm.templateFile(queryFile, isLeader) declarers = decls if err != nil { return nil, fmt.Errorf("template query file %s for go: %w", queryFile.SourcePath, err) @@ -49,17 +59,6 @@ func (tm Templater) TemplateAll(files []codegen.QueryFile) ([]TemplatedFile, err declarers.AddAll(ds...) } - // Pick leader file to define common structs and interfaces via Declarer. - firstIndex := -1 - firstName := string(unicode.MaxRune) - for i, goFile := range goQueryFiles { - if goFile.SourcePath < firstName { - firstIndex = i - firstName = goFile.SourcePath - } - } - goQueryFiles[firstIndex].IsLeader = true - if len(declarers) > 0 { goQueryFiles[firstIndex].Declarers = declarers.ListAll() } @@ -108,11 +107,14 @@ func (tm Templater) TemplateAll(files []codegen.QueryFile) ([]TemplatedFile, err // templateFile creates the data needed to build a Go file for a query file. // Also returns any declarations needed by this query file. The caller must // dedupe declarations. -func (tm Templater) templateFile(file codegen.QueryFile) (TemplatedFile, DeclarerSet, error) { +func (tm Templater) templateFile(file codegen.QueryFile, isLeader bool) (TemplatedFile, DeclarerSet, error) { imports := NewImportSet() imports.AddPackage("context") imports.AddPackage("fmt") imports.AddPackage("github.com/jackc/pgconn") + if isLeader { + imports.AddPackage("github.com/jackc/pgtype") + } imports.AddPackage("github.com/jackc/pgx/v4") pkgPath := "" @@ -167,12 +169,6 @@ func (tm Templater) templateFile(file codegen.QueryFile) (TemplatedFile, Declare return TemplatedFile{}, nil, err } imports.AddType(goType) - if isCompositeArray(goType) { - imports.AddPackage("github.com/jackc/pgtype") // needed for decoder types - } - if gotype.HasCompositeType(goType) { - imports.AddPackage("github.com/jackc/pgtype") // needed for newCompositeType - } outputs[i] = TemplatedColumn{ PgName: out.PgName, UpperName: tm.chooseUpperName(out.PgName, "UnnamedColumn", i, len(query.Outputs)), @@ -201,6 +197,7 @@ func (tm Templater) templateFile(file codegen.QueryFile) (TemplatedFile, Declare SourcePath: file.SourcePath, Queries: queries, Imports: imports.SortedPackages(), + IsLeader: isLeader, }, declarers, nil } @@ -231,12 +228,3 @@ func (tm Templater) chooseLowerName(pgName string, fallback string, idx int, num } return fallback + suffix } - -func isCompositeArray(typ gotype.Type) bool { - if typ, ok := typ.(gotype.ArrayType); !ok { - return false - } else if _, ok := typ.Elem.(gotype.CompositeType); !ok { - return false - } - return true -} diff --git a/internal/codegen/golang/testdata/declarer_composite.input.golden b/internal/codegen/golang/testdata/declarer_composite.input.golden index 6e60354..10fae15 100644 --- a/internal/codegen/golang/testdata/declarer_composite.input.golden +++ b/internal/codegen/golang/testdata/declarer_composite.input.golden @@ -4,63 +4,111 @@ type SomeTable struct { BarBaz pgtype.Text `json:"bar_baz"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} // setValue sets the value of a ValueTranscoder to a value that should always // work and panics if it fails. -func setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { if err := vt.Set(val); err != nil { panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) } return vt } -// textEncoder wraps a pgtype.ValueTranscoder and sets the preferred encoding -// format to text instead binary (the default). pggen must use the text format -// because the Postgres binary format requires the type OID but pggen doesn't -// necessarily know the OIDs of the types, hence ignoredOID. -type textEncoder struct { - pgtype.ValueTranscoder +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use } -// PreferredParamFormat implements pgtype.ParamFormatPreferrer. -func (t textEncoder) PreferredParamFormat() int16 { return pgtype.TextFormatCode } - -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newSomeTable creates a new pgtype.ValueTranscoder for the Postgres // composite type 'some_table'. -func newSomeTable() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newSomeTable() pgtype.ValueTranscoder { + return tr.newCompositeValue( "some_table", - []string{"foo", "bar_baz"}, - &pgtype.Int2{}, - &pgtype.Text{}, + compositeField{"foo", "int2", &pgtype.Int2{}}, + compositeField{"bar_baz", "text", &pgtype.Text{}}, ) } // newSomeTableInit creates an initialized pgtype.ValueTranscoder for the // Postgres composite type 'some_table' to encode query parameters. -func newSomeTableInit(v SomeTable) pgtype.ValueTranscoder { - return textEncoder{setValue(newSomeTable(), newSomeTableRaw(v))} +func (tr *typeResolver) newSomeTableInit(v SomeTable) pgtype.ValueTranscoder { + return textPreferrer{tr.setValue(tr.newSomeTable(), tr.newSomeTableRaw(v)), "some_table"} } // newSomeTableRaw returns all composite fields for the Postgres composite // type 'some_table' as a slice of interface{} to encode query parameters. -func newSomeTableRaw(v SomeTable) []interface{} { +func (tr *typeResolver) newSomeTableRaw(v SomeTable) []interface{} { return []interface{}{ v.Foo, v.BarBaz, diff --git a/internal/codegen/golang/testdata/declarer_composite.output.golden b/internal/codegen/golang/testdata/declarer_composite.output.golden index 1d7d852..65606d1 100644 --- a/internal/codegen/golang/testdata/declarer_composite.output.golden +++ b/internal/codegen/golang/testdata/declarer_composite.output.golden @@ -4,30 +4,98 @@ type SomeTable struct { BarBaz pgtype.Text `json:"bar_baz"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newSomeTable creates a new pgtype.ValueTranscoder for the Postgres // composite type 'some_table'. -func newSomeTable() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newSomeTable() pgtype.ValueTranscoder { + return tr.newCompositeValue( "some_table", - []string{"foo", "bar_baz"}, - &pgtype.Int2{}, - &pgtype.Text{}, + compositeField{"foo", "int2", &pgtype.Int2{}}, + compositeField{"bar_baz", "text", &pgtype.Text{}}, ) } \ No newline at end of file diff --git a/internal/codegen/golang/testdata/declarer_composite_array.input.golden b/internal/codegen/golang/testdata/declarer_composite_array.input.golden index 1134894..986a2b1 100644 --- a/internal/codegen/golang/testdata/declarer_composite_array.input.golden +++ b/internal/codegen/golang/testdata/declarer_composite_array.input.golden @@ -4,83 +4,131 @@ type SomeTable struct { BarBaz pgtype.Text `json:"bar_baz"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} // setValue sets the value of a ValueTranscoder to a value that should always // work and panics if it fails. -func setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { if err := vt.Set(val); err != nil { panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) } return vt } -// textEncoder wraps a pgtype.ValueTranscoder and sets the preferred encoding -// format to text instead binary (the default). pggen must use the text format -// because the Postgres binary format requires the type OID but pggen doesn't -// necessarily know the OIDs of the types, hence ignoredOID. -type textEncoder struct { - pgtype.ValueTranscoder +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use } -// PreferredParamFormat implements pgtype.ParamFormatPreferrer. -func (t textEncoder) PreferredParamFormat() int16 { return pgtype.TextFormatCode } - -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newSomeTableArray creates a new pgtype.ValueTranscoder for the Postgres // '_some_array' array type. -func newSomeTableArray() pgtype.ValueTranscoder { - return pgtype.NewArrayType("_some_array", ignoredOID, newSomeTable) +func (tr *typeResolver) newSomeTableArray() pgtype.ValueTranscoder { + return tr.newArrayValue("_some_array", "some_table", tr.newSomeTable) } // newSomeTableArrayInit creates an initialized pgtype.ValueTranscoder for the // Postgres array type '_some_array' to encode query parameters. -func newSomeTableArrayInit(ps SomeArray) textEncoder { - dec := newSomeTableArray() - if err := dec.Set(newSomeTableArrayRaw(ps)); err != nil { +func (tr *typeResolver) newSomeTableArrayInit(ps SomeArray) pgtype.ValueTranscoder { + dec := tr.newSomeTableArray() + if err := dec.Set(tr.newSomeTableArrayRaw(ps)); err != nil { panic("encode SomeArray: " + err.Error()) // should always succeed } - return textEncoder{ValueTranscoder: dec} + return textPreferrer{ValueTranscoder: dec, typeName: "_some_array"} } // newSomeTableArrayRaw returns all elements for the Postgres array type '_some_array' // as a slice of interface{} for use with the pgtype.Value Set method. -func newSomeTableArrayRaw(vs SomeArray) []interface{} { +func (tr *typeResolver) newSomeTableArrayRaw(vs SomeArray) []interface{} { elems := make([]interface{}, len(vs)) for i, v := range vs { - elems[i] = newSomeTableRaw(v) + elems[i] = tr.newSomeTableRaw(v) } return elems } // newSomeTable creates a new pgtype.ValueTranscoder for the Postgres // composite type 'some_table'. -func newSomeTable() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newSomeTable() pgtype.ValueTranscoder { + return tr.newCompositeValue( "some_table", - []string{"foo", "bar_baz"}, - &pgtype.Int2{}, - &pgtype.Text{}, + compositeField{"foo", "int2", &pgtype.Int2{}}, + compositeField{"bar_baz", "text", &pgtype.Text{}}, ) } // newSomeTableRaw returns all composite fields for the Postgres composite // type 'some_table' as a slice of interface{} to encode query parameters. -func newSomeTableRaw(v SomeTable) []interface{} { +func (tr *typeResolver) newSomeTableRaw(v SomeTable) []interface{} { return []interface{}{ v.Foo, v.BarBaz, diff --git a/internal/codegen/golang/testdata/declarer_composite_array.output.golden b/internal/codegen/golang/testdata/declarer_composite_array.output.golden index 000b724..fcd3164 100644 --- a/internal/codegen/golang/testdata/declarer_composite_array.output.golden +++ b/internal/codegen/golang/testdata/declarer_composite_array.output.golden @@ -4,36 +4,104 @@ type SomeTable struct { BarBaz pgtype.Text `json:"bar_baz"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newSomeTableArray creates a new pgtype.ValueTranscoder for the Postgres // '_some_array' array type. -func newSomeTableArray() pgtype.ValueTranscoder { - return pgtype.NewArrayType("_some_array", ignoredOID, newSomeTable) +func (tr *typeResolver) newSomeTableArray() pgtype.ValueTranscoder { + return tr.newArrayValue("_some_array", "some_table", tr.newSomeTable) } // newSomeTable creates a new pgtype.ValueTranscoder for the Postgres // composite type 'some_table'. -func newSomeTable() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newSomeTable() pgtype.ValueTranscoder { + return tr.newCompositeValue( "some_table", - []string{"foo", "bar_baz"}, - &pgtype.Int2{}, - &pgtype.Text{}, + compositeField{"foo", "int2", &pgtype.Int2{}}, + compositeField{"bar_baz", "text", &pgtype.Text{}}, ) } \ No newline at end of file diff --git a/internal/codegen/golang/testdata/declarer_composite_enum.input.golden b/internal/codegen/golang/testdata/declarer_composite_enum.input.golden index 288e2de..3c34435 100644 --- a/internal/codegen/golang/testdata/declarer_composite_enum.input.golden +++ b/internal/codegen/golang/testdata/declarer_composite_enum.input.golden @@ -3,32 +3,6 @@ type SomeTableEnum struct { Foo DeviceType `json:"foo"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 - -// setValue sets the value of a ValueTranscoder to a value that should always -// work and panics if it fails. -func setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { - if err := vt.Set(val); err != nil { - panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) - } - return vt -} - -// textEncoder wraps a pgtype.ValueTranscoder and sets the preferred encoding -// format to text instead binary (the default). pggen must use the text format -// because the Postgres binary format requires the type OID but pggen doesn't -// necessarily know the OIDs of the types, hence ignoredOID. -type textEncoder struct { - pgtype.ValueTranscoder -} - -// PreferredParamFormat implements pgtype.ParamFormatPreferrer. -func (t textEncoder) PreferredParamFormat() int16 { return pgtype.TextFormatCode } - // newDeviceTypeEnum creates a new pgtype.ValueTranscoder for the // Postgres enum type 'device_type'. func newDeviceTypeEnum() pgtype.ValueTranscoder { @@ -51,36 +25,110 @@ const ( func (d DeviceType) String() string { return string(d) } -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newSomeTableEnum creates a new pgtype.ValueTranscoder for the Postgres // composite type 'some_table_enum'. -func newSomeTableEnum() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newSomeTableEnum() pgtype.ValueTranscoder { + return tr.newCompositeValue( "some_table_enum", - []string{"foo"}, - newDeviceTypeEnum(), + compositeField{"foo", "some_table_enum", newDeviceTypeEnum()}, ) } // newSomeTableEnumInit creates an initialized pgtype.ValueTranscoder for the // Postgres composite type 'some_table_enum' to encode query parameters. -func newSomeTableEnumInit(v SomeTableEnum) pgtype.ValueTranscoder { - return textEncoder{setValue(newSomeTableEnum(), newSomeTableEnumRaw(v))} +func (tr *typeResolver) newSomeTableEnumInit(v SomeTableEnum) pgtype.ValueTranscoder { + return textPreferrer{tr.setValue(tr.newSomeTableEnum(), tr.newSomeTableEnumRaw(v)), "some_table_enum"} } // newSomeTableEnumRaw returns all composite fields for the Postgres composite // type 'some_table_enum' as a slice of interface{} to encode query parameters. -func newSomeTableEnumRaw(v SomeTableEnum) []interface{} { +func (tr *typeResolver) newSomeTableEnumRaw(v SomeTableEnum) []interface{} { return []interface{}{ v.Foo, } diff --git a/internal/codegen/golang/testdata/declarer_composite_enum.output.golden b/internal/codegen/golang/testdata/declarer_composite_enum.output.golden index 6690fff..fdcc6b0 100644 --- a/internal/codegen/golang/testdata/declarer_composite_enum.output.golden +++ b/internal/codegen/golang/testdata/declarer_composite_enum.output.golden @@ -3,12 +3,6 @@ type SomeTableEnum struct { Foo DeviceType `json:"foo"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 - // newDeviceTypeEnum creates a new pgtype.ValueTranscoder for the // Postgres enum type 'device_type'. func newDeviceTypeEnum() pgtype.ValueTranscoder { @@ -31,23 +25,97 @@ const ( func (d DeviceType) String() string { return string(d) } -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newSomeTableEnum creates a new pgtype.ValueTranscoder for the Postgres // composite type 'some_table_enum'. -func newSomeTableEnum() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newSomeTableEnum() pgtype.ValueTranscoder { + return tr.newCompositeValue( "some_table_enum", - []string{"foo"}, - newDeviceTypeEnum(), + compositeField{"foo", "some_table_enum", newDeviceTypeEnum()}, ) } \ No newline at end of file diff --git a/internal/codegen/golang/testdata/declarer_composite_nested.input.golden b/internal/codegen/golang/testdata/declarer_composite_nested.input.golden index 5d7e84e..de5eef7 100644 --- a/internal/codegen/golang/testdata/declarer_composite_nested.input.golden +++ b/internal/codegen/golang/testdata/declarer_composite_nested.input.golden @@ -9,56 +9,104 @@ type SomeTableNested struct { BarBaz pgtype.Text `json:"bar_baz"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} // setValue sets the value of a ValueTranscoder to a value that should always // work and panics if it fails. -func setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { if err := vt.Set(val); err != nil { panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) } return vt } -// textEncoder wraps a pgtype.ValueTranscoder and sets the preferred encoding -// format to text instead binary (the default). pggen must use the text format -// because the Postgres binary format requires the type OID but pggen doesn't -// necessarily know the OIDs of the types, hence ignoredOID. -type textEncoder struct { - pgtype.ValueTranscoder +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use } -// PreferredParamFormat implements pgtype.ParamFormatPreferrer. -func (t textEncoder) PreferredParamFormat() int16 { return pgtype.TextFormatCode } - -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newFooType creates a new pgtype.ValueTranscoder for the Postgres // composite type 'foo_type'. -func newFooType() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newFooType() pgtype.ValueTranscoder { + return tr.newCompositeValue( "foo_type", - []string{"alpha"}, - &pgtype.Text{}, + compositeField{"alpha", "text", &pgtype.Text{}}, ) } // newFooTypeRaw returns all composite fields for the Postgres composite // type 'foo_type' as a slice of interface{} to encode query parameters. -func newFooTypeRaw(v FooType) []interface{} { +func (tr *typeResolver) newFooTypeRaw(v FooType) []interface{} { return []interface{}{ v.Alpha, } @@ -66,26 +114,25 @@ func newFooTypeRaw(v FooType) []interface{} { // newSomeTableNested creates a new pgtype.ValueTranscoder for the Postgres // composite type 'some_table_nested'. -func newSomeTableNested() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newSomeTableNested() pgtype.ValueTranscoder { + return tr.newCompositeValue( "some_table_nested", - []string{"foo", "bar_baz"}, - newFooType(), - &pgtype.Text{}, + compositeField{"foo", "foo_type", tr.newFooType()}, + compositeField{"bar_baz", "text", &pgtype.Text{}}, ) } // newSomeTableNestedInit creates an initialized pgtype.ValueTranscoder for the // Postgres composite type 'some_table_nested' to encode query parameters. -func newSomeTableNestedInit(v SomeTableNested) pgtype.ValueTranscoder { - return textEncoder{setValue(newSomeTableNested(), newSomeTableNestedRaw(v))} +func (tr *typeResolver) newSomeTableNestedInit(v SomeTableNested) pgtype.ValueTranscoder { + return textPreferrer{tr.setValue(tr.newSomeTableNested(), tr.newSomeTableNestedRaw(v)), "some_table_nested"} } // newSomeTableNestedRaw returns all composite fields for the Postgres composite // type 'some_table_nested' as a slice of interface{} to encode query parameters. -func newSomeTableNestedRaw(v SomeTableNested) []interface{} { +func (tr *typeResolver) newSomeTableNestedRaw(v SomeTableNested) []interface{} { return []interface{}{ - newFooTypeRaw(v.Foo), + tr.newFooTypeRaw(v.Foo), v.BarBaz, } } \ No newline at end of file diff --git a/internal/codegen/golang/testdata/declarer_composite_nested.output.golden b/internal/codegen/golang/testdata/declarer_composite_nested.output.golden index 1ab3ace..cf70e8e 100644 --- a/internal/codegen/golang/testdata/declarer_composite_nested.output.golden +++ b/internal/codegen/golang/testdata/declarer_composite_nested.output.golden @@ -9,40 +9,107 @@ type SomeTableNested struct { BarBaz pgtype.Text `json:"bar_baz"` } -// ignoredOID means we don't know or care about the OID for a type. This is okay -// because pgx only uses the OID to encode values and lookup a decoder. We only -// use ignoredOID for decoding and we always specify a concrete decoder for scan -// methods. -const ignoredOID = 0 +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} -func newCompositeType(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name, OID: ignoredOID} +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val } // Okay to ignore error because it's only thrown when the number of field // names does not equal the number of ValueTranscoders. - rowType, _ := pgtype.NewCompositeTypeValues(name, fields, vals) - return rowType + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ } // newFooType creates a new pgtype.ValueTranscoder for the Postgres // composite type 'foo_type'. -func newFooType() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newFooType() pgtype.ValueTranscoder { + return tr.newCompositeValue( "foo_type", - []string{"alpha"}, - &pgtype.Text{}, + compositeField{"alpha", "text", &pgtype.Text{}}, ) } // newSomeTableNested creates a new pgtype.ValueTranscoder for the Postgres // composite type 'some_table_nested'. -func newSomeTableNested() pgtype.ValueTranscoder { - return newCompositeType( +func (tr *typeResolver) newSomeTableNested() pgtype.ValueTranscoder { + return tr.newCompositeValue( "some_table_nested", - []string{"foo", "bar_baz"}, - newFooType(), - &pgtype.Text{}, + compositeField{"foo", "foo_type", tr.newFooType()}, + compositeField{"bar_baz", "text", &pgtype.Text{}}, ) } \ No newline at end of file diff --git a/internal/codegen/golang/testdata/declarer_enum_escaping.input.golden b/internal/codegen/golang/testdata/declarer_enum_escaping.input.golden index 9db398c..b685d5a 100644 --- a/internal/codegen/golang/testdata/declarer_enum_escaping.input.golden +++ b/internal/codegen/golang/testdata/declarer_enum_escaping.input.golden @@ -6,4 +6,39 @@ const ( QuotingUnnamedLabel1 Quoting = "`\"`" ) -func (q Quoting) String() string { return string(q) } \ No newline at end of file +func (q Quoting) String() string { return string(q) } + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} \ No newline at end of file diff --git a/internal/codegen/golang/testdata/declarer_enum_escaping.output.golden b/internal/codegen/golang/testdata/declarer_enum_escaping.output.golden index 9db398c..b685d5a 100644 --- a/internal/codegen/golang/testdata/declarer_enum_escaping.output.golden +++ b/internal/codegen/golang/testdata/declarer_enum_escaping.output.golden @@ -6,4 +6,39 @@ const ( QuotingUnnamedLabel1 Quoting = "`\"`" ) -func (q Quoting) String() string { return string(q) } \ No newline at end of file +func (q Quoting) String() string { return string(q) } + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} \ No newline at end of file diff --git a/internal/codegen/golang/testdata/declarer_enum_simple.input.golden b/internal/codegen/golang/testdata/declarer_enum_simple.input.golden index 1356d1c..0e196e5 100644 --- a/internal/codegen/golang/testdata/declarer_enum_simple.input.golden +++ b/internal/codegen/golang/testdata/declarer_enum_simple.input.golden @@ -6,4 +6,39 @@ const ( DeviceTypeMobile DeviceType = "mobile" ) -func (d DeviceType) String() string { return string(d) } \ No newline at end of file +func (d DeviceType) String() string { return string(d) } + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} \ No newline at end of file diff --git a/internal/codegen/golang/testdata/declarer_enum_simple.output.golden b/internal/codegen/golang/testdata/declarer_enum_simple.output.golden index 1356d1c..0e196e5 100644 --- a/internal/codegen/golang/testdata/declarer_enum_simple.output.golden +++ b/internal/codegen/golang/testdata/declarer_enum_simple.output.golden @@ -6,4 +6,39 @@ const ( DeviceTypeMobile DeviceType = "mobile" ) -func (d DeviceType) String() string { return string(d) } \ No newline at end of file +func (d DeviceType) String() string { return string(d) } + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} \ No newline at end of file diff --git a/internal/pg/query.sql.go b/internal/pg/query.sql.go index 1684b34..fe41e28 100644 --- a/internal/pg/query.sql.go +++ b/internal/pg/query.sql.go @@ -72,7 +72,8 @@ type Querier interface { } type DBQuerier struct { - conn genericConn + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name } var _ Querier = &DBQuerier{} @@ -99,9 +100,23 @@ type genericConn interface { // NewQuerier creates a DBQuerier that implements Querier. conn is typically // *pgx.Conn, pgx.Tx, or *pgxpool.Pool. func NewQuerier(conn genericConn) *DBQuerier { - return &DBQuerier{ - conn: conn, - } + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead of + // pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its transative + // dependencies, pggen will use the binary encoding format for the input + // parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} } // WithTx creates a new DBQuerier that uses the transaction to run all queries. @@ -147,6 +162,41 @@ func PrepareAllQueries(ctx context.Context, p preparer) error { return nil } +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + const findEnumTypesSQL = `WITH enums AS ( SELECT enumtypid::int8 AS enum_type, @@ -610,3 +660,29 @@ func (q *DBQuerier) FindOIDNamesScan(results pgx.BatchResults) ([]FindOIDNamesRo } return items, err } + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0