Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cassandra] Expose timeout and consistency level configuration #5675

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ type (
TLS *TLS `yaml:"tls"`
// ProtoVersion
ProtoVersion int `yaml:"protoVersion"`
// ConnectTimeout defines duration for initial dial
ConnectTimeout time.Duration `yaml:"connectTimeout"`
// Timout is a connection timeout
Timeout time.Duration `yaml:"timeout"`
// Consistency defines default consistency level
Consistency string `yaml:"consistency"`
// SerialConsistency sets the consistency for the serial part of queries
SerialConsistency string `yaml:"serialConsistency"`
mantas-sidlauskas marked this conversation as resolved.
Show resolved Hide resolved
// ConnectAttributes is a set of key-value attributes as a supplement/extension to the above common fields
// Use it ONLY when a configure is too specific to a particular NoSQL database that should not be in the common struct
// Otherwise please add new fields to the struct for better documentation
Expand Down
102 changes: 102 additions & 0 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package gocql

import (
"fmt"
"strings"

"github.com/gocql/gocql"
)
Expand Down Expand Up @@ -80,3 +81,104 @@ func mustConvertSerialConsistency(c SerialConsistency) gocql.SerialConsistency {
panic(fmt.Sprintf("Unknown gocql SerialConsistency level: %v", c))
}
}

func (c Consistency) MarshalText() (text []byte, err error) {
mantas-sidlauskas marked this conversation as resolved.
Show resolved Hide resolved
return []byte(c.String()), nil
}

func (c *Consistency) UnmarshalText(text []byte) error {
switch string(text) {
case "ANY":
*c = Any
case "ONE":
*c = One
case "TWO":
*c = Two
case "THREE":
*c = Three
case "QUORUM":
*c = Quorum
case "ALL":
*c = All
case "LOCAL_QUORUM":
*c = LocalQuorum
case "EACH_QUORUM":
*c = EachQuorum
case "LOCAL_ONE":
*c = LocalOne
default:
return fmt.Errorf("invalid consistency %q", string(text))
}

return nil
}

func (c Consistency) String() string {
switch c {
case Any:
return "ANY"
case One:
return "ONE"
case Two:
return "TWO"
case Three:
return "THREE"
case Quorum:
return "QUORUM"
case All:
return "ALL"
case LocalQuorum:
return "LOCAL_QUORUM"
case EachQuorum:
return "EACH_QUORUM"
case LocalOne:
return "LOCAL_ONE"
default:
return fmt.Sprintf("invalid consistency: %d", uint16(c))
}
}

func ParseConsistency(s string) (Consistency, error) {
taylanisikdemir marked this conversation as resolved.
Show resolved Hide resolved
var c Consistency
if err := c.UnmarshalText([]byte(strings.ToUpper(s))); err != nil {
return c, fmt.Errorf("parse consistency: %w", err)
}
return c, nil
}

func ParseSerialConsistency(s string) (SerialConsistency, error) {
var sc SerialConsistency
if err := sc.UnmarshalText([]byte(strings.ToUpper(s))); err != nil {
return sc, fmt.Errorf("parse serial consistency: %w", err)

}
return sc, nil
}

func (s SerialConsistency) String() string {
switch s {
case Serial:
return "SERIAL"
case LocalSerial:
return "LOCAL_SERIAL"
default:
return fmt.Sprintf("invalid serial consistency %d", uint16(s))
}
}

func (s SerialConsistency) MarshalText() (text []byte, err error) {
return []byte(s.String()), nil
}

