Skip to content

Commit c873256

Browse files
committed
pggen: distinguish between direct and aggregated args
1 parent f02d6c8 commit c873256

File tree

1 file changed

+53
-20
lines changed

1 file changed

+53
-20
lines changed

internal/tools/sqlc-pg-gen/main.go

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,26 @@ import (
2020
// https://stackoverflow.com/questions/25308765/postgresql-how-can-i-inspect-which-arguments-to-a-procedure-have-a-default-valu
2121
const catalogFuncs = `
2222
SELECT p.proname as name,
23-
format_type(p.prorettype, NULL),
24-
array(select format_type(unnest(p.proargtypes), NULL)),
25-
p.proargnames,
26-
p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs]
23+
format_type(p.prorettype, NULL) as return_type,
24+
arg_types.direct_arg_types,
25+
arg_types.aggregated_arg_types,
26+
p.proargnames as arg_names,
27+
p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs] as has_default,
28+
p.prokind::text as kind
2729
FROM pg_catalog.pg_proc p
2830
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
31+
LEFT JOIN pg_catalog.pg_aggregate agg on agg.aggfnoid = p.oid and p.prokind = 'a'
32+
CROSS JOIN LATERAL (
33+
select
34+
array_agg(format_type(arg_type, NULL)) filter (where agg is null OR argn <= agg.aggnumdirectargs),
35+
array_agg(format_type(arg_type, NULL)) filter (where argn > agg.aggnumdirectargs)
36+
from unnest(p.proargtypes) WITH ORDINALITY as arg_type(arg_type, argn)
37+
) arg_types(direct_arg_types, aggregated_arg_types)
2938
WHERE n.nspname OPERATOR(pg_catalog.~) '^(pg_catalog)$'
3039
AND p.proargmodes IS NULL
3140
AND pg_function_is_visible(p.oid)
3241
-- The order isn't too important - just that it is stable between runs
33-
ORDER BY 1, 2, 3, 4, 5;
42+
ORDER BY 1, 2, 3, 4, 5, 6, 7
3443
`
3544

3645
// Relations are the relations available in pg_tables and pg_views
@@ -76,16 +85,25 @@ WITH extension_funcs AS (
7685
WHERE d.deptype = 'e' AND e.extname = $1
7786
)
7887
SELECT p.proname as name,
79-
format_type(p.prorettype, NULL),
80-
array(select format_type(unnest(p.proargtypes), NULL)),
81-
p.proargnames,
82-
p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs]
88+
format_type(p.prorettype, NULL) as return_type,
89+
arg_types.direct_arg_types,
90+
arg_types.aggregated_arg_types,
91+
p.proargnames as arg_names,
92+
p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs] as has_default,
93+
p.prokind::text as kind
8394
FROM pg_catalog.pg_proc p
8495
JOIN extension_funcs ef ON ef.oid = p.oid
96+
LEFT JOIN pg_catalog.pg_aggregate agg on agg.aggfnoid = p.oid and p.prokind = 'a'
97+
CROSS JOIN LATERAL (
98+
select
99+
array_agg(format_type(arg_type, NULL)) filter (where agg is null OR argn <= agg.aggnumdirectargs),
100+
array_agg(format_type(arg_type, NULL)) filter (where argn > agg.aggnumdirectargs)
101+
from unnest(p.proargtypes) WITH ORDINALITY as arg_type(arg_type, argn)
102+
) arg_types(direct_arg_types, aggregated_arg_types)
85103
WHERE p.proargmodes IS NULL
86104
AND pg_function_is_visible(p.oid)
87105
-- The order isn't too important - just that it is stable between runs
88-
ORDER BY 1, 2, 3, 4, 5;
106+
ORDER BY 1, 2, 3, 4, 5, 6, 7
89107
`
90108

91109
const catalogTmpl = `
@@ -196,12 +214,25 @@ func main() {
196214
}
197215
}
198216

217+
// ProcKind is the type that tells what type of routine the pg_proc is
218+
// This corresponds to the pg_proc.prokind column
219+
type ProcKind string
220+
221+
const (
222+
ProcKindNormal = ProcKind("f")
223+
ProcKindWindow = ProcKind("w")
224+
ProcKindAggregate = ProcKind("a")
225+
ProcKindProcedure = ProcKind("p")
226+
)
227+
199228
type Proc struct {
200-
Name string
201-
ReturnType string
202-
ArgTypes []string
203-
ArgNames []string
204-
HasDefault []string
229+
Name string
230+
ReturnType string
231+
DirectArgTypes []string
232+
AggregatedArgTypes []string
233+
ArgNames []string
234+
HasDefault []string
235+
Kind ProcKind
205236
}
206237

207238
func clean(arg string) string {
@@ -223,13 +254,13 @@ func (p Proc) Func() catalog.Function {
223254
func (p Proc) Args() []*catalog.Argument {
224255
defaults := map[string]bool{}
225256
var args []*catalog.Argument
226-
if len(p.ArgTypes) == 0 {
257+
if len(p.DirectArgTypes) == 0 {
227258
return args
228259
}
229260
for _, name := range p.HasDefault {
230261
defaults[name] = true
231262
}
232-
for i, arg := range p.ArgTypes {
263+
for i, arg := range p.DirectArgTypes {
233264
var name string
234265
if i < len(p.ArgNames) {
235266
name = p.ArgNames[i]
@@ -305,9 +336,11 @@ func scanFuncs(rows pgx.Rows) ([]catalog.Function, error) {
305336
err := rows.Scan(
306337
&p.Name,
307338
&p.ReturnType,
308-
&p.ArgTypes,
339+
&p.DirectArgTypes,
340+
&p.AggregatedArgTypes,
309341
&p.ArgNames,
310342
&p.HasDefault,
343+
&p.Kind,
311344
)
312345
if err != nil {
313346
return nil, err
@@ -328,8 +361,8 @@ func scanFuncs(rows pgx.Rows) ([]catalog.Function, error) {
328361
//
329362
// https://www.postgresql.org/docs/current/datatype-pseudo.html
330363
var skip bool
331-
for i := range p.ArgTypes {
332-
if p.ArgTypes[i] == "internal" {
364+
for i := range p.DirectArgTypes {
365+
if p.DirectArgTypes[i] == "internal" {
333366
skip = true
334367
}
335368
}

0 commit comments

Comments
 (0)