Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users of the driver to disable query formatting; support named parameters in QueryStmt, DMLStmt #43

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ import (
"github.com/goccy/go-zetasqlite/internal"
)

// DisableQueryFormattingKey use to disable query formatting for queries that require raw SQLite access
type DisableQueryFormattingKey = internal.DisableQueryFormattingKey

// WithQueryFormattingDisabled use for queries that require raw SQLite SQL.
// This is useful for queries that do not require additional functionality from go-zetasqlite.
// Utilizing this option often allows the SQLite query planner to generate more efficient plans.
func WithQueryFormattingDisabled(ctx context.Context) context.Context {
return context.WithValue(ctx, internal.DisableQueryFormattingKey{}, true)
}

// WithCurrentTime use to replace the current time with the specified time.
// To replace the time, you need to pass the returned context as an argument to QueryContext.
// `CURRENT_DATE`, `CURRENT_DATETIME`, `CURRENT_TIME`, `CURRENT_TIMESTAMP` functions are targeted.
Expand Down
31 changes: 28 additions & 3 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ CREATE TABLE IF NOT EXISTS Singers (
t.Fatal("found unexpected row; expected no rows")
}
})
t.Run("prepared insert", func(t *testing.T) {
t.Run("prepared insert with named values", func(t *testing.T) {
db, err := sql.Open("zetasqlite", ":memory:")
if err != nil {
t.Fatal(err)
Expand All @@ -224,11 +224,11 @@ CREATE TABLE IF NOT EXISTS Singers (
t.Fatal("expected error when inserting without args; got no error")
}

stmt, err := db.Prepare("INSERT `Items` (`ItemId`) VALUES (?)")
stmt, err := db.Prepare("INSERT `Items` (`ItemId`) VALUES (@itemID)")
if err != nil {
t.Fatal(err)
}
if _, err := stmt.Exec(456); err != nil {
if _, err := stmt.Exec(sql.Named("itemID", 456)); err != nil {
t.Fatal(err)
}

Expand All @@ -248,4 +248,29 @@ CREATE TABLE IF NOT EXISTS Singers (
t.Fatal("expected no rows; expected one row")
}
})

t.Run("prepared select with named values, formatting disabled, uppercased parameter", func(t *testing.T) {
db, err := sql.Open("zetasqlite", ":memory:")
ctx := zetasqlite.WithQueryFormattingDisabled(context.Background())
if err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS Items (ItemId INT64 NOT NULL)`); err != nil {
t.Fatal(err)
}
if _, err := db.Exec("INSERT `Items` (`ItemId`) VALUES (123)"); err != nil {
t.Fatal(err)
}

stmt, err := db.PrepareContext(ctx, "SELECT `ItemID` FROM `Items` WHERE `ItemID` = @itemID AND @bool = TRUE")
if err != nil {
t.Fatal("unexpected error when preparing stmt; got %w", err)
}

var itemID string
err = stmt.QueryRowContext(ctx, sql.Named("itemID", 123), sql.Named("bool", true)).Scan(&itemID)
if err != nil {
t.Fatal("expected one row; got error %w", err)
}
})
}
26 changes: 22 additions & 4 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type Analyzer struct {
opt *zetasql.AnalyzerOptions
}

type DisableQueryFormattingKey struct{}

func NewAnalyzer(catalog *Catalog) (*Analyzer, error) {
opt, err := newAnalyzerOptions()
if err != nil {
Expand Down Expand Up @@ -511,14 +513,30 @@ func (a *Analyzer) newQueryStmtAction(ctx context.Context, query string, args []
Type: newType(col.Column().Type()),
})
}
formattedQuery, err := newNode(node).FormatSQL(ctx)
if err != nil {
return nil, fmt.Errorf("failed to format query %s: %w", query, err)
var formattedQuery string
params := getParamsFromNode(node)
if disabledFormatting, ok := ctx.Value(DisableQueryFormattingKey{}).(bool); ok && disabledFormatting {
formattedQuery = query
// ZetaSQL will always lowercase parameter names, so we must match it in the query
queryBytes := []byte(query)
for _, param := range params {
location := param.ParseLocationRange()
start := location.Start().ByteOffset()
end := location.End().ByteOffset()
// Finds the parameter including its prefix i.e. @itemID
parameter := string(queryBytes[start:end])
formattedQuery = strings.ReplaceAll(formattedQuery, parameter, strings.ToLower(parameter))
}
} else {
var err error
formattedQuery, err = newNode(node).FormatSQL(ctx)
if err != nil {
return nil, fmt.Errorf("failed to format query %s: %w", query, err)
}
}
if formattedQuery == "" {
return nil, fmt.Errorf("failed to format query %s", query)
}
params := getParamsFromNode(node)
queryArgs, err := getArgsFromParams(args, params)
if err != nil {
return nil, err
Expand Down
42 changes: 23 additions & 19 deletions internal/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ func (s *DMLStmt) NumInput() int {
}

func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) {
values := make([]interface{}, 0, len(args))
for _, arg := range args {
values = append(values, arg)
}
newArgs, err := EncodeGoValues(values, s.args)
return s.ExecContext(context.Background(), valuesToNamedValues(args))
}

func (s *DMLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
newArgs, err := getArgsFromParams(args, s.args)
if err != nil {
return nil, err
}
result, err := s.stmt.Exec(newArgs...)
result, err := s.stmt.ExecContext(ctx, newArgs...)
if err != nil {
return nil, fmt.Errorf(
"failed to execute query %s: args %v: %w",
Expand All @@ -172,10 +172,6 @@ func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) {
return result, nil
}

func (s *DMLStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return nil, fmt.Errorf("unimplemented ExecContext for DMLStmt")
}

func (s *DMLStmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, fmt.Errorf("unsupported query for DMLStmt")
}
Expand Down Expand Up @@ -224,16 +220,28 @@ func (s *QueryStmt) ExecContext(ctx context.Context, query string, args []driver
return nil, fmt.Errorf("unsupported exec for QueryStmt")
}

func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) {
values := make([]interface{}, 0, len(args))
func valuesToNamedValues(args []driver.Value) []driver.NamedValue {
values := make([]driver.NamedValue, 0, len(args))
for _, arg := range args {
values = append(values, arg)
if namedValue, ok := arg.(driver.NamedValue); ok {
values = append(values, namedValue)
}
values = append(values, driver.NamedValue{Value: arg})
}
newArgs, err := EncodeGoValues(values, s.args)

return values
}

func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), valuesToNamedValues(args))
}

func (s *QueryStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
newArgs, err := getArgsFromParams(args, s.args)
if err != nil {
return nil, err
}
rows, err := s.stmt.Query(newArgs...)
rows, err := s.stmt.QueryContext(ctx, newArgs...)
if err != nil {
return nil, fmt.Errorf(
"failed to query %s: args: %v: %w",
Expand All @@ -244,7 +252,3 @@ func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) {
}
return &Rows{rows: rows, columns: s.outputColumns}, nil
}

func (s *QueryStmt) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return nil, fmt.Errorf("unimplemented QueryContext for QueryStmt")
}
Loading