Skip to content

Commit

Permalink
Merge pull request #319 from apecloud/expose-conn-stmt
Browse files Browse the repository at this point in the history
Expose conn and stmt [Part 1]
  • Loading branch information
taniabogatsch authored Nov 25, 2024
2 parents 52caf58 + 7c7218f commit 9f4f8de
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 36 deletions.
4 changes: 2 additions & 2 deletions appender.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

// Appender holds the DuckDB appender. It allows efficient bulk loading into a DuckDB database.
type Appender struct {
con *conn
con *Conn
schema string
table string
duckdbAppender C.duckdb_appender
Expand All @@ -31,7 +31,7 @@ type Appender struct {

// NewAppenderFromConn returns a new Appender from a DuckDB driver connection.
func NewAppenderFromConn(driverConn driver.Conn, schema, table string) (*Appender, error) {
con, ok := driverConn.(*conn)
con, ok := driverConn.(*Conn)
if !ok {
return nil, getError(errInvalidCon, nil)
}
Expand Down
6 changes: 3 additions & 3 deletions arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ import (
// Arrow exposes DuckDB Apache Arrow interface.
// https://duckdb.org/docs/api/c/api#arrow-interface
type Arrow struct {
c *conn
c *Conn
}

// NewArrowFromConn returns a new Arrow from a DuckDB driver connection.
func NewArrowFromConn(driverConn driver.Conn) (*Arrow, error) {
dbConn, ok := driverConn.(*conn)
dbConn, ok := driverConn.(*Conn)
if !ok {
return nil, fmt.Errorf("not a duckdb driver connection")
}
Expand Down Expand Up @@ -216,7 +216,7 @@ func (a *Arrow) queryArrowArray(res *C.duckdb_arrow, sc *arrow.Schema) (arrow.Re
return rec, nil
}

func (a *Arrow) execute(s *stmt, args []driver.NamedValue) (*C.duckdb_arrow, error) {
func (a *Arrow) execute(s *Stmt, args []driver.NamedValue) (*C.duckdb_arrow, error) {
if s.closed {
return nil, errClosedCon
}
Expand Down
41 changes: 28 additions & 13 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,26 @@ import (
"unsafe"
)

type conn struct {
// Conn holds a connection to a DuckDB database.
// It implements the driver.Conn interface.
type Conn struct {
duckdbCon C.duckdb_connection
closed bool
tx bool
}

func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
// CheckNamedValue implements the driver.NamedValueChecker interface.
func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
switch nv.Value.(type) {
case *big.Int, Interval:
return nil
}
return driver.ErrSkip
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
// ExecContext executes a query that doesn't return rows, such as an INSERT or UPDATE.
// It implements the driver.ExecerContext interface.
func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
prepared, err := c.prepareStmts(ctx, query)
if err != nil {
return nil, err
Expand All @@ -48,7 +53,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return res, nil
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// QueryContext executes a query that may return rows, such as a SELECT.
// It implements the driver.QueryerContext interface.
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
prepared, err := c.prepareStmts(ctx, query)
if err != nil {
return nil, err
Expand All @@ -68,11 +75,15 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return r, nil
}

func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
// PrepareContext returns a prepared statement, bound to this connection.
// It implements the driver.ConnPrepareContext interface.
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.prepareStmts(ctx, query)
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
// Prepare returns a prepared statement, bound to this connection.
// It implements the driver.Conn interface.
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
if c.closed {
return nil, errors.Join(errPrepare, errClosedCon)
}
Expand All @@ -90,11 +101,13 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) {
}

// Begin is deprecated: Use BeginTx instead.
func (c *conn) Begin() (driver.Tx, error) {
func (c *Conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}

func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
// BeginTx starts and returns a new transaction.
// It implements the driver.ConnBeginTx interface.
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if c.tx {
return nil, errors.Join(errBeginTx, errMultipleTx)
}
Expand All @@ -117,7 +130,9 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
return &tx{c}, nil
}

func (c *conn) Close() error {
// Close closes the connection to the database.
// It implements the driver.Conn interface.
func (c *Conn) Close() error {
if c.closed {
return errClosedCon
}
Expand All @@ -126,7 +141,7 @@ func (c *conn) Close() error {
return nil
}

func (c *conn) extractStmts(query string) (C.duckdb_extracted_statements, C.idx_t, error) {
func (c *Conn) extractStmts(query string) (C.duckdb_extracted_statements, C.idx_t, error) {
cQuery := C.CString(query)
defer C.duckdb_free(unsafe.Pointer(cQuery))

Expand All @@ -145,7 +160,7 @@ func (c *conn) extractStmts(query string) (C.duckdb_extracted_statements, C.idx_
return stmts, count, nil
}

func (c *conn) prepareExtractedStmt(stmts C.duckdb_extracted_statements, i C.idx_t) (*stmt, error) {
func (c *Conn) prepareExtractedStmt(stmts C.duckdb_extracted_statements, i C.idx_t) (*Stmt, error) {
var s C.duckdb_prepared_statement
state := C.duckdb_prepare_extracted_statement(c.duckdbCon, stmts, i, &s)

Expand All @@ -155,10 +170,10 @@ func (c *conn) prepareExtractedStmt(stmts C.duckdb_extracted_statements, i C.idx
return nil, err
}

return &stmt{c: c, stmt: &s}, nil
return &Stmt{c: c, stmt: &s}, nil
}

func (c *conn) prepareStmts(ctx context.Context, query string) (*stmt, error) {
func (c *Conn) prepareStmts(ctx context.Context, query string) (*Stmt, error) {
if c.closed {
return nil, errClosedCon
}
Expand Down
2 changes: 1 addition & 1 deletion duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (c *Connector) Connect(context.Context) (driver.Conn, error) {
return nil, getError(errConnect, nil)
}

con := &conn{duckdbCon: duckdbCon}
con := &Conn{duckdbCon: duckdbCon}

if c.connInitFn != nil {
if err := c.connInitFn(con); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion profiling.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type ProfilingInfo struct {
func GetProfilingInfo(c *sql.Conn) (ProfilingInfo, error) {
info := ProfilingInfo{}
err := c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
con := driverConn.(*Conn)
duckdbInfo := C.duckdb_get_profiling_info(con.duckdbCon)
if duckdbInfo == nil {
return getError(errProfilingInfoEmpty, nil)
Expand Down
4 changes: 2 additions & 2 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// rows is a helper struct for scanning a duckdb result.
type rows struct {
// stmt is a pointer to the stmt of which we are scanning the result.
stmt *stmt
stmt *Stmt
// res is the result of stmt.
res C.duckdb_result
// chunk holds the currently active data chunk.
Expand All @@ -32,7 +32,7 @@ type rows struct {
rowCount int
}

func newRowsWithStmt(res C.duckdb_result, stmt *stmt) *rows {
func newRowsWithStmt(res C.duckdb_result, stmt *Stmt) *rows {
columnCount := C.duckdb_column_count(&res)
r := rows{
res: res,
Expand Down
4 changes: 2 additions & 2 deletions scalarUDF.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {

// Register the function on the underlying driver connection exposed by c.Raw.
err = c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
con := driverConn.(*Conn)
state := C.duckdb_register_scalar_function(con.duckdbCon, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
Expand Down Expand Up @@ -111,7 +111,7 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err

// Register the function set on the underlying driver connection exposed by c.Raw.
err := c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
con := driverConn.(*Conn)
state := C.duckdb_register_scalar_function_set(con.duckdbCon, set)
C.duckdb_destroy_scalar_function_set(&set)
if state == C.DuckDBError {
Expand Down
29 changes: 19 additions & 10 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ import (
"unsafe"
)

type stmt struct {
c *conn
// Stmt implements the driver.Stmt interface.
type Stmt struct {
c *Conn
stmt *C.duckdb_prepared_statement
closeOnRowsClose bool
closed bool
rows bool
}

func (s *stmt) Close() error {
// Close closes the statement.
// It implements the driver.Stmt interface.
func (s *Stmt) Close() error {
if s.rows {
panic("database/sql/driver: misuse of duckdb driver: Close with active Rows")
}
Expand All @@ -36,15 +39,17 @@ func (s *stmt) Close() error {
return nil
}

func (s *stmt) NumInput() int {
// NumInput returns the number of placeholder parameters.
// It implements the driver.Stmt interface.
func (s *Stmt) NumInput() int {
if s.closed {
panic("database/sql/driver: misuse of duckdb driver: NumInput after Close")
}
paramCount := C.duckdb_nparams(*s.stmt)
return int(paramCount)
}

func (s *stmt) bind(args []driver.NamedValue) error {
func (s *Stmt) bind(args []driver.NamedValue) error {
if s.NumInput() > len(args) {
return fmt.Errorf("incorrect argument count for command: have %d want %d", len(args), s.NumInput())
}
Expand Down Expand Up @@ -175,11 +180,13 @@ func (s *stmt) bind(args []driver.NamedValue) error {
}

// Deprecated: Use ExecContext instead.
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.ExecContext(context.Background(), argsToNamedArgs(args))
}

func (s *stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driver.Result, error) {
// ExecContext executes a query that doesn't return rows, such as an INSERT or UPDATE.
// It implements the driver.StmtExecContext interface.
func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driver.Result, error) {
res, err := s.execute(ctx, nargs)
if err != nil {
return nil, err
Expand All @@ -191,11 +198,13 @@ func (s *stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv
}

// Deprecated: Use QueryContext instead.
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), argsToNamedArgs(args))
}

func (s *stmt) QueryContext(ctx context.Context, nargs []driver.NamedValue) (driver.Rows, error) {
// QueryContext executes a query that may return rows, such as a SELECT.
// It implements the driver.StmtQueryContext interface.
func (s *Stmt) QueryContext(ctx context.Context, nargs []driver.NamedValue) (driver.Rows, error) {
res, err := s.execute(ctx, nargs)
if err != nil {
return nil, err
Expand All @@ -206,7 +215,7 @@ func (s *stmt) QueryContext(ctx context.Context, nargs []driver.NamedValue) (dri

// This method executes the query in steps and checks if context is cancelled before executing each step.
// It uses Pending Result Interface C APIs to achieve this. Reference - https://duckdb.org/docs/api/c/api#pending-result-interface
func (s *stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb_result, error) {
func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb_result, error) {
if s.closed {
panic("database/sql/driver: misuse of duckdb driver: ExecContext or QueryContext after Close")
}
Expand Down
2 changes: 1 addition & 1 deletion tableUDF.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ func RegisterTableUDF[TFT TableFunction](c *sql.Conn, name string, f TFT) error

// Register the function on the underlying driver connection exposed by c.Raw.
err := c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
con := driverConn.(*Conn)
state := C.duckdb_register_table_function(con.duckdbCon, function)
C.duckdb_destroy_table_function(&function)
if state == C.DuckDBError {
Expand Down
2 changes: 1 addition & 1 deletion transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package duckdb
import "context"

type tx struct {
c *conn
c *Conn
}

func (t *tx) Commit() error {
Expand Down

0 comments on commit 9f4f8de

Please sign in to comment.