Skip to content

Commit

Permalink
Expose Bind & Exec|Query
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 committed Nov 13, 2024
1 parent 9868285 commit e73575f
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 18 deletions.
2 changes: 1 addition & 1 deletion arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (a *Arrow) execute(s *Stmt, args []driver.NamedValue) (*C.duckdb_arrow, err
return nil, errClosedCon
}

if err := s.bind(args); err != nil {
if err := s.Bind(args); err != nil {
return nil, err
}

Expand Down
3 changes: 3 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ var (
errPrepare = errors.New("could not prepare query")
errMissingPrepareContext = errors.New("missing context for multi-statement query: try using PrepareContext")
errEmptyQuery = errors.New("empty query")
errCouldNotBind = errors.New("could not bind parameter")
errActiveRows = errors.New("ExecContext or QueryContext with active Rows")
errNotBound = errors.New("parameters have not been bound")
errBeginTx = errors.New("could not begin transaction")
errMultipleTx = errors.New("multiple transactions")
errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported")
Expand Down
76 changes: 63 additions & 13 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type Stmt struct {
c *Conn
stmt *C.duckdb_prepared_statement
closeOnRowsClose bool
bound bool
closed bool
rows bool
}
Expand Down Expand Up @@ -112,6 +113,15 @@ func (s *Stmt) StatementType() StmtType {
return StmtType(C.duckdb_prepared_statement_type(*s.stmt))
}

// Bind binds the arguments to the query.
// WARNING: This is a low-level API and should be used with caution.
func (s *Stmt) Bind(args []driver.NamedValue) error {
if s.closed {
return errors.Join(errCouldNotBind, errClosedCon)
}
return s.bind(args)
}

func (s *Stmt) bind(args []driver.NamedValue) error {
if s.NumInput() > len(args) {
return fmt.Errorf("incorrect argument count for command: have %d want %d", len(args), s.NumInput())
Expand Down Expand Up @@ -239,6 +249,7 @@ func (s *Stmt) bind(args []driver.NamedValue) error {
}
}

s.bound = true
return nil
}

Expand All @@ -250,7 +261,36 @@ func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
// ExecContext executes a query that doesn't return rows, such as an INSERT or UPDATE.
// It implements the driver.StmtExecContext interface.
func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driver.Result, error) {
res, err := s.execute(ctx, nargs)
if s.closed {
return nil, errClosedCon
}
if s.rows {
return nil, errActiveRows
}
if err := s.bind(nargs); err != nil {
return nil, err
}
return s.execBound(ctx)
}

// ExecBound executes a bound query that doesn't return rows, such as an INSERT or UPDATE.
// It can only be called after Bind has been called.
// WARNING: This is a low-level API and should be used with caution.
func (s *Stmt) ExecBound(ctx context.Context) (driver.Result, error) {
if s.closed {
return nil, errClosedCon
}
if s.rows {
return nil, errActiveRows
}
if !s.bound {
return nil, errNotBound
}
return s.execBound(ctx)
}

func (s *Stmt) execBound(ctx context.Context) (driver.Result, error) {
res, err := s.execute(ctx)
if err != nil {
return nil, err
}
Expand All @@ -268,28 +308,40 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
// QueryContext executes a query that may return rows, such as a SELECT.
// It implements the driver.StmtQueryContext interface.
func (s *Stmt) QueryContext(ctx context.Context, nargs []driver.NamedValue) (driver.Rows, error) {
res, err := s.execute(ctx, nargs)
if err != nil {
if err := s.Bind(nargs); err != nil {
return nil, err
}
s.rows = true
return newRowsWithStmt(*res, s), nil
return s.QueryBound(ctx)
}

// This method executes the query in steps and checks if context is cancelled before executing each step.
// It uses Pending Result Interface C APIs to achieve this. Reference - https://duckdb.org/docs/api/c/api#pending-result-interface
func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb_result, error) {
// QueryBound executes a bound statement that may return rows, such as a SELECT.
// It can only be called after Bind has been called.
// WARNING: This is a low-level API and should be used with caution.
func (s *Stmt) QueryBound(ctx context.Context) (driver.Rows, error) {
if s.closed {
panic("database/sql/driver: misuse of duckdb driver: ExecContext or QueryContext after Close")
return nil, errClosedCon
}
if s.rows {
panic("database/sql/driver: misuse of duckdb driver: ExecContext or QueryContext with active Rows")
return nil, errActiveRows
}
if !s.bound {
return nil, errNotBound
}
return s.queryBound(ctx)
}

if err := s.bind(args); err != nil {
func (s *Stmt) queryBound(ctx context.Context) (driver.Rows, error) {
res, err := s.execute(ctx)
if err != nil {
return nil, err
}
s.rows = true
return newRowsWithStmt(*res, s), nil
}

// This method executes the query in steps and checks if context is cancelled before executing each step.
// It uses Pending Result Interface C APIs to achieve this. Reference - https://duckdb.org/docs/api/c/api#pending-result-interface
func (s *Stmt) execute(ctx context.Context) (*C.duckdb_result, error) {
var pendingRes C.duckdb_pending_result
if state := C.duckdb_pending_prepared(*s.stmt, &pendingRes); state == C.DuckDBError {
dbErr := getDuckDBError(C.GoString(C.duckdb_pending_error(pendingRes)))
Expand Down Expand Up @@ -341,5 +393,3 @@ func argsToNamedArgs(values []driver.Value) []driver.NamedValue {
}
return args
}

var errCouldNotBind = errors.New("could not bind parameter")
46 changes: 42 additions & 4 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package duckdb
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"testing"

Expand Down Expand Up @@ -38,6 +39,19 @@ func TestPrepareQuery(t *testing.T) {
stmt := s.(*Stmt)
require.Equal(t, DUCKDB_STATEMENT_TYPE_SELECT, stmt.StatementType())
require.Equal(t, TYPE_INTEGER, stmt.ParamType(1))

rows, err := stmt.QueryBound(context.Background())
require.Nil(t, rows)
require.ErrorIs(t, err, errNotBound)

err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}})
require.NoError(t, err)

rows, err = stmt.QueryBound(context.Background())
require.NoError(t, err)
require.NotNil(t, rows)
require.NoError(t, rows.Close())

require.NoError(t, stmt.Close())
return nil
})
Expand Down Expand Up @@ -81,14 +95,26 @@ func TestPrepareQueryPositional(t *testing.T) {
// Access the raw connection & statement.
err = c.Raw(func(driverConn interface{}) error {
conn := driverConn.(*Conn)
s, err := conn.PrepareContext(context.Background(), `SELECT * FROM foo WHERE bar = $2 AND baz = $1`)
s, err := conn.PrepareContext(context.Background(), `UPDATE foo SET bar = $2 WHERE baz = $1`)
require.NoError(t, err)
stmt := s.(*Stmt)
require.Equal(t, DUCKDB_STATEMENT_TYPE_SELECT, stmt.StatementType())
require.Equal(t, DUCKDB_STATEMENT_TYPE_UPDATE, stmt.StatementType())
require.Equal(t, "1", stmt.ParamName(1))
require.Equal(t, TYPE_INTEGER, stmt.ParamType(1))
require.Equal(t, "2", stmt.ParamName(2))
require.Equal(t, TYPE_VARCHAR, stmt.ParamType(2))

result, err := stmt.ExecBound(context.Background())
require.Nil(t, result)
require.ErrorIs(t, err, errNotBound)

err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}, {Ordinal: 2, Value: "hello"}})
require.NoError(t, err)

result, err = stmt.ExecBound(context.Background())
require.NoError(t, err)
require.NotNil(t, result)

require.NoError(t, stmt.Close())
return nil
})
Expand Down Expand Up @@ -129,14 +155,26 @@ func TestPrepareQueryNamed(t *testing.T) {
// Access the raw connection & statement.
err = c.Raw(func(driverConn interface{}) error {
conn := driverConn.(*Conn)
s, err := conn.PrepareContext(context.Background(), `SELECT * FROM foo WHERE bar = $bar AND baz = $baz`)
s, err := conn.PrepareContext(context.Background(), `INSERT INTO foo VALUES ($bar, $baz)`)
require.NoError(t, err)
stmt := s.(*Stmt)
require.Equal(t, DUCKDB_STATEMENT_TYPE_SELECT, stmt.StatementType())
require.Equal(t, DUCKDB_STATEMENT_TYPE_INSERT, stmt.StatementType())
require.Equal(t, "bar", stmt.ParamName(1))
require.Equal(t, TYPE_INVALID, stmt.ParamType(1)) // Not sure why this is invalid.
require.Equal(t, "baz", stmt.ParamName(2))
require.Equal(t, TYPE_INVALID, stmt.ParamType(2))

result, err := stmt.ExecBound(context.Background())
require.Nil(t, result)
require.ErrorIs(t, err, errNotBound)

err = stmt.Bind([]driver.NamedValue{{Name: "bar", Value: "hello"}, {Name: "baz", Value: 0}})
require.NoError(t, err)

result, err = stmt.ExecBound(context.Background())
require.NoError(t, err)
require.NotNil(t, result)

require.NoError(t, stmt.Close())
return nil
})
Expand Down

0 comments on commit e73575f

Please sign in to comment.