Skip to content

Commit

Permalink
Switch to prepared statements
Browse files Browse the repository at this point in the history
The gocql driver prepares statements automatically. However, to benefit
from prepared statements, we need to ensure that query and values are
separated from each other. Refactor statement generation to return a
query and a value-generating function.
  • Loading branch information
penberg committed Jul 30, 2018
1 parent fae8ad9 commit eb293a7
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 23 deletions.
21 changes: 13 additions & 8 deletions cmd/gemini/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,29 @@ func run(cmd *cobra.Command, args []string) {
return
}
}

nrPassedTests := 0

for i := 0; i < maxTests; i++ {
mutate := schema.GenMutateOp()
mutateStmt := schema.GenMutateStmt()
mutateQuery := mutateStmt.Query
mutateValues := mutateStmt.Values()
if verbose {
fmt.Printf("%s\n", mutate)
fmt.Printf("%s (values=%v)\n", mutateQuery, mutateValues)
}
if err := session.Mutate(mutate); err != nil {
fmt.Printf("Failed! Mutation '%s' caused an error: '%v'\n", mutate, err)
if err := session.Mutate(mutateQuery, mutateValues...); err != nil {
fmt.Printf("Failed! Mutation '%s' (values=%v) caused an error: '%v'\n", mutateQuery, mutateValues, err)
return
}

check := schema.GenCheckOp()
checkStmt := schema.GenCheckStmt()
checkQuery := checkStmt.Query
checkValues := checkStmt.Values()
if verbose {
fmt.Printf("%s\n", check)
fmt.Printf("%s (values=%v)\n", checkQuery, checkValues)
}
if diff := session.Check(check); diff != "" {
fmt.Printf("Failed! Check '%s' rows differ (-oracle +test)\n%s", check, diff)
if diff := session.Check(checkQuery, checkValues...); diff != "" {
fmt.Printf("Failed! Check '%s' (values=%v) rows differ (-oracle +test)\n%s", checkQuery, checkValues, diff)
return
}
nrPassedTests++
Expand Down
45 changes: 36 additions & 9 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@ type Table struct {
type Schema interface {
GetDropSchema() []string
GetCreateSchema() []string
GenMutateOp() string
GenCheckOp() string
GenMutateStmt() *Stmt
GenCheckStmt() *Stmt
}

type Stmt struct {
Query string
Values func() []interface{}
}

type schema struct {
Expand Down Expand Up @@ -63,25 +68,41 @@ func (s *schema) GetCreateSchema() []string {
}
}

func (s *schema) GenMutateOp() string {
func (s *schema) GenMutateStmt() *Stmt {
columns := []string{}
values := []string{}
for _, pk := range s.table.PartitionKeys {
columns = append(columns, pk.Name)
values = append(values, fmt.Sprintf("%d", rand.Intn(100)))
values = append(values, "?")
}
for _, pk := range s.table.ClusteringKeys {
columns = append(columns, pk.Name)
values = append(values, fmt.Sprintf("%d", rand.Intn(100)))
values = append(values, "?")
}
for _, cdef := range s.table.Columns {
columns = append(columns, cdef.Name)
values = append(values, fmt.Sprintf("%d", rand.Intn(100)))
values = append(values, "?")
}
query := fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)", s.keyspace.Name, s.table.Name, strings.Join(columns, ","), strings.Join(values, ","))
return &Stmt{
Query: query,
Values: func() []interface{} {
values := make([]interface{}, 0)
for _, _ = range s.table.PartitionKeys {
values = append(values, rand.Intn(100))
}
for _, _ = range s.table.ClusteringKeys {
values = append(values, rand.Intn(100))
}
for _, _ = range s.table.Columns {
values = append(values, rand.Intn(100))
}
return values
},
}
return fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)", s.keyspace.Name, s.table.Name, strings.Join(columns, ","), strings.Join(values, ","))
}

func (s *schema) GenCheckOp() string {
func (s *schema) GenCheckStmt() *Stmt {
query := fmt.Sprintf("SELECT * FROM %s.%s", s.keyspace.Name, s.table.Name)
if rand.Intn(2) == 1 {
query += fmt.Sprintf(" ORDER BY %s", s.table.Columns[0].Name)
Expand All @@ -92,7 +113,13 @@ func (s *schema) GenCheckOp() string {
if rand.Intn(2) == 1 {
query += fmt.Sprintf(" LIMIT %d", rand.Intn(100))
}
return query
values := func() []interface{} {
return nil
}
return &Stmt{
Query: query,
Values: values,
}
}

type SchemaBuilder interface {
Expand Down
12 changes: 6 additions & 6 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ func (s *Session) Close() {
s.oracleSession.Close()
}

func (s *Session) Mutate(query string) error {
if err := s.testSession.Query(query).Exec(); err != nil {
func (s *Session) Mutate(query string, values ...interface{}) error {
if err := s.testSession.Query(query, values...).Exec(); err != nil {
return fmt.Errorf("%v [cluster = test, query = '%s']", err, query)
}
if err := s.oracleSession.Query(query).Exec(); err != nil {
if err := s.oracleSession.Query(query, values...).Exec(); err != nil {
return fmt.Errorf("%v [cluster = oracle, query = '%s']", err, query)
}
return nil
}

func (s *Session) Check(query string) string {
testIter := s.testSession.Query(query).Iter()
oracleIter := s.oracleSession.Query(query).Iter()
func (s *Session) Check(query string, values ...interface{}) string {
testIter := s.testSession.Query(query, values...).Iter()
oracleIter := s.oracleSession.Query(query, values...).Iter()
for {
testRow := make(map[string]interface{})
if !testIter.MapScan(testRow) {
Expand Down

0 comments on commit eb293a7

Please sign in to comment.