-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
512 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
package database_test | ||
|
||
import ( | ||
"context" | ||
"database/sql/driver" | ||
"io" | ||
"sync/atomic" | ||
"testing" | ||
) | ||
|
||
type testDriver struct { | ||
t *testing.T | ||
conns atomic.Int32 | ||
queries atomic.Int32 | ||
execs atomic.Int32 | ||
stmts atomic.Int32 | ||
trans atomic.Int32 | ||
} | ||
|
||
var _ driver.DriverContext = &testDriver{} | ||
|
||
// Open implements driver.Driver. | ||
func (td *testDriver) Open(name string) (driver.Conn, error) { | ||
td.conns.Add(1) | ||
return &testConn{ | ||
driver: td, | ||
dbName: name, | ||
}, nil | ||
} | ||
|
||
func (td *testDriver) OpenConnector(name string) (driver.Connector, error) { | ||
return &testConnector{ | ||
driver: td, | ||
dbName: name, | ||
}, nil | ||
} | ||
|
||
func (td *testDriver) Close() { | ||
if rows := td.queries.Load(); rows != 0 { | ||
td.t.Errorf("%d rows left open after close", rows) | ||
} | ||
if stmts := td.stmts.Load(); stmts != 0 { | ||
td.t.Errorf("%d statements left open after close", stmts) | ||
} | ||
if conns := td.conns.Load(); conns != 0 { | ||
td.t.Errorf("%d connections left open after close", conns) | ||
} | ||
if trans := td.trans.Load(); trans != 0 { | ||
td.t.Errorf("%d transactions left open after close", trans) | ||
} | ||
} | ||
|
||
type testConnector struct { | ||
driver *testDriver | ||
dbName string | ||
} | ||
|
||
func (tc *testConnector) Driver() driver.Driver { return tc.driver } | ||
|
||
func (tc *testConnector) Connect(context.Context) (driver.Conn, error) { | ||
tc.driver.conns.Add(1) | ||
return &testConn{ | ||
driver: tc.driver, | ||
dbName: tc.dbName, | ||
}, nil | ||
} | ||
|
||
type testConn struct { | ||
driver *testDriver | ||
dbName string | ||
} | ||
|
||
var _ driver.ExecerContext = &testConn{} | ||
var _ driver.QueryerContext = &testConn{} | ||
|
||
func (t *testConn) Begin() (driver.Tx, error) { | ||
t.driver.trans.Add(1) | ||
return &testTx{ | ||
driver: t.driver, | ||
}, nil | ||
} | ||
|
||
func (t *testConn) Prepare(query string) (driver.Stmt, error) { | ||
t.driver.stmts.Add(1) | ||
return &testStmt{ | ||
driver: t.driver, | ||
query: query, | ||
}, nil | ||
} | ||
|
||
func (t *testConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | ||
t.driver.queries.Add(1) | ||
return &testRows{ | ||
driver: t.driver, | ||
}, nil | ||
} | ||
|
||
func (t *testConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||
t.driver.execs.Add(1) | ||
return testResult{}, nil | ||
} | ||
|
||
func (t *testConn) Close() error { return nil } | ||
|
||
type testStmt struct { | ||
driver *testDriver | ||
query string | ||
} | ||
|
||
func (t *testStmt) NumInput() int { return 0 } | ||
|
||
func (t *testStmt) Exec(args []driver.Value) (driver.Result, error) { | ||
t.driver.execs.Add(1) | ||
return testResult{}, nil | ||
} | ||
|
||
func (t *testStmt) Query(args []driver.Value) (driver.Rows, error) { | ||
t.driver.queries.Add(1) | ||
return &testRows{ | ||
driver: t.driver, | ||
}, nil | ||
} | ||
|
||
func (t *testStmt) Close() error { return nil } | ||
|
||
type testRows struct { | ||
driver *testDriver | ||
} | ||
|
||
func (t *testRows) Columns() []string { return nil } | ||
func (t *testRows) Next(dest []driver.Value) error { return io.EOF } | ||
func (t *testRows) Close() error { return nil } | ||
|
||
type testResult struct{} | ||
|
||
func (t testResult) LastInsertId() (int64, error) { return 0, nil } | ||
func (t testResult) RowsAffected() (int64, error) { return 0, nil } | ||
|
||
type testTx struct { | ||
driver *testDriver | ||
} | ||
|
||
func (t *testTx) Commit() error { return nil } | ||
func (t *testTx) Rollback() error { return nil } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package database | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"errors" | ||
) | ||
|
||
var errDBClosed = errors.New("sql: database is closed") | ||
|
||
type Interface interface { | ||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) | ||
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) | ||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||
BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error) | ||
Conn(ctx context.Context) (*sql.Conn, error) | ||
Close() error | ||
} | ||
|
||
type Transaction interface { | ||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) | ||
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) | ||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt | ||
Commit() error | ||
Rollback() error | ||
} | ||
|
||
type Wrapped[T Transaction] interface { | ||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) | ||
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) | ||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||
BeginTx(ctx context.Context, opts *sql.TxOptions) (T, error) | ||
Conn(ctx context.Context) (*sql.Conn, error) | ||
Close() error | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
package database | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"errors" | ||
"fmt" | ||
"sync" | ||
|
||
"go.opentelemetry.io/otel" | ||
"go.opentelemetry.io/otel/trace" | ||
) | ||
|
||
const otelName = "prepared" | ||
|
||
var otelTracer trace.Tracer | ||
|
||
func init() { | ||
otelTracer = otel.Tracer(otelName) | ||
} | ||
|
||
type preparedDb[T Transaction] struct { | ||
underlying Wrapped[T] | ||
mu sync.RWMutex | ||
cache map[string]*sql.Stmt | ||
} | ||
|
||
// NewPrepared creates a new Interface that wraps the given database and | ||
// uses a prepare cache to reduce the number of prepare calls. The cache | ||
// is only used when calling ExecContext and QueryContext methods on the | ||
// main instance or in a transaction. | ||
func NewPrepared[T Transaction](db Wrapped[T]) Interface { | ||
return &preparedDb[T]{ | ||
underlying: db, | ||
cache: make(map[string]*sql.Stmt), | ||
} | ||
} | ||
|
||
func (db *preparedDb[T]) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { | ||
ctx, span := otelTracer.Start(ctx, "DB.ExecContext") | ||
defer func() { | ||
span.RecordError(err) | ||
span.End() | ||
}() | ||
|
||
stmt, err := db.prepare(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return stmt.ExecContext(ctx, args...) | ||
} | ||
|
||
func (db *preparedDb[T]) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { | ||
ctx, span := otelTracer.Start(ctx, "DB.QueryContext") | ||
defer func() { | ||
span.RecordError(err) | ||
span.End() | ||
}() | ||
|
||
stmt, err := db.prepare(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return stmt.QueryContext(ctx, args...) | ||
} | ||
|
||
func (db *preparedDb[T]) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) { | ||
return db.underlying.PrepareContext(ctx, query) | ||
} | ||
|
||
func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.Stmt, err error) { | ||
ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.prepare", otelName)) | ||
defer func() { | ||
span.RecordError(err) | ||
span.End() | ||
}() | ||
|
||
db.mu.RLock() | ||
stmt = db.cache[query] | ||
db.mu.RUnlock() | ||
if stmt != nil { | ||
return stmt, nil | ||
} | ||
|
||
db.mu.Lock() | ||
defer db.mu.Unlock() | ||
|
||
if db.underlying == nil { | ||
return nil, errDBClosed | ||
} | ||
|
||
// Given that some time has passed since the unlock of the read lock, and the lock of the | ||
// write lock, another goroutine might have already prepared this query, so we should check | ||
// again to avoid preparing the same query twice. | ||
stmt = db.cache[query] | ||
if stmt != nil { | ||
return stmt, nil | ||
} | ||
|
||
prepared, err := db.PrepareContext(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
db.cache[query] = prepared | ||
return prepared, nil | ||
} | ||
|
||
func (db *preparedDb[T]) Conn(ctx context.Context) (*sql.Conn, error) { | ||
return db.underlying.Conn(ctx) | ||
} | ||
|
||
func (db *preparedDb[T]) BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error) { | ||
tx, err := db.underlying.BeginTx(ctx, opts) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return &preparedTx[T]{ | ||
Transaction: tx, | ||
db: db, | ||
}, nil | ||
} | ||
|
||
func (db *preparedDb[T]) Close() error { | ||
db.mu.Lock() | ||
defer db.mu.Unlock() | ||
|
||
errs := []error{} | ||
for _, stmt := range db.cache { | ||
if err := stmt.Close(); err != nil { | ||
errs = append(errs, err) | ||
} | ||
} | ||
db.cache = nil | ||
|
||
if err := db.underlying.Close(); err != nil { | ||
errs = append(errs, err) | ||
} | ||
db.underlying = nil | ||
|
||
return errors.Join(errs...) | ||
} | ||
|
||
type preparedTx[T Transaction] struct { | ||
Transaction | ||
db *preparedDb[T] | ||
} | ||
|
||
func (tx *preparedTx[T]) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||
stmt, err := tx.db.prepare(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return tx.StmtContext(ctx, stmt), nil | ||
} | ||
|
||
func (tx *preparedTx[T]) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { | ||
stmt, err := tx.PrepareContext(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return stmt.ExecContext(ctx, args...) | ||
} | ||
|
||
func (tx *preparedTx[T]) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { | ||
stmt, err := tx.PrepareContext(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return stmt.QueryContext(ctx, args...) | ||
} |
Oops, something went wrong.