diff --git a/datautils.go b/datautils.go index c21634af..e6e69b89 100644 --- a/datautils.go +++ b/datautils.go @@ -2,33 +2,62 @@ package gemini import ( "encoding/base64" + "fmt" + "math/big" "math/rand" + "net" + "strconv" "strings" "time" + "github.com/gocql/gocql" + "gopkg.in/inf.v0" + "github.com/segmentio/ksuid" ) -func randRange(min int, max int) int { +func randIntRange(min int, max int) int { return rand.Intn(max-min) + min } -func nonEmptyRandRange(min int, max int, def int) int { +func nonEmptyRandIntRange(min int, max int, def int) int { if max > min && min > 0 { - return randRange(min, max) + return randIntRange(min, max) } - return randRange(1, def) + return randIntRange(1, def) } -func randRange64(min int64, max int64) int64 { +func randInt64Range(min int64, max int64) int64 { return rand.Int63n(max-min) + min } -func nonEmptyRandRange64(min int64, max int64, def int64) int64 { +func nonEmptyRandInt64Range(min int64, max int64, def int64) int64 { + if max > min && min > 0 { + return randInt64Range(min, max) + } + return randInt64Range(1, def) +} + +func randFloat32Range(min float32, max float32) float32 { + return rand.Float32() * (max - min) +} + +func nonEmptyRandFloat32Range(min float32, max float32, def float32) float32 { + if max > min && min > 0 { + return randFloat32Range(min, max) + } + return randFloat32Range(1, def) +} + +func randFloat64Range(min float64, max float64) float64 { + return rand.Float64() * (max - min) +} + +func nonEmptyRandFloat64Range(min float64, max float64, def float64) float64 { if max > min && min > 0 { - return randRange64(min, max) + return randFloat64Range(min, max) } - return randRange64(1, def) + return randFloat64Range(1, def) } func randString(len int) string { @@ -74,3 +103,134 @@ func randDateNewer(d time.Time) time.Time { sec := rand.Int63n(max-min+1) + min return time.Unix(sec, 0) } + +func randIpV4Address(v, pos int) string { + if pos < 0 || pos > 4 { + panic(fmt.Sprintf("invalid position for the desired value of the IP part %d, 0-3 supported", pos)) + } + if v < 0 || v > 255 { + panic(fmt.Sprintf("invalid value for the desired position %d of the IP, 0-255 suppoerted", v)) + } + var blocks []string + for i := 0; i < 4; i++ { + if i == pos { + blocks = append(blocks, strconv.Itoa(v)) + } else { + blocks = append(blocks, strconv.Itoa(rand.Intn(255))) + } + } + return strings.Join(blocks, ".") +} + +func genValue(columnType string, p *PartitionRange, values []interface{}) []interface{} { + switch columnType { + case "ascii", "blob", "text", "varchar": + values = append(values, randStringWithTime(nonEmptyRandIntRange(p.Max, p.Max, 10), randDate())) + case "bigint": + values = append(values, rand.Int63()) + case "boolean": + values = append(values, rand.Int()%2 == 0) + case "date", "time", "timestamp": + values = append(values, randDate()) + case "decimal": + values = append(values, inf.NewDec(randInt64Range(int64(p.Min), int64(p.Max)), 3)) + case "double": + values = append(values, randFloat64Range(float64(p.Min), float64(p.Max))) + case "duration": + values = append(values, time.Minute*time.Duration(randIntRange(p.Min, p.Max))) + case "float": + values = append(values, randFloat32Range(float32(p.Min), float32(p.Max))) + case "inet": + values = append(values, net.ParseIP(randIpV4Address(rand.Intn(255), 2))) + case "int": + values = append(values, nonEmptyRandIntRange(p.Min, p.Max, 10)) + case "smallint": + values = append(values, int16(nonEmptyRandIntRange(p.Min, p.Max, 10))) + case "timeuuid", "uuid": + r := gocql.UUIDFromTime(randDate()) + values = append(values, r.String()) + case "tinyint": + values = append(values, int8(nonEmptyRandIntRange(p.Min, p.Max, 10))) + case "varint": + values = append(values, big.NewInt(randInt64Range(int64(p.Min), int64(p.Max)))) + default: + panic(fmt.Sprintf("generate value: not supported type %s", columnType)) + } + return values +} + +func genValueRange(columnType string, p *PartitionRange, values []interface{}) []interface{} { + switch columnType { + case "ascii", "blob", "text", "varchar": + startTime := randDate() + start := nonEmptyRandIntRange(p.Min, p.Max, 10) + end := start + nonEmptyRandIntRange(p.Min, p.Max, 10) + values = append(values, nonEmptyRandStringWithTime(start, startTime)) + values = append(values, nonEmptyRandStringWithTime(end, randDateNewer(startTime))) + case "bigint": + start := nonEmptyRandInt64Range(int64(p.Min), int64(p.Max), 10) + end := start + nonEmptyRandInt64Range(int64(p.Min), int64(p.Max), 10) + values = append(values, start) + values = append(values, end) + case "date", "time", "timestamp": + start := randDate() + end := randDateNewer(start) + values = append(values, start) + values = append(values, end) + case "decimal": + start := nonEmptyRandInt64Range(int64(p.Min), int64(p.Max), 10) + end := start + nonEmptyRandInt64Range(int64(p.Min), int64(p.Max), 10) + values = append(values, inf.NewDec(start, 3)) + values = append(values, inf.NewDec(end, 3)) + case "double": + start := nonEmptyRandFloat64Range(float64(p.Min), float64(p.Max), 10) + end := start + nonEmptyRandFloat64Range(float64(p.Min), float64(p.Max), 10) + values = append(values, start) + values = append(values, end) + case "duration": + start := time.Minute * time.Duration(nonEmptyRandIntRange(p.Min, p.Max, 10)) + end := start + time.Minute*time.Duration(nonEmptyRandIntRange(p.Min, p.Max, 10)) + values = append(values, start) + values = append(values, end) + case "float": + start := nonEmptyRandFloat32Range(float32(p.Min), float32(p.Max), 10) + end := start + nonEmptyRandFloat32Range(float32(p.Min), float32(p.Max), 10) + values = append(values, start) + values = append(values, end) + case "inet": + start := randIpV4Address(0, 3) + end := randIpV4Address(255, 3) + values = append(values, net.ParseIP(start)) + values = append(values, net.ParseIP(end)) + case "int": + start := nonEmptyRandIntRange(p.Min, p.Max, 10) + end := start + nonEmptyRandIntRange(p.Min, p.Max, 10) + values = append(values, start) + values = append(values, end) + case "smallint": + start := int16(nonEmptyRandIntRange(p.Min, p.Max, 10)) + end := start + int16(nonEmptyRandIntRange(p.Min, p.Max, 10)) + values = append(values, start) + values = append(values, end) + case "timeuuid", "uuid": + start := randDate() + end := randDateNewer(start) + values = append(values, gocql.UUIDFromTime(start).String()) + values = append(values, gocql.UUIDFromTime(end).String()) + case "tinyint": + start := int8(nonEmptyRandIntRange(p.Min, p.Max, 10)) + end := start + int8(nonEmptyRandIntRange(p.Min, p.Max, 10)) + values = append(values, start) + values = append(values, end) + case "varint": + end := &big.Int{} + start := big.NewInt(randInt64Range(int64(p.Min), int64(p.Max))) + end.Set(start) + end = end.Add(start, big.NewInt(randInt64Range(int64(p.Min), int64(p.Max)))) + values = append(values, start) + values = append(values, end) + default: + panic(fmt.Sprintf("generate value range: not supported type %s", columnType)) + } + return values +} diff --git a/datautils_test.go b/datautils_test.go index 93d9d79f..b1f2707e 100644 --- a/datautils_test.go +++ b/datautils_test.go @@ -8,7 +8,7 @@ import ( func TestNonEmptyRandRange(t *testing.T) { f := func(x, y int) bool { - r := nonEmptyRandRange(x, y, 10) + r := nonEmptyRandIntRange(x, y, 10) return r > 0 } if err := quick.Check(f, nil); err != nil { @@ -18,7 +18,27 @@ func TestNonEmptyRandRange(t *testing.T) { func TestNonEmptyRandRange64(t *testing.T) { f := func(x, y int) bool { - r := nonEmptyRandRange(x, y, 10) + r := nonEmptyRandIntRange(x, y, 10) + return r > 0 + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func TestNonEmptyRandFloat32Range(t *testing.T) { + f := func(x, y float32) bool { + r := nonEmptyRandFloat32Range(x, y, 10) + return r > 0 + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func TestNonEmptyRandFloat64Range(t *testing.T) { + f := func(x, y float64) bool { + r := nonEmptyRandFloat64Range(x, y, 10) return r > 0 } if err := quick.Check(f, nil); err != nil { @@ -61,7 +81,7 @@ var bench_rr int func BenchmarkNonEmptyRandRange(b *testing.B) { for i := 0; i < b.N; i++ { - bench_rr = nonEmptyRandRange(0, 50, 30) + bench_rr = nonEmptyRandIntRange(0, 50, 30) } } @@ -69,6 +89,6 @@ var bench_rr64 int64 func BenchmarkNonEmptyRandRange64(b *testing.B) { for i := 0; i < b.N; i++ { - bench_rr64 = nonEmptyRandRange64(0, 50, 30) + bench_rr64 = nonEmptyRandInt64Range(0, 50, 30) } } diff --git a/go.mod b/go.mod index 35cfabb4..6d287837 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/spf13/pflag v1.0.3 // indirect github.com/stretchr/testify v1.3.0 // indirect golang.org/x/net v0.0.0-20190313082753-5c2c250b6a70 + gopkg.in/inf.v0 v0.9.1 ) replace github.com/gocql/gocql => github.com/scylladb/gocql v1.0.1 diff --git a/go.sum b/go.sum index a9ee4bf2..48ffd748 100644 --- a/go.sum +++ b/go.sum @@ -8,9 +8,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/set v0.2.1 h1:nn2CaJyknWE/6txyUDGwysr3G5QC6xWB/PtVjPBbeaA= github.com/fatih/set v0.2.1/go.mod h1:+RKtMCH+favT2+3YecHGxcc0b4KyVWA1QWWJUs4E0CI= -github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4 h1:vF83LI8tAakwEwvWZtrIEx7pOySacl2TOxx6eXk4ePo= -github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4/go.mod h1:4Fw1eo5iaEhDUs8XyuhSVCVy52Jq3L+/3GJgYkwc+/0= github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= diff --git a/schema.go b/schema.go index bf4988a9..4078b734 100644 --- a/schema.go +++ b/schema.go @@ -4,8 +4,6 @@ import ( "fmt" "math/rand" "strings" - - "github.com/gocql/gocql" ) type Keyspace struct { @@ -51,7 +49,11 @@ func (s *Schema) GetDropSchema() []string { } } -var types = [...]string{"int", "bigint", "blob", "uuid", "text", "varchar", "timestamp"} +// TODO: Add support for time when gocql bug is fixed. +var ( + pkTypes = []string{"ascii", "bigint", "blob", "date", "decimal", "double", "float", "inet", "int", "smallint", "text" /*"time",*/, "timestamp", "timeuuid", "tinyint", "uuid", "varchar", "varint"} + types = append(append([]string{}, pkTypes...), "boolean", "duration") +) func genColumnName(prefix string, idx int) string { return fmt.Sprintf("%s%d", prefix, idx) @@ -62,6 +64,11 @@ func genColumnType() string { return types[n] } +func genPrimaryKeyColumnType() string { + n := rand.Intn(len(pkTypes)) + return types[n] +} + func genColumnDef(prefix string, idx int) ColumnDef { return ColumnDef{ Name: genColumnName(prefix, idx), @@ -76,7 +83,7 @@ func genIndexName(prefix string, idx int) string { const ( MaxPartitionKeys = 2 MaxClusteringKeys = 4 - MaxColumns = 8 + MaxColumns = 16 ) func GenSchema() *Schema { @@ -93,7 +100,7 @@ func GenSchema() *Schema { var clusteringKeys []ColumnDef numClusteringKeys := rand.Intn(MaxClusteringKeys) for i := 0; i < numClusteringKeys; i++ { - clusteringKeys = append(clusteringKeys, ColumnDef{Name: genColumnName("ck", i), Type: genColumnType()}) + clusteringKeys = append(clusteringKeys, ColumnDef{Name: genColumnName("ck", i), Type: genPrimaryKeyColumnType()}) } var columns []ColumnDef numColumns := rand.Intn(MaxColumns) @@ -118,59 +125,6 @@ func GenSchema() *Schema { return builder.Build() } -func genValue(columnType string, p *PartitionRange, values []interface{}) []interface{} { - switch columnType { - case "int": - values = append(values, nonEmptyRandRange(p.Min, p.Max, 10)) - case "bigint": - values = append(values, rand.Int63()) - case "uuid": - r := gocql.UUIDFromTime(randDate()) - values = append(values, r.String()) - case "blob", "text", "varchar": - values = append(values, randStringWithTime(nonEmptyRandRange(p.Max, p.Max, 10), randDate())) - case "timestamp", "date": - values = append(values, randDate()) - default: - panic(fmt.Sprintf("generate value: not supported type %s", columnType)) - } - return values -} - -func genValueRange(columnType string, p *PartitionRange, values []interface{}) []interface{} { - switch columnType { - case "int": - start := nonEmptyRandRange(p.Min, p.Max, 10) - end := start + nonEmptyRandRange(p.Min, p.Max, 10) - values = append(values, start) - values = append(values, end) - case "bigint": - start := nonEmptyRandRange64(int64(p.Min), int64(p.Max), 10) - end := start + nonEmptyRandRange64(int64(p.Min), int64(p.Max), 10) - values = append(values, start) - values = append(values, end) - case "uuid": - start := randDate() - end := randDateNewer(start) - values = append(values, gocql.UUIDFromTime(start).String()) - values = append(values, gocql.UUIDFromTime(end).String()) - case "blob", "text", "varchar": - startTime := randDate() - start := nonEmptyRandRange(p.Min, p.Max, 10) - end := start + nonEmptyRandRange(p.Min, p.Max, 10) - values = append(values, nonEmptyRandStringWithTime(start, startTime)) - values = append(values, nonEmptyRandStringWithTime(end, randDateNewer(startTime))) - case "timestamp", "date": - start := randDate() - end := randDateNewer(start) - values = append(values, start) - values = append(values, end) - default: - panic(fmt.Sprintf("generate value range: not supported type %s", columnType)) - } - return values -} - func (s *Schema) GetCreateSchema() []string { createKeyspace := fmt.Sprintf("CREATE KEYSPACE IF NOT EXISTS %s WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1}", s.Keyspace.Name) diff --git a/session.go b/session.go index 9c6faaf4..81c5ca87 100644 --- a/session.go +++ b/session.go @@ -3,14 +3,15 @@ package gemini import ( "errors" "fmt" + "math/big" "sort" - "strconv" "strings" "time" "github.com/gocql/gocql" "github.com/google/go-cmp/cmp" "github.com/scylladb/go-set/strset" + "gopkg.in/inf.v0" ) type Session struct { @@ -87,7 +88,12 @@ func (s *Session) Check(table Table, query string, values ...interface{}) error }) for i, oracleRow := range oracleRows { testRow := testRows[i] - diff := cmp.Diff(oracleRow, testRow) + cmp.AllowUnexported() + diff := cmp.Diff(oracleRow, testRow, cmp.Comparer(func(x, y *inf.Dec) bool { + return x.Cmp(y) == 0 + }), cmp.Comparer(func(x, y *big.Int) bool { + return x.Cmp(y) == 0 + })) if diff != "" { return fmt.Errorf("rows differ (-%v +%v): %v", oracleRow, testRow, diff) } @@ -108,29 +114,7 @@ func pks(t Table, rows []map[string]interface{}) []string { func extractRowValues(values []string, columns []ColumnDef, row map[string]interface{}) []string { for _, pk := range columns { - cv := row[pk.Name] - switch pk.Type { - case "int": - v, _ := cv.(int) - values = append(values, pk.Name+"="+strconv.Itoa(v)) - case "bigint": - v, _ := cv.(int64) - values = append(values, pk.Name+"="+strconv.FormatInt(v, 10)) - case "uuid": - v, _ := cv.(gocql.UUID) - values = append(values, pk.Name+"="+v.String()) - case "blob": - v, _ := cv.([]byte) - values = append(values, pk.Name+"="+string(v)) - case "text", "varchar": - v, _ := cv.(string) - values = append(values, pk.Name+"="+v) - case "timestamp", "date": - v, _ := cv.(time.Time) - values = append(values, pk.Name+"="+v.String()) - default: - panic(fmt.Sprintf("not supported type %s", pk)) - } + values = append(values, fmt.Sprintf(pk.Name+"=%v", row[pk.Name])) } return values }