Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: go type mapping per query argument #73

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions internal/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,16 @@ type (

// An SourceQuery node represents a query entry from the source code.
SourceQuery struct {
Name string // name of the query
Doc *CommentGroup // associated documentation; or nil
Start gotok.Pos // position of the start token, like 'SELECT' or 'UPDATE'
SourceSQL string // the complete sql query as it appeared in the source file
PreparedSQL string // the sql query with args replaced by $1, $2, etc.
ParamNames []string // the name of each param in the PreparedSQL, the nth entry is the $n+1 param
ResultKind ResultKind // the result output type
Pragmas Pragmas // optional query options
Semi gotok.Pos // position of the closing semicolon
Name string // name of the query
Doc *CommentGroup // associated documentation; or nil
Start gotok.Pos // position of the start token, like 'SELECT' or 'UPDATE'
SourceSQL string // the complete sql query as it appeared in the source file
PreparedSQL string // the sql query with args replaced by $1, $2, etc.
ParamNames []string // the name of each param in the PreparedSQL, the nth entry is the $n+1 param
ParamTypeOverrides map[string]string // map of Go type to override the Pg type of a param
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this can be ParamGoTypes with a doc of something like:

map of user-specified Go types to use for the arg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, and for the return value I would then name it accordingly:

ResultGoTypes map[string]string // map of user-specified Go type for the result columns.

ResultKind ResultKind // the result output type
Pragmas Pragmas // optional query options
Semi gotok.Pos // position of the closing semicolon
}
)

