From dcaba1a8088335f6e52fa01414c0fa849382e66b Mon Sep 17 00:00:00 2001 From: Rangel Reale Date: Wed, 12 Jun 2024 10:19:31 -0300 Subject: [PATCH] multi-db callback --- db/sql/query.go | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/db/sql/query.go b/db/sql/query.go index 5baeb05..50d8b69 100644 --- a/db/sql/query.go +++ b/db/sql/query.go @@ -11,25 +11,45 @@ import ( "strings" ) +type SQLDB interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +// SQLQueryInterfaceDBCallback is the callback to return a *sql.DB for each query. +type SQLQueryInterfaceDBCallback func(ctx context.Context, databaseName, tableName string) (SQLDB, error) + // sqlQueryInterface is a QueryInterface wrapper for *sql.DB. type sqlQueryInterface struct { - DB *sql.DB + callback SQLQueryInterfaceDBCallback } var _ QueryInterface = (*sqlQueryInterface)(nil) // NewSQLQueryInterface wraps a *sql.DB on the QueryInterface interface. -func NewSQLQueryInterface(db *sql.DB) QueryInterface { - return &sqlQueryInterface{db} +func NewSQLQueryInterface(db SQLDB) QueryInterface { + return &sqlQueryInterface{callback: func(ctx context.Context, databaseName, tableName string) (SQLDB, error) { + return db, nil + }} +} + +// NewSQLQueryInterfaceFunc sets a callback to return a *sql.DB for each query. +func NewSQLQueryInterfaceFunc(callback SQLQueryInterfaceDBCallback) QueryInterface { + return &sqlQueryInterface{callback: callback} } func (q sqlQueryInterface) Query(ctx context.Context, databaseName, tableName string, query string, returnFieldNames []string, args ...any) (map[string]any, error) { + db, err := q.callback(ctx, databaseName, tableName) + if err != nil { + return nil, err + } + if len(returnFieldNames) == 0 { - _, err := q.DB.Exec(query, args...) + _, err := db.ExecContext(ctx, query, args...) return nil, err } - rows, err := q.DB.QueryContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err }