func (s *SerialConsistency) UnmarshalText(text []byte) error {
switch string(text) {
case "SERIAL":
*s = Serial
case "LOCAL_SERIAL":
*s = LocalSerial
default:
return fmt.Errorf("invalid serial consistency %q", string(text))
}

return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,172 @@ import (
"github.com/stretchr/testify/assert"
)

func TestConsistency_MarshalText(t *testing.T) {
tests := []struct {
c Consistency
wantText []byte
testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool
}{
{c: Any, wantText: []byte("ANY"), testFn: assert.Equal},
{c: One, wantText: []byte("ONE"), testFn: assert.Equal},
{c: Two, wantText: []byte("TWO"), testFn: assert.Equal},
{c: Three, wantText: []byte("THREE"), testFn: assert.Equal},
{c: Quorum, wantText: []byte("QUORUM"), testFn: assert.Equal},
{c: All, wantText: []byte("ALL"), testFn: assert.Equal},
{c: LocalQuorum, wantText: []byte("LOCAL_QUORUM"), testFn: assert.Equal},
{c: EachQuorum, wantText: []byte("EACH_QUORUM"), testFn: assert.Equal},
{c: LocalOne, wantText: []byte("LOCAL_ONE"), testFn: assert.Equal},
{c: LocalOne, wantText: []byte("WRONG_VALUE"), testFn: assert.NotEqualValues},
}
for _, tt := range tests {
t.Run(tt.c.String(), func(t *testing.T) {
gotText, err := tt.c.MarshalText()
assert.NoError(t, err)
tt.testFn(t, tt.wantText, gotText)
})
}
}

func TestConsistency_String(t *testing.T) {
c := Consistency(9)
assert.Equal(t, c.String(), "invalid consistency: 9")
}

func TestConsistency_UnmarshalText(t *testing.T) {
tests := []struct {
destConsistency Consistency
inputText []byte
testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool
wantErr bool
}{
{destConsistency: Any, inputText: []byte("ANY"), testFn: assert.Equal},
{destConsistency: One, inputText: []byte("ONE"), testFn: assert.Equal},
{destConsistency: Two, inputText: []byte("TWO"), testFn: assert.Equal},
{destConsistency: Three, inputText: []byte("THREE"), testFn: assert.Equal},
{destConsistency: Quorum, inputText: []byte("QUORUM"), testFn: assert.Equal},
{destConsistency: All, inputText: []byte("ALL"), testFn: assert.Equal},
{destConsistency: LocalQuorum, inputText: []byte("LOCAL_QUORUM"), testFn: assert.Equal},
{destConsistency: EachQuorum, inputText: []byte("EACH_QUORUM"), testFn: assert.Equal},
{destConsistency: LocalOne, inputText: []byte("LOCAL_ONE"), testFn: assert.Equal},
{destConsistency: LocalOne, inputText: []byte("WRONG_VALUE"), testFn: assert.NotEqualValues, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.destConsistency.String(), func(t *testing.T) {
var c Consistency
err := c.UnmarshalText(tt.inputText)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
tt.testFn(t, tt.destConsistency, c)
})
}
}

func TestParseConsistency(t *testing.T) {
tests := []struct {
destConsistency Consistency
inputText string
testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool
wantErr bool
}{
{destConsistency: Any, inputText: "any", testFn: assert.Equal},
{destConsistency: One, inputText: "ONE", testFn: assert.Equal},
{destConsistency: Two, inputText: "TWO", testFn: assert.Equal},
{destConsistency: Three, inputText: "THREE", testFn: assert.Equal},
{destConsistency: Quorum, inputText: "QUORUM", testFn: assert.Equal},
{destConsistency: All, inputText: "all", testFn: assert.Equal},
{destConsistency: LocalQuorum, inputText: "LOCAL_QUORUM", testFn: assert.Equal},
{destConsistency: EachQuorum, inputText: "EACH_QUORUM", testFn: assert.Equal},
{destConsistency: LocalOne, inputText: "LOCAL_ONE", testFn: assert.Equal},
{destConsistency: Any, inputText: "WRONG_VALUE_FAILBACK_TO_ANY", testFn: assert.Equal, wantErr: true},
}
for _, tt := range tests {
t.Run(string(tt.inputText), func(t *testing.T) {
got, err := ParseConsistency(tt.inputText)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
tt.testFn(t, tt.destConsistency, got)
})
}
}

func TestParseSerialConsistency(t *testing.T) {
tests := []struct {
destConsistency SerialConsistency
inputText string
testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool
wantErr bool
}{
{destConsistency: Serial, inputText: "serial", testFn: assert.Equal},
{destConsistency: LocalSerial, inputText: "LOCAL_SERIAL", testFn: assert.Equal},
{destConsistency: Serial, inputText: "WRONG_VALUE_FAILBACK_TO_ANY", testFn: assert.Equal, wantErr: true},
}
for _, tt := range tests {
t.Run(string(tt.inputText), func(t *testing.T) {
got, err := ParseSerialConsistency(tt.inputText)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
tt.testFn(t, tt.destConsistency, got)
})
}
}

func TestSerialConsistency_MarshalText(t *testing.T) {
tests := []struct {
c SerialConsistency
wantText []byte
testFn func(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool
}{
{c: Serial, wantText: []byte("SERIAL"), testFn: assert.Equal},
{c: LocalSerial, wantText: []byte("LOCAL_SERIAL"), testFn: assert.Equal},
{c: LocalSerial, wantText: []byte("WRONG_VALUE"), testFn: assert.NotEqualValues},
}
for _, tt := range tests {
t.Run(tt.c.String(), func(t *testing.T) {
gotText, err := tt.c.MarshalText()
assert.NoError(t, err)
tt.testFn(t, tt.wantText, gotText)
})
}
}

func TestSerialConsistency_String(t *testing.T) {
c := SerialConsistency(2)
assert.Equal(t, c.String(), "invalid serial consistency 2")
}

func TestSerialConsistency_UnmarshalText(t *testing.T) {
tests := []struct {
destSerialConsistency SerialConsistency
inputText []byte
wantErr bool
}{
{destSerialConsistency: Serial, inputText: []byte("SERIAL")},
{destSerialConsistency: LocalSerial, inputText: []byte("LOCAL_SERIAL")},
{destSerialConsistency: Serial, inputText: []byte("WRONG_VALUE_DEFAULTS_TO_SERIAL"), wantErr: true},
}
for _, tt := range tests {
t.Run(tt.destSerialConsistency.String(), func(t *testing.T) {
var c SerialConsistency
err := c.UnmarshalText(tt.inputText)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.destSerialConsistency, c)
})
}
}

func Test_mustConvertConsistency(t *testing.T) {
tests := []struct {
input Consistency
Expand Down
35 changes: 31 additions & 4 deletions common/persistence/nosql/nosqlplugin/cassandra/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,33 @@ func toGoCqlConfig(cfg *config.NoSQL) (gocql.ClusterConfig, error) {
return gocql.ClusterConfig{}, err
}
}

if cfg.Timeout == 0 {
cfg.Timeout = defaultSessionTimeout
}

if cfg.ConnectTimeout == 0 {
cfg.ConnectTimeout = defaultConnectTimeout
}

if cfg.Consistency == "" {
cfg.Consistency = cassandraDefaultConsLevel.String()
}

if cfg.SerialConsistency == "" {
cfg.SerialConsistency = cassandraDefaultSerialConsLevel.String()
}

consistency, err := gocql.ParseConsistency(cfg.Consistency)
if err != nil {
return gocql.ClusterConfig{}, err
}
serialConsistency, err := gocql.ParseSerialConsistency(cfg.SerialConsistency)

if err != nil {
return gocql.ClusterConfig{}, err
}

return gocql.ClusterConfig{
Hosts: cfg.Hosts,
Port: cfg.Port,
Expand All @@ -107,9 +134,9 @@ func toGoCqlConfig(cfg *config.NoSQL) (gocql.ClusterConfig, error) {
MaxConns: cfg.MaxConns,
TLS: cfg.TLS,
ProtoVersion: cfg.ProtoVersion,
Consistency: cassandraDefaultConsLevel,
SerialConsistency: cassandraDefaultSerialConsLevel,
Timeout: defaultSessionTimeout,
ConnectTimeout: defaultConnectTimeout,
Consistency: consistency,
SerialConsistency: serialConsistency,
Timeout: cfg.Timeout,
ConnectTimeout: cfg.ConnectTimeout,
}, nil
}
Loading
Loading