diff --git a/example/acceptance_test.go b/example/acceptance_test.go index 894166b..d940f5f 100644 --- a/example/acceptance_test.go +++ b/example/acceptance_test.go @@ -175,6 +175,14 @@ func TestExamples(t *testing.T) { "--go-type", "text=string", }, }, + { + name: "example/numeric_external", + args: []string{ + "--schema-glob", "example/numeric_external/schema.sql", + "--query-glob", "example/numeric_external/query.sql", + "--go-type", "numeric=github.com/shopspring/decimal.Decimal", + }, + }, { name: "example/domain", args: []string{ diff --git a/example/numeric_external/codegen_test.go b/example/numeric_external/codegen_test.go new file mode 100644 index 0000000..d102c19 --- /dev/null +++ b/example/numeric_external/codegen_test.go @@ -0,0 +1,50 @@ +package numeric_external + +import ( + "github.com/jschaf/pggen" + "github.com/jschaf/pggen/internal/pgtest" + "github.com/stretchr/testify/assert" + "io/ioutil" + "path/filepath" + "testing" +) + +func TestGenerate_Go_Example_Numeric_External(t *testing.T) { + conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) + defer cleanupFunc() + + tmpDir := t.TempDir() + err := pggen.Generate( + pggen.GenerateOptions{ + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "numeric_external", + Language: pggen.LangGo, + TypeOverrides: map[string]string{ + "int4": "int", + "int8": "int", + "text": "string", + "numeric": "github.com/shopspring/decimal.Decimal", + }, + }) + if err != nil { + t.Fatalf("Generate(): %s", err) + } + + wantQueriesFile := "query.sql.go" + gotQueriesFile := filepath.Join(tmpDir, "query.sql.go") + assert.FileExists(t, gotQueriesFile, + "Generate() should emit query.sql.go") + wantQueries, err := ioutil.ReadFile(wantQueriesFile) + if err != nil { + t.Fatalf("read wanted query.go.sql: %s", err) + } + gotQueries, err := ioutil.ReadFile(gotQueriesFile) + if err != nil { + t.Fatalf("read generated query.go.sql: %s", err) + } + assert.Equalf(t, string(wantQueries), string(gotQueries), + "Got file %s; does not match contents of %s", + gotQueriesFile, wantQueriesFile) +} diff --git a/example/numeric_external/query.sql b/example/numeric_external/query.sql new file mode 100644 index 0000000..e1951f3 --- /dev/null +++ b/example/numeric_external/query.sql @@ -0,0 +1,7 @@ +-- name: InsertNumeric :exec +INSERT INTO numeric_external (num, num_arr) +VALUES (pggen.arg('num'), pggen.arg('num_arr')); + +-- name: FindNumerics :many +SELECT num, num_arr +FROM numeric_external; diff --git a/example/numeric_external/query.sql.go b/example/numeric_external/query.sql.go new file mode 100644 index 0000000..9f5a86e --- /dev/null +++ b/example/numeric_external/query.sql.go @@ -0,0 +1,368 @@ +// Code generated by pggen. DO NOT EDIT. + +package numeric_external + +import ( + "context" + "fmt" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" + "github.com/shopspring/decimal" +) + +// Querier is a typesafe Go interface backed by SQL queries. +// +// Methods ending with Batch enqueue a query to run later in a pgx.Batch. After +// calling SendBatch on pgx.Conn, pgxpool.Pool, or pgx.Tx, use the Scan methods +// to parse the results. +type Querier interface { + InsertNumeric(ctx context.Context, num decimal.Decimal, numArr []NumericExternalType) (pgconn.CommandTag, error) + // InsertNumericBatch enqueues a InsertNumeric query into batch to be executed + // later by the batch. + InsertNumericBatch(batch genericBatch, num decimal.Decimal, numArr []NumericExternalType) + // InsertNumericScan scans the result of an executed InsertNumericBatch query. + InsertNumericScan(results pgx.BatchResults) (pgconn.CommandTag, error) + + FindNumerics(ctx context.Context) ([]FindNumericsRow, error) + // FindNumericsBatch enqueues a FindNumerics query into batch to be executed + // later by the batch. + FindNumericsBatch(batch genericBatch) + // FindNumericsScan scans the result of an executed FindNumericsBatch query. + FindNumericsScan(results pgx.BatchResults) ([]FindNumericsRow, error) +} + +type DBQuerier struct { + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name +} + +var _ Querier = &DBQuerier{} + +// genericConn is a connection to a Postgres database. This is usually backed by +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +type genericConn interface { + // Query executes sql with args. If there is an error the returned Rows will + // be returned in an error state. So it is allowed to ignore the error + // returned from Query and handle it in Rows. + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + + // QueryRow is a convenience wrapper over Query. Any error that occurs while + // querying is deferred until calling Scan on the returned Row. That Row will + // error with pgx.ErrNoRows if no rows are returned. + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + + // Exec executes sql. sql can be either a prepared statement name or an SQL + // string. arguments should be referenced positionally from the sql string + // as $1, $2, etc. + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) +} + +// genericBatch batches queries to send in a single network request to a +// Postgres server. This is usually backed by *pgx.Batch. +type genericBatch interface { + // Queue queues a query to batch b. query can be an SQL query or the name of a + // prepared statement. See Queue on *pgx.Batch. + Queue(query string, arguments ...interface{}) +} + +// NewQuerier creates a DBQuerier that implements Querier. conn is typically +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerier(conn genericConn) *DBQuerier { + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead + // of pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its + // transitive dependencies, pggen will use the binary encoding format for + // the input parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} +} + +// WithTx creates a new DBQuerier that uses the transaction to run all queries. +func (q *DBQuerier) WithTx(tx pgx.Tx) (*DBQuerier, error) { + return &DBQuerier{conn: tx}, nil +} + +// preparer is any Postgres connection transport that provides a way to prepare +// a statement, most commonly *pgx.Conn. +type preparer interface { + Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) +} + +// PrepareAllQueries executes a PREPARE statement for all pggen generated SQL +// queries in querier files. Typical usage is as the AfterConnect callback +// for pgxpool.Config +// +// pgx will use the prepared statement if available. Calling PrepareAllQueries +// is an optional optimization to avoid a network round-trip the first time pgx +// runs a query if pgx statement caching is enabled. +func PrepareAllQueries(ctx context.Context, p preparer) error { + if _, err := p.Prepare(ctx, insertNumericSQL, insertNumericSQL); err != nil { + return fmt.Errorf("prepare query 'InsertNumeric': %w", err) + } + if _, err := p.Prepare(ctx, findNumericsSQL, findNumericsSQL); err != nil { + return fmt.Errorf("prepare query 'FindNumerics': %w", err) + } + return nil +} + +// NumericExternalType represents the Postgres composite type "numeric_external_type". +type NumericExternalType struct { + Num decimal.Decimal `json:"num"` +} + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val + } + // Okay to ignore error because it's only thrown when the number of field + // names does not equal the number of ValueTranscoders. + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ +} + +// newNumericExternalType creates a new pgtype.ValueTranscoder for the Postgres +// composite type 'numeric_external_type'. +func (tr *typeResolver) newNumericExternalType() pgtype.ValueTranscoder { + return tr.newCompositeValue( + "numeric_external_type", + compositeField{"num", "numeric", &pgtype.Numeric{}}, + ) +} + +// newNumericExternalTypeRaw returns all composite fields for the Postgres composite +// type 'numeric_external_type' as a slice of interface{} to encode query parameters. +func (tr *typeResolver) newNumericExternalTypeRaw(v NumericExternalType) []interface{} { + return []interface{}{ + v.Num, + } +} + +// newNumericExternalTypeArray creates a new pgtype.ValueTranscoder for the Postgres +// '_numeric_external_type' array type. +func (tr *typeResolver) newNumericExternalTypeArray() pgtype.ValueTranscoder { + return tr.newArrayValue("_numeric_external_type", "numeric_external_type", tr.newNumericExternalType) +} + +// newNumericExternalTypeArrayInit creates an initialized pgtype.ValueTranscoder for the +// Postgres array type '_numeric_external_type' to encode query parameters. +func (tr *typeResolver) newNumericExternalTypeArrayInit(ps []NumericExternalType) pgtype.ValueTranscoder { + dec := tr.newNumericExternalTypeArray() + if err := dec.Set(tr.newNumericExternalTypeArrayRaw(ps)); err != nil { + panic("encode []NumericExternalType: " + err.Error()) // should always succeed + } + return textPreferrer{ValueTranscoder: dec, typeName: "_numeric_external_type"} +} + +// newNumericExternalTypeArrayRaw returns all elements for the Postgres array type '_numeric_external_type' +// as a slice of interface{} for use with the pgtype.Value Set method. +func (tr *typeResolver) newNumericExternalTypeArrayRaw(vs []NumericExternalType) []interface{} { + elems := make([]interface{}, len(vs)) + for i, v := range vs { + elems[i] = tr.newNumericExternalTypeRaw(v) + } + return elems +} + +const insertNumericSQL = `INSERT INTO numeric_external (num, num_arr) +VALUES ($1, $2);` + +// InsertNumeric implements Querier.InsertNumeric. +func (q *DBQuerier) InsertNumeric(ctx context.Context, num decimal.Decimal, numArr []NumericExternalType) (pgconn.CommandTag, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "InsertNumeric") + cmdTag, err := q.conn.Exec(ctx, insertNumericSQL, num, q.types.newNumericExternalTypeArrayInit(numArr)) + if err != nil { + return cmdTag, fmt.Errorf("exec query InsertNumeric: %w", err) + } + return cmdTag, err +} + +// InsertNumericBatch implements Querier.InsertNumericBatch. +func (q *DBQuerier) InsertNumericBatch(batch genericBatch, num decimal.Decimal, numArr []NumericExternalType) { + batch.Queue(insertNumericSQL, num, q.types.newNumericExternalTypeArrayInit(numArr)) +} + +// InsertNumericScan implements Querier.InsertNumericScan. +func (q *DBQuerier) InsertNumericScan(results pgx.BatchResults) (pgconn.CommandTag, error) { + cmdTag, err := results.Exec() + if err != nil { + return cmdTag, fmt.Errorf("exec InsertNumericBatch: %w", err) + } + return cmdTag, err +} + +const findNumericsSQL = `SELECT num, num_arr +FROM numeric_external;` + +type FindNumericsRow struct { + Num decimal.Decimal `json:"num"` + NumArr []NumericExternalType `json:"num_arr"` +} + +// FindNumerics implements Querier.FindNumerics. +func (q *DBQuerier) FindNumerics(ctx context.Context) ([]FindNumericsRow, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "FindNumerics") + rows, err := q.conn.Query(ctx, findNumericsSQL) + if err != nil { + return nil, fmt.Errorf("query FindNumerics: %w", err) + } + defer rows.Close() + items := []FindNumericsRow{} + numArrArray := q.types.newNumericExternalTypeArray() + for rows.Next() { + var item FindNumericsRow + if err := rows.Scan(&item.Num, numArrArray); err != nil { + return nil, fmt.Errorf("scan FindNumerics row: %w", err) + } + if err := numArrArray.AssignTo(&item.NumArr); err != nil { + return nil, fmt.Errorf("assign FindNumerics row: %w", err) + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("close FindNumerics rows: %w", err) + } + return items, err +} + +// FindNumericsBatch implements Querier.FindNumericsBatch. +func (q *DBQuerier) FindNumericsBatch(batch genericBatch) { + batch.Queue(findNumericsSQL) +} + +// FindNumericsScan implements Querier.FindNumericsScan. +func (q *DBQuerier) FindNumericsScan(results pgx.BatchResults) ([]FindNumericsRow, error) { + rows, err := results.Query() + if err != nil { + return nil, fmt.Errorf("query FindNumericsBatch: %w", err) + } + defer rows.Close() + items := []FindNumericsRow{} + numArrArray := q.types.newNumericExternalTypeArray() + for rows.Next() { + var item FindNumericsRow + if err := rows.Scan(&item.Num, numArrArray); err != nil { + return nil, fmt.Errorf("scan FindNumericsBatch row: %w", err) + } + if err := numArrArray.AssignTo(&item.NumArr); err != nil { + return nil, fmt.Errorf("assign FindNumerics row: %w", err) + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("close FindNumericsBatch rows: %w", err) + } + return items, err +} + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/numeric_external/query.sql_test.go b/example/numeric_external/query.sql_test.go new file mode 100644 index 0000000..9c72138 --- /dev/null +++ b/example/numeric_external/query.sql_test.go @@ -0,0 +1,59 @@ +package numeric_external + +import ( + "context" + "github.com/google/go-cmp/cmp" + "github.com/jackc/pgtype" + "github.com/jschaf/pggen/internal/pgtest" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + + shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" + "testing" +) + +func TestNewQuerier_FindNumerics(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) + defer cleanup() + + conn.ConnInfo() + q := NewQuerierConfig(conn, QuerierConfig{ + DataTypes: []pgtype.DataType{ + { + Value: &shopspring.Numeric{}, + Name: "numeric", + OID: pgtype.NumericOID, + }, + }, + }) + _, err := q.InsertNumeric(context.Background(), decimal.New(10, 0), []NumericExternalType{ + {Num: decimal.New(11, 0)}, + }) + require.NoError(t, err) + _, err = q.InsertNumeric(context.Background(), decimal.New(20, 0), []NumericExternalType{ + {Num: decimal.New(21, 0)}, + {Num: decimal.New(22, 0)}, + }) + require.NoError(t, err) + + t.Run("FindNumerics", func(t *testing.T) { + rows, err := q.FindNumerics(context.Background()) + require.NoError(t, err) + want := []FindNumericsRow{ + { + Num: decimal.New(10, 0), + NumArr: []NumericExternalType{{Num: decimal.New(11, 0)}}, + }, + { + Num: decimal.New(20, 0), + NumArr: []NumericExternalType{ + {Num: decimal.New(21, 0)}, + {Num: decimal.New(22, 0)}, + }, + }, + } + if diff := cmp.Diff(want, rows); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + }) +} diff --git a/example/numeric_external/schema.sql b/example/numeric_external/schema.sql new file mode 100644 index 0000000..0f034d2 --- /dev/null +++ b/example/numeric_external/schema.sql @@ -0,0 +1,8 @@ +CREATE TYPE numeric_external_type AS ( + num numeric(8, 2) +); + +CREATE TABLE numeric_external ( + num numeric(10, 6), + num_arr numeric_external_type[] +); \ No newline at end of file diff --git a/go.mod b/go.mod index 9cd3b20..94fa676 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.1 // indirect github.com/peterbourgon/ff/v3 v3.0.0 + github.com/shopspring/decimal v1.2.0 // indirect github.com/stretchr/testify v1.5.1 go.uber.org/multierr v1.5.0 go.uber.org/zap v1.13.0 diff --git a/go.sum b/go.sum index 6d47d6f..8dd15e6 100644 --- a/go.sum +++ b/go.sum @@ -342,6 +342,8 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc h1:jUIKcSPO9MoMJBbEoyE/RJoE8vz7Mb8AjvifMMwSyvY= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= diff --git a/internal/codegen/common.go b/internal/codegen/common.go index 58b9612..9a15413 100644 --- a/internal/codegen/common.go +++ b/internal/codegen/common.go @@ -6,7 +6,7 @@ import ( "github.com/jschaf/pggen/internal/pginfer" ) -// QueryFile represents all of the SQL queries from a single file. +// QueryFile represents all SQL queries from a single file. type QueryFile struct { SourcePath string // absolute path to the source SQL query file Queries []pginfer.TypedQuery // the typed queries