Skip to content

Commit

Permalink
Busy handlers.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Feb 3, 2024
1 parent da0e98f commit 7438fdb
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 16 deletions.
41 changes: 39 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"errors"
"fmt"
"math"
"net/url"
"strings"
"time"

"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
Expand All @@ -20,6 +22,7 @@ type Conn struct {

interrupt context.Context
pending *Stmt
busy func(int) bool
log func(xErrorCode, string)
collation func(*Conn, string)
authorizer func(AuthorizerActionCode, string, string, string, string) AuthorizerReturnCode
Expand Down Expand Up @@ -322,15 +325,49 @@ func (c *Conn) checkInterrupt() {
}
}

func progressCallback(ctx context.Context, mod api.Module, _ uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
func progressCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
if c.interrupt != nil && c.interrupt.Err() != nil {
return 1
}
}
return 0
}

// BusyTimeout sets a busy timeout.
//
// https://sqlite.org/c3ref/busy_timeout.html
func (c *Conn) BusyTimeout(timeout time.Duration) error {
ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32)
r := c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms))
return c.error(r)
}

// BusyHandler registers a callback to handle [BUSY] errors.
//
// https://sqlite.org/c3ref/busy_handler.html
func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error {
var enable uint64
if cb != nil {
enable = 1
}
r := c.call("sqlite3_busy_handler_go", uint64(c.handle), enable)
if err := c.error(r); err != nil {
return err
}
c.busy = cb
return nil
}

func busyCallback(ctx context.Context, mod api.Module, pDB, count uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
if retry := c.busy(int(count)); retry {
return 1
}
}
return 0
}

// Deprecated: executes a PRAGMA statement and returns results.
func (c *Conn) Pragma(str string) ([]string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + str)
Expand Down
2 changes: 1 addition & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
defer c.Conn.SetInterrupt(old)

if !n.pragmas {
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
err = c.Conn.BusyTimeout(60 * time.Second)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions embed/exports.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ sqlite3_blob_open
sqlite3_blob_read
sqlite3_blob_reopen
sqlite3_blob_write
sqlite3_busy_handler_go
sqlite3_busy_timeout
sqlite3_changes64
sqlite3_clear_bindings
sqlite3_close
Expand Down
3 changes: 2 additions & 1 deletion sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ func (a *arena) string(s string) uint32 {
}

func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress", progressCallback)
util.ExportFuncIII(env, "go_busy_handler", busyCallback)
util.ExportFuncII(env, "go_progress_handler", progressCallback)
util.ExportFuncII(env, "go_commit_hook", commitCallback)
util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)
util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback)
Expand Down
9 changes: 7 additions & 2 deletions sqlite3/hooks.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

#include "sqlite3.h"

int go_progress(void *);
int go_progress_handler(void *);
int go_busy_handler(void *, int);

int go_commit_hook(void *);
void go_rollback_hook(void *);
Expand All @@ -14,7 +15,11 @@ int go_authorizer(void *, int, const char *, const char *, const char *,
void go_log(void *, int, const char *);

void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress, /*arg=*/db);
sqlite3_progress_handler(db, n, go_progress_handler, /*arg=*/db);
}

int sqlite3_busy_handler_go(sqlite3 *db, bool enable) {
return sqlite3_busy_handler(db, enable ? go_busy_handler : NULL, /*arg=*/db);
}

void sqlite3_commit_hook_go(sqlite3 *db, bool enable) {
Expand Down
21 changes: 12 additions & 9 deletions tests/parallel/parallel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os/exec"
"path/filepath"
"testing"
"time"

"golang.org/x/sync/errgroup"

Expand Down Expand Up @@ -39,10 +40,7 @@ func TestMemory(t *testing.T) {
iter = 5000
}

name := "file:/test.db?vfs=memdb" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(memory)" +
"&_pragma=synchronous(off)"
name := "file:/test.db?vfs=memdb"
testParallel(t, name, iter)
testIntegrity(t, name)
}
Expand Down Expand Up @@ -100,10 +98,7 @@ func TestChildProcess(t *testing.T) {

func BenchmarkMemory(b *testing.B) {
memdb.Delete("test.db")
name := "file:/test.db?vfs=memdb" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=journal_mode(memory)" +
"&_pragma=synchronous(off)"
name := "file:/test.db?vfs=memdb"
testParallel(b, name, b.N)
}

Expand All @@ -115,6 +110,14 @@ func testParallel(t testing.TB, name string, n int) {
}
defer db.Close()

err = db.BusyHandler(func(count int) (retry bool) {
time.Sleep(time.Millisecond)
return true
})
if err != nil {
return err
}

err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
return err
Expand All @@ -135,7 +138,7 @@ func testParallel(t testing.TB, name string, n int) {
}
defer db.Close()

err = db.Exec(`PRAGMA busy_timeout=10000`)
err = db.BusyTimeout(10 * time.Second)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table str

func commitCallback(ctx context.Context, mod api.Module, pDB uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
if !c.commit() {
if ok := c.commit(); !ok {
return 1
}
}
Expand Down

0 comments on commit 7438fdb

Please sign in to comment.