Expand Down
11 changes: 8 additions & 3 deletions internal/codegen/golang/templater.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,15 @@ func (tm Templater) templateFile(file codegen.QueryFile, isLeader bool) (Templat
// Build inputs.
inputs := make([]TemplatedParam, len(query.Inputs))
for i, input := range query.Inputs {
goType, err := tm.resolver.Resolve(input.PgType /*nullable*/, false, pkgPath)
if err != nil {
return TemplatedFile{}, nil, err
goType := input.TypeOverride
var err error
if goType == nil { // no custom arg type defined
goType, err = tm.resolver.Resolve(input.PgType /*nullable*/, false, pkgPath)
if err != nil {
return TemplatedFile{}, nil, err
}
}

imports.AddType(goType)
inputs[i] = TemplatedParam{
UpperName: tm.chooseUpperName(input.PgName, "UnnamedParam", i, len(query.Inputs)),
Expand Down
55 changes: 45 additions & 10 deletions internal/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ func (p *parser) errorExpected(pos gotok.Pos, msg string) {

// Regexp to extract query annotations that control output.
var annotationRegexp = regexp.MustCompile(`name: ([a-zA-Z0-9_$]+)[ \t]+(:many|:one|:exec)[ \t]*(.*)`)
var annotationArgRegexp = regexp.MustCompile(`arg: ([a-zA-Z_][[a-zA-Z0-9_]*)[ \t]+(.*[a-zA-Z_][[a-zA-Z0-9_$]*)[ \t]*(.*)`)

func (p *parser) parseQuery() ast.Query {
if p.trace {
Expand Down Expand Up @@ -246,7 +247,40 @@ func (p *parser) parseQuery() ast.Query {
p.error(pos, "no comment preceding query")
return &ast.BadQuery{From: pos, To: p.pos}
}
last := doc.List[len(doc.List)-1]

paramTypeOverrides := make(map[string]string, 4)
annotationPos := len(doc.List) - 1

for ; annotationPos > 0; annotationPos-- {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit messier than I'm comfortable with. Might be time to do a proper parse, something like:

  • split the line on whitespace
  • parse each token

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was taking insipiration from how the name: :kind annotation was parsed but I agree it is a bit messy and can be improved. (N.B.)I've moved the parsing of go types into it's own function to clean things up already. I just need to push it.

Can you elaborate on what you mean by proper parser? Do you want a custom built DSL and parser or do you have something in mind to be reused?
One idea would be to find the name: line, then parse everything that's afterwards in a known language like TOML (and we can enforce the structure of it). This would have the advantage to reusue a known configuration language and libs to parse it. However while it's close to the CLI it is a bit different, notably spaces around the equal are allowed and the values (i.e. go types) have to be quoted.

Which way would you like to go?

Copy link
Contributor Author

@0xjac 0xjac Aug 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To give an example of what I mean with TOML, it would work like this:

-- name: GetDailyFlightsFromAircraft :many
-- [arg]
-- day = "github.com/0xjac/custom-project/types.Day"
-- [return]
-- departure = "time.Time"
-- eta = "*time.Time"
SELECT f.flight_number, f.departure, f.arrival, f.eta
FROM flights f
WHERE pggen.arg("day") <= f.departure AND f.departure < pggen.arg("day") + 1
ORDER BY f.departure DESC;

We could also add the name and kind as top-level keys but it should still support then current name:  notation for compatibility.

argAnnotations := annotationArgRegexp.FindStringSubmatch(doc.List[annotationPos].Text)
if argAnnotations == nil {
break
}

argName, argGoType := argAnnotations[1], argAnnotations[2]

if _, present := paramTypeOverrides[argName]; present {
p.error(pos, "duplicate arg type override for "+argName+": "+argAnnotations[0])
return &ast.BadQuery{From: pos, To: p.pos}
}

unknownArg := true
for _, arg := range names {
if argName == arg.name {
unknownArg = false
break
}
}
if unknownArg {
p.error(pos, "arg type override for unknown arg "+argName+": "+argAnnotations[0])
return &ast.BadQuery{From: pos, To: p.pos}
}

paramTypeOverrides[argName] = argGoType

}

last := doc.List[annotationPos]
annotations := annotationRegexp.FindStringSubmatch(last.Text)
if annotations == nil {
p.error(pos, "no 'name: <name> :<type>' token found in comment before query; comment line: \""+last.Text+`"`)
Expand All @@ -263,15 +297,16 @@ func (p *parser) parseQuery() ast.Query {
preparedSQL, params := prepareSQL(templateSQL, names)

return &ast.SourceQuery{
Name: annotations[1],
Doc: doc,
Start: pos,
SourceSQL: templateSQL,
PreparedSQL: preparedSQL,
ParamNames: params,
ResultKind: ast.ResultKind(annotations[2]),
Pragmas: pragmas,
Semi: semi,
Name: annotations[1],
Doc: doc,
Start: pos,
SourceSQL: templateSQL,
PreparedSQL: preparedSQL,
ParamNames: params,
ParamTypeOverrides: paramTypeOverrides,
ResultKind: ast.ResultKind(annotations[2]),
Pragmas: pragmas,
Semi: semi,
}
}

Expand Down
17 changes: 14 additions & 3 deletions internal/pginfer/pginfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
"github.com/jschaf/pggen/internal/ast"
"github.com/jschaf/pggen/internal/codegen/golang/gotype"
"github.com/jschaf/pggen/internal/errs"
"github.com/jschaf/pggen/internal/pg"
)
Expand Down Expand Up @@ -49,6 +50,8 @@ type InputParam struct {
DefaultVal string
// The postgres type of this param as reported by Postgres.
PgType pg.Type
// Fully qualified Go type to override the Go type for the input Pg type.
TypeOverride gotype.Type
}

// OutputColumn is a single column output from a select query or returning
Expand Down Expand Up @@ -149,10 +152,18 @@ func (inf *Inferrer) inferInputTypes(query *ast.SourceQuery) (ps []InputParam, m
params := make([]InputParam, len(query.ParamNames))
for i := 0; i < len(params); i++ {
pgType := types[pgtype.OID(oids[i])]
var goType gotype.Type
if overrideType, present := query.ParamTypeOverrides[query.ParamNames[i]]; present {
goType, err = gotype.ParseOpaqueType(overrideType, pgType)
if err != nil {
return nil, fmt.Errorf("resolve custom arg type: %w", err)
}
}
params[i] = InputParam{
PgName: query.ParamNames[i],
DefaultVal: "",
PgType: pgType,
PgName: query.ParamNames[i],
DefaultVal: "",
PgType: pgType,
TypeOverride: goType,
}
}
return params, nil
Expand Down