diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 6df155ee..2425b1d9 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -70,24 +70,29 @@ func run(cmd *cobra.Command, args []string) { return } } + nrPassedTests := 0 for i := 0; i < maxTests; i++ { - mutate := schema.GenMutateOp() + mutateStmt := schema.GenMutateStmt() + mutateQuery := mutateStmt.Query + mutateValues := mutateStmt.Values() if verbose { - fmt.Printf("%s\n", mutate) + fmt.Printf("%s (values=%v)\n", mutateQuery, mutateValues) } - if err := session.Mutate(mutate); err != nil { - fmt.Printf("Failed! Mutation '%s' caused an error: '%v'\n", mutate, err) + if err := session.Mutate(mutateQuery, mutateValues...); err != nil { + fmt.Printf("Failed! Mutation '%s' (values=%v) caused an error: '%v'\n", mutateQuery, mutateValues, err) return } - check := schema.GenCheckOp() + checkStmt := schema.GenCheckStmt() + checkQuery := checkStmt.Query + checkValues := checkStmt.Values() if verbose { - fmt.Printf("%s\n", check) + fmt.Printf("%s (values=%v)\n", checkQuery, checkValues) } - if diff := session.Check(check); diff != "" { - fmt.Printf("Failed! Check '%s' rows differ (-oracle +test)\n%s", check, diff) + if diff := session.Check(checkQuery, checkValues...); diff != "" { + fmt.Printf("Failed! Check '%s' (values=%v) rows differ (-oracle +test)\n%s", checkQuery, checkValues, diff) return } nrPassedTests++ diff --git a/schema.go b/schema.go index 81faa70c..ec1d8241 100644 --- a/schema.go +++ b/schema.go @@ -25,8 +25,13 @@ type Table struct { type Schema interface { GetDropSchema() []string GetCreateSchema() []string - GenMutateOp() string - GenCheckOp() string + GenMutateStmt() *Stmt + GenCheckStmt() *Stmt +} + +type Stmt struct { + Query string + Values func() []interface{} } type schema struct { @@ -63,25 +68,41 @@ func (s *schema) GetCreateSchema() []string { } } -func (s *schema) GenMutateOp() string { +func (s *schema) GenMutateStmt() *Stmt { columns := []string{} values := []string{} for _, pk := range s.table.PartitionKeys { columns = append(columns, pk.Name) - values = append(values, fmt.Sprintf("%d", rand.Intn(100))) + values = append(values, "?") } for _, pk := range s.table.ClusteringKeys { columns = append(columns, pk.Name) - values = append(values, fmt.Sprintf("%d", rand.Intn(100))) + values = append(values, "?") } for _, cdef := range s.table.Columns { columns = append(columns, cdef.Name) - values = append(values, fmt.Sprintf("%d", rand.Intn(100))) + values = append(values, "?") + } + query := fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)", s.keyspace.Name, s.table.Name, strings.Join(columns, ","), strings.Join(values, ",")) + return &Stmt{ + Query: query, + Values: func() []interface{} { + values := make([]interface{}, 0) + for _, _ = range s.table.PartitionKeys { + values = append(values, rand.Intn(100)) + } + for _, _ = range s.table.ClusteringKeys { + values = append(values, rand.Intn(100)) + } + for _, _ = range s.table.Columns { + values = append(values, rand.Intn(100)) + } + return values + }, } - return fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)", s.keyspace.Name, s.table.Name, strings.Join(columns, ","), strings.Join(values, ",")) } -func (s *schema) GenCheckOp() string { +func (s *schema) GenCheckStmt() *Stmt { query := fmt.Sprintf("SELECT * FROM %s.%s", s.keyspace.Name, s.table.Name) if rand.Intn(2) == 1 { query += fmt.Sprintf(" ORDER BY %s", s.table.Columns[0].Name) @@ -92,7 +113,13 @@ func (s *schema) GenCheckOp() string { if rand.Intn(2) == 1 { query += fmt.Sprintf(" LIMIT %d", rand.Intn(100)) } - return query + values := func() []interface{} { + return nil + } + return &Stmt{ + Query: query, + Values: values, + } } type SchemaBuilder interface { diff --git a/session.go b/session.go index 23aeea8a..c2bd06ca 100644 --- a/session.go +++ b/session.go @@ -40,19 +40,19 @@ func (s *Session) Close() { s.oracleSession.Close() } -func (s *Session) Mutate(query string) error { - if err := s.testSession.Query(query).Exec(); err != nil { +func (s *Session) Mutate(query string, values ...interface{}) error { + if err := s.testSession.Query(query, values...).Exec(); err != nil { return fmt.Errorf("%v [cluster = test, query = '%s']", err, query) } - if err := s.oracleSession.Query(query).Exec(); err != nil { + if err := s.oracleSession.Query(query, values...).Exec(); err != nil { return fmt.Errorf("%v [cluster = oracle, query = '%s']", err, query) } return nil } -func (s *Session) Check(query string) string { - testIter := s.testSession.Query(query).Iter() - oracleIter := s.oracleSession.Query(query).Iter() +func (s *Session) Check(query string, values ...interface{}) string { + testIter := s.testSession.Query(query, values...).Iter() + oracleIter := s.oracleSession.Query(query, values...).Iter() for { testRow := make(map[string]interface{}) if !testIter.MapScan(testRow) {