Skip to content

Commit

Permalink
Add database package (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
marco6 authored Jan 15, 2025
1 parent 58d7e76 commit 2c4e7b6
Show file tree
Hide file tree
Showing 9 changed files with 512 additions and 174 deletions.
144 changes: 144 additions & 0 deletions pkg/database/driver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package database_test

import (
"context"
"database/sql/driver"
"io"
"sync/atomic"
"testing"
)

type testDriver struct {
t *testing.T
conns atomic.Int32
queries atomic.Int32
execs atomic.Int32
stmts atomic.Int32
trans atomic.Int32
}

var _ driver.DriverContext = &testDriver{}

// Open implements driver.Driver.
func (td *testDriver) Open(name string) (driver.Conn, error) {
td.conns.Add(1)
return &testConn{
driver: td,
dbName: name,
}, nil
}

func (td *testDriver) OpenConnector(name string) (driver.Connector, error) {
return &testConnector{
driver: td,
dbName: name,
}, nil
}

func (td *testDriver) Close() {
if rows := td.queries.Load(); rows != 0 {
td.t.Errorf("%d rows left open after close", rows)
}
if stmts := td.stmts.Load(); stmts != 0 {
td.t.Errorf("%d statements left open after close", stmts)
}
if conns := td.conns.Load(); conns != 0 {
td.t.Errorf("%d connections left open after close", conns)
}
if trans := td.trans.Load(); trans != 0 {
td.t.Errorf("%d transactions left open after close", trans)
}
}

type testConnector struct {
driver *testDriver
dbName string
}

func (tc *testConnector) Driver() driver.Driver { return tc.driver }

func (tc *testConnector) Connect(context.Context) (driver.Conn, error) {
tc.driver.conns.Add(1)
return &testConn{
driver: tc.driver,
dbName: tc.dbName,
}, nil
}

type testConn struct {
driver *testDriver
dbName string
}

var _ driver.ExecerContext = &testConn{}
var _ driver.QueryerContext = &testConn{}

func (t *testConn) Begin() (driver.Tx, error) {
t.driver.trans.Add(1)
return &testTx{
driver: t.driver,
}, nil
}

func (t *testConn) Prepare(query string) (driver.Stmt, error) {
t.driver.stmts.Add(1)
return &testStmt{
driver: t.driver,
query: query,
}, nil
}

func (t *testConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
t.driver.queries.Add(1)
return &testRows{
driver: t.driver,
}, nil
}

func (t *testConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
t.driver.execs.Add(1)
return testResult{}, nil
}

func (t *testConn) Close() error { return nil }

type testStmt struct {
driver *testDriver
query string
}

func (t *testStmt) NumInput() int { return 0 }

func (t *testStmt) Exec(args []driver.Value) (driver.Result, error) {
t.driver.execs.Add(1)
return testResult{}, nil
}

func (t *testStmt) Query(args []driver.Value) (driver.Rows, error) {
t.driver.queries.Add(1)
return &testRows{
driver: t.driver,
}, nil
}

func (t *testStmt) Close() error { return nil }

type testRows struct {
driver *testDriver
}

func (t *testRows) Columns() []string { return nil }
func (t *testRows) Next(dest []driver.Value) error { return io.EOF }
func (t *testRows) Close() error { return nil }

type testResult struct{}

func (t testResult) LastInsertId() (int64, error) { return 0, nil }
func (t testResult) RowsAffected() (int64, error) { return 0, nil }

type testTx struct {
driver *testDriver
}

func (t *testTx) Commit() error { return nil }
func (t *testTx) Rollback() error { return nil }
36 changes: 36 additions & 0 deletions pkg/database/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package database

import (
"context"
"database/sql"
"errors"
)

var errDBClosed = errors.New("sql: database is closed")

type Interface interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error)
Conn(ctx context.Context) (*sql.Conn, error)
Close() error
}

type Transaction interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
Commit() error
Rollback() error
}

type Wrapped[T Transaction] interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (T, error)
Conn(ctx context.Context) (*sql.Conn, error)
Close() error
}
172 changes: 172 additions & 0 deletions pkg/database/prepared.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package database

import (
"context"
"database/sql"
"errors"
"fmt"
"sync"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
)

const otelName = "prepared"

var otelTracer trace.Tracer

func init() {
otelTracer = otel.Tracer(otelName)
}

type preparedDb[T Transaction] struct {
underlying Wrapped[T]
mu sync.RWMutex
cache map[string]*sql.Stmt
}

// NewPrepared creates a new Interface that wraps the given database and
// uses a prepare cache to reduce the number of prepare calls. The cache
// is only used when calling ExecContext and QueryContext methods on the
// main instance or in a transaction.
func NewPrepared[T Transaction](db Wrapped[T]) Interface {
return &preparedDb[T]{
underlying: db,
cache: make(map[string]*sql.Stmt),
}
}

func (db *preparedDb[T]) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
ctx, span := otelTracer.Start(ctx, "DB.ExecContext")
defer func() {
span.RecordError(err)
span.End()
}()

stmt, err := db.prepare(ctx, query)
if err != nil {
return nil, err
}
return stmt.ExecContext(ctx, args...)
}

func (db *preparedDb[T]) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
ctx, span := otelTracer.Start(ctx, "DB.QueryContext")
defer func() {
span.RecordError(err)
span.End()
}()

stmt, err := db.prepare(ctx, query)
if err != nil {
return nil, err
}
return stmt.QueryContext(ctx, args...)
}

func (db *preparedDb[T]) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) {
return db.underlying.PrepareContext(ctx, query)
}

func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.Stmt, err error) {
ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.prepare", otelName))
defer func() {
span.RecordError(err)
span.End()
}()

db.mu.RLock()
stmt = db.cache[query]
db.mu.RUnlock()
if stmt != nil {
return stmt, nil
}

db.mu.Lock()
defer db.mu.Unlock()

if db.underlying == nil {
return nil, errDBClosed
}

// Given that some time has passed since the unlock of the read lock, and the lock of the
// write lock, another goroutine might have already prepared this query, so we should check
// again to avoid preparing the same query twice.
stmt = db.cache[query]
if stmt != nil {
return stmt, nil
}

prepared, err := db.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
db.cache[query] = prepared
return prepared, nil
}

func (db *preparedDb[T]) Conn(ctx context.Context) (*sql.Conn, error) {
return db.underlying.Conn(ctx)
}

func (db *preparedDb[T]) BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error) {
tx, err := db.underlying.BeginTx(ctx, opts)
if err != nil {
return nil, err
}

return &preparedTx[T]{
Transaction: tx,
db: db,
}, nil
}

func (db *preparedDb[T]) Close() error {
db.mu.Lock()
defer db.mu.Unlock()

errs := []error{}
for _, stmt := range db.cache {
if err := stmt.Close(); err != nil {
errs = append(errs, err)
}
}
db.cache = nil

if err := db.underlying.Close(); err != nil {
errs = append(errs, err)
}
db.underlying = nil

return errors.Join(errs...)
}

type preparedTx[T Transaction] struct {
Transaction
db *preparedDb[T]
}

func (tx *preparedTx[T]) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
stmt, err := tx.db.prepare(ctx, query)
if err != nil {
return nil, err
}
return tx.StmtContext(ctx, stmt), nil
}

func (tx *preparedTx[T]) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return stmt.ExecContext(ctx, args...)
}

func (tx *preparedTx[T]) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}

return stmt.QueryContext(ctx, args...)
}
Loading

0 comments on commit 2c4e7b6

Please sign in to comment.