Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce QueryContext allocations by reusing the channel #1295

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ type SQLiteRows struct {
decltype []string
ctx context.Context // no better alternative to pass context into Next() method
closemu sync.Mutex
// semaphore to signal the goroutine used to interrupt queries when a
// cancellable context is passed to QueryContext
sema chan struct{}
}

type functionInfo struct {
Expand Down Expand Up @@ -2050,36 +2053,37 @@ func isInterruptErr(err error) bool {

// exec executes a query that doesn't return rows. Attempts to honor context timeout.
func (s *SQLiteStmt) exec(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if ctx.Done() == nil {
done := ctx.Done()
if done == nil {
return s.execSync(args)
}

type result struct {
r driver.Result
err error
if err := ctx.Err(); err != nil {
return nil, err // Fast check if the channel is closed
}
resultCh := make(chan result)
defer close(resultCh)

sema := make(chan struct{})
go func() {
r, err := s.execSync(args)
resultCh <- result{r, err}
}()
var rv result
select {
case rv = <-resultCh:
case <-ctx.Done():
select {
case rv = <-resultCh: // no need to interrupt, operation completed in db
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
case <-done:
C.sqlite3_interrupt(s.c.db)
rv = <-resultCh // wait for goroutine completed
if isInterruptErr(rv.err) {
return nil, ctx.Err()
}
// Wait until signaled. We need to ensure that this goroutine
// will not call interrupt after this method returns.
<-sema
case <-sema:
}
}()
r, err := s.execSync(args)
// Signal the goroutine to exit. This send will only succeed at a point
// where it is impossible for the goroutine to call sqlite3_interrupt.
//
// This is necessary to ensure the goroutine does not interrupt an
// unrelated query if the context is cancelled after this method returns
// but before the goroutine exits (we don't wait for it to exit).
sema <- struct{}{}
if err != nil && isInterruptErr(err) {
return nil, ctx.Err()
}
return rv.r, rv.err
return r, err
}

func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) {
Expand Down Expand Up @@ -2117,6 +2121,9 @@ func (rc *SQLiteRows) Close() error {
return nil
}
rc.s = nil // remove reference to SQLiteStmt
if rc.sema != nil {
close(rc.sema)
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
Expand Down Expand Up @@ -2174,27 +2181,40 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return io.EOF
}

if rc.ctx.Done() == nil {
done := rc.ctx.Done()
if done == nil {
return rc.nextSyncLocked(dest)
}
resultCh := make(chan error)
defer close(resultCh)
if err := rc.ctx.Err(); err != nil {
return err // Fast check if the channel is closed
}

if rc.sema == nil {
rc.sema = make(chan struct{})
}
go func() {
resultCh <- rc.nextSyncLocked(dest)
}()
select {
case err := <-resultCh:
return err
case <-rc.ctx.Done():
select {
case <-resultCh: // no need to interrupt
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
case <-done:
C.sqlite3_interrupt(rc.s.c.db)
<-resultCh // ensure goroutine completed
// Wait until signaled. We need to ensure that this goroutine
// will not call interrupt after this method returns.
<-rc.sema
case <-rc.sema:
}
return rc.ctx.Err()
}()

err := rc.nextSyncLocked(dest)
// Signal the goroutine to exit. This send will only succeed at a point
// where it is impossible for the goroutine to call sqlite3_interrupt.
//
// This is necessary to ensure the goroutine does not interrupt an
// unrelated query if the context is cancelled after this method returns
// but before the goroutine exits (we don't wait for it to exit).
rc.sema <- struct{}{}
if err != nil && isInterruptErr(err) {
err = rc.ctx.Err()
}
return err
}

// nextSyncLocked moves cursor to next; must be called with locked mutex.
Expand Down
147 changes: 147 additions & 0 deletions sqlite3_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"os"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
}
}

// Test that we can successfully interrupt a long running query when
// the context is canceled. The previous two QueryRowContext tests
// only test that we handle a previously cancelled context and thus
// do not call sqlite3_interrupt.
func TestQueryRowContextCancelInterrupt(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Test that we have the unixepoch function and if not skip the test.
if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil {
libVersion, libVersionNumber, sourceID := Version()
if strings.Contains(err.Error(), "no such function: unixepoch") {
t.Skip("Skipping the 'unixepoch' function is not implemented in "+
"this version of sqlite3:", libVersion, libVersionNumber, sourceID)
}
t.Fatal(err)
}

const createTableStmt = `
CREATE TABLE timestamps (
ts TIMESTAMP NOT NULL
);`
if _, err := db.Exec(createTableStmt); err != nil {
t.Fatal(err)
}

stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

// Computationally expensive query that consumes many rows. This is needed
// to test cancellation because queries are not interrupted immediately.
// Instead, queries are only halted at certain checkpoints where the
// sqlite3.isInterrupted is checked and true.
queryStmt := `
SELECT
SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1,
SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2,
SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3,
SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4
FROM
timestamps
WHERE datetime(ts, 'unixepoch', 'localtime')
LIKE
?;`

query := func(t *testing.T, timeout time.Duration) (int, error) {
// Create a complicated pattern to match timestamps
const pattern = "%2%0%2%4%-%-%:%:%"
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
rows, err := db.QueryContext(ctx, queryStmt, pattern)
if err != nil {
return 0, err
}
var count int
for rows.Next() {
var n int64
if err := rows.Scan(&n, &n, &n, &n); err != nil {
return count, err
}
count++
}
return count, rows.Err()
}

average := func(n int, fn func()) time.Duration {
start := time.Now()
for i := 0; i < n; i++ {
fn()
}
return time.Since(start) / time.Duration(n)
}

createRows := func(n int) {
t.Logf("Creating %d rows", n)
if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil {
t.Fatal(err)
}
ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix()
rr := rand.New(rand.NewSource(1234))
for i := 0; i < n; i++ {
if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil {
t.Fatal(err)
}
}
}

const TargetRuntime = 200 * time.Millisecond
const N = 5_000 // Number of rows to insert at a time

// Create enough rows that the query takes ~200ms to run.
start := time.Now()
createRows(N)
baseAvg := average(4, func() {
if _, err := query(t, time.Hour); err != nil {
t.Fatal(err)
}
})
t.Log("Base average:", baseAvg)
rowCount := N * (int(TargetRuntime/baseAvg) + 1)
createRows(rowCount)
t.Log("Table setup time:", time.Since(start))

// Set the timeout to 1/10 of the average query time.
avg := average(2, func() {
n, err := query(t, time.Hour)
if err != nil {
t.Fatal(err)
}
if n == 0 {
t.Fatal("scanned zero rows")
}
})
// Guard against the timeout being too short to reliably test.
if avg < TargetRuntime/2 {
t.Fatalf("Average query runtime should be around %s got: %s ",
TargetRuntime, avg)
}
timeout := (avg / 10).Round(100 * time.Microsecond)
t.Logf("Average: %s Timeout: %s", avg, timeout)

for i := 0; i < 10; i++ {
tt := time.Now()
n, err := query(t, timeout)
if !errors.Is(err, context.DeadlineExceeded) {
fn := t.Errorf
if err != nil {
fn = t.Fatalf
}
fn("expected error %v got %v", context.DeadlineExceeded, err)
}
d := time.Since(tt)
t.Logf("%d: rows: %d duration: %s", i, n, d)
if d > timeout*4 {
t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d)
}
}
}

func TestExecCancel(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
Expand Down
Loading
Loading