diff --git a/pgxpool/pool.go b/pgxpool/pool.go index fdcba7241..079966724 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -12,14 +12,17 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/puddle/v2" ) -var defaultMaxConns = int32(4) -var defaultMinConns = int32(0) -var defaultMaxConnLifetime = time.Hour -var defaultMaxConnIdleTime = time.Minute * 30 -var defaultHealthCheckPeriod = time.Minute +var ( + defaultMaxConns = int32(4) + defaultMinConns = int32(0) + defaultMaxConnLifetime = time.Hour + defaultMaxConnIdleTime = time.Minute * 30 + defaultHealthCheckPeriod = time.Minute +) type connResource struct { conn *pgx.Conn @@ -100,6 +103,11 @@ type Pool struct { closeOnce sync.Once closeChan chan struct{} + + autoLoadTypeNames []string + reuseTypeMap bool + autoLoadMutex *sync.Mutex + autoLoadTypes []*pgtype.Type } // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be @@ -147,6 +155,23 @@ type Config struct { // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration + // AutoLoadTypes is a list of user-defined types which should automatically be loaded + // as each new connection is created. This will also load any derived types, directly + // or indirectly required to handle these types. + // This is equivalent to manually calling pgx.LoadTypes() + // followed by conn.TypeMap().RegisterTypes() + // This will occur after the AfterConnect hook. If manual type + // registrating is performed during AfterConnect, the autoloading + // will be aware of those registrations. + AutoLoadTypes []string + + // ReuseTypeMaps, if enabled, will reuse the typemap information being used by AutoLoadTypes. + // This removes the need to query the database each time a new connection is created; + // only RegisterDerivedTypes will need to be called for each new connection. + // In some situations, where OID mapping can differ between pg servers in the pool, perhaps due + // to certain replication strategies, this should be left disabled. + ReuseTypeMaps bool + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -185,6 +210,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { config: config, beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, + autoLoadTypeNames: config.AutoLoadTypes, + reuseTypeMap: config.ReuseTypeMaps, beforeAcquire: config.BeforeAcquire, afterRelease: config.AfterRelease, beforeClose: config.BeforeClose, @@ -196,6 +223,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { healthCheckPeriod: config.HealthCheckPeriod, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), + autoLoadMutex: new(sync.Mutex), } if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { @@ -237,6 +265,15 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { } } + if p.autoLoadTypeNames != nil && len(p.autoLoadTypeNames) > 0 { + types, err := p.loadTypes(ctx, conn, p.autoLoadTypeNames) + if err != nil { + conn.Close(ctx) + return nil, err + } + conn.TypeMap().RegisterTypes(types) + } + jitterSecs := rand.Float64() * config.MaxConnLifetimeJitter.Seconds() maxAgeTime := time.Now().Add(config.MaxConnLifetime).Add(time.Duration(jitterSecs) * time.Second) @@ -388,6 +425,27 @@ func (p *Pool) Close() { }) } +// loadTypes is used internally to autoload the custom types for a connection, +// potentially reusing previously-loaded typemap information. +func (p *Pool) loadTypes(ctx context.Context, conn *pgx.Conn, typeNames []string) ([]*pgtype.Type, error) { + if p.reuseTypeMap { + p.autoLoadMutex.Lock() + defer p.autoLoadMutex.Unlock() + if p.autoLoadTypes != nil { + return p.autoLoadTypes, nil + } + types, err := pgx.LoadTypes(ctx, conn, typeNames) + if err != nil { + return nil, err + } + p.autoLoadTypes = types + return types, err + } + // Avoid needing to acquire the mutex and allow connections to initialise in parallel + // if we have chosen to not reuse the type mapping + return pgx.LoadTypes(ctx, conn, typeNames) +} + func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { return time.Now().After(res.Value().maxAgeTime) } @@ -482,7 +540,6 @@ func (p *Pool) checkMinConns() error { func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error { ctx, cancel := context.WithCancel(parentCtx) defer cancel() - errs := make(chan error, targetResources) for i := 0; i < targetResources; i++ { @@ -495,7 +552,6 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in errs <- err }() } - var firstError error for i := 0; i < targetResources; i++ { err := <-errs diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 90428931b..445ae0c33 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -261,6 +261,35 @@ func TestPoolBeforeConnect(t *testing.T) { assert.EqualValues(t, "pgx", str) } +func TestAutoLoadTypes(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer controllerConn.Close(ctx) + pgxtest.SkipCockroachDB(t, controllerConn, "Server does not support autoloading of uint64") + db1, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db1.Close() + db1.Exec(ctx, "DROP DOMAIN IF EXISTS autoload_uint64; CREATE DOMAIN autoload_uint64 as numeric(20,0)") + defer db1.Exec(ctx, "DROP DOMAIN autoload_uint64") + + config.AutoLoadTypes = []string{"autoload_uint64"} + db2, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + + var n uint64 + err = db2.QueryRow(ctx, "select 12::autoload_uint64").Scan(&n) + require.NoError(t, err) + assert.EqualValues(t, uint64(12), n) +} + func TestPoolAfterConnect(t *testing.T) { t.Parallel() @@ -676,7 +705,6 @@ func TestPoolQuery(t *testing.T) { stats = pool.Stat() assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) - } func TestPoolQueryRow(t *testing.T) { @@ -1104,7 +1132,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { } t.Fatal("did not reach min pool size") - } func TestPoolSendBatchBatchCloseTwice(t *testing.T) {