diff --git a/common/config/config.go b/common/config/config.go index 4db4ee7b8ef..38336e6388f 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -236,6 +236,14 @@ type ( TLS *TLS `yaml:"tls"` // ProtoVersion ProtoVersion int `yaml:"protoVersion"` + // ConnectTimeout defines duration for initial dial + ConnectTimeout time.Duration `yaml:"connectTimeout"` + // Timout is a connection timeout + Timeout time.Duration `yaml:"timeout"` + // Consistency defines default consistency level + Consistency string `yaml:"consistency"` + // SerialConsistency sets the consistency for the serial part of queries + SerialConsistency string `yaml:"serialConsistency"` // ConnectAttributes is a set of key-value attributes as a supplement/extension to the above common fields // Use it ONLY when a configure is too specific to a particular NoSQL database that should not be in the common struct // Otherwise please add new fields to the struct for better documentation diff --git a/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency.go index 3154c0ad957..684661e2f93 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency.go @@ -22,6 +22,7 @@ package gocql import ( "fmt" + "strings" "github.com/gocql/gocql" ) @@ -80,3 +81,104 @@ func mustConvertSerialConsistency(c SerialConsistency) gocql.SerialConsistency { panic(fmt.Sprintf("Unknown gocql SerialConsistency level: %v", c)) } } + +func (c Consistency) MarshalText() (text []byte, err error) { + return []byte(c.String()), nil +} + +func (c *Consistency) UnmarshalText(text []byte) error { + switch string(text) { + case "ANY": + *c = Any + case "ONE": + *c = One + case "TWO": + *c = Two + case "THREE": + *c = Three + case "QUORUM": + *c = Quorum + case "ALL": + *c = All + case "LOCAL_QUORUM": + *c = LocalQuorum + case "EACH_QUORUM": + *c = EachQuorum + case "LOCAL_ONE": + *c = LocalOne + default: + return fmt.Errorf("invalid consistency %q", string(text)) + } + + return nil +} + +func (c Consistency) String() string { + switch c { + case Any: + return "ANY" + case One: + return "ONE" + case Two: + return "TWO" + case Three: + return "THREE" + case Quorum: + return "QUORUM" + case All: + return "ALL" + case LocalQuorum: + return "LOCAL_QUORUM" + case EachQuorum: + return "EACH_QUORUM" + case LocalOne: + return "LOCAL_ONE" + default: + return fmt.Sprintf("invalid consistency: %d", uint16(c)) + } +} + +func ParseConsistency(s string) (Consistency, error) { + var c Consistency + if err := c.UnmarshalText([]byte(strings.ToUpper(s))); err != nil { + return c, fmt.Errorf("parse consistency: %w", err) + } + return c, nil +} + +func ParseSerialConsistency(s string) (SerialConsistency, error) { + var sc SerialConsistency + if err := sc.UnmarshalText([]byte(strings.ToUpper(s))); err != nil { + return sc, fmt.Errorf("parse serial consistency: %w", err) + + } + return sc, nil +} + +func (s SerialConsistency) String() string { + switch s { + case Serial: + return "SERIAL" + case LocalSerial: + return "LOCAL_SERIAL" + default: + return fmt.Sprintf("invalid serial consistency %d", uint16(s)) + } +} + +func (s SerialConsistency) MarshalText() (text []byte, err error) { + return []byte(s.String()), nil +} + +func (s *SerialConsistency) UnmarshalText(text []byte) error { + switch string(text) { + case "SERIAL": + *s = Serial + case "LOCAL_SERIAL": + *s = LocalSerial + default: + return fmt.Errorf("invalid serial consistency %q", string(text)) + } + + return nil +} diff --git a/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency_test.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency_test.go index 81d21fb3961..77eb73a1d14 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency_test.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency_test.go @@ -29,6 +29,172 @@ import ( "github.com/stretchr/testify/assert" ) +func TestConsistency_MarshalText(t *testing.T) { + tests := []struct { + c Consistency + wantText []byte + testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool + }{ + {c: Any, wantText: []byte("ANY"), testFn: assert.Equal}, + {c: One, wantText: []byte("ONE"), testFn: assert.Equal}, + {c: Two, wantText: []byte("TWO"), testFn: assert.Equal}, + {c: Three, wantText: []byte("THREE"), testFn: assert.Equal}, + {c: Quorum, wantText: []byte("QUORUM"), testFn: assert.Equal}, + {c: All, wantText: []byte("ALL"), testFn: assert.Equal}, + {c: LocalQuorum, wantText: []byte("LOCAL_QUORUM"), testFn: assert.Equal}, + {c: EachQuorum, wantText: []byte("EACH_QUORUM"), testFn: assert.Equal}, + {c: LocalOne, wantText: []byte("LOCAL_ONE"), testFn: assert.Equal}, + {c: LocalOne, wantText: []byte("WRONG_VALUE"), testFn: assert.NotEqualValues}, + } + for _, tt := range tests { + t.Run(tt.c.String(), func(t *testing.T) { + gotText, err := tt.c.MarshalText() + assert.NoError(t, err) + tt.testFn(t, tt.wantText, gotText) + }) + } +} + +func TestConsistency_String(t *testing.T) { + c := Consistency(9) + assert.Equal(t, c.String(), "invalid consistency: 9") +} + +func TestConsistency_UnmarshalText(t *testing.T) { + tests := []struct { + destConsistency Consistency + inputText []byte + testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool + wantErr bool + }{ + {destConsistency: Any, inputText: []byte("ANY"), testFn: assert.Equal}, + {destConsistency: One, inputText: []byte("ONE"), testFn: assert.Equal}, + {destConsistency: Two, inputText: []byte("TWO"), testFn: assert.Equal}, + {destConsistency: Three, inputText: []byte("THREE"), testFn: assert.Equal}, + {destConsistency: Quorum, inputText: []byte("QUORUM"), testFn: assert.Equal}, + {destConsistency: All, inputText: []byte("ALL"), testFn: assert.Equal}, + {destConsistency: LocalQuorum, inputText: []byte("LOCAL_QUORUM"), testFn: assert.Equal}, + {destConsistency: EachQuorum, inputText: []byte("EACH_QUORUM"), testFn: assert.Equal}, + {destConsistency: LocalOne, inputText: []byte("LOCAL_ONE"), testFn: assert.Equal}, + {destConsistency: LocalOne, inputText: []byte("WRONG_VALUE"), testFn: assert.NotEqualValues, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.destConsistency.String(), func(t *testing.T) { + var c Consistency + err := c.UnmarshalText(tt.inputText) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + tt.testFn(t, tt.destConsistency, c) + }) + } +} + +func TestParseConsistency(t *testing.T) { + tests := []struct { + destConsistency Consistency + inputText string + testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool + wantErr bool + }{ + {destConsistency: Any, inputText: "any", testFn: assert.Equal}, + {destConsistency: One, inputText: "ONE", testFn: assert.Equal}, + {destConsistency: Two, inputText: "TWO", testFn: assert.Equal}, + {destConsistency: Three, inputText: "THREE", testFn: assert.Equal}, + {destConsistency: Quorum, inputText: "QUORUM", testFn: assert.Equal}, + {destConsistency: All, inputText: "all", testFn: assert.Equal}, + {destConsistency: LocalQuorum, inputText: "LOCAL_QUORUM", testFn: assert.Equal}, + {destConsistency: EachQuorum, inputText: "EACH_QUORUM", testFn: assert.Equal}, + {destConsistency: LocalOne, inputText: "LOCAL_ONE", testFn: assert.Equal}, + {destConsistency: Any, inputText: "WRONG_VALUE_FAILBACK_TO_ANY", testFn: assert.Equal, wantErr: true}, + } + for _, tt := range tests { + t.Run(string(tt.inputText), func(t *testing.T) { + got, err := ParseConsistency(tt.inputText) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + tt.testFn(t, tt.destConsistency, got) + }) + } +} + +func TestParseSerialConsistency(t *testing.T) { + tests := []struct { + destConsistency SerialConsistency + inputText string + testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool + wantErr bool + }{ + {destConsistency: Serial, inputText: "serial", testFn: assert.Equal}, + {destConsistency: LocalSerial, inputText: "LOCAL_SERIAL", testFn: assert.Equal}, + {destConsistency: Serial, inputText: "WRONG_VALUE_FAILBACK_TO_ANY", testFn: assert.Equal, wantErr: true}, + } + for _, tt := range tests { + t.Run(string(tt.inputText), func(t *testing.T) { + got, err := ParseSerialConsistency(tt.inputText) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + tt.testFn(t, tt.destConsistency, got) + }) + } +} + +func TestSerialConsistency_MarshalText(t *testing.T) { + tests := []struct { + c SerialConsistency + wantText []byte + testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool + }{ + {c: Serial, wantText: []byte("SERIAL"), testFn: assert.Equal}, + {c: LocalSerial, wantText: []byte("LOCAL_SERIAL"), testFn: assert.Equal}, + {c: LocalSerial, wantText: []byte("WRONG_VALUE"), testFn: assert.NotEqualValues}, + } + for _, tt := range tests { + t.Run(tt.c.String(), func(t *testing.T) { + gotText, err := tt.c.MarshalText() + assert.NoError(t, err) + tt.testFn(t, tt.wantText, gotText) + }) + } +} + +func TestSerialConsistency_String(t *testing.T) { + c := SerialConsistency(2) + assert.Equal(t, c.String(), "invalid serial consistency 2") +} + +func TestSerialConsistency_UnmarshalText(t *testing.T) { + tests := []struct { + destSerialConsistency SerialConsistency + inputText []byte + wantErr bool + }{ + {destSerialConsistency: Serial, inputText: []byte("SERIAL")}, + {destSerialConsistency: LocalSerial, inputText: []byte("LOCAL_SERIAL")}, + {destSerialConsistency: Serial, inputText: []byte("WRONG_VALUE_DEFAULTS_TO_SERIAL"), wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.destSerialConsistency.String(), func(t *testing.T) { + var c SerialConsistency + err := c.UnmarshalText(tt.inputText) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.destSerialConsistency, c) + }) + } +} + func Test_mustConvertConsistency(t *testing.T) { tests := []struct { input Consistency diff --git a/common/persistence/nosql/nosqlplugin/cassandra/plugin.go b/common/persistence/nosql/nosqlplugin/cassandra/plugin.go index 83278adeffb..30bd8365408 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/plugin.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/plugin.go @@ -95,6 +95,33 @@ func toGoCqlConfig(cfg *config.NoSQL) (gocql.ClusterConfig, error) { return gocql.ClusterConfig{}, err } } + + if cfg.Timeout == 0 { + cfg.Timeout = defaultSessionTimeout + } + + if cfg.ConnectTimeout == 0 { + cfg.ConnectTimeout = defaultConnectTimeout + } + + if cfg.Consistency == "" { + cfg.Consistency = cassandraDefaultConsLevel.String() + } + + if cfg.SerialConsistency == "" { + cfg.SerialConsistency = cassandraDefaultSerialConsLevel.String() + } + + consistency, err := gocql.ParseConsistency(cfg.Consistency) + if err != nil { + return gocql.ClusterConfig{}, err + } + serialConsistency, err := gocql.ParseSerialConsistency(cfg.SerialConsistency) + + if err != nil { + return gocql.ClusterConfig{}, err + } + return gocql.ClusterConfig{ Hosts: cfg.Hosts, Port: cfg.Port, @@ -107,9 +134,9 @@ func toGoCqlConfig(cfg *config.NoSQL) (gocql.ClusterConfig, error) { MaxConns: cfg.MaxConns, TLS: cfg.TLS, ProtoVersion: cfg.ProtoVersion, - Consistency: cassandraDefaultConsLevel, - SerialConsistency: cassandraDefaultSerialConsLevel, - Timeout: defaultSessionTimeout, - ConnectTimeout: defaultConnectTimeout, + Consistency: consistency, + SerialConsistency: serialConsistency, + Timeout: cfg.Timeout, + ConnectTimeout: cfg.ConnectTimeout, }, nil } diff --git a/common/persistence/nosql/nosqlplugin/cassandra/plugin_test.go b/common/persistence/nosql/nosqlplugin/cassandra/plugin_test.go new file mode 100644 index 00000000000..1ca580ae659 --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/plugin_test.go @@ -0,0 +1,69 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package cassandra + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common/config" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra/gocql" + "github.com/uber/cadence/environment" +) + +func Test_toGoCqlConfig(t *testing.T) { + t.Setenv(environment.CassandraSeeds, environment.Localhost) + tests := []struct { + name string + cfg *config.NoSQL + want gocql.ClusterConfig + wantErr assert.ErrorAssertionFunc + }{ + { + "empty config will be filled with defaults", + &config.NoSQL{}, + gocql.ClusterConfig{ + Hosts: environment.Localhost, + Port: 9042, + ProtoVersion: 4, + Timeout: time.Second * 10, + Consistency: gocql.LocalQuorum, + SerialConsistency: gocql.LocalSerial, + ConnectTimeout: time.Second * 2, + }, + assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := toGoCqlConfig(tt.cfg) + if !tt.wantErr(t, err, fmt.Sprintf("toGoCqlConfig(%v)", tt.cfg)) { + return + } + assert.Equalf(t, tt.want, got, "toGoCqlConfig(%v)", tt.cfg) + }) + } +} diff --git a/config/development.yaml b/config/development.yaml index a62df9eb96a..2ff127df3bf 100644 --- a/config/development.yaml +++ b/config/development.yaml @@ -8,6 +8,10 @@ persistence: pluginName: "cassandra" hosts: "127.0.0.1" keyspace: "cadence" + connectTimeout: 2s # defaults to 2s if not defined + timeout: 5s # defaults to 10s if not defined + consistency: LOCAL_QUORUM # default value + serialConsistency: LOCAL_SERIAL # default value cass-visibility: nosql: pluginName: "cassandra"