diff --git a/example/acceptance_test.go b/example/acceptance_test.go index 74212c6..47b9e51 100644 --- a/example/acceptance_test.go +++ b/example/acceptance_test.go @@ -119,6 +119,14 @@ func TestExamples(t *testing.T) { "--go-type", "tenant_id=int", }, }, + { + name: "example/function", + args: []string{ + "--schema-glob", "example/function/schema.sql", + "--query-glob", "example/function/query.sql", + "--go-type", "hstore=map[string]string", + }, + }, { name: "example/go_pointer_types", args: []string{ diff --git a/example/composite/query.sql_test.go b/example/composite/query.sql_test.go index 1580b87..b7c79d0 100644 --- a/example/composite/query.sql_test.go +++ b/example/composite/query.sql_test.go @@ -96,9 +96,9 @@ func TestNewQuerier_ArraysInput(t *testing.T) { t.Run("ArraysInput", func(t *testing.T) { want := Arrays{ Texts: []string{"foo", "bar"}, - Int8s: []*int{ptrs.NewInt(1), ptrs.NewInt(2), ptrs.NewInt(3)}, + Int8s: []*int{ptrs.Int(1), ptrs.Int(2), ptrs.Int(3)}, Bools: []bool{true, true, false}, - Floats: []*float64{ptrs.NewFloat64(33.3), ptrs.NewFloat64(66.6)}, + Floats: []*float64{ptrs.Float64(33.3), ptrs.Float64(66.6)}, } got, err := q.ArraysInput(context.Background(), want) require.NoError(t, err) diff --git a/example/function/codegen_test.go b/example/function/codegen_test.go new file mode 100644 index 0000000..feb67c0 --- /dev/null +++ b/example/function/codegen_test.go @@ -0,0 +1,44 @@ +package function + +import ( + "github.com/jschaf/pggen" + "github.com/jschaf/pggen/internal/pgtest" + "github.com/stretchr/testify/assert" + "io/ioutil" + "path/filepath" + "testing" +) + +func TestGenerate_Go_Example_Function(t *testing.T) { + conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) + defer cleanupFunc() + + tmpDir := t.TempDir() + err := pggen.Generate( + pggen.GenerateOptions{ + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "function", + Language: pggen.LangGo, + }) + if err != nil { + t.Fatalf("Generate() example/function: %s", err) + } + + wantQueriesFile := "query.sql.go" + gotQueriesFile := filepath.Join(tmpDir, "query.sql.go") + assert.FileExists(t, gotQueriesFile, + "Generate() should emit query.sql.go") + wantQueries, err := ioutil.ReadFile(wantQueriesFile) + if err != nil { + t.Fatalf("read wanted query.go.sql: %s", err) + } + gotQueries, err := ioutil.ReadFile(gotQueriesFile) + if err != nil { + t.Fatalf("read generated query.go.sql: %s", err) + } + assert.Equalf(t, string(wantQueries), string(gotQueries), + "Got file %s; does not match contents of %s", + gotQueriesFile, wantQueriesFile) +} diff --git a/example/function/query.sql b/example/function/query.sql new file mode 100644 index 0000000..47a35ad --- /dev/null +++ b/example/function/query.sql @@ -0,0 +1,2 @@ +-- name: OutParams :many +SELECT * FROM out_params(); diff --git a/example/function/query.sql.go b/example/function/query.sql.go new file mode 100644 index 0000000..e04dcf2 --- /dev/null +++ b/example/function/query.sql.go @@ -0,0 +1,327 @@ +// Code generated by pggen. DO NOT EDIT. + +package function + +import ( + "context" + "fmt" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +// Querier is a typesafe Go interface backed by SQL queries. +// +// Methods ending with Batch enqueue a query to run later in a pgx.Batch. After +// calling SendBatch on pgx.Conn, pgxpool.Pool, or pgx.Tx, use the Scan methods +// to parse the results. +type Querier interface { + OutParams(ctx context.Context) ([]OutParamsRow, error) + // OutParamsBatch enqueues a OutParams query into batch to be executed + // later by the batch. + OutParamsBatch(batch genericBatch) + // OutParamsScan scans the result of an executed OutParamsBatch query. + OutParamsScan(results pgx.BatchResults) ([]OutParamsRow, error) +} + +type DBQuerier struct { + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name +} + +var _ Querier = &DBQuerier{} + +// genericConn is a connection to a Postgres database. This is usually backed by +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +type genericConn interface { + // Query executes sql with args. If there is an error the returned Rows will + // be returned in an error state. So it is allowed to ignore the error + // returned from Query and handle it in Rows. + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + + // QueryRow is a convenience wrapper over Query. Any error that occurs while + // querying is deferred until calling Scan on the returned Row. That Row will + // error with pgx.ErrNoRows if no rows are returned. + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + + // Exec executes sql. sql can be either a prepared statement name or an SQL + // string. arguments should be referenced positionally from the sql string + // as $1, $2, etc. + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) +} + +// genericBatch batches queries to send in a single network request to a +// Postgres server. This is usually backed by *pgx.Batch. +type genericBatch interface { + // Queue queues a query to batch b. query can be an SQL query or the name of a + // prepared statement. See Queue on *pgx.Batch. + Queue(query string, arguments ...interface{}) +} + +// NewQuerier creates a DBQuerier that implements Querier. conn is typically +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerier(conn genericConn) *DBQuerier { + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead + // of pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its + // transitive dependencies, pggen will use the binary encoding format for + // the input parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} +} + +// WithTx creates a new DBQuerier that uses the transaction to run all queries. +func (q *DBQuerier) WithTx(tx pgx.Tx) (*DBQuerier, error) { + return &DBQuerier{conn: tx}, nil +} + +// preparer is any Postgres connection transport that provides a way to prepare +// a statement, most commonly *pgx.Conn. +type preparer interface { + Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) +} + +// PrepareAllQueries executes a PREPARE statement for all pggen generated SQL +// queries in querier files. Typical usage is as the AfterConnect callback +// for pgxpool.Config +// +// pgx will use the prepared statement if available. Calling PrepareAllQueries +// is an optional optimization to avoid a network round-trip the first time pgx +// runs a query if pgx statement caching is enabled. +func PrepareAllQueries(ctx context.Context, p preparer) error { + if _, err := p.Prepare(ctx, outParamsSQL, outParamsSQL); err != nil { + return fmt.Errorf("prepare query 'OutParams': %w", err) + } + return nil +} + +// ListItem represents the Postgres composite type "list_item". +type ListItem struct { + Name *string `json:"name"` + Color *string `json:"color"` +} + +// ListStats represents the Postgres composite type "list_stats". +type ListStats struct { + Val1 *string `json:"val1"` + Val2 []*int32 `json:"val2"` +} + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +type compositeField struct { + name string // name of the field + typeName string // Postgres type name + defaultVal pgtype.ValueTranscoder // default value to use +} + +func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + fs := make([]pgtype.CompositeTypeField, len(fields)) + vals := make([]pgtype.ValueTranscoder, len(fields)) + isBinaryOk := true + for i, field := range fields { + oid, val, ok := tr.findValue(field.typeName) + if !ok { + oid = unknownOID + val = field.defaultVal + } + isBinaryOk = isBinaryOk && oid != unknownOID + fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} + vals[i] = val + } + // Okay to ignore error because it's only thrown when the number of field + // names does not equal the number of ValueTranscoders. + typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) + if !isBinaryOk { + return textPreferrer{typ, name} + } + return typ +} + +func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { + if _, val, ok := tr.findValue(name); ok { + return val + } + elemOID, elemVal, ok := tr.findValue(elemName) + elemValFunc := func() pgtype.ValueTranscoder { + return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) + } + if !ok { + elemOID = unknownOID + elemValFunc = defaultVal + } + typ := pgtype.NewArrayType(name, elemOID, elemValFunc) + if elemOID == unknownOID { + return textPreferrer{typ, name} + } + return typ +} + +// newListItem creates a new pgtype.ValueTranscoder for the Postgres +// composite type 'list_item'. +func (tr *typeResolver) newListItem() pgtype.ValueTranscoder { + return tr.newCompositeValue( + "list_item", + compositeField{"name", "text", &pgtype.Text{}}, + compositeField{"color", "text", &pgtype.Text{}}, + ) +} + +// newListStats creates a new pgtype.ValueTranscoder for the Postgres +// composite type 'list_stats'. +func (tr *typeResolver) newListStats() pgtype.ValueTranscoder { + return tr.newCompositeValue( + "list_stats", + compositeField{"val1", "text", &pgtype.Text{}}, + compositeField{"val2", "_int4", &pgtype.Int4Array{}}, + ) +} + +// newListItemArray creates a new pgtype.ValueTranscoder for the Postgres +// '_list_item' array type. +func (tr *typeResolver) newListItemArray() pgtype.ValueTranscoder { + return tr.newArrayValue("_list_item", "list_item", tr.newListItem) +} + +const outParamsSQL = `SELECT * FROM out_params();` + +type OutParamsRow struct { + Items []ListItem `json:"_items"` + Stats ListStats `json:"_stats"` +} + +// OutParams implements Querier.OutParams. +func (q *DBQuerier) OutParams(ctx context.Context) ([]OutParamsRow, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "OutParams") + rows, err := q.conn.Query(ctx, outParamsSQL) + if err != nil { + return nil, fmt.Errorf("query OutParams: %w", err) + } + defer rows.Close() + items := []OutParamsRow{} + itemsArray := q.types.newListItemArray() + statsRow := q.types.newListStats() + for rows.Next() { + var item OutParamsRow + if err := rows.Scan(itemsArray, statsRow); err != nil { + return nil, fmt.Errorf("scan OutParams row: %w", err) + } + if err := itemsArray.AssignTo(&item.Items); err != nil { + return nil, fmt.Errorf("assign OutParams row: %w", err) + } + if err := statsRow.AssignTo(&item.Stats); err != nil { + return nil, fmt.Errorf("assign OutParams row: %w", err) + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("close OutParams rows: %w", err) + } + return items, err +} + +// OutParamsBatch implements Querier.OutParamsBatch. +func (q *DBQuerier) OutParamsBatch(batch genericBatch) { + batch.Queue(outParamsSQL) +} + +// OutParamsScan implements Querier.OutParamsScan. +func (q *DBQuerier) OutParamsScan(results pgx.BatchResults) ([]OutParamsRow, error) { + rows, err := results.Query() + if err != nil { + return nil, fmt.Errorf("query OutParamsBatch: %w", err) + } + defer rows.Close() + items := []OutParamsRow{} + itemsArray := q.types.newListItemArray() + statsRow := q.types.newListStats() + for rows.Next() { + var item OutParamsRow + if err := rows.Scan(itemsArray, statsRow); err != nil { + return nil, fmt.Errorf("scan OutParamsBatch row: %w", err) + } + if err := itemsArray.AssignTo(&item.Items); err != nil { + return nil, fmt.Errorf("assign OutParams row: %w", err) + } + if err := statsRow.AssignTo(&item.Stats); err != nil { + return nil, fmt.Errorf("assign OutParams row: %w", err) + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("close OutParamsBatch rows: %w", err) + } + return items, err +} + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/function/query.sql_test.go b/example/function/query.sql_test.go new file mode 100644 index 0000000..edeb3c8 --- /dev/null +++ b/example/function/query.sql_test.go @@ -0,0 +1,33 @@ +package function + +import ( + "context" + "github.com/jschaf/pggen/internal/difftest" + "github.com/jschaf/pggen/internal/ptrs" + "github.com/stretchr/testify/require" + "testing" + + "github.com/jschaf/pggen/internal/pgtest" +) + +func TestNewQuerier_OutParams(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) + defer cleanup() + + q := NewQuerier(conn) + + t.Run("OutParams", func(t *testing.T) { + got, err := q.OutParams(context.Background()) + require.NoError(t, err) + want := []OutParamsRow{ + { + Items: []ListItem{{Name: ptrs.String("some_name"), Color: ptrs.String("some_color")}}, + Stats: ListStats{ + Val1: ptrs.String("abc"), + Val2: []*int32{ptrs.Int32(1), ptrs.Int32(2)}, + }, + }, + } + difftest.AssertSame(t, want, got) + }) +} diff --git a/example/function/schema.sql b/example/function/schema.sql new file mode 100644 index 0000000..651c2cc --- /dev/null +++ b/example/function/schema.sql @@ -0,0 +1,20 @@ +CREATE TYPE list_item AS ( + name text, + color text +); + +CREATE TYPE list_stats AS ( + val1 text, + val2 int[] +); + +CREATE OR REPLACE FUNCTION out_params( + OUT _items list_item[], + OUT _stats list_stats +) + LANGUAGE plpgsql AS $$ +BEGIN + _items := ARRAY [('some_name', 'some_color')::list_item]; + _stats := ('abc', ARRAY [1, 2])::list_stats; +END +$$; diff --git a/internal/ptrs/ptrs.go b/internal/ptrs/ptrs.go index 6fd45d8..b5e934d 100644 --- a/internal/ptrs/ptrs.go +++ b/internal/ptrs/ptrs.go @@ -1,5 +1,6 @@ package ptrs -func NewInt(n int) *int { return &n } - -func NewFloat64(f float64) *float64 { return &f } +func Int(n int) *int { return &n } +func Int32(n int32) *int32 { return &n } +func Float64(f float64) *float64 { return &f } +func String(s string) *string { return &s }