From 4c1fda091bb8d40bd81c657de9959a8df179b633 Mon Sep 17 00:00:00 2001 From: mateuszkowalke Date: Thu, 22 Aug 2024 06:51:58 +0200 Subject: [PATCH 01/27] Add additional info for nullable pgtype types Additional information warns about using nullable types being used as parameters to query with Valid set to false. --- pgtype/doc.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgtype/doc.go b/pgtype/doc.go index d56c1dc70..7687ea8fe 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -53,6 +53,9 @@ similar fashion to database/sql. The second is to use a pointer to a pointer. return err } +When using nullable pgtype types as parameters for queries, one has to remember +to explicitly set their Valid field to true, otherwise the parameter's value will be NULL. + JSON Support pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. From 811b5014da99153b05f76c3196f21fd16c4a7662 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Fri, 23 Aug 2024 16:17:07 -0700 Subject: [PATCH 02/27] add byte length check to uint32 --- pgtype/uint32.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgtype/uint32.go b/pgtype/uint32.go index c01341099..f2b2fa6d4 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -296,6 +296,10 @@ func (scanPlanBinaryUint32ToTextScanner) Scan(src []byte, dst any) error { return s.ScanText(Text{}) } + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + n := uint64(binary.BigEndian.Uint32(src)) return s.ScanText(Text{String: strconv.FormatUint(n, 10), Valid: true}) } From 1a30a6255785676d8b58a77eb295a9a236aac6c0 Mon Sep 17 00:00:00 2001 From: merlin Date: Mon, 26 Aug 2024 14:01:37 +0300 Subject: [PATCH 03/27] Use sql.ErrNoRows as value for pgx.ErrNoRows --- conn.go | 19 ++++++++++++++++++- conn_test.go | 11 +++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 979095af4..187b3dd57 100644 --- a/conn.go +++ b/conn.go @@ -3,6 +3,7 @@ package pgx import ( "context" "crypto/sha256" + "database/sql" "encoding/hex" "errors" "fmt" @@ -102,11 +103,27 @@ func (ident Identifier) Sanitize() string { var ( // ErrNoRows occurs when rows are expected but none are returned. - ErrNoRows = errors.New("no rows in result set") + ErrNoRows = newProxyErr(sql.ErrNoRows, "no rows in result set") // ErrTooManyRows occurs when more rows than expected are returned. ErrTooManyRows = errors.New("too many rows in result set") ) +func newProxyErr(background error, msg string) error { + return &proxyError{ + msg: msg, + background: background, + } +} + +type proxyError struct { + msg string + background error +} + +func (err *proxyError) Error() string { return err.msg } + +func (err *proxyError) Unwrap() error { return err.background } + var ( errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") diff --git a/conn_test.go b/conn_test.go index df8c9186f..200ecc1a6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "bytes" "context" + "database/sql" "os" "strings" "sync" @@ -1408,3 +1409,13 @@ func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *test ensureConnValid(t, conn) } + +func TestErrNoRows(t *testing.T) { + t.Parallel() + + // ensure we preserve old error message + require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error()) + + require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows") + require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows") +} From 19f19941027778c69d90fb4bc2a1d250c3956e05 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Sep 2024 10:20:08 -0500 Subject: [PATCH 04/27] Release v5.7.0 --- CHANGELOG.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61b4695fd..72e216d67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,19 @@ +# 5.7.0 (September 2024) + +* Add support for sslrootcert=system (Yann Soubeyrand) +* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell) +* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda) +* Add MultiTrace (Stepan Rabotkin) +* Add TraceLogConfig with customizable TimeKey (stringintech) +* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin) +* Support scanning binary formatted uint32 into string / TextScanner (jennifersp) +* Fix interval encoding to allow 0s and avoid extra spaces (Carlos PĂ©rez-Aradros Herce) +* Update pgservicefile - fixes panic when parsing invalid file +* Better error message when reading past end of batch +* Don't print url when url.Parse returns an error (Kevin Biju) +* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler) +* Fix: Scan and encode types with underlying types of arrays + # 5.6.0 (May 25, 2024) * Add StrictNamedArgs (Tomas Zahradnicek) From da513455d796b05ae417a018ffb1520a0818dfae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2024 07:06:39 -0500 Subject: [PATCH 05/27] Fix data race with TraceLog.Config initialization https://github.com/jackc/pgx/pull/2120 --- go.mod | 2 +- go.sum | 4 +-- tracelog/tracelog.go | 19 +++++++++----- tracelog/tracelog_test.go | 53 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 0d952d691..20e948a3d 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/jackc/puddle/v2 v2.2.1 github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.17.0 + golang.org/x/sync v0.8.0 golang.org/x/text v0.14.0 ) @@ -15,7 +16,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.3.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/sync v0.1.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 29fe452b2..c5bdd043d 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,8 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/tracelog/tracelog.go b/tracelog/tracelog.go index 212a25407..61015354a 100644 --- a/tracelog/tracelog.go +++ b/tracelog/tracelog.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "sync" "time" "unicode/utf8" @@ -129,19 +130,25 @@ func DefaultTraceLogConfig() *TraceLogConfig { } } -// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are -// required. +// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. Logger and LogLevel +// are required. Config will be automatically initialized on first use if nil. type TraceLog struct { Logger Logger LogLevel LogLevel - Config *TraceLogConfig + + Config *TraceLogConfig + ensureConfigOnce sync.Once } // ensureConfig initializes the Config field with default values if it is nil. func (tl *TraceLog) ensureConfig() { - if tl.Config == nil { - tl.Config = DefaultTraceLogConfig() - } + tl.ensureConfigOnce.Do( + func() { + if tl.Config == nil { + tl.Config = DefaultTraceLogConfig() + } + }, + ) } type ctxKey int diff --git a/tracelog/tracelog_test.go b/tracelog/tracelog_test.go index 6812a97ba..f1959b324 100644 --- a/tracelog/tracelog_test.go +++ b/tracelog/tracelog_test.go @@ -6,14 +6,17 @@ import ( "log" "os" "strings" + "sync" "testing" "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxtest" "github.com/jackc/pgx/v5/tracelog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) var defaultConnTestRunner pgxtest.ConnTestRunner @@ -35,18 +38,29 @@ type testLog struct { type testLogger struct { logs []testLog + + mux sync.Mutex } func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { + l.mux.Lock() + defer l.mux.Unlock() + data["ctxdata"] = ctx.Value("ctxdata") l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) } func (l *testLogger) Clear() { + l.mux.Lock() + defer l.mux.Unlock() + l.logs = l.logs[0:0] } func (l *testLogger) FilterByMsg(msg string) (res []testLog) { + l.mux.Lock() + defer l.mux.Unlock() + for _, log := range l.logs { if log.msg == msg { res = append(res, log) @@ -457,3 +471,42 @@ func TestLogPrepare(t *testing.T) { require.Equal(t, err, logger.logs[0].data["err"]) }) } + +// https://github.com/jackc/pgx/pull/2120 +func TestConcurrentUsage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnConfig.Tracer = tracer + + for i := 0; i < 50; i++ { + func() { + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + + defer pool.Close() + + eg := errgroup.Group{} + + for i := 0; i < 5; i++ { + eg.Go(func() error { + _, err := pool.Exec(ctx, `select 1`) + return err + }) + } + + err = eg.Wait() + require.NoError(t, err) + }() + } +} From 513a53fdd8e2a791717ea72174dd4b5436815a47 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2024 07:11:44 -0500 Subject: [PATCH 06/27] Upgrade puddle to v2.2.2 This removes the import of nanotime via linkname. --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 20e948a3d..efd9a00c0 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21 require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 - github.com/jackc/puddle/v2 v2.2.1 + github.com/jackc/puddle/v2 v2.2.2 github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.17.0 golang.org/x/sync v0.8.0 diff --git a/go.sum b/go.sum index c5bdd043d..7845266bb 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7Ulw github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= From be673157acc88b0fb428cd4b853b1efb0ecdaf95 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2024 07:17:03 -0500 Subject: [PATCH 07/27] Update golang.org/x/crypto and golang.org/x/text --- go.mod | 4 ++-- go.sum | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index efd9a00c0..100a90e8b 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,9 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 github.com/jackc/puddle/v2 v2.2.2 github.com/stretchr/testify v1.8.1 - golang.org/x/crypto v0.17.0 + golang.org/x/crypto v0.27.0 golang.org/x/sync v0.8.0 - golang.org/x/text v0.14.0 + golang.org/x/text v0.18.0 ) require ( diff --git a/go.sum b/go.sum index 7845266bb..ea86f04e5 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -31,12 +29,12 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= From e400c5e7c62e706fb110a58fab52449da83726a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2024 07:25:07 -0500 Subject: [PATCH 08/27] Release v5.7.1 --- CHANGELOG.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72e216d67..a0ff9ba3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,10 @@ -# 5.7.0 (September 2024) +# 5.7.1 (September 10, 2024) + +* Fix data race in tracelog.TraceLog +* Update puddle to v2.2.2. This removes the import of nanotime via linkname. +* Update golang.org/x/crypto and golang.org/x/text + +# 5.7.0 (September 7, 2024) * Add support for sslrootcert=system (Yann Soubeyrand) * Add LoadTypes to load multiple types in a single SQL query (Nick Farrell) From 808da0613a098ed1aae7b126041938615ea217d1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 13 Sep 2024 08:03:37 -0500 Subject: [PATCH 09/27] Fix prepared statement already exists on batch prepare failure When a batch successfully prepared some statements, but then failed to prepare others, the prepared statements that were successfully prepared were not properly cleaned up. This could lead to a "prepared statement already exists" error on subsequent attempts to prepare the same statement. https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887 --- batch_test.go | 30 +++++++++++++++++++++ conn.go | 75 +++++++++++++++++++++++++++++++-------------------- 2 files changed, 76 insertions(+), 29 deletions(-) diff --git a/batch_test.go b/batch_test.go index eb560e068..b1bc25de6 100644 --- a/batch_test.go +++ b/batch_test.go @@ -1008,6 +1008,36 @@ func TestSendBatchSimpleProtocol(t *testing.T) { assert.False(t, rows.Next()) } +// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887 +func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") + + mustExec(t, conn, `create temporary table foo(col1 text primary key);`) + + batch := &pgx.Batch{} + batch.Queue("select col1 from foo") + batch.Queue("select col1 from baz") + err := conn.SendBatch(ctx, batch).Close() + require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`) + + mustExec(t, conn, `create temporary table baz(col1 text primary key);`) + + // Since table baz now exists, the batch should succeed. + + batch = &pgx.Batch{} + batch.Queue("select col1 from foo") + batch.Queue("select col1 from baz") + err = conn.SendBatch(ctx, batch).Close() + require.NoError(t, err) + }) +} + func ExampleConn_SendBatch() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() diff --git a/conn.go b/conn.go index 187b3dd57..1d4c414fb 100644 --- a/conn.go +++ b/conn.go @@ -1126,47 +1126,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d // Prepare any needed queries if len(distinctNewQueries) > 0 { - for _, sd := range distinctNewQueries { - pipeline.SendPrepare(sd.Name, sd.SQL, nil) - } + err := func() (err error) { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } - err := pipeline.Sync() - if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} - } + // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will + // clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Put(sd) + } + } + + // If something goes wrong preparing the statements, we need to invalidate the cache entries we just added. + defer func() { + if err != nil && sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Invalidate(sd.SQL) + } + } + }() + + err = pipeline.Sync() + if err != nil { + return err + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return err + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return fmt.Errorf("expected statement description, got %T", results) + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } - for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + return err } - resultSD, ok := results.(*pgconn.StatementDescription) + _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} + return fmt.Errorf("expected sync, got %T", results) } - // Fill in the previously empty / pending statement descriptions. - sd.ParamOIDs = resultSD.ParamOIDs - sd.Fields = resultSD.Fields - } - - results, err := pipeline.GetResults() + return nil + }() if err != nil { return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } - - _, ok := results.(*pgconn.PipelineSync) - if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} - } - } - - // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. - if sdCache != nil { - for _, sd := range distinctNewQueries { - sdCache.Put(sd) - } } // Queue the queries. From fd6496fab8fbfce0e660309889399e94ee47ade2 Mon Sep 17 00:00:00 2001 From: Shean de Montigny-Desautels Date: Mon, 23 Sep 2024 12:17:45 -0400 Subject: [PATCH 10/27] Fix pgtype.Timestamp json unmarshal Add the missing 'Z' at the end of the timestamp string, so it can be parsed as timestamp in the RFC3339 format. --- pgtype/timestamp.go | 5 ++--- pgtype/timestamp_test.go | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 677a2c6ea..ff2739d6b 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -104,8 +104,8 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error { case "-infinity": *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: - // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz - tim, err := time.Parse(time.RFC3339Nano, *s) + // PostgreSQL uses ISO 8601 wihout timezone for to_json function and casting from a string to timestampt + tim, err := time.Parse(time.RFC3339Nano, *s+"Z") if err != nil { return err } @@ -225,7 +225,6 @@ func discardTimeZone(t time.Time) time.Time { } func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 31b3ad822..345da819c 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -128,8 +128,8 @@ func TestTimestampUnmarshalJSON(t *testing.T) { result pgtype.Timestamp }{ {source: "null", result: pgtype.Timestamp{}}, - {source: "\"2012-03-29T10:05:45Z\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}}, - {source: "\"2012-03-29T10:05:45.555Z\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}}, + {source: "\"2012-03-29T10:05:45\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}}, + {source: "\"2012-03-29T10:05:45.555\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}}, {source: "\"infinity\"", result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}}, {source: "\"-infinity\"", result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } From 21392a20093175fef9478a540f1c541dbdc47cb5 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 12:52:55 +0300 Subject: [PATCH 11/27] base case make benchmark more extensive add quote to string add BenchmarkSanitizeSQL --- internal/sanitize/sanitize_bench_test.go | 62 ++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 internal/sanitize/sanitize_bench_test.go diff --git a/internal/sanitize/sanitize_bench_test.go b/internal/sanitize/sanitize_bench_test.go new file mode 100644 index 000000000..baa742b11 --- /dev/null +++ b/internal/sanitize/sanitize_bench_test.go @@ -0,0 +1,62 @@ +// sanitize_benchmark_test.go +package sanitize_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/sanitize" +) + +var benchmarkSanitizeResult string + +const benchmarkQuery = "" + + `SELECT * + FROM "water_containers" + WHERE NOT "id" = $1 -- int64 + AND "tags" NOT IN $2 -- nil + AND "volume" > $3 -- float64 + AND "transportable" = $4 -- bool + AND position($5 IN "sign") -- bytes + AND "label" LIKE $6 -- string + AND "created_at" > $7; -- time.Time` + +var benchmarkArgs = []any{ + int64(12345), + nil, + float64(500), + true, + []byte("8BADF00D"), + "kombucha's han'dy awokowa", + time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC), +} + +func BenchmarkSanitize(b *testing.B) { + query, err := sanitize.NewQuery(benchmarkQuery) + if err != nil { + b.Fatalf("failed to create query: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...) + if err != nil { + b.Fatalf("failed to sanitize query: %v", err) + } + } +} + +var benchmarkNewSQLResult string + +func BenchmarkSanitizeSQL(b *testing.B) { + b.ReportAllocs() + var err error + for i := 0; i < b.N; i++ { + benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...) + if err != nil { + b.Fatalf("failed to sanitize SQL: %v", err) + } + } +} From d8d0cabb060a190b727d39b717225b23fa8901e3 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 12:57:07 +0300 Subject: [PATCH 12/27] add benchmark tool fix benchmmark script fix benchmark script --- internal/sanitize/benchmmark.sh | 59 +++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 internal/sanitize/benchmmark.sh diff --git a/internal/sanitize/benchmmark.sh b/internal/sanitize/benchmmark.sh new file mode 100644 index 000000000..87e7e0a11 --- /dev/null +++ b/internal/sanitize/benchmmark.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +current_branch=$(git rev-parse --abbrev-ref HEAD) +if [ "$current_branch" == "HEAD" ]; then + current_branch=$(git rev-parse HEAD) +fi + +restore_branch() { + echo "Restoring original branch/commit: $current_branch" + git checkout "$current_branch" +} +trap restore_branch EXIT + +# Check if there are uncommitted changes +if ! git diff --quiet || ! git diff --cached --quiet; then + echo "There are uncommitted changes. Please commit or stash them before running this script." + exit 1 +fi + +# Ensure that at least one commit argument is passed +if [ "$#" -lt 1 ]; then + echo "Usage: $0 ... " + exit 1 +fi + +commits=("$@") +benchmarks_dir=benchmarks + +if ! mkdir -p "${benchmarks_dir}"; then + echo "Unable to create dir for benchmarks data" + exit 1 +fi + +# Benchmark results +bench_files=() + +# Run benchmark for each listed commit +for i in "${!commits[@]}"; do + commit="${commits[i]}" + git checkout "$commit" || { + echo "Failed to checkout $commit" + exit 1 + } + + # Sanitized commmit message + commit_message=$(git log -1 --pretty=format:"%s" | tr ' ' '_') + + # Benchmark data will go there + bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" + + if ! go test -bench=. -count=25 >"$bench_file"; then + echo "Benchmarking failed for commit $commit" + exit 1 + fi + + bench_files+=("$bench_file") +done + +benchstat "${bench_files[@]}" From 9435a2c1984ccf2a39b1c60ff8c147616bd0f70c Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 12:53:07 +0300 Subject: [PATCH 13/27] buf pool --- internal/sanitize/sanitize.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index df58c4484..4a069658b 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" "strings" + "sync" "time" "unicode/utf8" ) @@ -24,9 +25,26 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 +var bufPool = &sync.Pool{} + +func getBuf() *bytes.Buffer { + buf, _ := bufPool.Get().(*bytes.Buffer) + if buf == nil { + buf = &bytes.Buffer{} + } + + return buf +} + +func putBuf(buf *bytes.Buffer) { + buf.Reset() + bufPool.Put(buf) +} + func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) - buf := &bytes.Buffer{} + buf := getBuf() + defer putBuf(buf) for _, part := range q.Parts { var str string From 4f4e892b4bd6f7a876e9fd28b5356b300b1e5932 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 13:24:03 +0300 Subject: [PATCH 14/27] shared bytestring --- internal/sanitize/sanitize.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 4a069658b..c7c8acd59 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -41,16 +41,19 @@ func putBuf(buf *bytes.Buffer) { bufPool.Put(buf) } +var null = []byte("null") + func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) buf := getBuf() defer putBuf(buf) + var p []byte for _, part := range q.Parts { - var str string + p = p[:0] switch part := part.(type) { case string: - str = part + buf.WriteString(part) case int: argIdx := part - 1 @@ -64,19 +67,19 @@ func (q *Query) Sanitize(args ...any) (string, error) { arg := args[argIdx] switch arg := arg.(type) { case nil: - str = "null" + p = null case int64: - str = strconv.FormatInt(arg, 10) + p = strconv.AppendInt(p, arg, 10) case float64: - str = strconv.FormatFloat(arg, 'f', -1, 64) + p = strconv.AppendFloat(p, arg, 'f', -1, 64) case bool: - str = strconv.FormatBool(arg) + p = strconv.AppendBool(p, arg) case []byte: - str = QuoteBytes(arg) + p = []byte(QuoteBytes(arg)) case string: - str = QuoteString(arg) + p = []byte(QuoteString(arg)) case time.Time: - str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond).AppendFormat(p, "'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } @@ -84,11 +87,12 @@ func (q *Query) Sanitize(args ...any) (string, error) { // Prevent SQL injection via Line Comment Creation // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - str = " " + str + " " + buf.WriteByte(' ') + buf.Write(p) + buf.WriteByte(' ') default: return "", fmt.Errorf("invalid Part type: %T", part) } - buf.WriteString(str) } for i, used := range argUse { From 39db71a0d565a408be4c93195e3378202c9795c8 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 13:30:34 +0300 Subject: [PATCH 15/27] append AvailableBuffer --- internal/sanitize/sanitize.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index c7c8acd59..1e0b20aca 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -47,16 +47,14 @@ func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) buf := getBuf() defer putBuf(buf) - var p []byte for _, part := range q.Parts { - p = p[:0] switch part := part.(type) { case string: buf.WriteString(part) case int: argIdx := part - 1 - + var p []byte if argIdx < 0 { return "", fmt.Errorf("first sql argument must be > 0") } @@ -64,22 +62,23 @@ func (q *Query) Sanitize(args ...any) (string, error) { if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } + buf.WriteByte(' ') arg := args[argIdx] switch arg := arg.(type) { case nil: p = null case int64: - p = strconv.AppendInt(p, arg, 10) + p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) case float64: - p = strconv.AppendFloat(p, arg, 'f', -1, 64) + p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) case bool: - p = strconv.AppendBool(p, arg) + p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: p = []byte(QuoteBytes(arg)) case string: p = []byte(QuoteString(arg)) case time.Time: - p = arg.Truncate(time.Microsecond).AppendFormat(p, "'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond).AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } @@ -87,7 +86,6 @@ func (q *Query) Sanitize(args ...any) (string, error) { // Prevent SQL injection via Line Comment Creation // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - buf.WriteByte(' ') buf.Write(p) buf.WriteByte(' ') default: From e142286b937f86e777e0eab8171265fd25e59511 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 13:47:44 +0300 Subject: [PATCH 16/27] docs --- internal/sanitize/sanitize.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 1e0b20aca..3414d6d1a 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -62,7 +62,11 @@ func (q *Query) Sanitize(args ...any) (string, error) { if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p buf.WriteByte(' ') + arg := args[argIdx] switch arg := arg.(type) { case nil: @@ -78,15 +82,17 @@ func (q *Query) Sanitize(args ...any) (string, error) { case string: p = []byte(QuoteString(arg)) case time.Time: - p = arg.Truncate(time.Microsecond).AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond). + AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true + buf.Write(p) + // Prevent SQL injection via Line Comment Creation // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - buf.Write(p) buf.WriteByte(' ') default: return "", fmt.Errorf("invalid Part type: %T", part) From f0180ba22a8dc2944c694f08c5753e0795f23d6e Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 14:30:05 +0300 Subject: [PATCH 17/27] quoteBytes check new quoteBytes --- internal/sanitize/sanitize.go | 17 +++++++++++++++-- internal/sanitize/sanitize_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 3414d6d1a..91d6db58c 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/hex" "fmt" + "slices" "strconv" "strings" "sync" @@ -78,7 +79,7 @@ func (q *Query) Sanitize(args ...any) (string, error) { case bool: p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - p = []byte(QuoteBytes(arg)) + p = quoteBytes(buf.AvailableBuffer(), arg) case string: p = []byte(QuoteString(arg)) case time.Time: @@ -127,7 +128,19 @@ func QuoteString(str string) string { } func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" + return string(quoteBytes(nil, buf)) +} + +func quoteBytes(dst, buf []byte) []byte { + dst = append(dst, `'\x`...) + + n := hex.EncodedLen(len(buf)) + p := slices.Grow(dst[len(dst):], n)[:n] + hex.Encode(p, buf) + dst = append(dst, p...) + + dst = append(dst, `'`...) + return dst } type sqlLexer struct { diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 1deff3fba..76ae7a47f 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -1,6 +1,7 @@ package sanitize_test import ( + "encoding/hex" "testing" "time" @@ -227,3 +228,27 @@ func TestQuerySanitize(t *testing.T) { } } } + +func TestQuoteBytes(t *testing.T) { + tc := func(name string, input []byte) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := sanitize.QuoteBytes(input) + want := oldQuoteBytes(input) + + if got != want { + t.Errorf("got: %s", got) + t.Fatalf("want: %s", want) + } + }) + } + + tc("nil", nil) + tc("empty", []byte{}) + tc("text", []byte("abcd")) +} + +func oldQuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} From c50cb144384319d5166e64806f810413380782ba Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 14:50:59 +0300 Subject: [PATCH 18/27] quoteString --- internal/sanitize/sanitize.go | 31 ++++++++++++++++++++++++++++-- internal/sanitize/sanitize_test.go | 25 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 91d6db58c..d83633a72 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -81,7 +81,7 @@ func (q *Query) Sanitize(args ...any) (string, error) { case []byte: p = quoteBytes(buf.AvailableBuffer(), arg) case string: - p = []byte(QuoteString(arg)) + p = quoteString(buf.AvailableBuffer(), arg) case time.Time: p = arg.Truncate(time.Microsecond). AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") @@ -124,7 +124,34 @@ func NewQuery(sql string) (*Query, error) { } func QuoteString(str string) string { - return "'" + strings.ReplaceAll(str, "'", "''") + "'" + return string(quoteString(nil, str)) +} + +func quoteString(dst []byte, str string) []byte { + const quote = "'" + + n := strings.Count(str, quote) + + dst = append(dst, quote...) + + p := slices.Grow(dst[len(dst):], len(str)+2*n) + + for len(str) > 0 { + i := strings.Index(str, quote) + if i < 0 { + p = append(p, str...) + break + } + p = append(p, str[:i]...) + p = append(p, "''"...) + str = str[i+1:] + } + + dst = append(dst, p...) + + dst = append(dst, quote...) + + return dst } func QuoteBytes(buf []byte) string { diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 76ae7a47f..aafcd682d 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -2,6 +2,7 @@ package sanitize_test import ( "encoding/hex" + "strings" "testing" "time" @@ -229,6 +230,30 @@ func TestQuerySanitize(t *testing.T) { } } +func TestQuoteString(t *testing.T) { + tc := func(name, input string) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := sanitize.QuoteString(input) + want := oldQuoteString(input) + + if got != want { + t.Errorf("got: %s", got) + t.Fatalf("want: %s", want) + } + }) + } + + tc("empty", "") + tc("text", "abcd") + tc("with quotes", `one's hat is always a cat`) +} + +func oldQuoteString(str string) string { + return "'" + strings.ReplaceAll(str, "'", "''") + "'" +} + func TestQuoteBytes(t *testing.T) { tc := func(name string, input []byte) { t.Run(name, func(t *testing.T) { From 3a97ffd9474462cf230152332aaccf1830498383 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 15:25:24 +0300 Subject: [PATCH 19/27] decrease number of samples in go benchmark --- internal/sanitize/benchmmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/sanitize/benchmmark.sh b/internal/sanitize/benchmmark.sh index 87e7e0a11..06842c0aa 100644 --- a/internal/sanitize/benchmmark.sh +++ b/internal/sanitize/benchmmark.sh @@ -48,7 +48,7 @@ for i in "${!commits[@]}"; do # Benchmark data will go there bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" - if ! go test -bench=. -count=25 >"$bench_file"; then + if ! go test -bench=. -count=10 >"$bench_file"; then echo "Benchmarking failed for commit $commit" exit 1 fi From 1ec4baa11f3962f2a865c5d9a812e047409e1950 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 16:37:04 +0300 Subject: [PATCH 20/27] add FuzzQuoteString and FuzzQuoteBytes --- internal/sanitize/sanitize_fuzz_test.go | 43 +++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 internal/sanitize/sanitize_fuzz_test.go diff --git a/internal/sanitize/sanitize_fuzz_test.go b/internal/sanitize/sanitize_fuzz_test.go new file mode 100644 index 000000000..7d594def0 --- /dev/null +++ b/internal/sanitize/sanitize_fuzz_test.go @@ -0,0 +1,43 @@ +package sanitize_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/internal/sanitize" +) + +func FuzzQuoteString(f *testing.F) { + f.Add("") + f.Add("\n") + f.Add("sample text") + f.Add("sample q'u'o't'e's") + f.Add("select 'quoted $42', $1") + + f.Fuzz(func(t *testing.T, input string) { + got := sanitize.QuoteString(input) + want := oldQuoteString(input) + + if want != got { + t.Errorf("got %q", got) + t.Fatalf("want %q", want) + } + }) +} + +func FuzzQuoteBytes(f *testing.F) { + f.Add([]byte(nil)) + f.Add([]byte("\n")) + f.Add([]byte("sample text")) + f.Add([]byte("sample q'u'o't'e's")) + f.Add([]byte("select 'quoted $42', $1")) + + f.Fuzz(func(t *testing.T, input []byte) { + got := sanitize.QuoteBytes(input) + want := oldQuoteBytes(input) + + if want != got { + t.Errorf("got %q", got) + t.Fatalf("want %q", want) + } + }) +} From 000ce9c614e7bbd656a2c652b5edc540882e5f33 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 16:42:27 +0300 Subject: [PATCH 21/27] add lexer and query pools use lexer pool --- internal/sanitize/sanitize.go | 93 +++++++++++++++++++++++++---------- 1 file changed, 67 insertions(+), 26 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index d83633a72..4aca2fb98 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -26,28 +26,19 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 -var bufPool = &sync.Pool{} - -func getBuf() *bytes.Buffer { - buf, _ := bufPool.Get().(*bytes.Buffer) - if buf == nil { - buf = &bytes.Buffer{} - } - - return buf -} - -func putBuf(buf *bytes.Buffer) { - buf.Reset() - bufPool.Put(buf) +var bufPool = &pool[*bytes.Buffer]{ + new: func() *bytes.Buffer { + return &bytes.Buffer{} + }, + reset: (*bytes.Buffer).Reset, } var null = []byte("null") func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) - buf := getBuf() - defer putBuf(buf) + buf := bufPool.get() + defer bufPool.put(buf) for _, part := range q.Parts { switch part := part.(type) { @@ -109,18 +100,39 @@ func (q *Query) Sanitize(args ...any) (string, error) { } func NewQuery(sql string) (*Query, error) { - l := &sqlLexer{ - src: sql, - stateFn: rawState, + query := &Query{} + query.init(sql) + + return query, nil +} + +var sqlLexerPool = &pool[*sqlLexer]{ + new: func() *sqlLexer { + return &sqlLexer{} + }, + reset: func(sl *sqlLexer) { + *sl = sqlLexer{} + }, +} + +func (q *Query) init(sql string) { + parts := q.Parts[:0] + if parts == nil { + n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 + parts = make([]Part, 0, n) } + l := sqlLexerPool.get() + defer sqlLexerPool.put(l) + l.src = sql + l.stateFn = rawState + l.parts = parts + for l.stateFn != nil { l.stateFn = l.stateFn(l) } - query := &Query{Parts: l.parts} - - return query, nil + q.Parts = l.parts } func QuoteString(str string) string { @@ -385,13 +397,42 @@ func multilineCommentState(l *sqlLexer) stateFn { } } +var queryPool = &pool[*Query]{ + new: func() *Query { + return &Query{} + }, + reset: func(q *Query) { + q.Parts = q.Parts[:0] + }, +} + // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. func SanitizeSQL(sql string, args ...any) (string, error) { - query, err := NewQuery(sql) - if err != nil { - return "", err - } + query := queryPool.get() + query.init(sql) + defer queryPool.put(query) + return query.Sanitize(args...) } + +type pool[E any] struct { + p sync.Pool + new func() E + reset func(E) +} + +func (pool *pool[E]) get() E { + v, ok := pool.p.Get().(E) + if !ok { + v = pool.new() + } + + return v +} + +func (p *pool[E]) put(v E) { + p.reset(v) + p.p.Put(v) +} From 25a4bd3f038aa70e5d9bad23f431e994f399d791 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 17:04:48 +0300 Subject: [PATCH 22/27] rework QuoteString and QuoteBytes as append-style --- internal/sanitize/sanitize.go | 16 ++++------------ internal/sanitize/sanitize_fuzz_test.go | 8 ++++---- internal/sanitize/sanitize_test.go | 4 ++-- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 4aca2fb98..fd1e808b4 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -70,9 +70,9 @@ func (q *Query) Sanitize(args ...any) (string, error) { case bool: p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - p = quoteBytes(buf.AvailableBuffer(), arg) + p = QuoteBytes(buf.AvailableBuffer(), arg) case string: - p = quoteString(buf.AvailableBuffer(), arg) + p = QuoteString(buf.AvailableBuffer(), arg) case time.Time: p = arg.Truncate(time.Microsecond). AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") @@ -135,11 +135,7 @@ func (q *Query) init(sql string) { q.Parts = l.parts } -func QuoteString(str string) string { - return string(quoteString(nil, str)) -} - -func quoteString(dst []byte, str string) []byte { +func QuoteString(dst []byte, str string) []byte { const quote = "'" n := strings.Count(str, quote) @@ -166,11 +162,7 @@ func quoteString(dst []byte, str string) []byte { return dst } -func QuoteBytes(buf []byte) string { - return string(quoteBytes(nil, buf)) -} - -func quoteBytes(dst, buf []byte) []byte { +func QuoteBytes(dst, buf []byte) []byte { dst = append(dst, `'\x`...) n := hex.EncodedLen(len(buf)) diff --git a/internal/sanitize/sanitize_fuzz_test.go b/internal/sanitize/sanitize_fuzz_test.go index 7d594def0..746558276 100644 --- a/internal/sanitize/sanitize_fuzz_test.go +++ b/internal/sanitize/sanitize_fuzz_test.go @@ -14,10 +14,10 @@ func FuzzQuoteString(f *testing.F) { f.Add("select 'quoted $42', $1") f.Fuzz(func(t *testing.T, input string) { - got := sanitize.QuoteString(input) + got := sanitize.QuoteString(nil, input) want := oldQuoteString(input) - if want != got { + if want != string(got) { t.Errorf("got %q", got) t.Fatalf("want %q", want) } @@ -32,10 +32,10 @@ func FuzzQuoteBytes(f *testing.F) { f.Add([]byte("select 'quoted $42', $1")) f.Fuzz(func(t *testing.T, input []byte) { - got := sanitize.QuoteBytes(input) + got := sanitize.QuoteBytes(nil, input) want := oldQuoteBytes(input) - if want != got { + if want != string(got) { t.Errorf("got %q", got) t.Fatalf("want %q", want) } diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index aafcd682d..9da701ea9 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -235,7 +235,7 @@ func TestQuoteString(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - got := sanitize.QuoteString(input) + got := string(sanitize.QuoteString(nil, input)) want := oldQuoteString(input) if got != want { @@ -259,7 +259,7 @@ func TestQuoteBytes(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - got := sanitize.QuoteBytes(input) + got := string(sanitize.QuoteBytes(nil, input)) want := oldQuoteBytes(input) if got != want { From 2f3ae5a6dcf3013a30f98691a22ed2844bef7ff7 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 17:15:38 +0300 Subject: [PATCH 23/27] add docs to sanitize tests --- internal/sanitize/sanitize_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 9da701ea9..926751534 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -250,6 +250,8 @@ func TestQuoteString(t *testing.T) { tc("with quotes", `one's hat is always a cat`) } +// This function was used before optimizations. +// You should keep for testing purposes - we want to ensure there are no breaking changes. func oldQuoteString(str string) string { return "'" + strings.ReplaceAll(str, "'", "''") + "'" } @@ -274,6 +276,8 @@ func TestQuoteBytes(t *testing.T) { tc("text", []byte("abcd")) } +// This function was used before optimizations. +// You should keep for testing purposes - we want to ensure there are no breaking changes. func oldQuoteBytes(buf []byte) string { return `'\x` + hex.EncodeToString(buf) + "'" } From 9a33a622e7bc6219a56967488b74395ae46a4d27 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 17:16:02 +0300 Subject: [PATCH 24/27] drop too large values from memory pools --- internal/sanitize/sanitize.go | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index fd1e808b4..173523d95 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -26,11 +26,17 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 +const maxBufSize = 16384 // 16 Ki + var bufPool = &pool[*bytes.Buffer]{ new: func() *bytes.Buffer { return &bytes.Buffer{} }, - reset: (*bytes.Buffer).Reset, + reset: func(b *bytes.Buffer) bool { + n := b.Len() + b.Reset() + return n < maxBufSize + }, } var null = []byte("null") @@ -110,20 +116,23 @@ var sqlLexerPool = &pool[*sqlLexer]{ new: func() *sqlLexer { return &sqlLexer{} }, - reset: func(sl *sqlLexer) { + reset: func(sl *sqlLexer) bool { *sl = sqlLexer{} + return true }, } func (q *Query) init(sql string) { parts := q.Parts[:0] if parts == nil { + // dirty, but fast heuristic to preallocate for ~90% usecases n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 parts = make([]Part, 0, n) } l := sqlLexerPool.get() defer sqlLexerPool.put(l) + l.src = sql l.stateFn = rawState l.parts = parts @@ -393,8 +402,10 @@ var queryPool = &pool[*Query]{ new: func() *Query { return &Query{} }, - reset: func(q *Query) { + reset: func(q *Query) bool { + n := len(q.Parts) q.Parts = q.Parts[:0] + return n < 64 // drop too large queries }, } @@ -412,7 +423,7 @@ func SanitizeSQL(sql string, args ...any) (string, error) { type pool[E any] struct { p sync.Pool new func() E - reset func(E) + reset func(E) bool } func (pool *pool[E]) get() E { @@ -425,6 +436,7 @@ func (pool *pool[E]) get() E { } func (p *pool[E]) put(v E) { - p.reset(v) - p.p.Put(v) + if p.reset(v) { + p.p.Put(v) + } } From 85d5a4dea91a9d63808b3497d478cd4b77d09ed6 Mon Sep 17 00:00:00 2001 From: merlin Date: Sun, 20 Oct 2024 18:00:59 +0300 Subject: [PATCH 25/27] add prefix to quoters tests --- internal/sanitize/sanitize_fuzz_test.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/internal/sanitize/sanitize_fuzz_test.go b/internal/sanitize/sanitize_fuzz_test.go index 746558276..a8f2e7791 100644 --- a/internal/sanitize/sanitize_fuzz_test.go +++ b/internal/sanitize/sanitize_fuzz_test.go @@ -7,17 +7,22 @@ import ( ) func FuzzQuoteString(f *testing.F) { - f.Add("") - f.Add("\n") + const prefix = "prefix" + f.Add("new\nline") f.Add("sample text") f.Add("sample q'u'o't'e's") f.Add("select 'quoted $42', $1") f.Fuzz(func(t *testing.T, input string) { - got := sanitize.QuoteString(nil, input) + got := string(sanitize.QuoteString([]byte(prefix), input)) want := oldQuoteString(input) - if want != string(got) { + quoted, ok := strings.CutPrefix(got, prefix) + if !ok { + t.Fatalf("result has no prefix") + } + + if want != quoted { t.Errorf("got %q", got) t.Fatalf("want %q", want) } @@ -25,6 +30,7 @@ func FuzzQuoteString(f *testing.F) { } func FuzzQuoteBytes(f *testing.F) { + const prefix = "prefix" f.Add([]byte(nil)) f.Add([]byte("\n")) f.Add([]byte("sample text")) @@ -32,10 +38,15 @@ func FuzzQuoteBytes(f *testing.F) { f.Add([]byte("select 'quoted $42', $1")) f.Fuzz(func(t *testing.T, input []byte) { - got := sanitize.QuoteBytes(nil, input) + got := string(sanitize.QuoteBytes([]byte(prefix), input)) want := oldQuoteBytes(input) - if want != string(got) { + quoted, ok := strings.CutPrefix(got, prefix) + if !ok { + t.Fatalf("result has no prefix") + } + + if want != quoted { t.Errorf("got %q", got) t.Fatalf("want %q", want) } From 97d835870f32ccf1e3bb59bca709a38979570ada Mon Sep 17 00:00:00 2001 From: merlin Date: Sun, 20 Oct 2024 18:08:23 +0300 Subject: [PATCH 26/27] fix preallocations of quoted string --- internal/sanitize/sanitize.go | 2 +- internal/sanitize/sanitize_fuzz_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 173523d95..e0ae9bedb 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -151,7 +151,7 @@ func QuoteString(dst []byte, str string) []byte { dst = append(dst, quote...) - p := slices.Grow(dst[len(dst):], len(str)+2*n) + p := slices.Grow(dst[len(dst):], 2*len(quote)+len(str)+2*n) for len(str) > 0 { i := strings.Index(str, quote) diff --git a/internal/sanitize/sanitize_fuzz_test.go b/internal/sanitize/sanitize_fuzz_test.go index a8f2e7791..2f0c41223 100644 --- a/internal/sanitize/sanitize_fuzz_test.go +++ b/internal/sanitize/sanitize_fuzz_test.go @@ -1,6 +1,7 @@ package sanitize_test import ( + "strings" "testing" "github.com/jackc/pgx/v5/internal/sanitize" From 174e6787aaed1584f15895c969ef61027454cda4 Mon Sep 17 00:00:00 2001 From: merlin Date: Mon, 9 Dec 2024 16:33:57 +0200 Subject: [PATCH 27/27] optimisations of quote functions by @sean- --- internal/sanitize/benchmmark.sh | 3 +- internal/sanitize/sanitize.go | 62 +++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/internal/sanitize/benchmmark.sh b/internal/sanitize/benchmmark.sh index 06842c0aa..ec0f7b03a 100644 --- a/internal/sanitize/benchmmark.sh +++ b/internal/sanitize/benchmmark.sh @@ -43,7 +43,7 @@ for i in "${!commits[@]}"; do } # Sanitized commmit message - commit_message=$(git log -1 --pretty=format:"%s" | tr ' ' '_') + commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') # Benchmark data will go there bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" @@ -56,4 +56,5 @@ for i in "${!commits[@]}"; do bench_files+=("$bench_file") done +# go install golang.org/x/perf/cmd/benchstat[@latest] benchstat "${bench_files[@]}" diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index e0ae9bedb..b516817cb 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -145,41 +145,59 @@ func (q *Query) init(sql string) { } func QuoteString(dst []byte, str string) []byte { - const quote = "'" + const quote = '\'' - n := strings.Count(str, quote) + // Preallocate space for the worst case scenario + dst = slices.Grow(dst, len(str)*2+2) - dst = append(dst, quote...) + // Add opening quote + dst = append(dst, quote) - p := slices.Grow(dst[len(dst):], 2*len(quote)+len(str)+2*n) - - for len(str) > 0 { - i := strings.Index(str, quote) - if i < 0 { - p = append(p, str...) - break + // Iterate through the string without allocating + for i := 0; i < len(str); i++ { + if str[i] == quote { + dst = append(dst, quote, quote) + } else { + dst = append(dst, str[i]) } - p = append(p, str[:i]...) - p = append(p, "''"...) - str = str[i+1:] } - dst = append(dst, p...) - - dst = append(dst, quote...) + // Add closing quote + dst = append(dst, quote) return dst } func QuoteBytes(dst, buf []byte) []byte { - dst = append(dst, `'\x`...) + if len(buf) == 0 { + return append(dst, `'\x'`...) + } + + // Calculate required length + requiredLen := 3 + hex.EncodedLen(len(buf)) + 1 + + // Ensure dst has enough capacity + if cap(dst)-len(dst) < requiredLen { + newDst := make([]byte, len(dst), len(dst)+requiredLen) + copy(newDst, dst) + dst = newDst + } + + // Record original length and extend slice + origLen := len(dst) + dst = dst[:origLen+requiredLen] + + // Add prefix + dst[origLen] = '\'' + dst[origLen+1] = '\\' + dst[origLen+2] = 'x' + + // Encode bytes directly into dst + hex.Encode(dst[origLen+3:len(dst)-1], buf) - n := hex.EncodedLen(len(buf)) - p := slices.Grow(dst[len(dst):], n)[:n] - hex.Encode(p, buf) - dst = append(dst, p...) + // Add suffix + dst[len(dst)-1] = '\'' - dst = append(dst, `'`...) return dst }