From ff474e27b960399aec5ffa0cede43c135d6caa3e Mon Sep 17 00:00:00 2001 From: Jacques Dafflon <0xjac+git@truelevel.ch> Date: Thu, 18 Aug 2022 12:35:43 +0200 Subject: [PATCH] feat: go type mapping per query argument Extends the query annotation to specify custom go types using the same qualified go type format as for the general custom type mapping. Closes #46 --- internal/ast/ast.go | 19 +++++----- internal/codegen/golang/templater.go | 11 ++++-- internal/parser/parser.go | 55 +++++++++++++++++++++++----- internal/pginfer/pginfer.go | 17 +++++++-- 4 files changed, 77 insertions(+), 25 deletions(-) diff --git a/internal/ast/ast.go b/internal/ast/ast.go index a4a2897..6b2ddc5 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -88,15 +88,16 @@ type ( // An SourceQuery node represents a query entry from the source code. SourceQuery struct { - Name string // name of the query - Doc *CommentGroup // associated documentation; or nil - Start gotok.Pos // position of the start token, like 'SELECT' or 'UPDATE' - SourceSQL string // the complete sql query as it appeared in the source file - PreparedSQL string // the sql query with args replaced by $1, $2, etc. - ParamNames []string // the name of each param in the PreparedSQL, the nth entry is the $n+1 param - ResultKind ResultKind // the result output type - Pragmas Pragmas // optional query options - Semi gotok.Pos // position of the closing semicolon + Name string // name of the query + Doc *CommentGroup // associated documentation; or nil + Start gotok.Pos // position of the start token, like 'SELECT' or 'UPDATE' + SourceSQL string // the complete sql query as it appeared in the source file + PreparedSQL string // the sql query with args replaced by $1, $2, etc. + ParamNames []string // the name of each param in the PreparedSQL, the nth entry is the $n+1 param + ParamTypeOverrides map[string]string // map of Go type to override the Pg type of a param + ResultKind ResultKind // the result output type + Pragmas Pragmas // optional query options + Semi gotok.Pos // position of the closing semicolon } ) diff --git a/internal/codegen/golang/templater.go b/internal/codegen/golang/templater.go index 8a1e603..6954c65 100644 --- a/internal/codegen/golang/templater.go +++ b/internal/codegen/golang/templater.go @@ -144,10 +144,15 @@ func (tm Templater) templateFile(file codegen.QueryFile, isLeader bool) (Templat // Build inputs. inputs := make([]TemplatedParam, len(query.Inputs)) for i, input := range query.Inputs { - goType, err := tm.resolver.Resolve(input.PgType /*nullable*/, false, pkgPath) - if err != nil { - return TemplatedFile{}, nil, err + goType := input.TypeOverride + var err error + if goType == nil { // no custom arg type defined + goType, err = tm.resolver.Resolve(input.PgType /*nullable*/, false, pkgPath) + if err != nil { + return TemplatedFile{}, nil, err + } } + imports.AddType(goType) inputs[i] = TemplatedParam{ UpperName: tm.chooseUpperName(input.PgName, "UnnamedParam", i, len(query.Inputs)), diff --git a/internal/parser/parser.go b/internal/parser/parser.go index 01646d3..97bdabe 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -203,6 +203,7 @@ func (p *parser) errorExpected(pos gotok.Pos, msg string) { // Regexp to extract query annotations that control output. var annotationRegexp = regexp.MustCompile(`name: ([a-zA-Z0-9_$]+)[ \t]+(:many|:one|:exec)[ \t]*(.*)`) +var annotationArgRegexp = regexp.MustCompile(`arg: ([a-zA-Z_][[a-zA-Z0-9_]*)[ \t]+(.*[a-zA-Z_][[a-zA-Z0-9_$]*)[ \t]*(.*)`) func (p *parser) parseQuery() ast.Query { if p.trace { @@ -246,7 +247,40 @@ func (p *parser) parseQuery() ast.Query { p.error(pos, "no comment preceding query") return &ast.BadQuery{From: pos, To: p.pos} } - last := doc.List[len(doc.List)-1] + + paramTypeOverrides := make(map[string]string, 4) + annotationPos := len(doc.List) - 1 + + for ; annotationPos > 0; annotationPos-- { + argAnnotations := annotationArgRegexp.FindStringSubmatch(doc.List[annotationPos].Text) + if argAnnotations == nil { + break + } + + argName, argGoType := argAnnotations[1], argAnnotations[2] + + if _, present := paramTypeOverrides[argName]; present { + p.error(pos, "duplicate arg type override for "+argName+": "+argAnnotations[0]) + return &ast.BadQuery{From: pos, To: p.pos} + } + + unknownArg := true + for _, arg := range names { + if argName == arg.name { + unknownArg = false + break + } + } + if unknownArg { + p.error(pos, "arg type override for unknown arg "+argName+": "+argAnnotations[0]) + return &ast.BadQuery{From: pos, To: p.pos} + } + + paramTypeOverrides[argName] = argGoType + + } + + last := doc.List[annotationPos] annotations := annotationRegexp.FindStringSubmatch(last.Text) if annotations == nil { p.error(pos, "no 'name: :' token found in comment before query; comment line: \""+last.Text+`"`) @@ -263,15 +297,16 @@ func (p *parser) parseQuery() ast.Query { preparedSQL, params := prepareSQL(templateSQL, names) return &ast.SourceQuery{ - Name: annotations[1], - Doc: doc, - Start: pos, - SourceSQL: templateSQL, - PreparedSQL: preparedSQL, - ParamNames: params, - ResultKind: ast.ResultKind(annotations[2]), - Pragmas: pragmas, - Semi: semi, + Name: annotations[1], + Doc: doc, + Start: pos, + SourceSQL: templateSQL, + PreparedSQL: preparedSQL, + ParamNames: params, + ParamTypeOverrides: paramTypeOverrides, + ResultKind: ast.ResultKind(annotations[2]), + Pragmas: pragmas, + Semi: semi, } } diff --git a/internal/pginfer/pginfer.go b/internal/pginfer/pginfer.go index 8a7e304..2cd5ef7 100644 --- a/internal/pginfer/pginfer.go +++ b/internal/pginfer/pginfer.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/jschaf/pggen/internal/ast" + "github.com/jschaf/pggen/internal/codegen/golang/gotype" "github.com/jschaf/pggen/internal/errs" "github.com/jschaf/pggen/internal/pg" ) @@ -49,6 +50,8 @@ type InputParam struct { DefaultVal string // The postgres type of this param as reported by Postgres. PgType pg.Type + // Fully qualified Go type to override the Go type for the input Pg type. + TypeOverride gotype.Type } // OutputColumn is a single column output from a select query or returning @@ -149,10 +152,18 @@ func (inf *Inferrer) inferInputTypes(query *ast.SourceQuery) (ps []InputParam, m params := make([]InputParam, len(query.ParamNames)) for i := 0; i < len(params); i++ { pgType := types[pgtype.OID(oids[i])] + var goType gotype.Type + if overrideType, present := query.ParamTypeOverrides[query.ParamNames[i]]; present { + goType, err = gotype.ParseOpaqueType(overrideType, pgType) + if err != nil { + return nil, fmt.Errorf("resolve custom arg type: %w", err) + } + } params[i] = InputParam{ - PgName: query.ParamNames[i], - DefaultVal: "", - PgType: pgType, + PgName: query.ParamNames[i], + DefaultVal: "", + PgType: pgType, + TypeOverride: goType, } } return params, nil