Skip to content

Commit

Permalink
Add mechanism to register query parameters with the analyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
ohaibbq committed May 19, 2024
1 parent b6dd6a5 commit 76d5330
Show file tree
Hide file tree
Showing 3 changed files with 336 additions and 2 deletions.
6 changes: 6 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"google.golang.org/api/bigquery/v2"
"sync"

"github.com/mattn/go-sqlite3"
Expand Down Expand Up @@ -133,6 +134,11 @@ func (c *ZetaSQLiteConn) AddNamePath(path string) error {
return c.analyzer.AddNamePath(path)
}

func (c *ZetaSQLiteConn) SetQueryParameters(parameters []*bigquery.QueryParameter) {
c.analyzer.SetQueryParameters(parameters)

}

Check failure on line 140 in driver.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary trailing newline (whitespace)

func (s *ZetaSQLiteConn) CheckNamedValue(value *driver.NamedValue) error {
return nil
}
Expand Down
220 changes: 220 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package zetasqlite_test
import (
"context"
"database/sql"
"fmt"
"google.golang.org/api/bigquery/v2"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -62,6 +64,224 @@ CREATE VIEW IF NOT EXISTS SingerNames AS SELECT FirstName || ' ' || LastName AS
}
}

func configureParameters(conn *sql.Conn, parameters []*bigquery.QueryParameter) error {
if err := conn.Raw(func(c interface{}) error {
zetasqliteConn, ok := c.(*zetasqlite.ZetaSQLiteConn)
if !ok {
return fmt.Errorf("failed to get ZetaSQLiteConn from %T", c)
}
zetasqliteConn.SetQueryParameters(parameters)
return nil
}); err != nil {
return fmt.Errorf("failed to setup query parameters: %s", err)
}
return nil
}

func TestNamedParameters(t *testing.T) {
ctx := context.Background()
db, err := sql.Open("zetasqlite", ":memory:")
if err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS Singers (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX)
)`); err != nil {
t.Fatal(err)
}
conn, err := db.Conn(ctx)
if _, err := conn.ExecContext(ctx, `INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'John', 'Titor')`); err != nil {
t.Fatal(err)
}
if err != nil {
t.Fatal(err)
}
t.Run("test multiple statements named params", func(t *testing.T) {
err = configureParameters(conn, []*bigquery.QueryParameter{
{
Name: "id",
ParameterType: &bigquery.QueryParameterType{
Type: "INT64",
},
ParameterValue: &bigquery.QueryParameterValue{
Value: "1",
},
},
{
Name: "name",
ParameterType: &bigquery.QueryParameterType{
Type: "STRING",
},
ParameterValue: &bigquery.QueryParameterValue{
Value: "John",
},
},
})
row := conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE SingerId = @id OR (@name is null OR FirstName = @name)", 1, "John")
if row.Err() != nil {
t.Fatal(row.Err())
}
var (
singerID int64
firstName string
lastName string
)
if err := row.Scan(&singerID, &firstName, &lastName); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" {
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}
})

t.Run("test array type", func(t *testing.T) {
err = configureParameters(conn, []*bigquery.QueryParameter{
{
Name: "names",
ParameterType: &bigquery.QueryParameterType{
Type: "ARRAY",
ArrayType: &bigquery.QueryParameterType{
Type: "STRING",
},
},
ParameterValue: &bigquery.QueryParameterValue{
ArrayValues: []*bigquery.QueryParameterValue{
{Value: "John"},
},
},
},
})
if err != nil {
t.Fatal(err)
}
row := conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE FirstName IN UNNEST(@names)", []string{
"John",
})
if row.Err() != nil {
t.Fatal(row.Err())
}
var (
singerID int64
firstName string
lastName string
)
if err := row.Scan(&singerID, &firstName, &lastName); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" {
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}
})

t.Run("test struct type", func(t *testing.T) {
err = configureParameters(conn, []*bigquery.QueryParameter{
{
Name: "names",
ParameterType: &bigquery.QueryParameterType{
Type: "STRUCT",
StructTypes: []*bigquery.QueryParameterTypeStructTypes{
{Name: "first", Type: &bigquery.QueryParameterType{Type: "STRING"}},
},
},
ParameterValue: &bigquery.QueryParameterValue{
StructValues: map[string]bigquery.QueryParameterValue{
"first": {Value: "John"},
},
},
},
})
if err != nil {
t.Fatal(err)
}
row := conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE FirstName = @names.first", map[string]string{
"first": "John",
})
if row.Err() != nil {
t.Fatal(row.Err())
}
var (
singerID int64
firstName string
lastName string
)
if err := row.Scan(&singerID, &firstName, &lastName); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" {
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}
})

t.Run("test parameter pollution type", func(t *testing.T) {
param := "test_param"
// re-using the same parameter name should with different types works across queries
err = configureParameters(conn, []*bigquery.QueryParameter{
{
Name: param,
ParameterType: &bigquery.QueryParameterType{
Type: "STRUCT",
StructTypes: []*bigquery.QueryParameterTypeStructTypes{
{Name: "first", Type: &bigquery.QueryParameterType{Type: "STRING"}},
},
},
ParameterValue: &bigquery.QueryParameterValue{
StructValues: map[string]bigquery.QueryParameterValue{
"first": {Value: "John"},
},
},
},
})
if err != nil {
t.Fatal(err)
}
row := conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE FirstName = @test_param.first", map[string]string{
"first": "John",
})
if row.Err() != nil {
t.Fatal(row.Err())
}
var (
singerID int64
firstName string
lastName string
)
if err := row.Scan(&singerID, &firstName, &lastName); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" {
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}
err = configureParameters(conn, []*bigquery.QueryParameter{
{
Name: param,
ParameterType: &bigquery.QueryParameterType{
Type: "STRING",
},
ParameterValue: &bigquery.QueryParameterValue{
Value: "John",
},
},
})
if err != nil {
t.Fatal(err)
}
row = conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE FirstName = @test_param", "John")
if row.Err() != nil {
t.Fatal(row.Err())
}
if err := row.Scan(&singerID, &firstName, &lastName); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" {
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}
})
}

func TestRegisterCustomDriver(t *testing.T) {
sql.Register("zetasqlite-custom", &zetasqlite.ZetaSQLiteDriver{
ConnectHook: func(conn *zetasqlite.ZetaSQLiteConn) error {
Expand Down
Loading

0 comments on commit 76d5330

Please sign in to comment.