Skip to content

Commit

Permalink
Load types using a single SQL query
Browse files Browse the repository at this point in the history
When loading even a single type into pgx's type map, multiple SQL
queries are performed in series. Over a slow link, this is not ideal.
Worse, if multiple types are being registered, this is repeated multiple
times.

This commit add LoadTypes, which can retrieve type
mapping information for multiple types in a single SQL call, including
recursive fetching of dependent types.
RegisterTypes performs the second stage of this operation.
  • Loading branch information
nicois committed Jun 17, 2024
1 parent 9907b87 commit ca18af4
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 12 deletions.
259 changes: 256 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -107,8 +108,10 @@ var (
ErrTooManyRows = errors.New("too many rows in result set")
)

var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
var (
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
)

// Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details.
Expand Down Expand Up @@ -843,7 +846,6 @@ func (c *Conn) getStatementDescription(
mode QueryExecMode,
sql string,
) (sd *pgconn.StatementDescription, err error) {

switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
Expand Down Expand Up @@ -1393,3 +1395,254 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error

return nil
}

/*
buildLoadTypesSQL generates the correct query for retrieving type information.
pgVersion: the major version of the PostgreSQL server
typeNames: the names of the types to load. If nil, load all types.
*/
func buildLoadTypesSQL(pgVersion int64, typeNames []string) string {
supportsMultirange := (pgVersion >= 14)
var typeNamesClause string

if typeNames == nil {
// collect all types. Not currently recommended.
typeNamesClause = "IS NOT NULL"
} else {
typeNamesClause = "= ANY($1)"
}
parts := make([]string, 0, 10)

// Each of the type names provided might be found in pg_class or pg_type.
// Additionally, it may or may not include a schema portion.
parts = append(parts, `
WITH RECURSIVE
-- find the OIDs in pg_class which match one of the provided type names
selected_classes(oid,reltype) AS (
-- this query uses the namespace search path, so will match type names without a schema prefix
SELECT pg_class.oid, pg_class.reltype
FROM pg_catalog.pg_class
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
AND relname `, typeNamesClause, `
UNION ALL
-- this query will only match type names which include the schema prefix
SELECT pg_class.oid, pg_class.reltype
FROM pg_class
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
WHERE nspname || '.' || relname `, typeNamesClause, `
),
selected_types(oid) AS (
-- collect the OIDs from pg_types which correspond to the selected classes
SELECT reltype AS oid
FROM selected_classes
UNION ALL
-- as well as any other type names which match our criteria
SELECT oid
FROM pg_type
WHERE typname `, typeNamesClause, `
),
-- this builds a parent/child mapping of objects, allowing us to know
-- all the child (ie: dependent) types that a parent (type) requires
-- As can be seen, there are 3 ways this can occur (the last of which
-- is due to being a composite class, where the composite fields are children)
pc(parent, child) AS (
SELECT parent.oid, parent.typelem
FROM pg_type parent
WHERE parent.typtype = 'b' AND parent.typelem != 0
UNION ALL
SELECT parent.oid, parent.typbasetype
FROM pg_type parent
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
UNION ALL
SELECT pg_type.oid, atttypid
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
),
-- Now construct a recursive query which includes a 'depth' element.
-- This is used to ensure that the "youngest" children are registered before
-- their parents.
relationships(parent, child, depth) AS (
SELECT DISTINCT 0::OID, selected_types.oid, 0
FROM selected_types
UNION ALL
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
FROM selected_classes c
inner join pg_type ON (c.reltype = pg_type.oid)
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
UNION ALL
SELECT pc.parent, pc.child, relationships.depth + 1
FROM pc
INNER JOIN relationships ON (pc.parent = relationships.child)
),
-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration
composite AS (
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
GROUP BY pg_type.oid
)
-- Bring together this information, showing all the information which might possibly be required
-- to complete the registration, applying filters to only show the items which relate to the selected
-- types/classes.
SELECT typname,
typtype,
typbasetype,
typelem,
pg_type.oid,`)
if supportsMultirange {
parts = append(parts, `
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
} else {
parts = append(parts, `
0 AS rngtypid,`)
}
parts = append(parts, `
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
attnames, atttypids
FROM relationships
INNER JOIN pg_type ON (pg_type.oid IN ( relationships.child,relationships.parent) )
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
if supportsMultirange {
parts = append(parts, `
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
}

parts = append(parts, `
LEFT OUTER JOIN composite USING (oid)
WHERE NOT (typtype = 'b' AND typelem = 0)`)
parts = append(parts, `
GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
if supportsMultirange {
parts = append(parts, `
multirange.rngtypid,`)
}
parts = append(parts, `
attnames, atttypids
ORDER BY MAX(depth) desc, typname;`)
return strings.Join(parts, "")
}

type TypeInfo struct {
oid, typbasetype, typelem, rngsubtype, rngtypid uint32
typeName, typtype string
attnames []string
atttypids []uint32
}

// LoadTypes performs a single (complex) query, returning all the required
// information to register the named types, as well as any other types directly
// or indirectly required to complete the registration.
// The result of this call can be passed into RegisterTypes to complete the process.
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*TypeInfo, error) {
if typeNames == nil || len(typeNames) == 0 {
return nil, fmt.Errorf("No type names were supplied.")
}

serverVersion, err := c.serverVersion()
if err != nil {
return nil, fmt.Errorf("Unexpected server version error: %w", err)
}
sql := buildLoadTypesSQL(serverVersion, typeNames)
var rows Rows
if typeNames == nil {
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
} else {
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
}
if err != nil {
return nil, fmt.Errorf("While generating load types query: %w", err)
}
defer rows.Close()
result := make([]*TypeInfo, 0, 100)
for rows.Next() {
ti := TypeInfo{}
err = rows.Scan(&ti.typeName, &ti.typtype, &ti.typbasetype, &ti.typelem, &ti.oid, &ti.rngtypid, &ti.rngsubtype, &ti.attnames, &ti.atttypids)
if err != nil {
return nil, fmt.Errorf("While scanning type information: %w", err)
}
result = append(result, &ti)
}
return result, nil
}

// RegisterTypes complements LoadTypes, applying the type information collected by LoadTypes
// to the connection's typemap.
func (c *Conn) RegisterTypes(typeInfo []*TypeInfo, registerWith *pgtype.Map) error {
if registerWith == nil {
return fmt.Errorf("Type map must be supplied")
}
for _, ti := range typeInfo {
switch ti.typtype {
case "b": // array
dt, ok := registerWith.TypeForOID(ti.typelem)
if !ok {
return fmt.Errorf("array element OID %v not registered while loading for %v", ti.typelem, ti.typeName)
}
registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.ArrayCodec{ElementType: dt}})
case "c": // composite
var fields []pgtype.CompositeCodecField
for i, fieldName := range ti.attnames {
//if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil {
// return nil, fmt.Errorf("While extracting OID used in composite field: %w", err)
//}
dt, ok := registerWith.TypeForOID(ti.atttypids[i])
if !ok {
return fmt.Errorf("unknown composite type field OID %v (%v)", ti.atttypids[i], fieldName)
}
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
}

registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.CompositeCodec{Fields: fields}})
case "d": // domain
dt, ok := registerWith.TypeForOID(ti.typbasetype)
if !ok {
return fmt.Errorf("domain base type OID %v was not already registered, needed for %v", ti.typbasetype, ti.typeName)
}

registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: dt.Codec})
case "e": // enum
registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.EnumCodec{}})
case "r": // range
dt, ok := registerWith.TypeForOID(ti.rngsubtype)
if !ok {
return fmt.Errorf("range element OID %v not registered for %v", ti.rngsubtype, ti.typeName)
}

registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.RangeCodec{ElementType: dt}})
case "m": // multirange
dt, ok := registerWith.TypeForOID(ti.rngtypid)
if !ok {
return fmt.Errorf("multirange element OID %v not registered while loading %v", ti.rngtypid, ti.typeName)
}

registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}})
default:
return fmt.Errorf("unknown typtype %v for %v", ti.typtype, ti.typeName)
}
}
return nil
}

// serverVersion returns the postgresql server version.
func (conn *Conn) serverVersion() (int64, error) {
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
}

serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
}
return serverVersion, nil
}
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
Expand Down
53 changes: 48 additions & 5 deletions pgtype/composite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,58 @@ import (
"github.com/stretchr/testify/require"
)

func TestCompositeCodecTranscode(t *testing.T) {
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop domain if exists anotheruint64;
drop type if exists ct_test;
create domain anotheruint64 as numeric(20,0);
create type ct_test as (
a text,
b int4,
c anotheruint64
);`)
require.NoError(t, err)
defer conn.Exec(ctx, "drop type ct_test")
defer conn.Exec(ctx, "drop domain anotheruint64")

types, err := conn.LoadTypes(ctx, []string{"ct_test"})
require.NoError(t, err)
err = conn.RegisterTypes(types, conn.TypeMap())
require.NoError(t, err)

formats := []struct {
name string
code int16
}{
{name: "TextFormat", code: pgx.TextFormatCode},
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
}

for _, format := range formats {
var a string
var b int32
var c uint64

err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code},
pgtype.CompositeFields{"hi", int32(42), uint64(123)},
).Scan(
pgtype.CompositeFields{&a, &b, &c},
)
require.NoErrorf(t, err, "%v", format.name)
require.EqualValuesf(t, "hi", a, "%v", format.name)
require.EqualValuesf(t, 42, b, "%v", format.name)
require.EqualValuesf(t, 123, c, "%v", format.name)
}
})
}

func TestCompositeCodecTranscode(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists ct_test;
create type ct_test as (
Expand Down Expand Up @@ -94,7 +141,6 @@ func TestCompositeCodecTranscodeStruct(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

_, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as (
Expand Down Expand Up @@ -131,7 +177,6 @@ func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

_, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as (
Expand Down Expand Up @@ -172,7 +217,6 @@ func TestCompositeCodecDecodeValue(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

_, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as (
Expand Down Expand Up @@ -217,7 +261,6 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

_, err := conn.Exec(ctx, `drop table if exists point3d;
create table point3d (
Expand Down

0 comments on commit ca18af4

Please sign in to comment.